diff --git a/build.sh b/build.sh index 1871bbb8..0afaa7fb 100644 --- a/build.sh +++ b/build.sh @@ -174,11 +174,9 @@ echo "---------------- GraphEngine output generated ----------------" # generate output package in tar form, including ut/st libraries/executables cd ${BASEPATH} mkdir -p output/plugin/nnengine/ge_config/ -mkdir -p output/plugin/opskernel/ find output/ -name graphengine_lib.tar -exec rm {} \; cp src/ge/engine_manager/engine_conf.json output/plugin/nnengine/ge_config/ find output/ -maxdepth 1 -name libengine.so -exec mv -f {} output/plugin/nnengine/ \; -find output/ -maxdepth 1 -name libge_local_engine.so -exec mv -f {} output/plugin/opskernel/ \; tar -cf graphengine_lib.tar output/* mv -f graphengine_lib.tar output echo "---------------- GraphEngine package archive generated ----------------" diff --git a/inc/common/opskernel/ge_task_info.h b/inc/common/opskernel/ge_task_info.h index 8a55b7de..360f8a5d 100644 --- a/inc/common/opskernel/ge_task_info.h +++ b/inc/common/opskernel/ge_task_info.h @@ -52,16 +52,5 @@ struct GETaskInfo { std::vector kernelHcclInfo; }; - -struct HcomOpertion { - std::string hcclType; - void *inputPtr; - void *outputPtr; - uint64_t count; - int32_t dataType; - int32_t opType; - int32_t root; -}; - } // namespace ge #endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ diff --git a/inc/common/util/compress/compress.h b/inc/common/util/compress/compress.h index e350f9e5..6908fb75 100644 --- a/inc/common/util/compress/compress.h +++ b/inc/common/util/compress/compress.h @@ -28,7 +28,6 @@ struct CompressConfig { size_t channel; // channels of L2 or DDR. For load balance size_t fractalSize; // size of compressing block bool isTight; // whether compose compressed data tightly - size_t init_offset; }; CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, diff --git a/inc/common/util/compress/compress_weight.h b/inc/common/util/compress/compress_weight.h deleted file mode 100644 index 34ea47d1..00000000 --- a/inc/common/util/compress/compress_weight.h +++ /dev/null @@ -1,33 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef COMPRESS_WEIGHT_H -#define COMPRESS_WEIGHT_H - -#include "compress.h" - -const int SHAPE_SIZE_WEIGHT = 4; - -struct CompressOpConfig { - int64_t wShape[SHAPE_SIZE_WEIGHT]; - size_t compressTilingK; - size_t compressTilingN; - struct CompressConfig compressConfig; -}; - -extern "C" CmpStatus CompressWeightsConv2D(const char *const input, char *const zipBuffer, char *const infoBuffer, - CompressOpConfig *const param); -#endif // COMPRESS_WEIGHT_H diff --git a/inc/common/util/platform_info.h b/inc/common/util/platform_info.h index 2a145d68..cd143fcc 100644 --- a/inc/common/util/platform_info.h +++ b/inc/common/util/platform_info.h @@ -27,6 +27,7 @@ using std::string; using std::vector; namespace fe { + class PlatformInfoManager { public: PlatformInfoManager(const PlatformInfoManager &) = delete; @@ -38,8 +39,6 @@ class PlatformInfoManager { uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); - uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); - void SetOptionalCompilationInfo(OptionalInfo &optiCompilationInfo); private: @@ -95,5 +94,6 @@ class PlatformInfoManager { map platformInfoMap_; OptionalInfo optiCompilationInfo_; }; + } // namespace fe #endif diff --git a/inc/external/ge/ge_api_types.h b/inc/external/ge/ge_api_types.h index cffb28bd..1632f11c 100644 --- a/inc/external/ge/ge_api_types.h +++ b/inc/external/ge/ge_api_types.h @@ -44,12 +44,8 @@ const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; const char *const OPTION_EXEC_DUMP_MODE = "ge.exec.dumpMode"; -const char *const OPTION_EXEC_ENABLE_DUMP_DEBUG = "ge.exec.enableDumpDebug"; -const char *const OPTION_EXEC_DUMP_DEBUG_MODE = "ge.exec.dumpDebugMode"; -const char *const OPTION_EXEC_OP_DEBUG_LEVEL = "ge.exec.opDebugLevel"; const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; -const char *const OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES = "ge.exec.enableScopeFusionPasses"; // profiling flag const char *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode"; const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; @@ -223,10 +219,6 @@ const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; // Configure input fp16 nodes const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; -// Configure debug level, its value should be 0(default), 1 or 2. -// 0: close debug; 1: open TBE compiler; 2: open ccec compiler -const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; - // Graph run mode enum GraphRunMode { PREDICTION = 0, TRAIN }; diff --git a/inc/external/graph/types.h b/inc/external/graph/types.h index a1245c9d..4cd9ba91 100644 --- a/inc/external/graph/types.h +++ b/inc/external/graph/types.h @@ -145,8 +145,7 @@ enum Format { FORMAT_FRACTAL_ZN_LSTM, FORMAT_FRACTAL_Z_G, FORMAT_RESERVED, - FORMAT_ALL, - FORMAT_NULL + FORMAT_ALL }; // for unknown shape op type diff --git a/inc/external/register/register.h b/inc/external/register/register.h index 9834d8a8..a8421511 100644 --- a/inc/external/register/register.h +++ b/inc/external/register/register.h @@ -98,8 +98,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type); - OpRegistrationData &InputReorderVector(const vector &input_order); - domi::ImplyType GetImplyType() const; std::string GetOmOptype() const; std::set GetOriginOpTypeSet() const; diff --git a/inc/framework/common/debug/ge_log.h b/inc/framework/common/debug/ge_log.h index 6ac00037..e2023cb8 100644 --- a/inc/framework/common/debug/ge_log.h +++ b/inc/framework/common/debug/ge_log.h @@ -51,6 +51,30 @@ inline pid_t GetTid() { return tid; } +#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() + +#define GE_TIMESTAMP_END(stage, stage_name) \ + do { \ + uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ + GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ + (endUsec_##stage - startUsec_##stage)); \ + } while (0); + +#define GE_TIMESTAMP_CALLNUM_START(stage) \ + uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \ + uint64_t call_num_of##stage = 0; \ + uint64_t time_of##stage = 0 + +#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap()) + +#define GE_TIMESTAMP_ADD(stage) \ + time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \ + call_num_of##stage++ + +#define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ + GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ + call_num_of##stage) + #define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GetTid(), __FUNCTION__, ERROR_CODE, \ ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__) diff --git a/inc/framework/common/debug/log.h b/inc/framework/common/debug/log.h index f07a8fa0..28c6585e 100644 --- a/inc/framework/common/debug/log.h +++ b/inc/framework/common/debug/log.h @@ -19,12 +19,15 @@ #include -#include "runtime/rt.h" +#include "cce/cce_def.hpp" #include "common/string_util.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "ge/ge_api_error_codes.h" +using cce::CC_STATUS_SUCCESS; +using cce::ccStatus_t; + #if !defined(__ANDROID__) && !defined(ANDROID) #define DOMI_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) #else @@ -99,13 +102,17 @@ } while (0); // If expr is not true, print the log and return the specified status -#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - GELOGE(_status, __VA_ARGS__); \ - return _status; \ - } \ +#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + std::string msg; \ + (void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ + (void)msg.append( \ + ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ + DOMI_LOGE("%s", msg.c_str()); \ + return _status; \ + } \ } while (0); // If expr is not true, print the log and return the specified status @@ -125,7 +132,7 @@ DOMI_LOGE(__VA_ARGS__); \ exec_expr; \ } \ - } + }; // If expr is not true, print the log and execute a custom statement #define GE_CHK_BOOL_EXEC_WARN(expr, exec_expr, ...) \ @@ -135,7 +142,7 @@ GELOGW(__VA_ARGS__); \ exec_expr; \ } \ - } + }; // If expr is not true, print the log and execute a custom statement #define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ { \ @@ -144,7 +151,7 @@ GELOGI(__VA_ARGS__); \ exec_expr; \ } \ - } + }; // If expr is not true, print the log and execute a custom statement #define GE_CHK_BOOL_TRUE_EXEC_INFO(expr, exec_expr, ...) \ @@ -154,7 +161,7 @@ GELOGI(__VA_ARGS__); \ exec_expr; \ } \ - } + }; // If expr is true, print logs and execute custom statements #define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \ @@ -164,7 +171,7 @@ DOMI_LOGE(__VA_ARGS__); \ exec_expr; \ } \ - } + }; // If expr is true, print the Information log and execute a custom statement #define GE_CHK_TRUE_EXEC_INFO(expr, exec_expr, ...) \ { \ @@ -173,7 +180,7 @@ GELOGI(__VA_ARGS__); \ exec_expr; \ } \ - } + }; // If expr is not SUCCESS, print the log and execute the expression + return #define GE_CHK_BOOL_TRUE_RET_VOID(expr, exec_expr, ...) \ @@ -184,7 +191,7 @@ exec_expr; \ return; \ } \ - } + }; // If expr is not SUCCESS, print the log and execute the expression + return _status #define GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(expr, _status, exec_expr, ...) \ @@ -195,7 +202,7 @@ exec_expr; \ return _status; \ } \ - } + }; // If expr is not true, execute a custom statement #define GE_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ @@ -204,7 +211,7 @@ if (!b) { \ exec_expr; \ } \ - } + }; // -----------------runtime related macro definitions------------------------------- // If expr is not RT_ERROR_NONE, print the log @@ -224,7 +231,7 @@ DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ exec_expr; \ } \ - } + }; // If expr is not RT_ERROR_NONE, print the log and return #define GE_CHK_RT_RET(expr) \ @@ -236,13 +243,23 @@ } \ } while (0); +// ------------------------cce related macro definitions---------------------------- +// If expr is not CC_STATUS_SUCCESS, print the log +#define GE_CHK_CCE(expr) \ + do { \ + ccStatus_t _cc_ret = (expr); \ + if (_cc_ret != CC_STATUS_SUCCESS) { \ + DOMI_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ + } \ + } while (0); + // If expr is true, execute exec_expr without printing logs #define GE_IF_BOOL_EXEC(expr, exec_expr) \ { \ if (expr) { \ exec_expr; \ } \ - } + }; // If make_shared is abnormal, print the log and execute the statement #define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ diff --git a/inc/framework/common/ge_types.h b/inc/framework/common/ge_types.h index 00bfa301..27ae28ee 100644 --- a/inc/framework/common/ge_types.h +++ b/inc/framework/common/ge_types.h @@ -54,9 +54,9 @@ const char *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM"; struct DataBuffer { public: void *data; // Data address - uint64_t length; // Data length + uint32_t length; // Data length bool isDataSupportMemShare = false; - DataBuffer(void *dataIn, uint64_t len, bool isSupportMemShare) + DataBuffer(void *dataIn, uint32_t len, bool isSupportMemShare) : data(dataIn), length(len), isDataSupportMemShare(isSupportMemShare) {} DataBuffer() : data(nullptr), length(0), isDataSupportMemShare(false) {} @@ -106,7 +106,7 @@ struct ShapeDescription { // Definition of input and output description information struct InputOutputDescInfo { std::string name; - uint64_t size; + uint32_t size; uint32_t data_type; ShapeDescription shape_info; }; @@ -231,7 +231,6 @@ struct Options { // Profiling info of task struct TaskDescInfo { - std::string model_name; std::string op_name; uint32_t block_dim; uint32_t task_id; @@ -240,7 +239,6 @@ struct TaskDescInfo { // Profiling info of graph struct ComputeGraphDescInfo { - std::string model_name; std::string op_name; std::string op_type; std::vector input_format; diff --git a/inc/framework/common/helper/model_helper.h b/inc/framework/common/helper/model_helper.h index 3671f970..3c9de891 100644 --- a/inc/framework/common/helper/model_helper.h +++ b/inc/framework/common/helper/model_helper.h @@ -44,6 +44,8 @@ class ModelHelper { void SetSaveMode(bool val) { is_offline_ = val; } bool GetSaveMode(void) const { return is_offline_; } + static Status TransModelToGeModel(const ModelPtr& model, GeModelPtr& ge_model); + static Status TransGeModelToModel(const GeModelPtr& geModelPtr, ModelPtr& modelPtr); Status GetBaseNameFromFileName(const std::string& file_name, std::string& base_name); Status GetModelNameFromMergedGraphName(const std::string& graph_name, std::string& model_name); diff --git a/inc/framework/common/types.h b/inc/framework/common/types.h index 50e41755..e3844a61 100644 --- a/inc/framework/common/types.h +++ b/inc/framework/common/types.h @@ -48,9 +48,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_S FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODE; -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_AICORE; -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ATOMIC; -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ALL; // Supported public properties name FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time @@ -338,7 +335,6 @@ REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell"); REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape") -REGISTER_OPTYPE_DECLARE(REFIDENTITY, "RefIdentity"); // ANN dedicated operator REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); @@ -635,9 +631,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_N FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_END_GRAPH; -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_OP_DEBUG; -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_OP_DEBUG; - // convolution node type FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_CONVOLUTION; // adds a convolutional node name for the hard AIPP diff --git a/inc/framework/executor/ge_executor.h b/inc/framework/executor/ge_executor.h index 2b7335ef..91b50311 100644 --- a/inc/framework/executor/ge_executor.h +++ b/inc/framework/executor/ge_executor.h @@ -21,12 +21,12 @@ #include #include -#include "common/dynamic_aipp.h" #include "common/ge_inner_error_codes.h" #include "common/ge_types.h" #include "common/types.h" #include "graph/tensor.h" #include "runtime/base.h" +#include "common/dynamic_aipp.h" namespace ge { class ModelListenerAdapter; diff --git a/inc/framework/generator/ge_generator.h b/inc/framework/generator/ge_generator.h index d3f472e9..f0707c67 100644 --- a/inc/framework/generator/ge_generator.h +++ b/inc/framework/generator/ge_generator.h @@ -27,7 +27,6 @@ #include "graph/ge_tensor.h" #include "graph/graph.h" #include "graph/op_desc.h" -#include "graph/detail/attributes_holder.h" namespace ge { class GeGenerator { diff --git a/inc/framework/omg/omg.h b/inc/framework/omg/omg.h index 45a8896d..c7dbdd5b 100644 --- a/inc/framework/omg/omg.h +++ b/inc/framework/omg/omg.h @@ -106,6 +106,7 @@ void GetOutputNodesNameAndIndex(std::vector> &ou void UpdateOmgCtxWithParserCtx(); void UpdateParserCtxWithOmgCtx(); + } // namespace ge namespace domi { diff --git a/inc/graph/compute_graph.h b/inc/graph/compute_graph.h index 1cb65a6c..4f865f12 100644 --- a/inc/graph/compute_graph.h +++ b/inc/graph/compute_graph.h @@ -74,9 +74,6 @@ class ComputeGraph : public std::enable_shared_from_this, public A size_t GetAllNodesSize() const; Vistor GetAllNodes() const; - // is_unknown_shape: false, same with GetAllNodes func - // is_unknown_shape: true, same with GetDirectNodes func - Vistor GetNodes(bool is_unknown_shape) const; size_t GetDirectNodesSize() const; Vistor GetDirectNode() const; Vistor GetInputNodes() const; @@ -177,10 +174,6 @@ class ComputeGraph : public std::enable_shared_from_this, public A void SetInputSize(uint32_t size) { input_size_ = size; } uint32_t GetInputSize() const { return input_size_; } - // false: known shape true: unknow shape - bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; } - void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; } - /// /// Set is need train iteration. /// If set true, it means this graph need to be run iteration some @@ -289,8 +282,7 @@ class ComputeGraph : public std::enable_shared_from_this, public A std::map op_name_map_; uint64_t session_id_ = 0; ge::Format data_format_ = ge::FORMAT_ND; - // unknown graph indicator, default is false, mean known shape - bool is_unknown_shape_graph_ = false; }; } // namespace ge + #endif // INC_GRAPH_COMPUTE_GRAPH_H_ diff --git a/inc/graph/debug/ge_attr_define.h b/inc/graph/debug/ge_attr_define.h index ff015be1..5db047c0 100644 --- a/inc/graph/debug/ge_attr_define.h +++ b/inc/graph/debug/ge_attr_define.h @@ -778,10 +778,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATC_VERSION; - -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OPP_VERSION; - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; @@ -1000,7 +996,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; -// used for lX fusion +// used for l1 fusion and other fusion in future GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; @@ -1014,17 +1010,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_DUMP_REF; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_ADDR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE; - -// op overflow dump -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_FLAG; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_MODE; // functional ops attr GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH; @@ -1070,13 +1058,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOR // for gradient group GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG; - -// dynamic shape attrs -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR; -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX; - -// for fusion op plugin -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; } // namespace ge #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ diff --git a/inc/graph/detail/attributes_holder.h b/inc/graph/detail/attributes_holder.h index a82ecca8..bb26dec5 100644 --- a/inc/graph/detail/attributes_holder.h +++ b/inc/graph/detail/attributes_holder.h @@ -149,4 +149,5 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { AnyMap extAttrs_; }; } // namespace ge + #endif // INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ diff --git a/inc/graph/ge_context.h b/inc/graph/ge_context.h index af6b35bc..b1ccd5b9 100644 --- a/inc/graph/ge_context.h +++ b/inc/graph/ge_context.h @@ -28,7 +28,6 @@ class GEContext { uint32_t DeviceId(); uint64_t TraceId(); void Init(); - void SetSessionId(uint64_t session_id); void SetCtxDeviceId(uint32_t device_id); private: diff --git a/inc/graph/ge_tensor.h b/inc/graph/ge_tensor.h index 834dca0b..29a315d6 100644 --- a/inc/graph/ge_tensor.h +++ b/inc/graph/ge_tensor.h @@ -25,7 +25,6 @@ #include "graph/buffer.h" #include "graph/ge_error_codes.h" #include "graph/types.h" - namespace ge { class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { public: @@ -109,11 +108,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH DataType GetDataType() const; void SetDataType(DataType dt); - DataType GetOriginDataType() const; void SetOriginDataType(DataType originDataType); - - std::vector GetRefPortIndex() const; - void SetRefPortByIndex(const std::vector &index); + DataType GetOriginDataType() const; GeTensorDesc Clone() const; GeTensorDesc &operator=(const GeTensorDesc &desc); @@ -190,4 +186,5 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { GeTensorDesc &DescReference() const; }; } // namespace ge + #endif // INC_GRAPH_GE_TENSOR_H_ diff --git a/inc/graph/model_serialize.h b/inc/graph/model_serialize.h index 16529512..3f7d65a9 100644 --- a/inc/graph/model_serialize.h +++ b/inc/graph/model_serialize.h @@ -49,4 +49,5 @@ class ModelSerialize { friend class GraphDebugImp; }; } // namespace ge + #endif // INC_GRAPH_MODEL_SERIALIZE_H_ diff --git a/inc/graph/op_desc.h b/inc/graph/op_desc.h index 1bba7340..faca2d99 100644 --- a/inc/graph/op_desc.h +++ b/inc/graph/op_desc.h @@ -105,8 +105,6 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { GeTensorDescPtr MutableInputDesc(uint32_t index) const; - GeTensorDescPtr MutableInputDesc(const string &name) const; - Vistor GetAllInputsDesc() const; Vistor GetAllInputsDescPtr() const; @@ -129,8 +127,6 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { GeTensorDescPtr MutableOutputDesc(uint32_t index) const; - GeTensorDescPtr MutableOutputDesc(const string &name) const; - uint32_t GetAllOutputsDescSize() const; Vistor GetAllOutputsDesc() const; diff --git a/src/ge/CMakeLists.txt b/src/ge/CMakeLists.txt index 8d20caf2..894eaf1e 100755 --- a/src/ge/CMakeLists.txt +++ b/src/ge/CMakeLists.txt @@ -60,7 +60,6 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "common/formats/formats.cc" "common/formats/utils/formats_trans_utils.cc" "common/fp16_t.cc" - "common/ge/op_tiling_manager.cc" "common/ge/plugin_manager.cc" "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" @@ -95,6 +94,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" "graph/load/new_model_manager/task_info/task_info.cc" + "graph/load/output/output.cc" "graph/manager/*.cc" "graph/manager/model_manager/event_manager.cc" "graph/manager/util/debug.cc" @@ -159,11 +159,8 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "hybrid/node_executor/aicpu/aicpu_ext_info.cc" "hybrid/node_executor/aicpu/aicpu_node_executor.cc" "hybrid/node_executor/compiledsubgraph/known_node_executor.cc" - "hybrid/node_executor/controlop/control_op_executor.cc" - "hybrid/node_executor/hccl/hccl_node_executor.cc" "hybrid/node_executor/hostcpu/ge_local_node_executor.cc" "hybrid/node_executor/node_executor.cc" - "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" "hybrid/node_executor/task_context.cc" "init/gelib.cc" "model/ge_model.cc" @@ -207,7 +204,6 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "common/formats/formats.cc" "common/formats/utils/formats_trans_utils.cc" "common/fp16_t.cc" - "common/ge/op_tiling_manager.cc" "common/ge/plugin_manager.cc" "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" @@ -240,6 +236,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" "graph/load/new_model_manager/task_info/task_info.cc" + "graph/load/output/output.cc" "graph/manager/*.cc" "graph/manager/model_manager/event_manager.cc" "graph/manager/util/debug.cc" diff --git a/src/ge/client/ge_api.cc b/src/ge/client/ge_api.cc index 120c144a..ae6a9892 100644 --- a/src/ge/client/ge_api.cc +++ b/src/ge/client/ge_api.cc @@ -28,7 +28,6 @@ #include "graph/opsproto_manager.h" #include "graph/utils/type_utils.h" #include "graph/manager/util/rt_context_util.h" -#include "graph/common/ge_call_wrapper.h" #include "register/op_registry.h" #include "common/ge/tbe_plugin_manager.h" @@ -42,8 +41,8 @@ namespace { const int32_t kMaxStrLen = 128; } -static bool g_ge_initialized = false; -static std::mutex g_ge_release_mutex; // GEFinalize and ~Session use +static bool kGeInitialized = false; +static std::mutex kGeReleaseMutex; // GEFinalize and ~Session use namespace ge { void GetOpsProtoPath(std::string &opsproto_path) { @@ -62,6 +61,31 @@ void GetOpsProtoPath(std::string &opsproto_path) { opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); } +Status CheckDumpAndReuseMemory(const std::map &options) { + const int kDecimal = 10; + auto dump_op_env = std::getenv("DUMP_OP"); + int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0; + auto disableReuseMemoryIter = options.find("ge.exec.disableReuseMemory"); + if (disableReuseMemoryIter != options.end()) { + if (disableReuseMemoryIter->second == "0") { + GELOGD("ge.exec.disableReuseMemory=0, reuse memory is open"); + if (dump_op_flag) { + GELOGW("Will dump incorrect op data with GE Option ge.exec.disableReuseMemory=0"); + } + } else if (disableReuseMemoryIter->second == "1") { + GELOGD("ge.exec.disableReuseMemory=1, reuse memory is close"); + } else { + GELOGE(PARAM_INVALID, "CheckDumpAndReuseMemory ge.exec.disableReuseMemory is valid"); + return FAILED; + } + } else { + if (dump_op_flag) { + GELOGW("Will dump incorrect op data with default reuse memory"); + } + } + return SUCCESS; +} + Status CheckOptionsValid(const std::map &options) { // check job_id is valid auto job_id_iter = options.find(OPTION_EXEC_JOB_ID); @@ -72,6 +96,11 @@ Status CheckOptionsValid(const std::map &options) { } } + // Check ge.exec.disableReuseMemory and env DUMP_OP + if (CheckDumpAndReuseMemory(options) != SUCCESS) { + return FAILED; + } + return SUCCESS; } @@ -79,7 +108,7 @@ Status CheckOptionsValid(const std::map &options) { Status GEInitialize(const std::map &options) { GELOGT(TRACE_INIT, "GEInitialize start"); // 0.check init status - if (g_ge_initialized) { + if (kGeInitialized) { GELOGW("GEInitialize is called more than once"); return SUCCESS; } @@ -118,9 +147,9 @@ Status GEInitialize(const std::map &options) { } // 7.check return status, return - if (!g_ge_initialized) { + if (!kGeInitialized) { // Initialize success, first time calling initialize - g_ge_initialized = true; + kGeInitialized = true; } GELOGT(TRACE_STOP, "GEInitialize finished"); @@ -131,12 +160,12 @@ Status GEInitialize(const std::map &options) { Status GEFinalize() { GELOGT(TRACE_INIT, "GEFinalize start"); // check init status - if (!g_ge_initialized) { + if (!kGeInitialized) { GELOGW("GEFinalize is called before GEInitialize"); return SUCCESS; } - std::lock_guard lock(g_ge_release_mutex); + std::lock_guard lock(kGeReleaseMutex); // call Finalize Status ret = SUCCESS; Status middle_ret; @@ -158,10 +187,10 @@ Status GEFinalize() { ret = middle_ret; } - if (g_ge_initialized && ret == SUCCESS) { + if (kGeInitialized && ret == SUCCESS) { // Unified destruct rt_context - RtContextUtil::GetInstance().DestroyAllRtContexts(); - g_ge_initialized = false; + RtContextUtil::GetInstance().DestroyrtContexts(); + kGeInitialized = false; } GELOGT(TRACE_STOP, "GEFinalize finished"); @@ -173,7 +202,7 @@ Session::Session(const std::map &options) { GELOGT(TRACE_INIT, "Session Constructor start"); // check init status sessionId_ = 0; - if (!g_ge_initialized) { + if (!kGeInitialized) { GELOGE(GE_CLI_GE_NOT_INITIALIZED); return; } @@ -203,13 +232,13 @@ Session::Session(const std::map &options) { Session::~Session() { GELOGT(TRACE_INIT, "Session Destructor start"); // 0.check init status - if (!g_ge_initialized) { + if (!kGeInitialized) { GELOGW("GE is not yet initialized or is finalized."); return; } Status ret = FAILED; - std::lock_guard lock(g_ge_release_mutex); + std::lock_guard lock(kGeReleaseMutex); try { uint64_t session_id = sessionId_; // call DestroySession diff --git a/src/ge/engine_manager/dnnengine_manager.cc b/src/ge/engine_manager/dnnengine_manager.cc index 9afb207f..c8843c09 100644 --- a/src/ge/engine_manager/dnnengine_manager.cc +++ b/src/ge/engine_manager/dnnengine_manager.cc @@ -24,7 +24,6 @@ #include "common/debug/log.h" #include "common/ge/ge_util.h" -#include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" #include "graph/ge_context.h" #include "init/gelib.h" @@ -162,10 +161,6 @@ bool DNNEngineManager::IsEngineRegistered(const std::string &name) { return false; } -void DNNEngineManager::InitPerformanceStaistic() { checksupport_cost_.clear(); } - -const map &DNNEngineManager::GetCheckSupportCost() const { return checksupport_cost_; } - std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GE_CLI_GE_NOT_INITIALIZED, "DNNEngineManager: op_desc is nullptr"); return ""); @@ -199,20 +194,15 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { if (kernel_info_store != kernel_map.end()) { std::string unsupported_reason; // It will be replaced by engine' checksupport - uint64_t start_time = GetCurrentTimestap(); if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { - checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time; op_desc->SetOpEngineName(it.engine); op_desc->SetOpKernelLibName(kernel_name); GELOGD("DNNEngineManager:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), it.engine.c_str(), op_desc->GetName().c_str()); return it.engine; } else { - checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time; bool is_custom_op = false; if ((ge::AttrUtils::GetBool(op_desc, kCustomOpFlag, is_custom_op)) && is_custom_op) { - ErrorManager::GetInstance().ATCReportErrMessage("E13001", {"kernelname", "optype", "opname"}, - {kernel_name, op_desc->GetType(), op_desc->GetName()}); GELOGE(FAILED, "The custom operator registered by the user does not support the logic function delivered by this " "network. Check support failed, kernel_name is %s, op type is %s, op name is %s", @@ -231,13 +221,9 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { } } for (const auto &it : unsupported_reasons) { - ErrorManager::GetInstance().ATCReportErrMessage("E13002", {"optype", "opskernel", "reason"}, - {op_desc->GetType(), it.first, it.second}); GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "GetDNNEngineName:Op type %s of ops kernel %s is unsupported, reason:%s", op_desc->GetType().c_str(), it.first.c_str(), it.second.c_str()); } - ErrorManager::GetInstance().ATCReportErrMessage("E13003", {"opname", "optype"}, - {op_desc->GetName(), op_desc->GetType()}); GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "Can't find any supported ops kernel and engine of %s, type is %s", op_desc->GetName().c_str(), op_desc->GetType().c_str()); return ""; @@ -398,13 +384,7 @@ Status DNNEngineManager::ReadJsonFile(const std::string &file_path, JsonHandle h return FAILED; } - try { - ifs >> *json_file; - } catch (const json::exception &e) { - GELOGE(FAILED, "Read json file failed"); - ifs.close(); - return FAILED; - } + ifs >> *json_file; ifs.close(); GELOGI("Read json file success"); return SUCCESS; diff --git a/src/ge/engine_manager/dnnengine_manager.h b/src/ge/engine_manager/dnnengine_manager.h index 15628ecf..ab813398 100644 --- a/src/ge/engine_manager/dnnengine_manager.h +++ b/src/ge/engine_manager/dnnengine_manager.h @@ -63,8 +63,6 @@ class DNNEngineManager { // If can't find appropriate engine name, return "", report error string GetDNNEngineName(const OpDescPtr &op_desc); const map &GetSchedulers() const; - const map &GetCheckSupportCost() const; - void InitPerformanceStaistic(); private: DNNEngineManager(); @@ -80,7 +78,6 @@ class DNNEngineManager { std::map engines_map_; std::map engines_attrs_map_; std::map schedulers_; - std::map checksupport_cost_; bool init_flag_; }; } // namespace ge diff --git a/src/ge/executor/CMakeLists.txt b/src/ge/executor/CMakeLists.txt index 0cdb00e2..cddf25b7 100755 --- a/src/ge/executor/CMakeLists.txt +++ b/src/ge/executor/CMakeLists.txt @@ -26,7 +26,6 @@ file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "ge_executor.cc" - "../common/ge/op_tiling_manager.cc" "../common/ge/plugin_manager.cc" "../common/profiling/profiling_manager.cc" "../graph/execute/graph_execute.cc" @@ -60,6 +59,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../graph/load/new_model_manager/task_info/task_info.cc" "../graph/load/new_model_manager/tbe_handle_store.cc" "../graph/load/new_model_manager/zero_copy_task.cc" + "../graph/load/output/output.cc" "../graph/manager/graph_caching_allocator.cc" "../graph/manager/graph_manager_utils.cc" "../graph/manager/graph_mem_allocator.cc" diff --git a/src/ge/executor/ge_executor.cc b/src/ge/executor/ge_executor.cc index 098c57b6..b5a3b3cf 100644 --- a/src/ge/executor/ge_executor.cc +++ b/src/ge/executor/ge_executor.cc @@ -854,4 +854,5 @@ Status GeExecutor::GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, GELOGI("GetAllAippInputOutputDims succ."); return SUCCESS; } + } // namespace ge diff --git a/src/ge/executor/module.mk b/src/ge/executor/module.mk index 0eb87822..efed8854 100644 --- a/src/ge/executor/module.mk +++ b/src/ge/executor/module.mk @@ -4,7 +4,6 @@ local_ge_executor_src_files := \ ge_executor.cc \ ../common/profiling/profiling_manager.cc \ ../common/ge/plugin_manager.cc \ - ../common/ge/op_tiling_manager.cc \ ../graph/load/graph_loader.cc \ ../graph/execute/graph_execute.cc \ ../omm/csa_interact.cc \ @@ -45,6 +44,7 @@ local_ge_executor_src_files := \ ../graph/load/new_model_manager/task_info/end_graph_task_info.cc \ ../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ ../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ + ../graph/load/output/output.cc \ ../single_op/single_op_manager.cc \ ../single_op/single_op_model.cc \ ../single_op/single_op.cc \ @@ -53,7 +53,6 @@ local_ge_executor_src_files := \ ../single_op/task/build_task_utils.cc \ ../single_op/task/tbe_task_builder.cc \ ../single_op/task/aicpu_task_builder.cc \ - ../single_op/task/aicpu_kernel_task_builder.cc \ ../hybrid/hybrid_davinci_model_stub.cc\ local_ge_executor_c_include := \ diff --git a/src/ge/ge_inference.mk b/src/ge/ge_inference.mk index f18f733a..2b26b214 100644 --- a/src/ge/ge_inference.mk +++ b/src/ge/ge_inference.mk @@ -32,7 +32,6 @@ COMMON_LOCAL_SRC_FILES := \ GRAPH_MANAGER_LOCAL_SRC_FILES := \ common/ge/plugin_manager.cc\ - common/ge/op_tiling_manager.cc\ init/gelib.cc \ session/inner_session.cc \ session/session_manager.cc \ @@ -92,7 +91,6 @@ OMG_HOST_SRC_FILES := \ graph/passes/no_use_reshape_remove_pass.cc \ graph/passes/iterator_op_pass.cc \ graph/passes/atomic_addr_clean_pass.cc \ - graph/passes/mark_same_addr_pass.cc \ graph/common/omg_util.cc \ graph/common/bcast.cc \ graph/passes/dimension_compute_pass.cc \ @@ -147,7 +145,6 @@ OMG_HOST_SRC_FILES := \ graph/passes/stop_gradient_pass.cc \ graph/passes/prevent_gradient_pass.cc \ graph/passes/identity_pass.cc \ - graph/passes/ref_identity_delete_op_pass.cc \ graph/passes/placeholder_with_default_pass.cc \ graph/passes/snapshot_pass.cc \ graph/passes/guarantee_const_pass.cc \ @@ -156,9 +153,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/folding_pass.cc \ graph/passes/cast_translate_pass.cc \ graph/passes/prune_pass.cc \ - graph/passes/merge_to_stream_merge_pass.cc \ - graph/passes/switch_to_stream_switch_pass.cc \ - graph/passes/attach_stream_label_pass.cc \ + graph/passes/switch_op_pass.cc \ graph/passes/multi_batch_pass.cc \ graph/passes/next_iteration_pass.cc \ graph/passes/control_trigger_pass.cc \ @@ -178,6 +173,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/variable_op_pass.cc \ graph/passes/cast_remove_pass.cc \ graph/passes/transpose_transdata_pass.cc \ + graph/passes/identify_reference_pass.cc \ graph/passes/hccl_memcpy_pass.cc \ graph/passes/flow_ctrl_pass.cc \ graph/passes/link_gen_mask_nodes_pass.cc \ @@ -203,6 +199,7 @@ OME_HOST_SRC_FILES := \ graph/load/new_model_manager/tbe_handle_store.cc \ graph/load/new_model_manager/cpu_queue_schedule.cc \ graph/load/new_model_manager/zero_copy_task.cc \ + graph/load/output/output.cc \ graph/load/new_model_manager/data_dumper.cc \ graph/load/new_model_manager/task_info/task_info.cc \ graph/load/new_model_manager/task_info/event_record_task_info.cc \ @@ -227,7 +224,6 @@ OME_HOST_SRC_FILES := \ single_op/task/build_task_utils.cc \ single_op/task/tbe_task_builder.cc \ single_op/task/aicpu_task_builder.cc \ - single_op/task/aicpu_kernel_task_builder.cc \ single_op/single_op.cc \ single_op/single_op_model.cc \ single_op/stream_resource.cc \ @@ -372,7 +368,7 @@ endif LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_ir_build.cc +LOCAL_SRC_FILES := ../../out/atc/lib64/stub/ge_ir_build.cc LOCAL_SHARED_LIBRARIES := diff --git a/src/ge/ge_runner.mk b/src/ge/ge_runner.mk index fe19de02..a9cfdd82 100644 --- a/src/ge/ge_runner.mk +++ b/src/ge/ge_runner.mk @@ -23,7 +23,6 @@ LIBGE_LOCAL_SRC_FILES := \ common/formats/utils/formats_trans_utils.cc \ common/fp16_t.cc \ common/ge/plugin_manager.cc\ - common/ge/op_tiling_manager.cc\ common/helper/model_cache_helper.cc \ common/profiling/profiling_manager.cc \ engine_manager/dnnengine_manager.cc \ @@ -78,6 +77,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/load/new_model_manager/task_info/task_info.cc \ graph/load/new_model_manager/tbe_handle_store.cc \ graph/load/new_model_manager/zero_copy_task.cc \ + graph/load/output/output.cc \ graph/manager/graph_context.cc \ graph/manager/graph_manager.cc \ graph/manager/graph_manager_utils.cc \ @@ -99,7 +99,6 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/aicpu_constant_folding_pass.cc \ graph/passes/assert_pass.cc \ graph/passes/atomic_addr_clean_pass.cc \ - graph/passes/mark_same_addr_pass.cc \ graph/partition/dynamic_shape_partition.cc \ graph/passes/base_pass.cc \ graph/passes/cast_remove_pass.cc \ @@ -159,8 +158,8 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/get_original_format_pass.cc \ graph/passes/guarantee_const_pass.cc \ graph/passes/hccl_memcpy_pass.cc \ + graph/passes/identify_reference_pass.cc \ graph/passes/identity_pass.cc \ - graph/passes/ref_identity_delete_op_pass.cc \ graph/passes/infershape_pass.cc \ graph/passes/isolated_op_remove_pass.cc \ graph/passes/iterator_op_pass.cc \ @@ -192,9 +191,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/data_pass.cc \ graph/passes/switch_data_edges_bypass.cc \ graph/passes/switch_logic_remove_pass.cc \ - graph/passes/merge_to_stream_merge_pass.cc \ - graph/passes/switch_to_stream_switch_pass.cc \ - graph/passes/attach_stream_label_pass.cc \ + graph/passes/switch_op_pass.cc \ graph/passes/switch_dead_branch_elimination.cc \ graph/passes/replace_transshape_pass.cc \ graph/passes/transop_breadth_fusion_pass.cc \ @@ -233,7 +230,6 @@ LIBGE_LOCAL_SRC_FILES := \ single_op/task/op_task.cc \ single_op/task/tbe_task_builder.cc \ single_op/task/aicpu_task_builder.cc \ - single_op/task/aicpu_kernel_task_builder.cc \ hybrid/common/tensor_value.cc \ hybrid/common/npu_memory_allocator.cc \ hybrid/executor/rt_callback_manager.cc \ @@ -243,15 +239,12 @@ LIBGE_LOCAL_SRC_FILES := \ hybrid/executor/hybrid_model_executor.cc \ hybrid/executor/hybrid_model_async_executor.cc \ hybrid/executor/hybrid_execution_context.cc \ - hybrid/executor/subgraph_context.cc \ - hybrid/executor/subgraph_executor.cc \ hybrid/executor/worker/task_compile_engine.cc \ hybrid/executor/worker/shape_inference_engine.cc \ hybrid/executor/worker/execution_engine.cc \ hybrid/model/hybrid_model.cc \ hybrid/model/hybrid_model_builder.cc \ hybrid/model/node_item.cc \ - hybrid/model/graph_item.cc \ hybrid/node_executor/aicore/aicore_node_executor.cc \ hybrid/node_executor/aicore/aicore_op_task.cc \ hybrid/node_executor/aicore/aicore_task_builder.cc \ @@ -260,9 +253,6 @@ LIBGE_LOCAL_SRC_FILES := \ hybrid/node_executor/aicpu/aicpu_node_executor.cc \ hybrid/node_executor/compiledsubgraph/known_node_executor.cc \ hybrid/node_executor/hostcpu/ge_local_node_executor.cc \ - hybrid/node_executor/controlop/control_op_executor.cc \ - hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc \ - hybrid/node_executor/hccl/hccl_node_executor.cc \ hybrid/node_executor/node_executor.cc \ hybrid/node_executor/task_context.cc \ hybrid/hybrid_davinci_model.cc \ @@ -348,28 +338,6 @@ LOCAL_SHARED_LIBRARIES += \ include $(BUILD_HOST_SHARED_LIBRARY) -#compiler for GeRunner -include $(CLEAR_VARS) - -LOCAL_MODULE := stub/libge_runner - -LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DREUSE_MEMORY=1 -O2 -LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -DDAVINCI_SUPPORT_PROFILING -DDAVINCI_CLOUD -ifeq ($(DEBUG), 1) -LOCAL_CFLAGS += -g -O0 -endif - - -LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES) - -LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_api.cc - - -LOCAL_SHARED_LIBRARIES := - -LOCAL_LDFLAGS := -lrt -ldl - -include $(BUILD_HOST_SHARED_LIBRARY) # add engine_conf.json to host include $(CLEAR_VARS) @@ -439,7 +407,6 @@ LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -DDAVINCI_SUPPORT_PROFILING -DDAVINCI_CLOUD LOCAL_CFLAGS += -g -O0 LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES) - LOCAL_SRC_FILES := $(LIBGE_LOCAL_SRC_FILES) LOCAL_SRC_FILES += $(LIBCLIENT_LOCAL_SRC_FILES) diff --git a/src/ge/ge_train.mk b/src/ge/ge_train.mk new file mode 100644 index 00000000..767ce86b --- /dev/null +++ b/src/ge/ge_train.mk @@ -0,0 +1,333 @@ +LOCAL_PATH := $(call my-dir) + +COMMON_LOCAL_SRC_FILES := \ + proto/fusion_model.proto \ + proto/optimizer_priority.proto \ + session/inner_session.cc \ + session/session_manager.cc \ + common/ge/plugin_manager.cc\ + common/fp16_t.cc \ + common/formats/utils/formats_trans_utils.cc \ + common/formats/format_transfers/datatype_transfer.cc \ + common/formats/format_transfers/format_transfer_transpose.cc \ + common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc \ + common/formats/format_transfers/format_transfer_fractal_z.cc \ + common/formats/format_transfers/format_transfer_fractal_nz.cc \ + common/formats/format_transfers/format_transfer_fractal_zz.cc \ + common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc \ + common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc \ + common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc \ + common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc \ + common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc \ + common/formats/format_transfers/format_transfer_fracz_nchw.cc \ + common/formats/format_transfers/format_transfer_fracz_nhwc.cc \ + common/formats/format_transfers/format_transfer_fracz_hwcn.cc \ + common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc \ + common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc \ + common/formats/formats.cc \ + init/gelib.cc \ + engine_manager/dnnengine_manager.cc \ + opskernel_manager/ops_kernel_manager.cc \ + graph/manager/graph_manager.cc \ + graph/manager/graph_manager_utils.cc \ + graph/manager/graph_context.cc \ + graph/preprocess/graph_preprocess.cc \ + graph/preprocess/multi_batch_copy_graph.cc \ + graph/execute/graph_execute.cc \ + graph/load/graph_loader.cc \ + graph/optimize/graph_optimize.cc \ + graph/passes/folding_pass.cc \ + graph/optimize/summary_optimize.cc \ + graph/build/graph_builder.cc \ + graph/partition/engine_place.cc \ + graph/partition/graph_partition.cc \ + graph/partition/dynamic_shape_partition.cc \ + generator/ge_generator.cc \ + generator/generator_api.cc \ + common/profiling/profiling_manager.cc \ + ge_local_engine/engine/host_cpu_engine.cc \ + common/helper/model_cache_helper.cc \ + +OMG_HOST_SRC_FILES := \ + model/ge_model.cc \ + model/ge_root_model.cc \ + graph/common/transop_util.cc \ + graph/manager/graph_var_manager.cc \ + graph/manager/trans_var_data_utils.cc \ + omm/csa_interact.cc \ + graph/passes/pass_manager.cc \ + graph/passes/pass_utils.cc \ + graph/passes/base_pass.cc \ + graph/passes/resource_pair_add_control_pass.cc \ + graph/passes/resource_pair_remove_control_pass.cc \ + graph/passes/constant_folding_pass.cc \ + graph/passes/aicpu_constant_folding_pass.cc \ + graph/passes/reshape_remove_pass.cc \ + graph/passes/reshape_recovery_pass.cc \ + graph/passes/transop_breadth_fusion_pass.cc \ + graph/passes/transop_depth_fusion_pass.cc \ + graph/passes/same_transdata_breadth_fusion_pass.cc \ + graph/passes/transop_without_reshape_fusion_pass.cc \ + graph/passes/compile_nodes_pass.cc \ + graph/passes/transop_nearby_allreduce_fusion_pass.cc \ + graph/passes/variable_prepare_op_pass.cc \ + graph/passes/variable_ref_delete_op_pass.cc \ + graph/passes/variable_ref_useless_control_out_delete_pass.cc \ + graph/passes/variable_op_pass.cc \ + graph/passes/cast_remove_pass.cc \ + graph/passes/replace_transshape_pass.cc \ + graph/passes/transpose_transdata_pass.cc \ + graph/passes/identify_reference_pass.cc \ + graph/passes/variable_format_pass.cc \ + graph/passes/subgraph_pass.cc \ + graph/passes/data_pass.cc \ + graph/passes/net_output_pass.cc \ + graph/passes/constant_fuse_same_pass.cc \ + graph/passes/print_op_pass.cc \ + graph/passes/no_use_reshape_remove_pass.cc \ + graph/passes/iterator_op_pass.cc \ + graph/passes/atomic_addr_clean_pass.cc \ + graph/optimize/optimizer/allreduce_fusion_pass.cc \ + graph/common/omg_util.cc \ + graph/common/bcast.cc \ + graph/passes/dimension_compute_pass.cc \ + graph/passes/dimension_adjust_pass.cc \ + graph/passes/get_original_format_pass.cc \ + graph/passes/shape_operate_op_remove_pass.cc \ + graph/passes/unused_op_remove_pass.cc \ + graph/passes/assert_pass.cc \ + graph/passes/dropout_pass.cc \ + graph/passes/infershape_pass.cc \ + graph/passes/unused_const_pass.cc \ + graph/passes/isolated_op_remove_pass.cc \ + graph/passes/permute_pass.cc \ + graph/passes/ctrl_edge_transfer_pass.cc \ + host_kernels/broadcast_gradient_args_kernel.cc \ + host_kernels/greater_kernel.cc \ + host_kernels/gather_v2_kernel.cc \ + host_kernels/maximum_kernel.cc \ + host_kernels/floormod_kernel.cc \ + host_kernels/floordiv_kernel.cc \ + host_kernels/range_kernel.cc \ + host_kernels/shape_kernel.cc \ + host_kernels/size_kernel.cc \ + host_kernels/shape_n_kernel.cc \ + host_kernels/rank_kernel.cc \ + host_kernels/broadcast_args_kernel.cc \ + host_kernels/fill_kernel.cc \ + host_kernels/empty_kernel.cc \ + host_kernels/expanddims_kernel.cc \ + host_kernels/reshape_kernel.cc \ + host_kernels/squeeze_kernel.cc \ + host_kernels/kernel_utils.cc \ + host_kernels/cast_kernel.cc \ + host_kernels/transdata_kernel.cc \ + host_kernels/transpose_kernel.cc \ + host_kernels/permute_kernel.cc \ + host_kernels/pack_kernel.cc \ + host_kernels/concat_v2_kernel.cc \ + host_kernels/concat_offset_kernel.cc \ + host_kernels/strided_slice_kernel.cc \ + host_kernels/ssd_prior_box_kernel.cc \ + host_kernels/add_kernel.cc \ + host_kernels/unpack_kernel.cc \ + host_kernels/sub_kernel.cc \ + host_kernels/mul_kernel.cc \ + host_kernels/reduce_prod_kernel.cc \ + host_kernels/rsqrt_kernel.cc \ + host_kernels/slice_kernel.cc \ + host_kernels/slice_d_kernel.cc \ + host_kernels/dynamic_stitch_kernel.cc \ + graph/passes/stop_gradient_pass.cc \ + graph/passes/prevent_gradient_pass.cc \ + graph/passes/identity_pass.cc \ + graph/passes/placeholder_with_default_pass.cc \ + graph/passes/snapshot_pass.cc \ + graph/passes/guarantee_const_pass.cc \ + graph/passes/var_is_initialized_op_pass.cc \ + graph/passes/parallel_concat_start_op_pass.cc \ + graph/passes/cast_translate_pass.cc \ + graph/passes/addn_pass.cc \ + graph/passes/common_subexpression_elimination_pass.cc \ + graph/passes/transop_symmetry_elimination_pass.cc \ + graph/passes/save_pass.cc \ + graph/passes/switch_dead_branch_elimination.cc \ + graph/passes/merge_pass.cc \ + graph/passes/prune_pass.cc \ + graph/passes/flow_ctrl_pass.cc \ + graph/passes/control_trigger_pass.cc \ + graph/passes/switch_data_edges_bypass.cc \ + graph/passes/switch_op_pass.cc \ + graph/passes/multi_batch_pass.cc \ + graph/passes/switch_logic_remove_pass.cc \ + graph/passes/next_iteration_pass.cc \ + graph/passes/cond_pass.cc \ + graph/passes/cond_remove_pass.cc \ + graph/passes/for_pass.cc \ + graph/passes/enter_pass.cc \ + graph/passes/hccl_memcpy_pass.cc \ + graph/passes/link_gen_mask_nodes_pass.cc \ + graph/passes/replace_with_empty_const_pass.cc \ + graph/passes/hccl_group_pass.cc \ + +OME_SRC_FILES := \ + graph/manager/graph_mem_allocator.cc \ + graph/manager/graph_caching_allocator.cc \ + graph/manager/model_manager/event_manager.cc \ + graph/manager/util/debug.cc \ + graph/manager/util/rt_context_util.cc \ + graph/manager/util/variable_accelerate_ctrl.cc \ + graph/manager/util/hcom_util.cc \ + graph/load/new_model_manager/model_manager.cc \ + graph/load/new_model_manager/data_inputer.cc \ + graph/load/new_model_manager/davinci_model.cc \ + graph/load/new_model_manager/davinci_model_parser.cc \ + graph/load/new_model_manager/model_utils.cc \ + graph/load/new_model_manager/tbe_handle_store.cc \ + graph/load/new_model_manager/cpu_queue_schedule.cc \ + graph/load/new_model_manager/zero_copy_task.cc \ + graph/load/output/output.cc \ + graph/load/new_model_manager/data_dumper.cc \ + graph/load/new_model_manager/task_info/task_info.cc \ + graph/load/new_model_manager/task_info/event_record_task_info.cc \ + graph/load/new_model_manager/task_info/event_wait_task_info.cc \ + graph/load/new_model_manager/task_info/fusion_start_task_info.cc \ + graph/load/new_model_manager/task_info/fusion_stop_task_info.cc \ + graph/load/new_model_manager/task_info/hccl_task_info.cc \ + graph/load/new_model_manager/task_info/kernel_ex_task_info.cc \ + graph/load/new_model_manager/task_info/kernel_task_info.cc \ + graph/load/new_model_manager/task_info/label_set_task_info.cc \ + graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc \ + graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc \ + graph/load/new_model_manager/task_info/memcpy_async_task_info.cc \ + graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc \ + graph/load/new_model_manager/task_info/profiler_trace_task_info.cc \ + graph/load/new_model_manager/task_info/stream_active_task_info.cc \ + graph/load/new_model_manager/task_info/stream_switch_task_info.cc \ + graph/load/new_model_manager/task_info/stream_switchn_task_info.cc \ + graph/load/new_model_manager/task_info/end_graph_task_info.cc \ + graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ + graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ + single_op/task/op_task.cc \ + single_op/task/build_task_utils.cc \ + single_op/task/tbe_task_builder.cc \ + single_op/task/aicpu_task_builder.cc \ + single_op/single_op.cc \ + single_op/single_op_model.cc \ + single_op/stream_resource.cc \ + single_op/single_op_manager.cc \ + hybrid/hybrid_davinci_model_stub.cc \ + + +COMMON_LOCAL_C_INCLUDES := \ + proto/om.proto \ + proto/task.proto \ + proto/insert_op.proto \ + proto/ge_ir.proto \ + proto/fwk_adapter.proto \ + proto/op_mapping_info.proto \ + proto/tensorflow/attr_value.proto \ + proto/tensorflow/function.proto \ + proto/tensorflow/graph.proto \ + proto/tensorflow/node_def.proto \ + proto/tensorflow/op_def.proto \ + proto/tensorflow/resource_handle.proto \ + proto/tensorflow/tensor.proto \ + proto/tensorflow/tensor_shape.proto \ + proto/tensorflow/types.proto \ + proto/tensorflow/versions.proto \ + $(LOCAL_PATH) ./ \ + $(TOPDIR)inc \ + $(TOPDIR)inc/external \ + $(TOPDIR)inc/external/graph \ + $(TOPDIR)inc/framework \ + $(TOPDIR)inc/framework/common \ + $(TOPDIR)inc/runtime \ + $(TOPDIR)libc_sec/include \ + $(TOPDIR)ops/built-in/op_proto/inc \ + third_party/json/include \ + third_party/protobuf/include \ + third_party/opencv/include \ + +NEW_OMG_HOST_SRC_FILES := \ + graph/preprocess/insert_op/util_insert_aipp_op.cc \ + graph/preprocess/insert_op/ge_aipp_op.cc \ + graph/build/model_builder.cc \ + graph/build/task_generator.cc \ + graph/build/stream_allocator.cc \ + graph/build/logical_stream_allocator.cc \ + graph/build/stream_graph_optimizer.cc \ + graph/build/run_context.cc \ + graph/build/label_allocator.cc \ + graph/label/label_maker.cc \ + graph/label/if_label_maker.cc \ + graph/label/case_label_maker.cc \ + graph/label/while_label_maker.cc \ + graph/label/partitioned_call_label_maker.cc \ + + + +#compiler for host train +include $(CLEAR_VARS) + +LOCAL_MODULE := libge_train + +LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DREUSE_MEMORY=1 -O2 +LOCAL_CFLAGS += -DDAVINCI_CLOUD -DDAVINCI_TRAIN -DFMK_SUPPORT_DUMP -DDAVINCI_SUPPORT_PROFILING +LOCAL_CFLAGS += -DFMK_SUPPORT_DEBUG +ifeq ($(DEBUG), 1) +LOCAL_CFLAGS += -g -O0 +endif + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) + +LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) +LOCAL_SRC_FILES += $(OMG_HOST_SRC_FILES) +LOCAL_SRC_FILES += $(OME_SRC_FILES) +LOCAL_SRC_FILES += $(NEW_OMG_HOST_SRC_FILES) + +LOCAL_STATIC_LIBRARIES := libge_memory \ + +LOCAL_SHARED_LIBRARIES := \ + libc_sec \ + libprotobuf \ + libslog \ + libmmpa \ + libgraph \ + libregister \ + libge_common \ + libhccl \ + libmsprof \ + + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_SHARED_LIBRARIES += \ + libruntime \ + libresource \ + +include $(BUILD_HOST_SHARED_LIBRARY) + +# add engine_conf.json to host +include $(CLEAR_VARS) + +LOCAL_MODULE := engine_conf.json + +LOCAL_SRC_FILES := engine_manager/engine_conf.json + +LOCAL_MODULE_CLASS := ETC + +LOCAL_INSTALLED_PATH := $(HOST_OUT_ROOT)/engine_conf.json +include $(BUILD_HOST_PREBUILT) + +# add optimizer_priority.pbtxt to host +include $(CLEAR_VARS) + +LOCAL_MODULE := optimizer_priority.pbtxt + +LOCAL_SRC_FILES := opskernel_manager/optimizer_priority.pbtxt + +LOCAL_MODULE_CLASS := ETC + +LOCAL_INSTALLED_PATH := $(HOST_OUT_ROOT)/optimizer_priority.pbtxt +include $(BUILD_HOST_PREBUILT) diff --git a/src/ge/generator/ge_generator.cc b/src/ge/generator/ge_generator.cc index 4869eb40..b01f7591 100644 --- a/src/ge/generator/ge_generator.cc +++ b/src/ge/generator/ge_generator.cc @@ -207,13 +207,6 @@ class GeGenerator::Impl { GraphManager graph_manager_; SaveParam save_param_; bool is_offline_ = true; - - private: - static std::string Trim(const std::string &str); - bool ParseVersion(const std::string &line, std::string &version); - bool GetVersionFromPath(const std::string &file_path, std::string &version); - bool SetAtcVersionInfo(AttrHolder &obj); - bool SetOppVersionInfo(AttrHolder &obj); }; Status GeGenerator::Initialize(const map &options) { @@ -295,124 +288,6 @@ Status GeGenerator::GenerateInfershapeGraph(const Graph &graph) { return SUCCESS; } -// Remove the space and tab before and after the string -std::string GeGenerator::Impl::Trim(const std::string &str) { - if (str.empty()) { - return str; - } - - std::string::size_type start = str.find_first_not_of(" \t\r\n"); - if (start == std::string::npos) { - return str; - } - - std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1; - return str.substr(start, end); -} - -// Parsing the command line -bool GeGenerator::Impl::ParseVersion(const std::string &line, std::string &version) { - std::string flag = "Version="; - std::string temp = Trim(line); - - if (temp.empty()) { - GELOGW("line is empty."); - return false; - } - - std::string::size_type pos = temp.find(flag); - if (pos == std::string::npos) { - GELOGW("Incorrect line [%s], it must include [%s].", line.c_str(), flag.c_str()); - return false; - } - - if (temp.size() == flag.size()) { - GELOGW("version information is empty. %s", line.c_str()); - return false; - } - - version = temp.substr(pos + flag.size()); - GELOGI("Version=%s", version.c_str()); - - return true; -} - -bool GeGenerator::Impl::GetVersionFromPath(const std::string &file_path, std::string &version) { - // Normalize the path - string resolved_file_path = RealPath(file_path.c_str()); - if (resolved_file_path.empty()) { - GELOGW("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str()); - return false; - } - std::ifstream fs(resolved_file_path, std::ifstream::in); - if (!fs.is_open()) { - GELOGW("Open %s failed.", file_path.c_str()); - return false; - } - - std::string line; - if (getline(fs, line)) { - if (!ParseVersion(line, version)) { - GELOGW("Parse version failed. content is [%s].", line.c_str()); - fs.close(); - return false; - } - } else { - GELOGW("No version information found in the file path:%s", file_path.c_str()); - fs.close(); - return false; - } - - fs.close(); // close the file - return true; -} - -// Set package version information in the model -bool GeGenerator::Impl::SetAtcVersionInfo(AttrHolder &obj) { - std::string path_base = ge::GELib::GetPath(); - path_base = path_base.substr(0, path_base.rfind('/')); - path_base = path_base.substr(0, path_base.rfind('/') + 1); - - std::string version_path = path_base + "version.info"; - GELOGI("version_path is %s", version_path.c_str()); - std::string version; - if (!GetVersionFromPath(version_path, version)) { - GELOGW("Get atc version information failed!"); - return false; - } - // set version info - if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_ATC_VERSION, version)) { - GELOGW("Ge model set atc version failed!"); - return false; - } - GELOGI("Ge model set atc version information success."); - return true; -} - -// Set package version information in the model -bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) { - const char *path_env = std::getenv("ASCEND_OPP_PATH"); - if (path_env == nullptr) { - GELOGW("Get environment variable ASCEND_OPP_PATH failed!"); - return false; - } - std::string version_path = path_env; - version_path += "/version.info"; - GELOGI("version_path is %s", version_path.c_str()); - std::string version; - if (!GetVersionFromPath(version_path, version)) { - GELOGW("Get opp version information failed!"); - return false; - } - // set version info - if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_OPP_VERSION, version)) { - GELOGW("Ge model set opp version failed!"); - return false; - } - GELOGI("Ge Model set opp version information success."); - return true; -} - Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector &inputs, ModelBufferData &model, bool is_offline) { rtContext_t ctx = nullptr; @@ -440,7 +315,6 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr string model_name = ""; Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), model_name); if (name_ret != SUCCESS) { - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); GELOGE(FAILED, "Get model_name failed. Param --output is invalid"); return PARAM_INVALID; } @@ -590,14 +464,6 @@ Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, c } Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &model, ModelBufferData &model_buff) { - // set atc version - if (!SetAtcVersionInfo(*(model.get()))) { - GELOGW("SetPackageVersionInfo of atc failed!"); - } - // set opp version - if (!SetOppVersionInfo(*(model.get()))) { - GELOGW("SetPackageVersionInfo of ops failed!"); - } ModelHelper model_helper; model_helper.SetSaveMode(is_offline_); Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff); @@ -660,4 +526,5 @@ Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph, GraphId &g return SUCCESS; } + } // namespace ge diff --git a/src/ge/graph/common/ge_call_wrapper.h b/src/ge/graph/common/ge_call_wrapper.h index a2bb6b88..a21d642e 100644 --- a/src/ge/graph/common/ge_call_wrapper.h +++ b/src/ge/graph/common/ge_call_wrapper.h @@ -18,41 +18,6 @@ #define GE_GE_CALL_WRAPPER_H_ #include "framework/common/debug/ge_log.h" -#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() - -#define GE_TIMESTAMP_END(stage, stage_name) \ - do { \ - uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ - GELOGI("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ - (endUsec_##stage - startUsec_##stage)); \ - } while (0); - -#define GE_TIMESTAMP_EVENT_END(stage, stage_name) \ - do { \ - uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ - GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ - (endUsec_##stage - startUsec_##stage)); \ - } while (0); - -#define GE_TIMESTAMP_CALLNUM_START(stage) \ - uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \ - uint64_t call_num_of##stage = 0; \ - uint64_t time_of##stage = 0 - -#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap()) - -#define GE_TIMESTAMP_ADD(stage) \ - time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \ - call_num_of##stage++ - -#define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ - GELOGI("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ - call_num_of##stage) - -#define GE_TIMESTAMP_CALLNUM_EVENT_END(stage, stage_name) \ - GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ - call_num_of##stage) - #define RUN_WITH_TIMESTAMP_NAME(var_name, prefix, func, ...) \ do { \ GE_TIMESTAMP_START(var_name); \ @@ -64,23 +29,10 @@ } \ } while (0) -#define RUN_WITH_PERF_TIMESTAMP_NAME(var_name, prefix, func, ...) \ - do { \ - GE_TIMESTAMP_START(var_name); \ - auto ret_inner_macro = func(__VA_ARGS__); \ - GE_TIMESTAMP_EVENT_END(var_name, #prefix "::" #func) \ - if (ret_inner_macro != ge::SUCCESS) { \ - GELOGE(ret_inner_macro, "Failed to process " #prefix "_" #func); \ - return ret_inner_macro; \ - } \ - } while (0) - #define JOIN_NAME_INNER(a, b) a##b #define JOIN_NAME(a, b) JOIN_NAME_INNER(a, b) #define COUNTER_NAME(a) JOIN_NAME(a, __COUNTER__) #define GE_RUN(prefix, func, ...) \ RUN_WITH_TIMESTAMP_NAME(COUNTER_NAME(ge_timestamp_##prefix), prefix, func, __VA_ARGS__) -#define GE_RUN_PERF(prefix, func, ...) \ - RUN_WITH_PERF_TIMESTAMP_NAME(COUNTER_NAME(ge_timestamp_##prefix), prefix, func, __VA_ARGS__) #endif // GE_GE_CALL_WRAPPER_H_ diff --git a/src/ge/graph/execute/graph_execute.cc b/src/ge/graph/execute/graph_execute.cc index 5ff89c07..b021ce55 100644 --- a/src/ge/graph/execute/graph_execute.cc +++ b/src/ge/graph/execute/graph_execute.cc @@ -120,7 +120,7 @@ Status GraphExecutor::FreeInOutBuffer() { } } -Status GraphExecutor::MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr) { +Status GraphExecutor::MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr) { if (malloc_flag_) { auto all_size_same = true; if (buffer_size.size() == buffer_size_.size()) { @@ -169,7 +169,7 @@ Status GraphExecutor::PrepareInputData(const std::vector &input_tensor graph_input_data.timestamp = 0; std::size_t inputSize = input_tensor.size(); std::size_t output_size = output_desc.size(); - std::vector bufferSizeVec; + std::vector bufferSizeVec; std::vector addrVec; for (std::size_t i = 0; i < inputSize; ++i) { @@ -211,7 +211,7 @@ Status GraphExecutor::PrepareInputData(const std::vector &input_tensor for (std::size_t j = 0; j < output_size; j++) { auto desc = output_desc[j]; - uint64_t buffer_size = desc.size; + uint32_t buffer_size = desc.size; DataBuffer out_data_buf; out_data_buf.data = reinterpret_cast(addrVec[inputSize + j]); @@ -225,13 +225,6 @@ Status GraphExecutor::PrepareInputData(const std::vector &input_tensor Status GraphExecutor::SyncExecuteModel(uint32_t model_id, const std::vector &input_tensor, std::vector &output_tensor) { - auto model_manager = ge::ModelManager::GetInstance(); - GE_CHECK_NOTNULL(model_manager); - if (model_manager->IsDynamicShape(model_id)) { - GELOGI("[ExecuteGraph] GetInputOutputDescInfo via dynamic shape model executor, modelId=%u", model_id); - return model_manager->SyncExecuteModel(model_id, input_tensor, output_tensor); - } - // Prepare input and output std::vector inputs_desc; std::vector output_desc; @@ -582,4 +575,5 @@ Status GraphExecutor::GetAllAippInputOutputDims(uint32_t model_id, uint32_t inde return SUCCESS; } + } // namespace ge diff --git a/src/ge/graph/execute/graph_execute.h b/src/ge/graph/execute/graph_execute.h index 6919a439..0518cf11 100644 --- a/src/ge/graph/execute/graph_execute.h +++ b/src/ge/graph/execute/graph_execute.h @@ -110,7 +110,7 @@ class GraphExecutor { Status FreeInOutBuffer(); - Status MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr); + Status MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr); bool init_flag_; @@ -129,7 +129,7 @@ class GraphExecutor { bool malloc_flag_; std::vector buffer_addr_; - std::vector buffer_size_; + std::vector buffer_size_; }; } // namespace ge diff --git a/src/ge/graph/load/graph_loader.cc b/src/ge/graph/load/graph_loader.cc index 4a986308..1f4cbcf9 100644 --- a/src/ge/graph/load/graph_loader.cc +++ b/src/ge/graph/load/graph_loader.cc @@ -350,8 +350,7 @@ Status GraphLoader::GetMemoryInfo(int64_t &free) { return RT_FAILED; } // Add small page memory size - free = - static_cast(free_mem + VarManager::Instance(GetContext().SessionId())->GetUseMaxMemorySize() - total_mem); + free = static_cast(free_mem + VarManager::Instance(0)->GetUseMaxMemorySize() - total_mem); GELOGI("GetMemoryInfo free[%zu], total[%zu], return free[%ld]", free_mem, total_mem, free); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc index a0011b34..06111015 100644 --- a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc +++ b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc @@ -339,7 +339,7 @@ Status CpuTaskActiveEntry::Distribute() { return RT_FAILED; } - GELOGI("Cpu kernel launch active entry task success."); + GELOGI("Cpu kernel launch wait end task success."); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/data_dumper.cc b/src/ge/graph/load/new_model_manager/data_dumper.cc index a4fe8898..653a3fa1 100644 --- a/src/ge/graph/load/new_model_manager/data_dumper.cc +++ b/src/ge/graph/load/new_model_manager/data_dumper.cc @@ -21,6 +21,7 @@ #include #include +#include "common/debug/log.h" #include "common/properties_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" @@ -36,36 +37,9 @@ namespace { const uint32_t kAicpuLoadFlag = 1; const uint32_t kAicpuUnloadFlag = 0; -const int64_t kOpDebugSize = 2048; -const int64_t kOpDebugShape = 2048; -const int8_t kDecimal = 10; -const uint32_t kAddrLen = sizeof(void *); const char *const kDumpOutput = "output"; const char *const kDumpInput = "input"; const char *const kDumpAll = "all"; - -// parse for format like nodename:input:index -static bool ParseNameIndex(const std::string &node_name_index, std::string &node_name, std::string &input_or_output, - size_t &index) { - auto sep = node_name_index.rfind(':'); - if (sep == std::string::npos) { - return false; - } - auto index_str = node_name_index.substr(sep + 1); - index = static_cast(std::strtol(index_str.c_str(), nullptr, kDecimal)); - auto node_name_without_index = node_name_index.substr(0, sep); - sep = node_name_without_index.rfind(':'); - if (sep == std::string::npos) { - return false; - } - node_name = node_name_without_index.substr(0, sep); - input_or_output = node_name_without_index.substr(sep + 1); - return !(input_or_output != kDumpInput && input_or_output != kDumpOutput); -} - -static bool IsTensorDescWithSkipDumpAddrType(bool has_mem_type_attr, vector v_memory_type, size_t i) { - return has_mem_type_attr && (v_memory_type[i] == RT_MEMORY_L1); -} } // namespace static int32_t GetIrDataType(ge::DataType data_type) { @@ -164,13 +138,6 @@ void DataDumper::SaveEndGraphId(uint32_t task_id, uint32_t stream_id) { end_graph_stream_id_ = stream_id; } -void DataDumper::SaveOpDebugId(uint32_t task_id, uint32_t stream_id, void *op_debug_addr, bool is_op_debug) { - op_debug_task_id_ = task_id; - op_debug_stream_id_ = stream_id; - op_debug_addr_ = op_debug_addr; - is_op_debug_ = is_op_debug; -} - void DataDumper::SaveDumpTask(uint32_t task_id, uint32_t stream_id, const std::shared_ptr &op_desc, uintptr_t args) { if (op_desc == nullptr) { @@ -235,121 +202,56 @@ static void SetOpMappingLoopAddr(uintptr_t step_id, uintptr_t loop_per_iter, uin } } -Status DataDumper::GenerateOutput(aicpu::dump::Output &output, const OpDesc::Vistor &tensor_descs, - const uintptr_t &addr, size_t index) { - output.set_data_type(static_cast(GetIrDataType(tensor_descs.at(index).GetDataType()))); - output.set_format(static_cast(tensor_descs.at(index).GetFormat())); +Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { + GELOGI("Start dump output"); + if (inner_dump_info.is_task) { + // tbe or aicpu op + const auto &output_descs = inner_dump_info.op->GetAllOutputsDesc(); + const auto input_size = inner_dump_info.op->GetAllInputsDesc().size(); + const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op, false); + if (output_descs.size() != output_addrs.size()) { + GELOGE(PARAM_INVALID, "Invalid output desc addrs size %zu, op %s has %zu output desc.", output_addrs.size(), + inner_dump_info.op->GetName().c_str(), output_descs.size()); + return PARAM_INVALID; + } - for (auto dim : tensor_descs.at(index).GetShape().GetDims()) { - output.mutable_shape()->add_dim(dim); - } - int64_t output_size = 0; - if (TensorUtils::GetTensorSizeInBytes(tensor_descs.at(index), output_size) != SUCCESS) { - GELOGE(PARAM_INVALID, "Get output size filed"); - return PARAM_INVALID; - } - GELOGD("Get output size in dump is %ld", output_size); - std::string origin_name; - int32_t origin_output_index = -1; - (void)AttrUtils::GetStr(&tensor_descs.at(index), ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); - (void)AttrUtils::GetInt(&tensor_descs.at(index), ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index); - output.set_size(output_size); - output.set_original_name(origin_name); - output.set_original_output_index(origin_output_index); - output.set_original_output_format(static_cast(tensor_descs.at(index).GetOriginFormat())); - output.set_original_output_data_type(static_cast(tensor_descs.at(index).GetOriginDataType())); - output.set_address(static_cast(addr)); - return SUCCESS; -} + for (size_t i = 0; i < output_descs.size(); ++i) { + aicpu::dump::Output output; + output.set_data_type(static_cast(GetIrDataType(output_descs.at(i).GetDataType()))); + output.set_format(static_cast(output_descs.at(i).GetFormat())); -Status DataDumper::DumpRefOutput(const DataDumper::InnerDumpInfo &inner_dump_info, aicpu::dump::Output &output, - size_t i, const std::string &node_name_index) { - std::string dump_op_name; - std::string input_or_output; - size_t index; - // parser and find which node's input or output tensor desc is chosen for dump info - if (!ParseNameIndex(node_name_index, dump_op_name, input_or_output, index)) { - GELOGE(PARAM_INVALID, "Op [%s] output desc[%zu] with invalid ATTR_DATA_DUMP_REF attr[%s].", - inner_dump_info.op->GetName().c_str(), i, node_name_index.c_str()); - return PARAM_INVALID; - } - GE_CHECK_NOTNULL(compute_graph_); - auto replace_node = compute_graph_->FindNode(dump_op_name); - GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(replace_node == nullptr, - "Op [%s] output desc[%zu] with invalid ATTR_DATA_DUMP_REF attr[%s]," - " cannot find redirect node[%s].", - inner_dump_info.op->GetName().c_str(), i, node_name_index.c_str(), - dump_op_name.c_str()); - auto replace_opdesc = replace_node->GetOpDesc(); - GE_CHECK_NOTNULL(replace_opdesc); - auto iter = ref_info_.find(replace_opdesc); - GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(iter == ref_info_.end(), - "Op [%s] output desc[%zu] cannot find any saved redirect node[%s]'s info.", - inner_dump_info.op->GetName().c_str(), i, replace_opdesc->GetName().c_str()); - GE_CHECK_NOTNULL(iter->second); - auto addr = reinterpret_cast(iter->second); - if (input_or_output == kDumpInput) { - const auto &replace_input_descs = replace_opdesc->GetAllInputsDesc(); - addr += kAddrLen * index; - GE_CHK_STATUS_RET(GenerateOutput(output, replace_input_descs, addr, index), "Generate output failed"); - } else if (input_or_output == kDumpOutput) { - const auto &replace_output_descs = replace_opdesc->GetAllOutputsDesc(); - const auto replace_input_size = replace_opdesc->GetAllInputsDesc().size(); - addr += (index + replace_input_size) * kAddrLen; - GE_CHK_STATUS_RET(GenerateOutput(output, replace_output_descs, addr, index), "Generate output failed"); - } - GELOGD("Op [%s] output desc[%zu] dump info is replaced by node[%s] [%s] tensor_desc [%zu]", - inner_dump_info.op->GetName().c_str(), i, dump_op_name.c_str(), input_or_output.c_str(), index); - return SUCCESS; -} + for (auto dim : output_descs.at(i).GetShape().GetDims()) { + output.mutable_shape()->add_dim(dim); + } -Status DataDumper::DumpOutputWithTask(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { - const auto &output_descs = inner_dump_info.op->GetAllOutputsDesc(); - const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op); - if (output_descs.size() != output_addrs.size()) { - GELOGE(PARAM_INVALID, "Invalid output desc addrs size %zu, op %s has %zu output desc.", output_addrs.size(), - inner_dump_info.op->GetName().c_str(), output_descs.size()); - return PARAM_INVALID; - } - std::vector v_memory_type; - bool has_mem_type_attr = ge::AttrUtils::GetListInt(inner_dump_info.op, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, v_memory_type); - GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(has_mem_type_attr && (v_memory_type.size() != output_descs.size()), - "DumpOutputWithTask[%s], output size[%zu], output memory type size[%zu]", - inner_dump_info.op->GetName().c_str(), output_descs.size(), - v_memory_type.size()); - - for (size_t i = 0; i < output_descs.size(); ++i) { - aicpu::dump::Output output; - std::string node_name_index; - const auto &output_desc = output_descs.at(i); - // check dump output tensor desc is redirected by attr ATTR_DATA_DUMP_REF - if (AttrUtils::GetStr(&output_desc, ATTR_DATA_DUMP_REF, node_name_index)) { - GE_CHK_STATUS_RET(DumpRefOutput(inner_dump_info, output, i, node_name_index), "DumpRefOutput failed"); - } else { - GE_IF_BOOL_EXEC( - IsTensorDescWithSkipDumpAddrType(has_mem_type_attr, v_memory_type, i), - GELOGD("DumpOutputWithTask[%s] output[%zu] is l1 addr, skip it", inner_dump_info.op->GetName().c_str(), i); - continue;); - - const auto input_size = inner_dump_info.op->GetInputsSize(); - auto addr = inner_dump_info.args + (i + input_size) * kAddrLen; - GE_CHK_STATUS_RET(GenerateOutput(output, output_descs, addr, i), "Generate output failed"); + int64_t output_size = 0; + if (TensorUtils::GetTensorSizeInBytes(output_descs.at(i), output_size) != SUCCESS) { + GELOGE(PARAM_INVALID, "Get output size filed"); + return PARAM_INVALID; + } + GELOGI("Get output size in dump is %ld", output_size); + std::string origin_name; + int32_t origin_output_index = -1; + (void)AttrUtils::GetStr(&output_descs.at(i), ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); + (void)AttrUtils::GetInt(&output_descs.at(i), ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index); + GE_IF_BOOL_EXEC(output_size <= 0, GELOGE(PARAM_INVALID, "Output size %ld is less than zero", output_size); + return PARAM_INVALID) + output.set_size(output_size); + output.set_original_name(origin_name); + output.set_original_output_index(origin_output_index); + output.set_original_output_format(static_cast(output_descs.at(i).GetOriginFormat())); + output.set_original_output_data_type(static_cast(output_descs.at(i).GetOriginDataType())); + output.set_address(static_cast(inner_dump_info.args + (i + input_size) * sizeof(void *))); + + task.mutable_output()->Add(std::move(output)); } - task.mutable_output()->Add(std::move(output)); + return SUCCESS; } - return SUCCESS; -} -Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { - GELOGI("Start dump output"); - if (inner_dump_info.is_task) { - // tbe or aicpu op, these ops are with task - return DumpOutputWithTask(inner_dump_info, task); - } // else data, const or variable op aicpu::dump::Output output; auto output_tensor = inner_dump_info.op->GetOutputDescPtr(inner_dump_info.output_anchor_index); - const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op); + const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op, false); if (output_tensor == nullptr) { GELOGE(PARAM_INVALID, "output_tensor is null, index: %d, size: %zu.", inner_dump_info.output_anchor_index, inner_dump_info.op->GetOutputsSize()); @@ -367,6 +269,9 @@ Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump: int32_t origin_output_index = -1; (void)AttrUtils::GetStr(output_tensor, ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); (void)AttrUtils::GetInt(output_tensor, ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index); + GE_IF_BOOL_EXEC(inner_dump_info.data_size <= 0, + GELOGE(PARAM_INVALID, "The size of data %ld is less than zero", inner_dump_info.data_size); + return PARAM_INVALID) output.set_size(inner_dump_info.data_size); output.set_original_name(origin_name); output.set_original_output_index(origin_output_index); @@ -377,7 +282,7 @@ Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump: GELOGE(FAILED, "Index is out of range."); return FAILED; } - auto data_addr = inner_dump_info.args + kAddrLen * static_cast(inner_dump_info.input_anchor_index); + auto data_addr = inner_dump_info.args + sizeof(void *) * static_cast(inner_dump_info.input_anchor_index); output.set_address(static_cast(data_addr)); task.mutable_output()->Add(std::move(output)); @@ -385,98 +290,37 @@ Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump: return SUCCESS; } -Status DataDumper::GenerateInput(aicpu::dump::Input &input, const OpDesc::Vistor &tensor_descs, - const uintptr_t &addr, size_t index) { - input.set_data_type(static_cast(GetIrDataType(tensor_descs.at(index).GetDataType()))); - input.set_format(static_cast(tensor_descs.at(index).GetFormat())); - - for (auto dim : tensor_descs.at(index).GetShape().GetDims()) { - input.mutable_shape()->add_dim(dim); - } - int64_t input_size = 0; - if (AttrUtils::GetInt(tensor_descs.at(index), ATTR_NAME_INPUT_ORIGIN_SIZE, input_size)) { - GELOGI("Get aipp input size according to attr is %ld", input_size); - } else if (TensorUtils::GetTensorSizeInBytes(tensor_descs.at(index), input_size) != SUCCESS) { - GELOGE(PARAM_INVALID, "Get input size filed"); - return PARAM_INVALID; - } - GELOGD("Get input size in dump is %ld", input_size); - input.set_size(input_size); - input.set_address(static_cast(addr)); - return SUCCESS; -} - -Status DataDumper::DumpRefInput(const DataDumper::InnerDumpInfo &inner_dump_info, aicpu::dump::Input &input, size_t i, - const std::string &node_name_index) { - std::string dump_op_name; - std::string input_or_output; - size_t index; - // parser and find which node's input or output tensor desc is chosen for dump info - if (!ParseNameIndex(node_name_index, dump_op_name, input_or_output, index)) { - GELOGE(PARAM_INVALID, "Op [%s] input desc[%zu] with invalid ATTR_DATA_DUMP_REF attr[%s].", - inner_dump_info.op->GetName().c_str(), i, node_name_index.c_str()); - return PARAM_INVALID; - } - GE_CHECK_NOTNULL(compute_graph_); - auto replace_node = compute_graph_->FindNode(dump_op_name); - GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(replace_node == nullptr, - "Op [%s] input desc[%zu] with invalid ATTR_DATA_DUMP_REF attr[%s]," - " cannot find redirect node[%s].", - inner_dump_info.op->GetName().c_str(), i, node_name_index.c_str(), - dump_op_name.c_str()); - auto replace_opdesc = replace_node->GetOpDesc(); - GE_CHECK_NOTNULL(replace_opdesc); - auto iter = ref_info_.find(replace_opdesc); - GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(iter == ref_info_.end(), - "Op [%s] input desc[%zu] cannot find any saved redirect node[%s]'s info.", - inner_dump_info.op->GetName().c_str(), i, replace_opdesc->GetName().c_str()); - GE_CHECK_NOTNULL(iter->second); - auto addr = reinterpret_cast(iter->second); - if (input_or_output == kDumpInput) { - const auto &replace_input_descs = replace_opdesc->GetAllInputsDesc(); - addr += kAddrLen * index; - GE_CHK_STATUS_RET(GenerateInput(input, replace_input_descs, addr, index), "Generate input failed"); - } else if (input_or_output == kDumpOutput) { - const auto &replace_output_descs = replace_opdesc->GetAllOutputsDesc(); - const auto replace_input_size = replace_opdesc->GetAllInputsDesc().size(); - addr += (index + replace_input_size) * kAddrLen; - GE_CHK_STATUS_RET(GenerateInput(input, replace_output_descs, addr, index), "Generate input failed"); - } - GELOGD("Op [%s] input desc[%zu] dump info is replaced by node[%s] [%s] tensor_desc [%zu]", - inner_dump_info.op->GetName().c_str(), i, dump_op_name.c_str(), input_or_output.c_str(), index); - return SUCCESS; -} - Status DataDumper::DumpInput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { GELOGI("Start dump input"); const auto &input_descs = inner_dump_info.op->GetAllInputsDesc(); - const std::vector input_addrs = ModelUtils::GetInputDataAddrs(runtime_param_, inner_dump_info.op); + const std::vector input_addrs = ModelUtils::GetInputDataAddrs(runtime_param_, inner_dump_info.op, false); if (input_descs.size() != input_addrs.size()) { GELOGE(PARAM_INVALID, "Invalid input desc addrs size %zu, op %s has %zu input desc.", input_addrs.size(), inner_dump_info.op->GetName().c_str(), input_descs.size()); return PARAM_INVALID; } - std::vector v_memory_type; - bool has_mem_type_attr = ge::AttrUtils::GetListInt(inner_dump_info.op, ATTR_NAME_INPUT_MEM_TYPE_LIST, v_memory_type); - GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(has_mem_type_attr && (v_memory_type.size() != input_descs.size()), - "DumpInput[%s], input size[%zu], input memory type size[%zu]", - inner_dump_info.op->GetName().c_str(), input_descs.size(), v_memory_type.size()); for (size_t i = 0; i < input_descs.size(); ++i) { aicpu::dump::Input input; - std::string node_name_index; - // check dump input tensor desc is redirected by attr ATTR_DATA_DUMP_REF - if (AttrUtils::GetStr(&input_descs.at(i), ATTR_DATA_DUMP_REF, node_name_index)) { - GE_CHK_STATUS_RET(DumpRefInput(inner_dump_info, input, i, node_name_index), "DumpRefInput failed"); - // normal dump without attr - } else { - GE_IF_BOOL_EXEC(IsTensorDescWithSkipDumpAddrType(has_mem_type_attr, v_memory_type, i), - GELOGD("DumpInput[%s] input[%zu] is l1 addr, skip it", inner_dump_info.op->GetName().c_str(), i); - continue;); - - auto addr = inner_dump_info.args + kAddrLen * i; - GE_CHK_STATUS_RET(GenerateInput(input, input_descs, addr, i), "Generate input failed"); + input.set_data_type(static_cast(GetIrDataType(input_descs.at(i).GetDataType()))); + input.set_format(static_cast(input_descs.at(i).GetFormat())); + + for (auto dim : input_descs.at(i).GetShape().GetDims()) { + input.mutable_shape()->add_dim(dim); } + + int64_t input_size = 0; + if (AttrUtils::GetInt(&input_descs.at(i), ATTR_NAME_INPUT_ORIGIN_SIZE, input_size)) { + GELOGI("Get aipp input size according to attr is %ld", input_size); + } else if (TensorUtils::GetTensorSizeInBytes(input_descs.at(i), input_size) != SUCCESS) { + GELOGE(PARAM_INVALID, "Get input size filed"); + return PARAM_INVALID; + } + GELOGI("Get input size in dump is %ld", input_size); + GE_IF_BOOL_EXEC(input_size <= 0, GELOGE(PARAM_INVALID, "Input size %ld is less than zero", input_size); + return PARAM_INVALID;) + input.set_size(input_size); + input.set_address(static_cast(inner_dump_info.args + sizeof(void *) * i)); task.mutable_input()->Add(std::move(input)); } return SUCCESS; @@ -556,38 +400,36 @@ Status DataDumper::ExecuteUnLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_ GELOGI("UnloadDumpInfo success, proto size is: %zu.", proto_size); return SUCCESS; } - Status DataDumper::LoadDumpInfo() { std::string dump_list_key; PrintCheckLog(dump_list_key); if (op_list_.empty()) { - GELOGW("op_list_ is empty"); + return SUCCESS; } aicpu::dump::OpMappingInfo op_mapping_info; - auto dump_path = dump_properties_.GetDumpPath() + std::to_string(device_id_) + "/"; - op_mapping_info.set_dump_path(dump_path); + auto dump_path = PropertiesManager::Instance().GetDumpOutputPath(); + op_mapping_info.set_dump_path(PropertiesManager::Instance().GetDumpOutputPath() + std::to_string(device_id_) + "/"); op_mapping_info.set_model_name(dump_list_key); op_mapping_info.set_model_id(model_id_); op_mapping_info.set_flag(kAicpuLoadFlag); - op_mapping_info.set_dump_step(dump_properties_.GetDumpStep()); + op_mapping_info.set_dump_step(PropertiesManager::Instance().GetDumpStep()); SetOpMappingLoopAddr(global_step_, loop_per_iter_, loop_cond_, op_mapping_info); - GELOGI("Dump step is %s and dump path is %s in load dump info", dump_properties_.GetDumpStep().c_str(), + GELOGI("Dump step is %s and dump path is %s in load dump info", PropertiesManager::Instance().GetDumpStep().c_str(), dump_path.c_str()); for (const auto &op_iter : op_list_) { - auto op_desc = op_iter.op; - GELOGD("Op %s in model %s begin to add task in op_mapping_info", op_desc->GetName().c_str(), dump_list_key.c_str()); aicpu::dump::Task task; + auto op_desc = op_iter.op; task.set_end_graph(false); task.set_task_id(op_iter.task_id); task.set_stream_id(op_iter.stream_id); task.mutable_op()->set_op_name(op_desc->GetName()); task.mutable_op()->set_op_type(op_desc->GetType()); - if (dump_properties_.GetDumpMode() == kDumpOutput) { + if (PropertiesManager::Instance().GetDumpMode() == kDumpOutput) { if (DumpOutput(op_iter, task) != SUCCESS) { GELOGE(FAILED, "Dump output failed"); return FAILED; @@ -595,7 +437,7 @@ Status DataDumper::LoadDumpInfo() { op_mapping_info.mutable_task()->Add(std::move(task)); continue; } - if (dump_properties_.GetDumpMode() == kDumpInput) { + if (PropertiesManager::Instance().GetDumpMode() == kDumpInput) { if (op_iter.is_task) { if (DumpInput(op_iter, task) != SUCCESS) { GELOGE(FAILED, "Dump input failed"); @@ -605,7 +447,7 @@ Status DataDumper::LoadDumpInfo() { op_mapping_info.mutable_task()->Add(std::move(task)); continue; } - if (dump_properties_.GetDumpMode() == kDumpAll) { + if (PropertiesManager::Instance().GetDumpMode() == kDumpAll) { auto ret = DumpOutput(op_iter, task); if (ret != SUCCESS) { GELOGE(FAILED, "Dump output failed when in dumping all"); @@ -625,22 +467,19 @@ Status DataDumper::LoadDumpInfo() { SetEndGraphIdToAicpu(end_graph_task_id_, end_graph_stream_id_, op_mapping_info); - SetOpDebugIdToAicpu(op_debug_task_id_, op_debug_stream_id_, op_debug_addr_, op_mapping_info); - - if (!op_list_.empty() || is_op_debug_) { - auto ret = ExecuteLoadDumpInfo(op_mapping_info); - if (ret != SUCCESS) { - GELOGE(FAILED, "Execute load dump info failed"); - return FAILED; - } + auto ret = ExecuteLoadDumpInfo(op_mapping_info); + if (ret != SUCCESS) { + GELOGE(FAILED, "Execute load dump info failed"); + return FAILED; } return SUCCESS; } void DataDumper::SetEndGraphIdToAicpu(uint32_t task_id, uint32_t stream_id, aicpu::dump::OpMappingInfo &op_mapping_info) { - if (dump_properties_.GetDumpMode() == kDumpOutput || dump_properties_.GetDumpMode() == kDumpInput || - dump_properties_.GetDumpMode() == kDumpAll) { + if (PropertiesManager::Instance().GetDumpMode() == kDumpOutput || + PropertiesManager::Instance().GetDumpMode() == kDumpInput || + PropertiesManager::Instance().GetDumpMode() == kDumpAll) { GELOGI("Add end_graph_info to aicpu, task_id is %u, stream_id is %u", end_graph_task_id_, end_graph_stream_id_); aicpu::dump::Task task; task.set_end_graph(true); @@ -652,37 +491,6 @@ void DataDumper::SetEndGraphIdToAicpu(uint32_t task_id, uint32_t stream_id, } } -void DataDumper::SetOpDebugIdToAicpu(uint32_t task_id, uint32_t stream_id, void *op_debug_addr, - aicpu::dump::OpMappingInfo &op_mapping_info) { - if (is_op_debug_) { - GELOGI("add op_debug_info to aicpu, task_id is %u, stream_id is %u", task_id, stream_id); - aicpu::dump::Task task; - task.set_end_graph(false); - task.set_task_id(task_id); - task.set_stream_id(stream_id); - task.mutable_op()->set_op_name(NODE_NAME_OP_DEBUG); - task.mutable_op()->set_op_type(OP_TYPE_OP_DEBUG); - - // set output - aicpu::dump::Output output; - output.set_data_type(DT_UINT8); - output.set_format(FORMAT_ND); - - output.mutable_shape()->add_dim(kOpDebugShape); - - output.set_original_name(NODE_NAME_OP_DEBUG); - output.set_original_output_index(0); - output.set_original_output_format(FORMAT_ND); - output.set_original_output_data_type(DT_UINT8); - // due to lhisi virtual addr bug, cannot use args now - output.set_address(static_cast(reinterpret_cast(op_debug_addr))); - output.set_size(kOpDebugSize); - - task.mutable_output()->Add(std::move(output)); - op_mapping_info.mutable_task()->Add(std::move(task)); - } -} - Status DataDumper::UnloadDumpInfo() { if (!load_flag_) { GELOGI("No need to UnloadDumpInfo."); @@ -709,17 +517,15 @@ Status DataDumper::UnloadDumpInfo() { } void DataDumper::PrintCheckLog(string &dump_list_key) { - std::set model_list = dump_properties_.GetAllDumpModel(); + std::set model_list = PropertiesManager::Instance().GetAllDumpModel(); if (model_list.empty()) { GELOGI("No model need dump."); return; } + GELOGI("%zu op need dump in %s.", op_list_.size(), model_name_.c_str()); bool not_find_by_omname = model_list.find(om_name_) == model_list.end(); bool not_find_by_modelname = model_list.find(model_name_) == model_list.end(); - dump_list_key = not_find_by_omname ? model_name_ : om_name_; - GELOGI("%zu op need dump in %s.", op_list_.size(), dump_list_key.c_str()); - if (model_list.find(DUMP_ALL_MODEL) == model_list.end()) { if (not_find_by_omname && not_find_by_modelname) { std::string model_list_str; @@ -727,12 +533,12 @@ void DataDumper::PrintCheckLog(string &dump_list_key) { model_list_str += "[" + model + "]."; } - GELOGW("Model %s will not be set to dump, dump list: %s", dump_list_key.c_str(), model_list_str.c_str()); + GELOGW("Model %s will not be set to dump, dump list: %s", model_name_.c_str(), model_list_str.c_str()); return; } } - - std::set config_dump_op_list = dump_properties_.GetPropertyValue(dump_list_key); + dump_list_key = not_find_by_omname ? model_name_ : om_name_; + std::set config_dump_op_list = PropertiesManager::Instance().GetDumpPropertyValue(dump_list_key); std::set dump_op_list; for (auto &inner_dump_info : op_list_) { // oplist value OpDescPtr is not nullptr diff --git a/src/ge/graph/load/new_model_manager/data_dumper.h b/src/ge/graph/load/new_model_manager/data_dumper.h index 0648a8ce..ee5b3241 100644 --- a/src/ge/graph/load/new_model_manager/data_dumper.h +++ b/src/ge/graph/load/new_model_manager/data_dumper.h @@ -23,9 +23,7 @@ #include #include "framework/common/ge_inner_error_codes.h" -#include "common/properties_manager.h" #include "graph/node.h" -#include "graph/compute_graph.h" #include "proto/ge_ir.pb.h" #include "proto/op_mapping_info.pb.h" #include "runtime/mem.h" @@ -46,9 +44,7 @@ class DataDumper { device_id_(0), global_step_(0), loop_per_iter_(0), - loop_cond_(0), - compute_graph_(nullptr), - ref_info_() {} + loop_cond_(0) {} ~DataDumper(); @@ -60,10 +56,6 @@ class DataDumper { void SetDeviceId(uint32_t device_id) { device_id_ = device_id; } - void SetComputeGraph(const ComputeGraphPtr &compute_graph) { compute_graph_ = compute_graph; }; - - void SetRefInfo(const std::map &ref_info) { ref_info_ = ref_info; }; - void SetLoopAddr(void *global_step, void *loop_per_iter, void *loop_cond); void SaveDumpInput(const std::shared_ptr &node); @@ -73,15 +65,11 @@ class DataDumper { void SaveEndGraphId(uint32_t task_id, uint32_t stream_id); void SetOmName(const std::string &om_name) { om_name_ = om_name; } - void SaveOpDebugId(uint32_t task_id, uint32_t stream_id, void *op_debug_addr, bool is_op_debug); Status LoadDumpInfo(); Status UnloadDumpInfo(); - void SetDumpProperties(const DumpProperties &dump_properties) { dump_properties_ = dump_properties; } - const DumpProperties &GetDumpProperties() const { return dump_properties_; } - private: void ReleaseDevMem(void **ptr) noexcept; @@ -109,32 +97,12 @@ class DataDumper { uintptr_t global_step_; uintptr_t loop_per_iter_; uintptr_t loop_cond_; - ComputeGraphPtr compute_graph_; - std::map ref_info_; - - uint32_t op_debug_task_id_ = 0; - uint32_t op_debug_stream_id_ = 0; - void *op_debug_addr_ = nullptr; - bool is_op_debug_ = false; - - DumpProperties dump_properties_; Status DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task); - Status DumpRefOutput(const DataDumper::InnerDumpInfo &inner_dump_info, aicpu::dump::Output &output, size_t i, - const std::string &node_name_index); - Status DumpOutputWithTask(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task); Status DumpInput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task); - Status DumpRefInput(const DataDumper::InnerDumpInfo &inner_dump_info, aicpu::dump::Input &input, size_t i, - const std::string &node_name_index); Status ExecuteLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_info); void SetEndGraphIdToAicpu(uint32_t task_id, uint32_t stream_id, aicpu::dump::OpMappingInfo &op_mapping_info); - void SetOpDebugIdToAicpu(uint32_t task_id, uint32_t stream_id, void *op_debug_addr, - aicpu::dump::OpMappingInfo &op_mapping_info); Status ExecuteUnLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_info); - Status GenerateInput(aicpu::dump::Input &input, const OpDesc::Vistor &tensor_descs, - const uintptr_t &addr, size_t index); - Status GenerateOutput(aicpu::dump::Output &output, const OpDesc::Vistor &tensor_descs, - const uintptr_t &addr, size_t index); }; struct DataDumper::InnerDumpInfo { uint32_t task_id; diff --git a/src/ge/graph/load/new_model_manager/davinci_model.cc b/src/ge/graph/load/new_model_manager/davinci_model.cc index c43c37eb..a8a11fd9 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model.cc +++ b/src/ge/graph/load/new_model_manager/davinci_model.cc @@ -42,11 +42,11 @@ #include "graph/graph.h" #include "graph/load/new_model_manager/cpu_queue_schedule.h" #include "graph/load/new_model_manager/tbe_handle_store.h" +#include "graph/load/output/output.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_var_manager.h" #include "graph/manager/trans_var_data_utils.h" #include "graph/manager/util/debug.h" -#include "graph/common/ge_call_wrapper.h" #include "graph/model_serialize.h" #include "graph/node.h" #include "graph/utils/graph_utils.h" @@ -59,7 +59,6 @@ #include "runtime/event.h" #include "runtime/mem.h" #include "runtime/stream.h" -#include "runtime/rt_model.h" #include "securec.h" // create std::thread, catch exceptions using try/catch @@ -79,8 +78,9 @@ namespace { const uint32_t kDataIndex = 0; const uint32_t kOutputNum = 1; const uint32_t kTrueBranchStreamNum = 1; -const uint32_t kThreadNum = 16; +const uint32_t kThreadNum = 1; const uint32_t kAddrLen = sizeof(void *); +const char *const kNeedDestroySpecifiedAicpuKernel = "need_destroy_specified_aicpu_kernel"; const int kDecimal = 10; const int kBytes = 8; const uint32_t kDataMemAlignSizeCompare = 64; @@ -89,10 +89,10 @@ const char *const kDefaultBatchLable = "Batch_default"; inline bool IsDataOp(const std::string &node_type) { return node_type == DATA_TYPE || node_type == AIPP_DATA_TYPE || node_type == ANN_DATA_TYPE; } -inline bool IsNoTaskAndDumpNeeded(const OpDescPtr &op_desc) { - bool save_dump_info = false; - (void)ge::AttrUtils::GetBool(op_desc, ATTR_NO_TASK_AND_DUMP_NEEDED, save_dump_info); - return save_dump_info; +inline bool IsCallDumpInputOp(const OpDescPtr &op_desc) { + bool skip_task_generate = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NO_TASK_AND_DUMP_NEEDED, skip_task_generate); + return skip_task_generate; } } // namespace @@ -125,10 +125,10 @@ DavinciModel::DavinciModel(int32_t priority, const std::shared_ptrGetModelTaskDefPtr(); return SUCCESS; } -/// -/// @ingroup ge -/// @brief Reduce memory usage after task sink. -/// @return: void -/// -void DavinciModel::Shrink() { - ge_model_.reset(); // delete object. - - // Old dump need op list, clear when closed. - char *ge_dump_env = std::getenv("DUMP_OP"); - int dump_op_switch = (ge_dump_env != nullptr) ? std::strtol(ge_dump_env, nullptr, kDecimal) : 0; - if (dump_op_switch == 0) { - op_list_.clear(); - } -} - Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { if (is_model_has_inited_) { GELOGI("call InitModelMem more than once ."); return FAILED; } is_model_has_inited_ = true; - std::size_t data_size = TotalMemSize(); - const Buffer &weights = ge_model_->GetWeight(); + ge::Buffer weights = ge_model_->GetWeight(); + + uint8_t *weights_addr = weights.GetData(); std::size_t weights_size = weights.GetSize(); + GE_CHECK_LE(weights_size, ALLOC_MEMORY_MAX_SIZE); if ((dev_ptr != nullptr) && (mem_size < TotalMemSize())) { @@ -308,7 +280,7 @@ Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_p } GELOGI("[IMAS]InitModelMem graph_%u MallocMemory type[W] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, weights_mem_base_, weights_size); - GE_CHK_RT_RET(rtMemcpy(weights_mem_base_, weights_size, weights.GetData(), weights_size, RT_MEMCPY_HOST_TO_DEVICE)); + GE_CHK_RT_RET(rtMemcpy(weights_mem_base_, weights_size, weights_addr, weights_size, RT_MEMCPY_HOST_TO_DEVICE)) GELOGI("copy weights data to device"); } @@ -363,15 +335,19 @@ void DavinciModel::InitRuntimeParams() { session_id_ = runtime_param_.session_id; GELOGI( - "InitRuntimeParams(), session_id:%u, stream_num:%lu, event_num:%u, label_num:%u, " - "logic_mem_base:0x%lx, logic_weight_base:0x%lx, logic_var_base:0x%lx, " - "memory_size:%lu, weight_size:%lu, var_size:%lu", - runtime_param_.session_id, runtime_param_.stream_num, runtime_param_.event_num, runtime_param_.label_num, - runtime_param_.logic_mem_base, runtime_param_.logic_weight_base, runtime_param_.logic_var_base, - runtime_param_.mem_size, runtime_param_.weight_size, runtime_param_.var_size); + "InitRuntimeParams(), memory_size:%lu, weight_size:%lu, session_id:%u, var_size:%lu, logic_var_base:%lu, " + "logic_mem_base:%lu.", + runtime_param_.mem_size, runtime_param_.weight_size, runtime_param_.session_id, runtime_param_.var_size, + runtime_param_.logic_var_base, runtime_param_.logic_mem_base); + + GELOGI("InitRuntimeParams(), stream_num:%lu, event_num:%u, label_num:%u", runtime_param_.stream_num, + runtime_param_.event_num, runtime_param_.label_num); } void DavinciModel::CheckHasHcomOp() { + // definiteness queue schedule, all stream by TS. + GE_IF_BOOL_EXEC(!input_queue_ids_.empty() || !output_queue_ids_.empty(), return ); + Graph graph = ge_model_->GetGraph(); auto compute_graph = GraphUtils::GetComputeGraph(graph); if (compute_graph == nullptr) { @@ -387,6 +363,11 @@ void DavinciModel::CheckHasHcomOp() { (op_desc->GetType() == HVDCALLBACKBROADCAST) || (op_desc->GetType() == HVDWAIT)), uint32_t stream_id = static_cast(op_desc->GetStreamId()); (void)hcom_streams_.emplace(stream_id); GELOGD("hcom stream: %u.", stream_id); continue); + + bool is_aicpu_stream = false; + GE_IF_BOOL_EXEC(AttrUtils::GetBool(op_desc, "is_aicpu_stream", is_aicpu_stream) && is_aicpu_stream, + uint32_t stream_id = static_cast(op_desc->GetStreamId()); + (void)aicpu_streams_.emplace(stream_id); GELOGD("aicpu stream: %u.", stream_id); continue); } } @@ -397,13 +378,20 @@ void DavinciModel::CheckHasHcomOp() { /// Status DavinciModel::BindModelStream() { // Stream not in active_stream_indication_ is active stream. - if ((!input_queue_ids_.empty() || !output_queue_ids_.empty()) || (deploy_type_ == AICPU_DEPLOY_CROSS_THREAD)) { + if (!input_queue_ids_.empty() || !output_queue_ids_.empty()) { + // Asynchronous Queue, need add S0, deactive all model stream. for (size_t i = 0; i < stream_list_.size(); ++i) { if (active_stream_indication_.count(i) == 0) { active_stream_list_.push_back(stream_list_[i]); active_stream_indication_.insert(i); // deactive all model stream. } } + } else { + for (size_t i = 0; i < stream_list_.size(); ++i) { + if (active_stream_indication_.count(i) == 0) { + active_stream_list_.push_back(stream_list_[i]); + } + } } for (size_t i = 0; i < stream_list_.size(); ++i) { @@ -421,29 +409,23 @@ Status DavinciModel::BindModelStream() { Status DavinciModel::DoTaskSink() { // task sink is supported as model_task_def is set - const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); - if (model_task_def == nullptr) { - return SUCCESS; - } - - GE_CHK_RT_RET(rtGetAicpuDeploy(&deploy_type_)); - GELOGI("do task_sink. AiCpu deploy type is: %x.", deploy_type_); + if (model_task_def_) { + GELOGI("do task_sink."); + GE_CHK_STATUS_RET(BindModelStream(), "Bind model stream failed."); - GE_CHK_STATUS_RET(BindModelStream(), "Bind model stream failed."); - - if (known_node_) { - GE_CHK_STATUS_RET(MallocKnownArgs(), "Mallloc known node args failed."); - } + if (known_node_) { + GE_CHK_STATUS_RET(MallocKnownArgs(), "Mallloc known node args failed."); + } - GE_CHK_STATUS_RET(InitTaskInfo(*model_task_def.get()), "InitTaskInfo failed."); + GE_CHK_STATUS_RET(InitTaskInfo(*model_task_def_.get()), "InitTaskInfo failed."); - GE_CHK_STATUS_RET(InitEntryTask(), "InitEntryTask failed."); + GE_CHK_STATUS_RET(LoadWithQueue(), "LoadWithQueue failed."); - GE_CHK_STATUS_RET(DistributeTask(), "Distribute failed."); + GE_CHK_STATUS_RET(DistributeTask(), "Distribute failed."); - GE_CHK_RT_RET(rtModelLoadComplete(rt_model_handle_)); + GE_CHK_RT_RET(rtModelLoadComplete(rt_model_handle_)); + } - SetCopyOnlyOutput(); return SUCCESS; } @@ -461,96 +443,12 @@ Status DavinciModel::SetTSDevice() { return SUCCESS; } -Status DavinciModel::OpDebugRegister() { - bool is_op_debug = false; - (void)ge::AttrUtils::GetBool(ge_model_, ATTR_OP_DEBUG_FLAG, is_op_debug); - GELOGD("The value of op_debug in ge_model_ is %d.", is_op_debug); - if (is_op_debug) { - debug_reg_mutex_.lock(); - rtError_t rt_ret = rtMalloc(&op_debug_addr_, kOpDebugMemorySize, RT_MEMORY_DDR); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "rtMalloc error, ret: 0x%X", rt_ret); - return RT_FAILED; - } - - uint64_t debug_addrs_tmp = static_cast(reinterpret_cast(op_debug_addr_)); - - // For data dump, aicpu needs the pointer to pointer that save the real debug address. - rt_ret = rtMalloc(&p2p_debug_addr_, kDebugP2pSize, RT_MEMORY_HBM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "rtMalloc error, ret: 0x%X", rt_ret); - return RT_FAILED; - } - rt_ret = rtMemcpy(p2p_debug_addr_, sizeof(uint64_t), &debug_addrs_tmp, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "rtMemcpy to p2p_addr error: 0x%X", rt_ret); - return FAILED; - } - - uint32_t op_debug_mode = 0; - (void)ge::AttrUtils::GetInt(ge_model_, ATTR_OP_DEBUG_MODE, op_debug_mode); - GELOGD("The value of op_debug_mode in ge_model_ is %u.", op_debug_mode); - uint32_t debug_task_id = 0; - uint32_t debug_stream_id = 0; - rt_ret = rtDebugRegister(rt_model_handle_, op_debug_mode, op_debug_addr_, &debug_stream_id, &debug_task_id); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "rtDebugRegister error, ret: 0x%X", rt_ret); - return RT_FAILED; - } - GELOGI("debug_task_id:%d, debug_stream_id:%u", debug_task_id, debug_stream_id); - is_op_debug_reg_ = true; - - data_dumper_.SaveOpDebugId(debug_task_id, debug_stream_id, p2p_debug_addr_, is_op_debug); - } - - return SUCCESS; -} - -void DavinciModel::OpDebugUnRegister() { - GELOGI("OpDebugUnRegister, is_op_debug_reg_ = %d", is_op_debug_reg_); - if (is_op_debug_reg_) { - debug_reg_mutex_.unlock(); - - rtError_t rt_ret = RT_ERROR_NONE; - if (rt_model_handle_ != nullptr) { - rt_ret = rtDebugUnRegister(rt_model_handle_); - if (rt_ret != RT_ERROR_NONE) { - GELOGW("rtDebugUnRegister failed, ret: 0x%X", rt_ret); - } - } - - if (op_debug_addr_ != nullptr) { - rt_ret = rtFree(op_debug_addr_); - if (rt_ret != RT_ERROR_NONE) { - GELOGW("rtFree failed, ret: 0x%X", rt_ret); - } - op_debug_addr_ = nullptr; - } - - if (p2p_debug_addr_ != nullptr) { - rt_ret = rtFree(p2p_debug_addr_); - if (rt_ret != RT_ERROR_NONE) { - GELOGW("rtFree failed, ret: 0x%X", rt_ret); - } - p2p_debug_addr_ = nullptr; - } - - is_op_debug_reg_ = false; - } - - return; -} - // initialize op sequence and call initialization function of each op respectively Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { // validating params GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(priority_ < 0 || priority_ > 7, return PARAM_INVALID, "Priority must between 0-7, now is %d", priority_); GE_CHK_BOOL_RET_STATUS(ge_model_ != nullptr, PARAM_INVALID, "GeModel is null."); - Graph graph = ge_model_->GetGraph(); - ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); - GE_CHK_BOOL_RET_STATUS(compute_graph != nullptr, INTERNAL_ERROR, "Get compute graph is nullptr."); - // Initializing runtime_param_ InitRuntimeParams(); @@ -579,6 +477,8 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size if (hcom_streams_.find(i) != hcom_streams_.end()) { GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, stream_flags | RT_STREAM_FORCE_COPY)); + } else if (aicpu_streams_.find(i) != aicpu_streams_.end()) { + GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, stream_flags | RT_STREAM_AICPU)); } else { GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, stream_flags)); } @@ -599,19 +499,20 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size // create model_handle to load model GE_CHK_RT_RET(rtModelCreate(&rt_model_handle_, 0)); GE_CHK_RT_RET(rtModelGetId(rt_model_handle_, &runtime_model_id_)); - // inference will use default graph_id 0; - runtime_param_.graph_id = compute_graph->GetGraphID(); - // op debug register - GE_CHK_STATUS_RET(OpDebugRegister(), "OpDebugRegister failed"); + Graph graph = ge_model_->GetGraph(); + compute_graph_ = GraphUtils::GetComputeGraph(graph); + GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, INTERNAL_ERROR, "Get compute graph is nullptr."); + + runtime_param_.graph_id = compute_graph_->GetGraphID(); GE_TIMESTAMP_START(TransAllVarData); - GE_CHK_STATUS_RET(TransAllVarData(compute_graph, runtime_param_.graph_id), "TransAllVarData failed."); + GE_CHK_STATUS_RET(TransAllVarData(compute_graph_, runtime_param_.graph_id), "TransAllVarData failed."); GE_TIMESTAMP_END(TransAllVarData, "GraphLoader::TransAllVarData"); - GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(compute_graph, session_id_, device_id_), "copy var data failed."); + GE_CHK_STATUS_RET(CopyVarData(compute_graph_), "copy var data failed."); GE_TIMESTAMP_START(InitModelMem); - GELOGI("Known node is %d", known_node_); + GELOGI("known_node is %d", known_node_); if (!known_node_) { GE_CHK_STATUS_RET_NOLOG(InitModelMem(dev_ptr, mem_size, weight_ptr, weight_size)); data_inputer_ = new (std::nothrow) DataInputer(); @@ -619,16 +520,14 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size } GE_TIMESTAMP_END(InitModelMem, "GraphLoader::InitModelMem"); - for (const ge::NodePtr &node : compute_graph->GetDirectNode()) { - auto op_desc = node->GetOpDesc(); - GE_IF_BOOL_EXEC(op_desc == nullptr, continue); - GetFixedAddrAttr(op_desc); - GE_IF_BOOL_EXEC(op_desc->GetType() != VARIABLE, continue); + for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { + GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); + GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); GE_IF_BOOL_EXEC(IsBroadCastOpData(node), - (void)ge::AttrUtils::SetStr(op_desc, VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore");); + (void)ge::AttrUtils::SetStr(node->GetOpDesc(), VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore");); } // for profiling - op_name_map_ = compute_graph->GetGraphOpName(); + op_name_map_ = compute_graph_->GetGraphOpName(); vector op_name; GE_IF_BOOL_EXEC(ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_TASK_INDEX_OP_NAME, op_name), @@ -637,14 +536,14 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size for (size_t idx = 0; idx < op_name.size(); idx++) { op_name_map_[idx] = op_name[idx]; } - GELOGI("Infer profiling: op_name_size(%zu)", op_name.size()); + GELOGI("infer profiling: op_name_size(%zu)", op_name.size()); } - if (InitNodes(compute_graph) != SUCCESS) { + if (InitNodes(compute_graph_) != SUCCESS) { return FAILED; } - SetDataDumperArgs(compute_graph); + SetDataDumperArgs(); GE_TIMESTAMP_START(DoTaskSink); auto ret = DoTaskSink(); GE_TIMESTAMP_END(DoTaskSink, "GraphLoader::DoTaskSink"); @@ -652,23 +551,22 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size /// In zero copy model, if a aicpu operator is connected to the first or last layer, before model execution, /// the aicpu opertor needs to destroy history record, and update operator memory address. /// The model with specified aicpu operators is only marked here, and destruction is in ModelManager::ExecuteModel(). - need_destroy_aicpu_kernel_ = IsAicpuKernelConnectSpecifiedLayer(); - (void)ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_OUT_NODES_NAME, out_node_name_); + if (MarkSpecifiedAicpuKernel() != SUCCESS) { + GELOGE(FAILED, "Mark model with specified aicpu operators failed."); + return FAILED; + } // collect profiling for ge if (ProfilingManager::Instance().ProfilingOn()) { std::vector compute_graph_desc_info; - Status ret1 = GetComputeGraphInfo(compute_graph, compute_graph_desc_info); + Status ret1 = GetComputeGraphInfo(compute_graph_desc_info); if (ret1 != SUCCESS) { GELOGE(ret1, "GetComputeGraphInfo failed."); return ret1; } ProfilingManager::Instance().ReportProfilingData(GetTaskDescInfo(), compute_graph_desc_info); - GE_CHK_STATUS(SinkModelProfile(), "Sink model profile failed."); } - - Shrink(); - GELOGI("Davinci model init success."); + GELOGI("davinci model init success."); return ret; } @@ -725,14 +623,26 @@ bool DavinciModel::IsAicpuKernelConnectSpecifiedLayer() { return false; } - -Status DavinciModel::UpdateSessionId(uint64_t session_id) { - GE_CHECK_NOTNULL(ge_model_); - if (!AttrUtils::SetInt(ge_model_, MODEL_ATTR_SESSION_ID, static_cast(session_id))) { - GELOGW("Set attr[%s] failed in updating session_id.", MODEL_ATTR_SESSION_ID.c_str()); +/// +/// @ingroup ge +/// @brief mark ge model with specified aicpu operators . +/// @return Status +/// +Status DavinciModel::MarkSpecifiedAicpuKernel() { + bool result = IsAicpuKernelConnectSpecifiedLayer(); + if (!result) { + // No aicpu operator needing destroy. + GELOGD("No specified aicpu operator that connects to data or netoutput."); + return SUCCESS; } - GELOGD("Update session id: %lu.", session_id); + bool ret = ge::AttrUtils::SetBool(ge_model_, kNeedDestroySpecifiedAicpuKernel, result); + if (!ret) { + GELOGW("Add attr[%s] in ge model failed, and may lead to specified aicpu operators destruction failure.", + kNeedDestroySpecifiedAicpuKernel); + } + GELOGI("Mark ge model success, the model has specified aicpu operators, ge model name: %s.", + ge_model_->GetName().c_str()); return SUCCESS; } @@ -779,6 +689,12 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { continue; } + if (IsCallDumpInputOp(op_desc)) { + GELOGI("node[%s] is no task op , call SaveDumpInput to save it's output node info", op_desc->GetName().c_str()); + data_dumper_.SaveDumpInput(node); + continue; + } + if (op_desc->GetType() == NETOUTPUT) { if (InitNetOutput(node) != SUCCESS) { GELOGE(PARAM_INVALID, "NetOutput init failed, Name: %s", op_desc->GetName().c_str()); @@ -796,29 +712,6 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { continue; } - if (IsNoTaskAndDumpNeeded(op_desc)) { - GELOGD("node[%s] without task, and save op_desc and addr for dump", op_desc->GetName().c_str()); - const RuntimeParam &rts_param = GetRuntimeParam(); - const vector input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); - const vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); - const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); - vector tensor_device_addrs; - tensor_device_addrs.insert(tensor_device_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); - tensor_device_addrs.insert(tensor_device_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); - tensor_device_addrs.insert(tensor_device_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); - void *addr = nullptr; - auto size = kAddrLen * tensor_device_addrs.size(); - GE_CHK_RT_RET(rtMalloc(&addr, size, RT_MEMORY_HBM)); - - rtError_t rt_ret = rtMemcpy(addr, size, tensor_device_addrs.data(), size, RT_MEMCPY_HOST_TO_DEVICE); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "rtMemcpy error"); - GE_CHK_RT(rtFree(addr)); - return FAILED; - } - saved_task_addrs_.emplace(op_desc, addr); - } - GE_TIMESTAMP_RESTART(InitTbeHandle); uint32_t run_mode = static_cast(domi::ImplyType::INVALID); if (AttrUtils::GetInt(op_desc, ATTR_NAME_IMPLY_TYPE, run_mode) && @@ -848,6 +741,7 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { /// @brief Data Op Initialize. /// @param [in] NodePtr: Data Op. /// @param [in/out] data_op_index: NetOutput addr size info. +/// @param [in/out] input_data_info: Data index and addr info {index, {size, addr}}. /// @return Status Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { // op_desc Checked by Init: Data, valid. @@ -875,7 +769,7 @@ Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { // Make information for copy input data. const vector output_size_list = ModelUtils::GetOutputSize(op_desc); - const vector virtual_addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc); + const vector virtual_addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc, false); if (output_size_list.empty() || virtual_addr_list.empty() || (output_size_list.size() != virtual_addr_list.size())) { GELOGE(PARAM_INVALID, "Data[%s] init failed: Output size is %zu, Output addr is %zu", op_desc->GetName().c_str(), output_size_list.size(), virtual_addr_list.size()); @@ -951,7 +845,7 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { output_op_list_.push_back(op_desc); // Make information for copy output data. const vector input_size_list = ModelUtils::GetInputSize(op_desc); - const vector virtual_addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, op_desc); + const vector virtual_addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, op_desc, false); if (input_size_list.empty() && virtual_addr_list.empty()) { GELOGI("NetOutput[%s] is empty.", op_desc->GetName().c_str()); return SUCCESS; @@ -964,15 +858,7 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { size_t num = output_data_info_.size(); for (size_t idx = 0; idx < input_size_list.size(); ++idx) { - int64_t size = input_size_list[idx]; - auto tensor_desc = op_desc->GetInputDescPtr(idx); - if ((tensor_desc == nullptr) || (TensorUtils::GetTensorSizeInBytes(*tensor_desc, size) != GRAPH_SUCCESS)) { - GELOGE(FAILED, "GetTensorSizeInBytes failed!"); - return FAILED; - } - - GELOGI("Tensor data size: GetSize=%ld, GetTensorSizeInBytes=%ld", input_size_list[idx], size); - output_data_info_[num + idx] = {size, virtual_addr_list[idx]}; + output_data_info_[num + idx] = {input_size_list[idx], virtual_addr_list[idx]}; } SetOutputOutsideAddr(virtual_addr_list); @@ -1082,7 +968,7 @@ Status DavinciModel::InitVariable(const OpDescPtr &op_desc) { Status DavinciModel::SetQueIds(const std::vector &input_queue_ids, const std::vector &output_queue_ids) { if (input_queue_ids.empty() && output_queue_ids.empty()) { - GELOGE(PARAM_INVALID, "Param is empty"); + GELOGE(PARAM_INVALID, "Para is empty"); return PARAM_INVALID; } @@ -1115,7 +1001,11 @@ Status DavinciModel::LoadWithQueue() { return PARAM_INVALID; } - GE_CHK_STATUS_RET(AddHeadStream(), "Add head stream failed."); + // create stream instance which rt_model_handel is running on, this is S0. + GE_CHK_RT_RET(rtStreamCreateWithFlags(&rt_model_stream_, priority_, RT_STREAM_AICPU)); + is_inner_model_stream_ = true; + GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, rt_model_stream_, RT_HEAD_STREAM)); + // Binding input_queue and Data Op. GE_CHK_STATUS_RET(BindInputQueue(), "Launch bind input queue failed."); GE_CHK_STATUS_RET(CpuTaskModelZeroCopy(input_mbuf_list_, input_outside_addrs_), "Launch zero copy failed."); @@ -1124,7 +1014,7 @@ Status DavinciModel::LoadWithQueue() { GE_CHK_STATUS_RET(BindOutputQueue(), "Launch bind output queue failed."); GE_CHK_STATUS_RET(CpuTaskModelZeroCopy(output_mbuf_list_, output_outside_addrs_), "Launch zero copy failed."); - GE_CHK_STATUS_RET(CpuActiveStream(), "Launch active entry stream failed."); + GE_CHK_STATUS_RET(CpuActiveStream(active_stream_list_), "Launch active entry stream failed."); GE_CHK_STATUS_RET(CpuWaitEndGraph(), "Launch wait end graph failed."); GE_CHK_STATUS_RET(BindEnqueue(), "Launch enqueue failed."); GE_CHK_STATUS_RET(CpuModelRepeat(), "Launch model repeat failed."); @@ -1168,7 +1058,7 @@ Status DavinciModel::BindInputQueue() { /// @return: 0 for success / others for failed Status DavinciModel::CpuModelDequeue(uint32_t queue_id) { GELOGI("Set CpuKernel model dequeue task enter."); - std::shared_ptr dequeue_task = MakeShared(rt_entry_stream_); + std::shared_ptr dequeue_task = MakeShared(rt_model_stream_); if (dequeue_task == nullptr) { GELOGE(FAILED, "Make CpuTaskModelDequeue task failed."); return FAILED; @@ -1189,7 +1079,7 @@ Status DavinciModel::CpuModelDequeue(uint32_t queue_id) { Status DavinciModel::CpuTaskModelZeroCopy(std::vector &mbuf_list, std::map> &outside_addrs) { GELOGI("Set CpuKernel model zero_copy task enter."); - std::shared_ptr zero_copy = MakeShared(rt_entry_stream_); + std::shared_ptr zero_copy = MakeShared(rt_model_stream_); if (zero_copy == nullptr) { GELOGE(FAILED, "Make CpuTaskZeroCopy task failed."); return FAILED; @@ -1234,6 +1124,7 @@ Status DavinciModel::BindOutputQueue() { /// @ingroup ge /// @brief definiteness queue schedule, bind output queue to task. +/// @param [in] queue_id: output queue id from user. /// @param [in] addr: NetOutput Op input tensor address. /// @param [in] size: NetOutput Op input tensor size. /// @return: 0 for success / others for failed @@ -1244,7 +1135,7 @@ Status DavinciModel::CpuModelPrepareOutput(uintptr_t addr, uint32_t size) { return FAILED; } - std::shared_ptr prepare_output = MakeShared(rt_entry_stream_); + std::shared_ptr prepare_output = MakeShared(rt_model_stream_); if (prepare_output == nullptr) { GELOGE(FAILED, "Make CpuTaskPrepareOutput task failed."); return FAILED; @@ -1264,21 +1155,25 @@ Status DavinciModel::CpuModelPrepareOutput(uintptr_t addr, uint32_t size) { /// /// @ingroup ge /// @brief definiteness queue schedule, active original model stream. +/// @param [in] streams: streams will active by S0. /// @return: 0 for success / others for failed /// -Status DavinciModel::CpuActiveStream() { - GELOGI("Set CpuKernel active stream task enter."); - std::shared_ptr active_entry = MakeShared(rt_entry_stream_); - if (active_entry == nullptr) { - GELOGE(FAILED, "Make CpuTaskActiveEntry task failed."); - return FAILED; - } +Status DavinciModel::CpuActiveStream(const std::vector &stream_list) { + GELOGI("Set CpuKernel active stream task:%zu enter.", stream_list.size()); + for (auto s : stream_list) { + std::shared_ptr active_entry = MakeShared(rt_model_stream_); + if (active_entry == nullptr) { + GELOGE(FAILED, "Make CpuTaskActiveEntry task failed."); + return FAILED; + } - if (active_entry->Init(rt_head_stream_) != SUCCESS) { - return FAILED; + if (active_entry->Init(s) != SUCCESS) { + return FAILED; + } + + cpu_task_list_.push_back(active_entry); } - cpu_task_list_.push_back(active_entry); GELOGI("Set CpuKernel active stream task success."); return SUCCESS; } @@ -1288,7 +1183,7 @@ Status DavinciModel::CpuActiveStream() { /// @return: 0 for success / others for failed Status DavinciModel::CpuWaitEndGraph() { GELOGI("Set CpuKernel wait end graph task enter."); - std::shared_ptr wait_endgraph = MakeShared(rt_entry_stream_); + std::shared_ptr wait_endgraph = MakeShared(rt_model_stream_); if (wait_endgraph == nullptr) { GELOGE(FAILED, "Make CpuTaskWaitEndGraph task failed."); return FAILED; @@ -1321,7 +1216,7 @@ Status DavinciModel::BindEnqueue() { Status DavinciModel::CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf) { GELOGI("Set CpuKernel model enqueue task enter."); - std::shared_ptr model_enqueue = MakeShared(rt_entry_stream_); + std::shared_ptr model_enqueue = MakeShared(rt_model_stream_); if (model_enqueue == nullptr) { GELOGE(FAILED, "Make CpuTaskModelEnqueue task failed."); return FAILED; @@ -1340,7 +1235,7 @@ Status DavinciModel::CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf) { /// @return: 0 for success / others for failed Status DavinciModel::CpuModelRepeat() { GELOGI("Set CpuKernel repeat task enter."); - std::shared_ptr model_repeat = MakeShared(rt_entry_stream_); + std::shared_ptr model_repeat = MakeShared(rt_model_stream_); if (model_repeat == nullptr) { GELOGE(FAILED, "Make CpuTaskModelRepeat task failed."); return FAILED; @@ -1392,8 +1287,36 @@ Status DavinciModel::GetInputOutputDescInfo(vector &input_d /// @param [out] batch_info /// @return execute result /// -Status DavinciModel::GetDynamicBatchInfo(std::vector> &batch_info) const { - batch_info = batch_info_; +Status DavinciModel::GetDynamicBatchInfo(std::vector> &batch_info) { + for (auto &iter : op_list_) { + OpDescPtr op_desc = iter.second; + if (op_desc == nullptr) { + GELOGE(FAILED, "op_desc is null, index=%u.", iter.first); + return FAILED; + } + + if (op_desc->GetType() != STREAMSWITCHN) { + continue; + } + + batch_info.clear(); + uint32_t batch_num = 0; + if (!AttrUtils::GetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) { + GELOGE(FAILED, "Failed to get attr ATTR_NAME_BATCH_NUM, StreamSwitchN: %s.", op_desc->GetName().c_str()); + return FAILED; + } + std::vector batch_shape; + for (uint32_t i = 0; i < batch_num; i++) { + batch_shape.clear(); + const std::string attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); + if (!AttrUtils::GetListInt(op_desc, attr_name, batch_shape)) { + GELOGE(FAILED, "Failed to get attr ATTR_NAME_PRED_VALUE, StreamSwitchN: %s.", op_desc->GetName().c_str()); + return FAILED; + } + batch_info.emplace_back(batch_shape); + } + break; + } return SUCCESS; } @@ -1606,7 +1529,7 @@ void DavinciModel::CreateOutput(uint32_t index, OpDescPtr &op_desc, InputOutputD int64_t tensor_size = 0; (void)TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size); // no need to check value - output.size = static_cast(tensor_size); + output.size = static_cast(tensor_size); output.data_type = op_desc->GetInputDescPtr(index)->GetDataType(); } @@ -1615,6 +1538,9 @@ Status DavinciModel::GetOutputDescInfo(vector &output_desc, for (size_t i = 0; i < output_op_list_.size(); i++) { auto &op_desc = output_op_list_[i]; uint32_t out_size = static_cast(op_desc->GetInputsSize()); + // get real out nodes from model + vector out_node_name; + (void)ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_OUT_NODES_NAME, out_node_name); for (uint32_t index = 0; index < out_size; index++) { string output_name; InputOutputDescInfo output; @@ -1626,11 +1552,11 @@ Status DavinciModel::GetOutputDescInfo(vector &output_desc, GE_CHK_BOOL_RET_STATUS(src_name.size() > index && src_index.size() > index, INTERNAL_ERROR, "construct output_name failed."); // forward compatbility, if old om has no out_node_name, need to return output follow origin way - if (out_size == out_node_name_.size()) { + if (out_size == out_node_name.size()) { // neweast plan, the index will add to name during generate model. - bool contains_colon = out_node_name_[index].find(":") != std::string::npos; + bool contains_colon = out_node_name[index].find(":") != std::string::npos; output_name = - contains_colon ? out_node_name_[index] : out_node_name_[index] + ":" + std::to_string(src_index[index]); + contains_colon ? out_node_name[index] : out_node_name[index] + ":" + std::to_string(src_index[index]); } else { output_name = std::string("output_") + std::to_string(index) + "_" + src_name[index] + "_" + std::to_string(src_index[index]); @@ -1664,12 +1590,12 @@ Status DavinciModel::CopyInputData(const InputData &input_data, bool device_data const DataBuffer &data_buf = blobs[data.first]; void *mem_addr = data.second.second; - uint64_t mem_size = static_cast(data.second.first); + uint32_t mem_size = static_cast(data.second.first); GE_CHK_BOOL_RET_STATUS(mem_size >= data_buf.length, PARAM_INVALID, - "input data size(%lu) does not match model required size(%lu), ret failed.", data_buf.length, + "input data size(%u) does not match model required size(%u), ret failed.", data_buf.length, mem_size); - GELOGI("[IMAS]CopyPlainData memcpy graph_%lu type[F] input[%lu] dst[%p] src[%p] mem_size[%lu] datasize[%lu]", + GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] input[%u] dst[%p] src[%p] mem_size[%u] datasize[%u]", runtime_param_.graph_id, data.first, mem_addr, data_buf.data, mem_size, data_buf.length); if (data_buf.length == 0) { GELOGW("No data need to memcpy!"); @@ -1717,9 +1643,15 @@ inline int64_t SumSize(const vector &size_list) { } Status DavinciModel::SinkModelProfile() { + // not support non-sink model + GE_CHK_BOOL_EXEC(this->model_task_def_ != nullptr, return SUCCESS); + // profiling plugin must be registered Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); - GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return SUCCESS); + if (reporter == nullptr) { + GELOGI("Profiling report is nullptr!"); + return SUCCESS; + } GELOGI("Start collect model load profiling data."); @@ -1731,19 +1663,15 @@ Status DavinciModel::SinkModelProfile() { return FAILED, "Sink model tag memcpy error."); // Model Header - string name; - if (!om_name_.empty()) { - name = om_name_; - } else { - name = name_; - } - size_t name_len = name.size(); + string name = this->Name(); + int32_t name_len = name.size(); // phy device id uint32_t phy_device_id = 0; rtError_t rt_ret = rtGetDevicePhyIdByIndex(device_id_, &phy_device_id); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, - GELOGE(rt_ret, "runtime get phy_device_id failed, current phy_device_id:%u", phy_device_id); - return FAILED); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(rt_ret, "runtime get phy_device_id failed, current phy_device_id:%d", phy_device_id); + return FAILED; + } reporter_data.deviceId = phy_device_id; reporter_data.data = (unsigned char *)&name_len; reporter_data.dataLen = sizeof(int32_t); @@ -1780,6 +1708,7 @@ Status DavinciModel::SinkModelProfile() { for (int32_t i = 0; i < task_num; i++) { auto task = task_list_[i]; auto fusion_op_info = task->GetFusionOpInfo(); + // when type is RT_MODEL_TASK_KERNEL, ctx is not null if (fusion_op_info != nullptr) { uint32_t op_num = fusion_op_info->original_op_names.size(); @@ -1898,9 +1827,15 @@ Status DavinciModel::SinkModelProfile() { } Status DavinciModel::SinkTimeProfile(const InputData ¤t_data) { + // not support non-sink model + GE_CHK_BOOL_EXEC(this->model_task_def_ != nullptr, return SUCCESS); + // profiling plugin must be registered Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); - GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return SUCCESS); + if (reporter == nullptr) { + GELOGI("Profiling report is nullptr!"); + return SUCCESS; + } Msprof::Engine::ReporterData reporter_data{}; // report model data tag name @@ -1915,19 +1850,15 @@ Status DavinciModel::SinkTimeProfile(const InputData ¤t_data) { // device id uint32_t phy_device_id = 0; rtError_t rt_ret = rtGetDevicePhyIdByIndex(device_id_, &phy_device_id); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, - GELOGE(rt_ret, "runtime get phy_device_id failed, current phy_device_id:%u", phy_device_id); - return FAILED); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(rt_ret, "runtime get phy_device_id failed, current phy_device_id:%d", phy_device_id); + return FAILED; + } reporter_data.deviceId = phy_device_id; // Model Header - string name; - if (!om_name_.empty()) { - name = om_name_; - } else { - name = name_; - } - size_t name_len = name.size(); + string name = this->Name(); + int32_t name_len = name.size(); reporter_data.data = (unsigned char *)&name_len; reporter_data.dataLen = sizeof(int32_t); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", @@ -2005,62 +1936,81 @@ void DavinciModel::SetProfileTime(ModelProcStage stage, int64_t endTime) { /// @ingroup ge /// @brief send Output Op result to upper layer /// @already malloced in ModelLoad, no need to malloc again -/// @param [in] data_id: the index of output_data -/// @param [in/out] output_data: real user output_data -/// @param [in] kind: the kind of rtMemcpy +/// @param [in] sink_op Sink Op /// @return Status result /// @author /// -Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data, rtMemcpyKind_t kind) { +Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data) { + Status ret = SUCCESS; if (output_op_list_.empty()) { - Status ret = SyncVarData(); - DumpOpInputOutput(); - return ret; - } + ret = SyncVarData(); + } else { + output_data.index = data_id; + output_data.model_id = model_id_; + GE_CHK_BOOL_RET_STATUS(output_data.blobs.size() == output_data_info_.size(), INTERNAL_ERROR, + "output buffer size[%zu] not equal output_size_list[%zu] size!", output_data.blobs.size(), + output_data_info_.size()); - output_data.index = data_id; - output_data.model_id = model_id_; - if (output_data.blobs.size() != output_data_info_.size()) { - GELOGE(FAILED, "Output data buffer num=%zu not equal model data num=%zu", output_data.blobs.size(), - output_data_info_.size()); - return FAILED; + // index of data in output_data + uint32_t output_data_index = 0; + for (auto &op_desc : output_op_list_) { + ret = CopyOutputDataToUser(op_desc, output_data.blobs, output_data_index); + GE_CHK_BOOL_EXEC(ret == SUCCESS, break, "Copy output data to model ret failed, index:%u, model id:%u", + output_data.index, output_data.model_id); + } } - std::vector &blobs = output_data.blobs; - for (const auto &output : output_data_info_) { - if (output.first >= blobs.size()) { - GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld", blobs.size(), - input_data_info_.size(), output.first, output.second.first); - return FAILED; - } + (void)DumpOpInputOutput(); // dump, not care result. + return ret; +} - if ((kind == RT_MEMCPY_DEVICE_TO_DEVICE) && (copy_only_addrs_.count(output.second.second) == 0)) { - continue; // Skip: Feed by zero copy. - } +Status DavinciModel::CopyOutputDataToUser(OpDescPtr &op_desc, std::vector &blobs, uint32_t &data_index) { + Output model_output(op_desc, this); - DataBuffer &buffer = blobs[output.first]; - uint64_t mem_size = static_cast(output.second.first); - if ((buffer.length == 0) || (mem_size == 0)) { - GELOGI("Length of data is zero, No need copy. output tensor index=%u", output.first); - continue; - } + GE_CHK_BOOL_RET_STATUS(model_output.Init() == SUCCESS, PARAM_INVALID, "make shared model_output failed"); - if (buffer.length < mem_size) { - GELOGE(FAILED, "Tensor data size=%lu, buffer size=%u", mem_size, buffer.length); - return FAILED; - } else if (buffer.length > mem_size) { - GELOGW("Tensor data size=%lu, buffer size=%u", mem_size, buffer.length); - } + vector v_output_size; + vector v_output_data_addr; + model_output.GetOutputData(v_output_data_addr, v_output_size); + + // for all output tensor, copy output data from op to designated position + for (size_t i = 0; i < v_output_size.size(); ++i) { + GE_CHK_BOOL_RET_STATUS(data_index < blobs.size(), PARAM_INVALID, + "The blobs size:%zu, data_op size:%zu, curr output size:%zu", blobs.size(), + data_op_list_.size(), v_output_size.size()); + + DataBuffer &data_buf = blobs[data_index]; + data_index++; - GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] output[%u] memaddr[%p] mem_size[%lu] datasize[%u]", - runtime_param_.graph_id, output.first, output.second.second, mem_size, buffer.length); - GE_CHK_RT_RET(rtMemcpy(buffer.data, buffer.length, output.second.second, mem_size, kind)); + uint32_t size = data_buf.length; + GE_CHK_BOOL_RET_STATUS(size <= v_output_size[i], PARAM_INVALID, + "Model output data size(%u) does not match required size(%u).", v_output_size[i], + data_buf.length); + + if (copy_only_addrs_.count(v_output_data_addr[i]) == 0) { + GELOGI("[ZCPY] This addr[%p] has already feed by zero copy.", v_output_data_addr[i]); + continue; // Skip: Feed by zero copy. + } + GELOGI( + "CopyOutputDataToUser memcpy graph_%u type[F] name[%s] output[%lu] dst[%p] src[%p] mem_size[%u] datasize[%u]", + runtime_param_.graph_id, op_desc->GetName().c_str(), i, data_buf.data, v_output_data_addr[i], data_buf.length, + v_output_size[i]); + GE_CHK_RT_RET(rtMemcpy(data_buf.data, size, v_output_data_addr[i], size, RT_MEMCPY_DEVICE_TO_DEVICE)); } - DumpOpInputOutput(); return SUCCESS; } +Status DavinciModel::SyncDataAndDump() { + Status ret = SUCCESS; + if (output_op_list_.empty()) { + ret = SyncVarData(); + } + + (void)DumpOpInputOutput(); // dump, not care result. + return ret; +} + Status DavinciModel::GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data_index, OutputData *output_data, std::vector &outputs) { GE_CHECK_NOTNULL(op_desc); @@ -2092,13 +2042,13 @@ Status DavinciModel::GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data GELOGE(GE_GRAPH_MALLOC_FAILED, "Malloc buffer failed."); return GE_GRAPH_MALLOC_FAILED; } - output_data->blobs.push_back({data_buf.get(), static_cast(out_buffer_size_vec[i]), false}); + output_data->blobs.push_back({data_buf.get(), static_cast(out_buffer_size_vec[i]), false}); ge::OutputTensorInfo output; output.dims = shape_info_vec[i]; output.data = std::move(data_buf); output.length = out_buffer_size_vec[i]; outputs.emplace_back(std::move(output)); - GELOGI("Output index:%zu, data_length:%lu.", i, output.length); + GELOGI("Output index:%zu, data_length:%u.", i, output.length); } return SUCCESS; } @@ -2107,10 +2057,7 @@ Status DavinciModel::GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data /// @ingroup ge /// @brief send Output Op result to upper layer /// @already malloced in ModelLoad, no need to malloc again -/// @param [in] data_id: the index of output_data -/// @param [in] rslt_flg: result flag -/// @param [in] seq_end_flag: sequence end flag -/// @param [out] output_data: real user output_data +/// @param [in] sink_op Sink Op /// @return Status result /// @author /// @@ -2141,17 +2088,20 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b // copy output data from op to designated position for (auto &op_desc : output_op_list_) { - if (GenOutputTensorInfo(op_desc, data_index, output_data, outputs) != SUCCESS) { + Output model_output(op_desc, this); + if (model_output.Init() != SUCCESS || GenOutputTensorInfo(op_desc, data_index, output_data, outputs) != SUCCESS) { return INTERNAL_ERROR; } - data_index += op_desc->GetInputsSize(); - } - if (CopyOutputData(data_id, *output_data, RT_MEMCPY_DEVICE_TO_HOST) != SUCCESS) { - GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "OnComputeDone failed"); - return INTERNAL_ERROR; + Status ret = model_output.CopyResult(*output_data, data_index, data_index, false); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "CopyResult failed, op name: %s", op_desc->GetName().c_str()); + GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "OnComputeDone failed"); + return INTERNAL_ERROR; + } } + GE_IF_BOOL_EXEC((DumpOpInputOutput() != SUCCESS), GELOGW("dump op failed, model_id: %u", model_id_);); if (seq_end_flag) { GELOGW("End of sequence, model id: %u", model_id_); GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, END_OF_SEQUENCE, outputs), "OnCompute Done failed."); @@ -2164,7 +2114,6 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b /// /// @ingroup ge /// @brief return not output to upper layer for cloud case -/// @param [in] data_id /// @return Status result /// Status DavinciModel::ReturnNoOutput(uint32_t data_id) { @@ -2176,7 +2125,7 @@ Status DavinciModel::ReturnNoOutput(uint32_t data_id) { op_desc->GetName().c_str()); } - DumpOpInputOutput(); + GE_IF_BOOL_EXEC((DumpOpInputOutput() != SUCCESS), GELOGW("dump op failed, model_id: %u", model_id_);); GE_CHK_BOOL_EXEC(listener_ != nullptr, return PARAM_INVALID, "listener_ is null!"); std::vector outputs; GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, SUCCESS, outputs), "OnComputeDone failed."); @@ -2186,40 +2135,41 @@ Status DavinciModel::ReturnNoOutput(uint32_t data_id) { /// /// @ingroup ge /// @brief dump all op input and output information -/// @return void +/// @param [in] op_list model_id +/// @return Status result /// -void DavinciModel::DumpOpInputOutput() { - char *ge_dump_env = std::getenv("DUMP_OP"); - int dump_op_switch = (ge_dump_env != nullptr) ? std::strtol(ge_dump_env, nullptr, kDecimal) : 0; - if (dump_op_switch == 0) { - GELOGI("need to set DUMP_OP for dump op input and output"); - return; - } - +Status DavinciModel::DumpOpInputOutput() { if (op_list_.empty()) { - GELOGW("op list is empty"); - return; + GELOGW("op_list is empty."); + return FAILED; } - - int64_t cnt = 1; - for (auto it : op_list_) { - if (maxDumpOpNum_ != 0 && cnt > maxDumpOpNum_) { - GELOGW("dump op cnt > maxDumpOpNum, maxDumpOpNum: %ld", maxDumpOpNum_); - return; - } - - cnt++; - if (DumpSingleOpInputOutput(it.second) != SUCCESS) { - GELOGW("dump single op failed, model_id: %u", model_id_); - return; + char *ge_dump_env = getenv("DUMP_OP"); + int dump_op_switch = + (ge_dump_env != nullptr) ? std::strtol(ge_dump_env, nullptr, kDecimal) : 0; // 10 for decimal number + if (dump_op_switch != 0) { + int64_t cnt = 1; + for (auto it : op_list_) { + if (maxDumpOpNum_ != 0 && cnt > maxDumpOpNum_) { + GELOGW("dump op cnt > maxDumpOpNum, maxDumpOpNum: %ld.", maxDumpOpNum_); + return SUCCESS; + } + Status ret = DumpSingleOpInputOutput(it.second); + cnt++; + if (ret != SUCCESS) { + GELOGE(FAILED, "dump single op failed, model_id: %u", model_id_); + return FAILED; + } } + } else { + GELOGW("need to set DUMP_OP for dump op input and output."); } + return SUCCESS; } /// /// @ingroup ge /// @brief dump single op input and output information -/// @param [in] op_def: the op_desc which will be dump +/// @param [in] dump_op model_id /// @return Status result /// Status DavinciModel::DumpSingleOpInputOutput(const OpDescPtr &op_def) { @@ -2235,7 +2185,7 @@ Status DavinciModel::DumpSingleOpInputOutput(const OpDescPtr &op_def) { } } const vector input_size_vec = ModelUtils::GetInputSize(op_def); - const vector input_addr_vec = ModelUtils::GetInputDataAddrs(runtime_param_, op_def); + const vector input_addr_vec = ModelUtils::GetInputDataAddrs(runtime_param_, op_def, false); vector v_memory_type; bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_def, ATTR_NAME_INPUT_MEM_TYPE_LIST, v_memory_type); GELOGD("DumpSingleOp[%s], input size[%zu], input memory type size[%zu]", op_def->GetName().c_str(), @@ -2258,7 +2208,7 @@ Status DavinciModel::DumpSingleOpInputOutput(const OpDescPtr &op_def) { } const vector output_size_vec = ModelUtils::GetOutputSize(op_def); - const vector output_addr_vec = ModelUtils::GetOutputDataAddrs(runtime_param_, op_def); + const vector output_addr_vec = ModelUtils::GetOutputDataAddrs(runtime_param_, op_def, false); v_memory_type.clear(); has_mem_type_attr = ge::AttrUtils::GetListInt(op_def, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, v_memory_type); GELOGD("DumpSingleOp[%s], output size[%zu], output memory type size[%zu]", op_def->GetName().c_str(), @@ -2328,7 +2278,7 @@ void *DavinciModel::Run(DavinciModel *model) { ret != SUCCESS, (void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput()); CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); continue, "Copy input data to model failed."); // [No need to check value] - GE_IF_BOOL_EXEC(model->is_first_execute_, GE_TIMESTAMP_EVENT_END(Model_SyncVarData, "Model Run SyncVarData")); + GE_TIMESTAMP_END(Model_SyncVarData, "Model Run SyncVarData"); GELOGI("Copy input data, model id:%u", model_id); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), model->SetProfileTime(MODEL_PRE_PROC_START)); @@ -2374,7 +2324,7 @@ void *DavinciModel::Run(DavinciModel *model) { CsaInteract::GetInstance().WriteErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); continue); GELOGI("rtModelExecute end"); - GE_IF_BOOL_EXEC(model->is_first_execute_, GE_TIMESTAMP_EVENT_END(rtModelExecute, "GraphExcute::rtModelExecute")); + GE_TIMESTAMP_END(rtModelExecute, "GraphExcute::rtModelExecute"); GE_TIMESTAMP_START(rtStreamSynchronize); GELOGI("rtStreamSynchronize start."); @@ -2389,8 +2339,7 @@ void *DavinciModel::Run(DavinciModel *model) { CsaInteract::GetInstance().StoreInternalErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); continue); GELOGI("rtStreamSynchronize end."); - GE_IF_BOOL_EXEC(model->is_first_execute_, - GE_TIMESTAMP_EVENT_END(rtStreamSynchronize, "GraphExcute::Wait for rtStreamSynchronize")); + GE_TIMESTAMP_END(rtStreamSynchronize, "GraphExcute::Wait for rtStreamSynchronize"); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), model->SetProfileTime(MODEL_INFER_END)); } @@ -2401,13 +2350,11 @@ void *DavinciModel::Run(DavinciModel *model) { (void)model->ReturnResult(current_data.index, rslt_flg, false, data_wrapper->GetOutput())) // copy output data from device to host for variable graph GE_IF_BOOL_EXEC(model->output_op_list_.empty(), (void)model->ReturnNoOutput(current_data.index)); - GE_IF_BOOL_EXEC(model->is_first_execute_, - GE_TIMESTAMP_EVENT_END(ReturnResult3, "GraphExcute::CopyDataFromDeviceToHost")); + GE_TIMESTAMP_END(ReturnResult3, "GraphExcute::CopyDataFromDeviceToHost"); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), model->SetProfileTime(MODEL_AFTER_PROC_END)); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), (void)model->SinkTimeProfile(current_data)); model->iterator_count_++; - model->is_first_execute_ = false; GELOGI("run iterator count is %lu", model->iterator_count_); } @@ -2460,7 +2407,7 @@ Status DavinciModel::ModelRunStart() { is_inner_model_stream_ = true; string opt = "0"; - (void)ge::GetContext().GetOption(OPTION_GE_MAX_DUMP_OP_NUM, opt); // option may not be set up, no need to check value + (void)ge::GetContext().GetOption("ge.maxDumpOpNum", opt); // option may not be set up, no need to check value int64_t maxDumpOpNum = std::strtol(opt.c_str(), nullptr, kDecimal); maxDumpOpNum_ = maxDumpOpNum; @@ -2503,18 +2450,7 @@ void DavinciModel::UnbindTaskSinkStream() { // destroy stream that is bound with rt_model GE_LOGW_IF(rtStreamDestroy(rt_model_stream_) != RT_ERROR_NONE, "Destroy stream for rt_model failed.") } - - if (is_pure_head_stream_ && rt_head_stream_ != nullptr) { - GE_LOGW_IF(rtModelUnbindStream(rt_model_handle_, rt_head_stream_) != RT_ERROR_NONE, "Unbind stream failed!"); - GE_LOGW_IF(rtStreamDestroy(rt_head_stream_) != RT_ERROR_NONE, "Destroy stream for rt_model failed."); - rt_head_stream_ = nullptr; - } - - if (rt_entry_stream_ != nullptr) { - GE_LOGW_IF(rtModelUnbindStream(rt_model_handle_, rt_entry_stream_) != RT_ERROR_NONE, "Unbind stream failed!"); - GE_LOGW_IF(rtStreamDestroy(rt_entry_stream_) != RT_ERROR_NONE, "Destroy stream for rt_model failed."); - rt_entry_stream_ = nullptr; - } + return; } Status DavinciModel::CreateKnownZeroCopyMap(const vector &inputs, const vector &outputs) { @@ -2523,9 +2459,6 @@ Status DavinciModel::CreateKnownZeroCopyMap(const vector &inputs, const GELOGE(FAILED, "input data addr %u is not equal to input op number %u.", inputs.size(), data_op_list_.size()); return FAILED; } - // remove zero copy addr in last iteration - knonw_input_data_info_.clear(); - knonw_output_data_info_.clear(); for (size_t i = 0; i < data_op_list_.size(); ++i) { const vector addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, data_op_list_[i]); knonw_input_data_info_[addr_list[kDataIndex]] = inputs[i]; @@ -2607,9 +2540,7 @@ Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { for (int i = 0; i < model_task_def.task_size(); ++i) { // dynamic shape will create task_list_ before const domi::TaskDef &task = model_task_def.task(i); - if (this->task_list_[i] == nullptr) { - task_list_[i] = TaskInfoFactory::Instance().Create(static_cast(task.type())); - } + task_list_[i] = TaskInfoFactory::Instance().Create(static_cast(task.type())); GE_CHECK_NOTNULL(task_list_[i]); Status ret = task_list_[i]->Init(task, this); if (ret != SUCCESS) { @@ -2623,14 +2554,13 @@ Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { Status DavinciModel::MallocKnownArgs() { GELOGI("DavinciModel::MallocKnownArgs in"); - const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); - if (model_task_def->task_size() == 0) { + if (model_task_def_->task_size() == 0) { GELOGW("DavinciModel::MallocKnownArgs davincimodel has no task info."); return SUCCESS; } - task_list_.resize(model_task_def->task_size()); - for (int32_t i = 0; i < model_task_def->task_size(); ++i) { - const domi::TaskDef &taskdef = model_task_def->task(i); + task_list_.resize(model_task_def_->task_size()); + for (int32_t i = 0; i < model_task_def_->task_size(); ++i) { + const domi::TaskDef &taskdef = model_task_def_->task(i); task_list_[i] = TaskInfoFactory::Instance().Create(static_cast(taskdef.type())); GE_CHECK_NOTNULL(task_list_[i]); Status ret = task_list_[i]->CalculateArgs(taskdef, this); @@ -2651,19 +2581,7 @@ Status DavinciModel::MallocKnownArgs() { GELOGE(RT_FAILED, "Call rtMallocHost failed, ret: 0x%X", rt_ret); return RT_FAILED; } - - // malloc fixed addr memory, eg: rts op - if (total_fixed_addr_size_ != 0) { - GELOGI("Begin to allocate fixed addr."); - rt_ret = rtMalloc(&fixed_addrs_, total_fixed_addr_size_, RT_MEMORY_HBM); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); - return RT_FAILED; - } - } - - GELOGI("DavinciModel::MallocKnownArgs success, total args size %u. total fixed addr size %ld", total_args_size_, - total_fixed_addr_size_); + GELOGI("DavinciModel::MallocKnownArgs success, total args size %u.", total_args_size_); return SUCCESS; } @@ -2679,28 +2597,26 @@ Status DavinciModel::DistributeTask() { task_desc_info_.clear(); bool flag = GetL1FusionEnableOption(); - char *skt_enable_env = std::getenv("SKT_ENABLE"); - int64_t env_flag = (skt_enable_env != nullptr) ? std::strtol(skt_enable_env, nullptr, kDecimal) : 0; + char *skt_enable_env = getenv("SKT_ENABLE"); + int64_t env_flag = (skt_enable_env != nullptr) ? strtol(skt_enable_env, nullptr, 10) : 0; if (env_flag != 0) { flag = true; } - const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); for (size_t task_index = 0; task_index < task_list_.size(); ++task_index) { auto &task = task_list_.at(task_index); GE_CHK_STATUS_RET(task->Distribute(), "Task[%zu] distribute fail", task_index); // for data dump if (reinterpret_cast(task->GetDumpArgs()) != nullptr) { - auto op_index = std::max(model_task_def->task(task_index).kernel().context().op_index(), - model_task_def->task(task_index).kernel_ex().op_index()); + auto op_index = std::max(model_task_def_->task(task_index).kernel().context().op_index(), + model_task_def_->task(task_index).kernel_ex().op_index()); OpDescPtr op = GetOpByIndex(op_index); if (op == nullptr) { GELOGE(PARAM_INVALID, "Op index %u is null, op list size %zu.", op_index, op_list_.size()); return PARAM_INVALID; } - bool call_dump = GetDumpProperties().IsLayerNeedDump(name_, om_name_, op->GetName()) && task->CallSaveDumpInfo(); - if (call_dump) { + if (PropertiesManager::Instance().IsLayerNeedDump(name_, om_name_, op->GetName())) { SaveDumpTask(task->GetTaskID(), task->GetStreamId(), op, task->GetDumpArgs()); } } @@ -2715,13 +2631,8 @@ Status DavinciModel::DistributeTask() { // else task index is found in op_name_map_ TaskDescInfo task_desc_info; string op_name = op_name_map_[task_index]; - if (!om_name_.empty()) { - task_desc_info.model_name = om_name_; - } else { - task_desc_info.model_name = name_; - } task_desc_info.op_name = op_name; - task_desc_info.block_dim = model_task_def->task(task_index).kernel().block_dim(); + task_desc_info.block_dim = model_task_def_->task(task_index).kernel().block_dim(); task_desc_info.task_id = task->GetTaskID(); task_desc_info.stream_id = task->GetStreamId(); task_desc_info_.emplace_back(task_desc_info); @@ -2742,7 +2653,7 @@ Status DavinciModel::DistributeTask() { } void DavinciModel::SetEndGraphId(uint32_t task_id, uint32_t stream_id) { - auto all_dump_model = GetDumpProperties().GetAllDumpModel(); + auto all_dump_model = PropertiesManager::Instance().GetAllDumpModel(); bool findByOmName = all_dump_model.find(om_name_) != all_dump_model.end(); bool findByModelName = all_dump_model.find(name_) != all_dump_model.end(); if (all_dump_model.find(ge::DUMP_ALL_MODEL) != all_dump_model.end() || findByOmName || findByModelName) { @@ -2779,27 +2690,12 @@ void DavinciModel::SetOutputOutsideAddr(const std::vector &outside_addrs if (output_outside_addrs_.find(addr) != output_outside_addrs_.end()) { continue; } - DisableZeroCopy(addr); // Data to NetOutput directly. - output_outside_addrs_.emplace(std::pair>(addr, {})); + (void)output_outside_addrs_.emplace(std::pair>(addr, {})); GELOGI("SetOutputOutsideAddr success."); } } -/// -/// @ingroup ge -/// @brief Set copy only for No task feed NetOutput address. -/// @return None. -/// -void DavinciModel::SetCopyOnlyOutput() { - for (const auto &addrs : output_outside_addrs_) { - const auto &used_list = addrs.second; - if (used_list.empty()) { // No task feed Output addr, Need copy directly. - copy_only_addrs_.insert(addrs.first); - } - } -} - /// /// @ingroup ge /// @brief Set disabled input zero copy addr. @@ -2807,8 +2703,8 @@ void DavinciModel::SetCopyOnlyOutput() { /// @return None. /// void DavinciModel::DisableZeroCopy(const void *addr) { - if ((input_outside_addrs_.find(addr) == input_outside_addrs_.end()) && - (output_outside_addrs_.find(addr) == output_outside_addrs_.end())) { + auto it = input_outside_addrs_.find(addr); + if (it == input_outside_addrs_.end()) { return; } @@ -2822,10 +2718,7 @@ void DavinciModel::DisableZeroCopy(const void *addr) { /// @brief Save outside address used info for ZeroCopy. /// @param [in] const OpDescPtr &op_desc: current op desc /// @param [in] const std::vector &outside_addrs: address of task -/// @param [in] const void *info: task args -/// @param [in] const char *args: task args -/// @param [in] size_t size: size of task args -/// @param [in] size_t offset: offset of task args +/// @param [in] const char *args_offset: arguments address save the address. /// @return None. /// void DavinciModel::SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector &outside_addrs, const void *info, @@ -2901,7 +2794,7 @@ bool DavinciModel::CheckInputAndModelSize(const int64_t &input_size, const int64 if (input_size > op_size) { GELOGW( - "Input size [%u] is bigger than om size need [%u], " + "Input size [%u] is bigger than om size need [%u]," "MAY cause inference result ERROR, please check model input", input_size, op_size); } @@ -2995,7 +2888,7 @@ Status DavinciModel::UpdateIoTaskArgs(const map> return FAILED; } - GELOGI("[ZCPY] Copy Blobs: %u, addr: %p, size: %ld, data: %p, length: %lu.", data.first, data.second.second, + GELOGI("[ZCPY] Copy Blobs: %u, addr: %p, size: %ld, data: %p, length: %u.", data.first, data.second.second, data.second.first, buffer.data, buffer.length); if (!CheckInputAndModelSize(buffer.length, size, is_dynamic)) { GELOGE(FAILED, "Check input size and model size failed"); @@ -3003,11 +2896,15 @@ Status DavinciModel::UpdateIoTaskArgs(const map> } // For input data, just copy for rts task. - if (is_input && copy_only_addrs_.count(addr) > 0) { - if (rtMemcpy(addr, size, buffer.data, buffer.length, RT_MEMCPY_DEVICE_TO_DEVICE) != RT_ERROR_NONE) { - GELOGE(FAILED, "Non-zero copy data node copy failed"); - return FAILED; + if (copy_only_addrs_.count(addr) > 0) { + if (is_input) { + GELOGI("[IMAS] Find addr %p need direct copy from user malloc input %p.", addr, buffer.data); + if (rtMemcpy(addr, size, buffer.data, buffer.length, RT_MEMCPY_DEVICE_TO_DEVICE) != RT_ERROR_NONE) { + GELOGE(FAILED, "Non-zero copy data node copy failed"); + return FAILED; + } } + GELOGI("No need to exeucte zero copy task because this addr %p need direct copy.", addr); continue; } @@ -3263,24 +3160,6 @@ Status DavinciModel::InitStreamSwitchN(const OpDescPtr &op_desc) { GELOGI("StreamSwitchNOp node:%s, active_stream_id=%u.", op_desc->GetName().c_str(), active_stream_list[j]); } - batch_info_.clear(); - uint32_t batch_num = 0; - if (!AttrUtils::GetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) { - GELOGE(FAILED, "Failed to get attr ATTR_NAME_BATCH_NUM, StreamSwitchN: %s.", op_desc->GetName().c_str()); - return FAILED; - } - - for (uint32_t i = 0; i < batch_num; i++) { - std::vector batch_shape; - const std::string attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); - if (!AttrUtils::GetListInt(op_desc, attr_name, batch_shape)) { - GELOGE(FAILED, "Failed to get attr ATTR_NAME_PRED_VALUE, StreamSwitchN: %s.", op_desc->GetName().c_str()); - batch_info_.clear(); - return FAILED; - } - batch_info_.emplace_back(batch_shape); - } - return SUCCESS; } @@ -3299,6 +3178,20 @@ bool DavinciModel::IsBroadCastOpData(const ge::NodePtr &var_node) { return false; } +void DavinciModel::InitZeroCopyUtil(bool is_dynamic_batch, bool &input_zero_copy, bool &output_zero_copy) { + if (!is_dynamic_batch) { + zero_copy_batch_label_addrs_.clear(); + } + + for (const auto &addrs : output_outside_addrs_) { + const auto &used_list = addrs.second; + if (used_list.empty()) { + output_zero_copy = false; + break; + } + } +} + /// /// @ingroup ge /// @brief Init model stream for NN model. @@ -3346,12 +3239,13 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa GELOGI("Model Run begin, model id:%u, data index:%u, flag:%d.", model_id_, input_data.index, is_async_mode_); GE_CHK_STATUS_RET(InitModelStream(stream), "Init model stream failed."); - if (!input_data.is_dynamic_batch) { - zero_copy_batch_label_addrs_.clear(); - } + bool input_use_zero_copy = true; + bool output_use_zero_copy = true; + bool is_dynamic_batch = input_data.is_dynamic_batch; + InitZeroCopyUtil(is_dynamic_batch, input_use_zero_copy, output_use_zero_copy); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_PRE_PROC_START)); - Status ret = CopyModelData(input_data, output_data, input_data.is_dynamic_batch); + Status ret = CopyModelData(input_data, output_data, is_dynamic_batch); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return INTERNAL_ERROR, "Copy input data to model failed."); GELOGI("current_data.index=%u", input_data.index); @@ -3368,7 +3262,7 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa if (!is_async_mode_) { GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_AFTER_PROC_START)); - ret = CopyOutputData(input_data.index, output_data, RT_MEMCPY_DEVICE_TO_DEVICE); + ret = CopyOutputData(input_data.index, output_data); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return INTERNAL_ERROR, "Copy Output data to user failed."); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_AFTER_PROC_END)); } @@ -3379,60 +3273,11 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa return SUCCESS; } -// Add active entry stream for special env. -Status DavinciModel::AddHeadStream() { - if (active_stream_list_.empty()) { - GELOGE(INTERNAL_ERROR, "Active stream is empty, stream list size: %zu, stream indication size: %zu.", - stream_list_.size(), active_stream_indication_.size()); - return INTERNAL_ERROR; - } - - if (active_stream_list_.size() == 1) { - GELOGI("Just one active stream, take as head stream."); - rt_head_stream_ = active_stream_list_[0]; - is_pure_head_stream_ = false; - } else { - // Create stream which rt_model_handel running on, this is S0, TS stream. - GELOGI("Multiple active stream: %zu, create head stream.", active_stream_list_.size()); - GE_CHK_RT_RET(rtStreamCreateWithFlags(&rt_head_stream_, priority_, RT_STREAM_PERSISTENT)); - GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, rt_head_stream_, RT_INVALID_FLAG)); // Not active. - is_pure_head_stream_ = true; - - for (auto s : active_stream_list_) { - std::shared_ptr active_entry = MakeShared(rt_head_stream_); - if (active_entry == nullptr) { - GELOGE(FAILED, "Make CpuTaskActiveEntry task failed."); - return FAILED; - } - - if (active_entry->Init(s) != SUCCESS) { - return FAILED; - } - - cpu_task_list_.emplace_back(active_entry); - } - } - - // Create entry stream active head stream. AICPU stream. - GE_CHK_RT_RET(rtStreamCreateWithFlags(&rt_entry_stream_, priority_, RT_STREAM_AICPU)); - GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, rt_entry_stream_, RT_HEAD_STREAM)); - return SUCCESS; -} - -Status DavinciModel::InitEntryTask() { - if (deploy_type_ == AICPU_DEPLOY_CROSS_THREAD) { - GE_CHK_STATUS_RET(AddHeadStream(), "Add head stream failed."); - return CpuActiveStream(); - } else { - return LoadWithQueue(); - } -} - uint8_t *DavinciModel::MallocFeatureMapMem(size_t data_size) { uint8_t *mem_base = nullptr; const string purpose("feature map,used for op input and output."); if (std::getenv(kEnvGeuseStaticMemory) != nullptr) { - data_size = static_cast(VarManager::Instance(session_id_)->GetGraphMemoryMaxSize()); + data_size = static_cast(VarManager::Instance(0)->GetGraphMemoryMaxSize()); string memory_key = std::to_string(0) + "_f"; mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, memory_key, data_size, GetDeviceId()); } else { @@ -3517,14 +3362,12 @@ Status DavinciModel::TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id) return SUCCESS; } -void DavinciModel::SetDataDumperArgs(const ComputeGraphPtr &compute_graph) { +void DavinciModel::SetDataDumperArgs() { GELOGI("set data dumper args, name: %s, id: %u.", name_.c_str(), model_id_); data_dumper_.SetModelName(name_); data_dumper_.SetModelId(model_id_); data_dumper_.SetMemory(runtime_param_); data_dumper_.SetOmName(om_name_); - data_dumper_.SetComputeGraph(compute_graph); - data_dumper_.SetRefInfo(saved_task_addrs_); int32_t device_id = 0; rtError_t rt_ret = rtGetDevice(&device_id); @@ -3580,9 +3423,18 @@ void DavinciModel::ReuseHcclFollowStream(int64_t remain_cap, int64_t &index) { } } -Status DavinciModel::GetComputeGraphInfo(const ComputeGraphPtr &graph, vector &graph_desc_info) { +Status DavinciModel::CopyVarData(ComputeGraphPtr &compute_graph) { + return TransVarDataUtils::CopyVarData(compute_graph, session_id_, device_id_); +} + +Status DavinciModel::GetComputeGraphInfo(std::vector &compute_graph_desc_info) { GELOGI("GetComputeGraphInfo start."); - for (auto &node : graph->GetAllNodes()) { + if (compute_graph_ == nullptr) { + GELOGE(FAILED, "compute_graph_ is nullptr"); + return FAILED; + } + + for (auto &node : compute_graph_->GetAllNodes()) { ComputeGraphDescInfo compute_graph_info; auto op_desc = node->GetOpDesc(); if (op_desc == nullptr) { @@ -3593,11 +3445,6 @@ Status DavinciModel::GetComputeGraphInfo(const ComputeGraphPtr &graph, vector(domi::ImplyType::INVALID); if (AttrUtils::GetInt(op_desc, ATTR_NAME_IMPLY_TYPE, op_mode) && op_mode == static_cast(domi::ImplyType::TVM)) { - if (!om_name_.empty()) { - compute_graph_info.model_name = om_name_; - } else { - compute_graph_info.model_name = name_; - } compute_graph_info.op_name = op_desc->GetName(); compute_graph_info.op_type = op_desc->GetType(); @@ -3615,18 +3462,12 @@ Status DavinciModel::GetComputeGraphInfo(const ComputeGraphPtr &graph, vectorHasAttr(ATTR_DYNAMIC_SHAPE_FIXED_ADDR) && op_desc->HasAttr(ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX)) { - string tensor_name; - (void)AttrUtils::GetStr(op_desc, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, tensor_name); - int64_t index = -1; - (void)AttrUtils::GetInt(op_desc, ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX, index); - if (index >= 0) { - tensor_name_to_peer_output_index_[tensor_name] = index; - } - } -} } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/davinci_model.h b/src/ge/graph/load/new_model_manager/davinci_model.h index 0f0b1e5c..8123b0b8 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model.h +++ b/src/ge/graph/load/new_model_manager/davinci_model.h @@ -29,7 +29,6 @@ #include "common/helper/om_file_helper.h" #include "common/opskernel/ge_task_info.h" #include "common/types.h" -#include "common/properties_manager.h" #include "framework/common/util.h" #include "graph/debug/ge_attr_define.h" #include "graph/load/new_model_manager/data_dumper.h" @@ -48,10 +47,6 @@ #include "task_info/task_info.h" namespace ge { -// op debug need 2048 bits buffer -const size_t kOpDebugMemorySize = 2048UL; -const size_t kDebugP2pSize = 8UL; - typedef enum tagModelProcStage { MODEL_LOAD_START = 1, MODEL_LOAD_END, @@ -176,6 +171,13 @@ class DavinciModel { // get session id uint64_t SessionId() const { return runtime_param_.session_id; } + vector GetOpDesc() { + vector opDescVector; + GE_IF_BOOL_EXEC(AttrUtils::GetListOpDesc(GetGeModel(), MODEL_ATTR_FUSION_MODEL_DEF, opDescVector), + GELOGI("get opDesc of opDescVector")); + return opDescVector; + } + // get model priority int32_t Priority() const { return priority_; } @@ -246,9 +248,15 @@ class DavinciModel { /// Format GetFormat(); - rtModel_t GetRtModelHandle() const { return rt_model_handle_; } + rtModel_t GetRtModelHandle() { + rtModel_t res = rt_model_handle_; + return res; + } - rtStream_t GetRtModelStream() const { return rt_model_stream_; } + rtStream_t GetRtModelStream() { + rtModel_t res = rt_model_stream_; + return res; + } uint64_t GetRtBaseAddr() const { return runtime_param_.logic_mem_base; } @@ -287,7 +295,7 @@ class DavinciModel { /// @param [out] batch_info /// @return execute result /// - Status GetDynamicBatchInfo(std::vector> &batch_info) const; + Status GetDynamicBatchInfo(std::vector> &batch_info); void GetCurShape(std::vector &batch_info); @@ -336,9 +344,10 @@ class DavinciModel { /// /// @ingroup ge /// @brief dump all op input and output information - /// @return void + /// @param [in] op_list model_id + /// @return Status /// - void DumpOpInputOutput(); + Status DumpOpInputOutput(); /// /// @ingroup ge @@ -394,9 +403,7 @@ class DavinciModel { /// uint32_t GetDeviceId() const { return device_id_; } - bool NeedDestroyAicpuKernel() const { return need_destroy_aicpu_kernel_; } - - Status UpdateSessionId(uint64_t session_id); + GeModelPtr GetGeModel() { return ge_model_; } const RuntimeParam &GetRuntimeParam() { return runtime_param_; } @@ -456,19 +463,6 @@ class DavinciModel { void *cur_args = static_cast(args_) + offset; return cur_args; } - void SetTotalFixedAddrsSize(string tensor_name, int64_t fix_addr_size); - int64_t GetFixedAddrsSize(string tensor_name); - void *GetCurrentFixedAddr(int64_t offset) const { - void *cur_addr = static_cast(fixed_addrs_) + offset; - return cur_addr; - } - - uint32_t GetFixedAddrOutputIndex(string tensor_name) { - if (tensor_name_to_peer_output_index_.find(tensor_name) != tensor_name_to_peer_output_index_.end()) { - return tensor_name_to_peer_output_index_[tensor_name]; - } - return UINT32_MAX; - } void SetKnownNode(bool known_node) { known_node_ = known_node; } bool IsKnownNode() { return known_node_; } Status MallocKnownArgs(); @@ -483,9 +477,6 @@ class DavinciModel { // om file name void SetOmName(string om_name) { om_name_ = om_name; } - void SetDumpProperties(const DumpProperties &dump_properties) { data_dumper_.SetDumpProperties(dump_properties); } - const DumpProperties &GetDumpProperties() const { return data_dumper_.GetDumpProperties(); } - private: // memory address of weights uint8_t *weights_mem_base_; @@ -502,6 +493,8 @@ class DavinciModel { struct timeInfo time_info_; int32_t dataInputTid; + void InitZeroCopyUtil(bool is_dynamic_batch, bool &input_zero_copy, bool &output_zero_copy); + /// /// @ingroup ge /// @brief Save Batch label Info. @@ -537,13 +530,6 @@ class DavinciModel { /// bool CheckInputAndModelSize(const int64_t &input_size, const int64_t &op_size, bool is_dynamic); - /// - /// @ingroup ge - /// @brief Set copy only for No task feed NetOutput address. - /// @return None. - /// - void SetCopyOnlyOutput(); - /// /// @ingroup ge /// @brief Copy Input/Output to model for direct use. @@ -569,10 +555,14 @@ class DavinciModel { Status CopyInputData(const InputData &input_data, bool device_data = false); - Status CopyOutputData(uint32_t data_id, OutputData &output_data, rtMemcpyKind_t kind); + Status CopyOutputData(uint32_t data_id, OutputData &output_data); + + Status CopyOutputDataToUser(OpDescPtr &op_desc, std::vector &blobs, uint32_t &data_index); Status SyncVarData(); + Status SyncDataAndDump(); + Status InitModelMem(void *dev_ptr, size_t memsize, void *weight_ptr, size_t weightsize); void CreateInputDimsInfo(const OpDescPtr &op_desc, Format format, InputOutputDescInfo &input); @@ -599,12 +589,7 @@ class DavinciModel { bool IsAicpuKernelConnectSpecifiedLayer(); - /// - /// @ingroup ge - /// @brief Reduce memory usage after task sink. - /// @return: void - /// - void Shrink(); + Status MarkSpecifiedAicpuKernel(); /// /// @ingroup ge @@ -740,9 +725,10 @@ class DavinciModel { /// /// @ingroup ge /// @brief definiteness queue schedule, active original model stream. + /// @param [in] streams: streams will active by S0. /// @return: 0 for success / others for fail /// - Status CpuActiveStream(); + Status CpuActiveStream(const std::vector &stream_list); /// /// @ingroup ge @@ -760,9 +746,6 @@ class DavinciModel { /// Status CpuModelRepeat(); - Status InitEntryTask(); - Status AddHeadStream(); - /// /// @ingroup ge /// @brief set ts device. @@ -770,10 +753,6 @@ class DavinciModel { /// Status SetTSDevice(); - Status OpDebugRegister(); - - void OpDebugUnRegister(); - void CheckHasHcomOp(); Status DoTaskSink(); @@ -781,17 +760,17 @@ class DavinciModel { void CreateOutput(uint32_t index, OpDescPtr &op_desc, InputOutputDescInfo &output, uint32_t &format_result); Status TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id); + Status CopyVarData(ComputeGraphPtr &graph); // get desc info of graph for profiling - Status GetComputeGraphInfo(const ComputeGraphPtr &graph, vector &graph_desc_info); + Status GetComputeGraphInfo(vector &compute_graph_desc_info); - void SetDataDumperArgs(const ComputeGraphPtr &compute_graph); + void SetDataDumperArgs(); Status GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data_index, OutputData *output_data, std::vector &outputs); void ParseAIPPInfo(std::string in_out_info, InputOutputDims &dims_info); - void GetFixedAddrAttr(const OpDescPtr &op_desc); bool is_model_has_inited_; uint32_t model_id_; @@ -804,9 +783,6 @@ class DavinciModel { uint32_t version_; GeModelPtr ge_model_; - bool need_destroy_aicpu_kernel_{false}; - vector out_node_name_; - map op_list_; // data op_desc @@ -867,11 +843,6 @@ class DavinciModel { bool is_async_mode_; // For NN execute, Async mode use rtMemcpyAsync on rt_model_stream_. - bool is_pure_head_stream_{false}; - rtStream_t rt_head_stream_{nullptr}; - rtStream_t rt_entry_stream_{nullptr}; - rtAicpuDeployType_t deploy_type_{AICPU_DEPLOY_RESERVED}; - // ACL queue schedule, save queue ids for Init. std::vector cpu_task_list_; std::vector input_queue_ids_; // input queue ids created by caller. @@ -893,6 +864,8 @@ class DavinciModel { std::vector active_stream_list_; std::set active_stream_indication_; + std::shared_ptr model_task_def_; + std::set aicpu_streams_; std::set hcom_streams_; RuntimeParam runtime_param_; @@ -904,39 +877,22 @@ class DavinciModel { // for profiling task and graph info std::map op_name_map_; std::vector task_desc_info_; + ComputeGraphPtr compute_graph_; int64_t maxDumpOpNum_; // for data dump DataDumper data_dumper_; uint64_t iterator_count_; bool is_l1_fusion_enable_; - std::map saved_task_addrs_; bool known_node_ = false; uint32_t total_args_size_ = 0; void *args_ = nullptr; void *args_host_ = nullptr; - void *fixed_addrs_ = nullptr; - int64_t total_fixed_addr_size_ = 0; std::map knonw_input_data_info_; std::map knonw_output_data_info_; - vector> batch_info_; - vector batch_size_; - // key: input tensor name, generally rts op; - // value: the fixed addr of input anchor, same as the peer output anchor addr of the peer op - std::map tensor_name_to_fixed_addr_size_; - - // key: input tensor name, generally rts op; value: the peer output anchor of the peer op - std::map tensor_name_to_peer_output_index_; - // if model is first execute - bool is_first_execute_; - // for op debug - std::mutex debug_reg_mutex_; - bool is_op_debug_reg_ = false; - void *op_debug_addr_ = nullptr; - void *p2p_debug_addr_ = nullptr; bool is_new_model_desc_{false}; }; } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_manager.cc b/src/ge/graph/load/new_model_manager/model_manager.cc index 04c836dd..d98ad8de 100644 --- a/src/ge/graph/load/new_model_manager/model_manager.cc +++ b/src/ge/graph/load/new_model_manager/model_manager.cc @@ -22,9 +22,8 @@ #include "common/profiling/profiling_manager.h" #include "common/properties_manager.h" #include "framework/common/debug/ge_log.h" -#include "framework/common/util.h" -#include "graph/common/ge_call_wrapper.h" #include "graph/debug/ge_attr_define.h" +#include "framework/common/util.h" #include "graph/load/new_model_manager/davinci_model.h" #include "graph/load/new_model_manager/davinci_model_parser.h" #include "model/ge_root_model.h" @@ -34,10 +33,9 @@ thread_local uint32_t device_count = 0; namespace { const int kCmdParSize = 2; const int kDumpCmdPairSize = 2; +const char *const kNeedDestroySpecifiedAicpuKernel = "need_destroy_specified_aicpu_kernel"; } // namespace -DumpProperties ModelManager::dump_properties_; - std::shared_ptr ModelManager::GetInstance() { static const std::shared_ptr instance_ptr = shared_ptr(new (std::nothrow) ModelManager(), ModelManager::FinalizeForPtr); @@ -274,10 +272,6 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrSetId(model_id); davinci_model->SetDeviceId(GetContext().DeviceId()); - const DumpProperties &dump_properties = PropertiesManager::Instance().GetDumpProperties(GetContext().SessionId()); - davinci_model->SetDumpProperties(dump_properties); - dump_properties_ = dump_properties; - auto root_graph = ge_root_model->GetRootGraph(); GE_CHECK_NOTNULL(root_graph); string root_model_name = root_graph->GetName(); @@ -302,6 +296,9 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrSetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond davinci_model->SetProfileTime(MODEL_LOAD_END); + if (davinci_model->SinkModelProfile() != SUCCESS) { + GELOGW("Sink model profile failed."); + } } } while (0); @@ -614,10 +611,10 @@ Status ModelManager::HandleDumpCommand(const Command &command) { GELOGE(PARAM_INVALID, "parser dump model failed"); return FAILED; } - GELOGI("dump model = %s.", dump_model.c_str()); + GELOGI("dump status = %s.", dump_model.c_str()); if (dump_status == "off" || dump_status == "OFF") { - dump_properties_.DeletePropertyValue(dump_model); + PropertiesManager::Instance().DeleteDumpPropertyValue(dump_model); return SUCCESS; } @@ -634,10 +631,9 @@ Status ModelManager::HandleDumpCommand(const Command &command) { return FAILED; } if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { - dump_path = dump_path + "/"; + dump_path = dump_path + "/" + CurrentTimeInStr() + "/"; } - dump_path = dump_path + CurrentTimeInStr() + "/"; - GELOGI("dump path = %s.", dump_path.c_str()); + GELOGI("dump status = %s.", dump_path.c_str()); ret = ParserPara(command, DUMP_MODE, dump_mode); if (ret != SUCCESS) { @@ -646,10 +642,20 @@ Status ModelManager::HandleDumpCommand(const Command &command) { } GELOGI("dump mode = %s", dump_mode.c_str()); - dump_properties_.AddPropertyValue(dump_model, dump_layers); - dump_properties_.SetDumpPath(dump_path); - dump_properties_.SetDumpMode(dump_mode); + auto iter_dump_mode = std::find(command.cmd_params.begin(), command.cmd_params.end(), DUMP_MODE); + if (iter_dump_mode != command.cmd_params.end()) { + ++iter_dump_mode; + if (iter_dump_mode == command.cmd_params.end()) { + GELOGE(PARAM_INVALID, "Invalid access."); + return PARAM_INVALID; + } + dump_mode = *iter_dump_mode; + GELOGI("dump mode = %s", dump_mode.c_str()); + } + PropertiesManager::Instance().AddDumpPropertyValue(dump_model, dump_layers); + PropertiesManager::Instance().SetDumpOutputPath(dump_path); + PropertiesManager::Instance().SetDumpMode(dump_mode); return SUCCESS; } @@ -765,6 +771,17 @@ Status ModelManager::GenSessionId(uint64_t &session_id) { return SUCCESS; } +Status ModelManager::UpdateSessionId(std::shared_ptr &davinci_model, uint64_t session_id) { + GeModelPtr ge_model_current = davinci_model->GetGeModel(); + GE_CHECK_NOTNULL(ge_model_current); + if (!ge::AttrUtils::SetInt(ge_model_current, ge::MODEL_ATTR_SESSION_ID, static_cast(session_id))) { + GELOGW("Set attr[%s] failed in updating session_id.", MODEL_ATTR_SESSION_ID.c_str()); + } + + GELOGD("Update session id: %lu.", session_id); + return SUCCESS; +} + Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model, shared_ptr listener, void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { GE_CHK_BOOL_RET_STATUS(model.key.empty() || access(model.key.c_str(), F_OK) == 0, PARAM_INVALID, @@ -807,7 +824,6 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model } davinci_model->SetDeviceId(device_id); davinci_model->SetOmName(model.om_name); - davinci_model->SetDumpProperties(dump_properties_); /// In multi-threaded inference, using the same session_id among multiple threads may cause some threads to fail. /// These session_ids come from the same model, so the values of session_id are the same. @@ -815,7 +831,7 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model uint64_t new_session_id; ret = GenSessionId(new_session_id); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, break, "Generate session_id for infer failed."); - ret = davinci_model->UpdateSessionId(new_session_id); + ret = UpdateSessionId(davinci_model, new_session_id); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, break, "Update session_id for infer failed."); ret = davinci_model->Init(dev_ptr, mem_size, weight_ptr, weight_size); @@ -830,6 +846,9 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond davinci_model->SetProfileTime(MODEL_LOAD_END); + if (davinci_model->SinkModelProfile() != SUCCESS) { + GELOGW("Sink model profile failed."); + } } GE_IF_BOOL_EXEC(ret == SUCCESS, device_count++); @@ -879,7 +898,7 @@ Status ModelManager::LoadModelWithQ(uint32_t &model_id, const ModelData &model_d uint64_t new_session_id; ret = GenSessionId(new_session_id); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Generate session_id for infer failed."); - ret = davinci_model->UpdateSessionId(new_session_id); + ret = UpdateSessionId(davinci_model, new_session_id); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Update session_id for infer failed."); GenModelId(&model_id); @@ -890,8 +909,6 @@ Status ModelManager::LoadModelWithQ(uint32_t &model_id, const ModelData &model_d return ret; } - davinci_model->SetDumpProperties(dump_properties_); - ret = davinci_model->Init(); if (ret != SUCCESS) { GELOGE(ret, "init model failed."); @@ -918,8 +935,12 @@ Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asy std::shared_ptr davinci_model = GetModel(model_id); GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid Model ID %u to start! ", model_id); - if (davinci_model->NeedDestroyAicpuKernel()) { - GELOGI("Start to destroy specified aicpu kernel."); + GeModelPtr ge_model_current = davinci_model->GetGeModel(); + bool need_destroy_aicpu_kernel = false; + bool result = ge::AttrUtils::GetBool(ge_model_current, kNeedDestroySpecifiedAicpuKernel, need_destroy_aicpu_kernel); + if (result && need_destroy_aicpu_kernel) { + GELOGI("Get attr %s successfully, start to destroy specified aicpu kernel.", kNeedDestroySpecifiedAicpuKernel); + // Zero copy is enabled by default, no need to judge. uint64_t session_id_davinci = davinci_model->GetSessionId(); uint32_t model_id_davinci = davinci_model->GetModelId(); @@ -1029,19 +1050,4 @@ Status ModelManager::GetAllAippInputOutputDims(uint32_t model_id, uint32_t index return davinci_model->GetAllAippInputOutputDims(index, input_dims, output_dims); } -bool ModelManager::IsDynamicShape(uint32_t model_id) { - auto model = GetHybridModel(model_id); - return model != nullptr; -} - -ge::Status ModelManager::SyncExecuteModel(uint32_t model_id, const vector &inputs, - vector &outputs) { - auto model = GetHybridModel(model_id); - if (model == nullptr) { - GELOGE(FAILED, "Hybrid model not found. model id = %u.", model_id); - return FAILED; - } - - return model->Execute(inputs, outputs); -} } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_manager.h b/src/ge/graph/load/new_model_manager/model_manager.h index 2ba23d7c..8e2424bf 100644 --- a/src/ge/graph/load/new_model_manager/model_manager.h +++ b/src/ge/graph/load/new_model_manager/model_manager.h @@ -31,7 +31,6 @@ #include "common/ge_types.h" #include "common/helper/model_helper.h" #include "common/helper/om_file_helper.h" -#include "common/properties_manager.h" #include "common/types.h" #include "ge/ge_api_types.h" #include "graph/ge_context.h" @@ -142,8 +141,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { ge::Status ExecuteModel(uint32_t model_id, rtStream_t stream, bool async_mode, const InputData &input_data, OutputData &output_data); - ge::Status SyncExecuteModel(uint32_t model_id, const std::vector &inputs, std::vector &outputs); - /// /// @ingroup domi_ome /// @brief model stop @@ -252,8 +249,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { ge::Status GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, std::vector &input_dims, std::vector &output_dims); - bool IsDynamicShape(uint32_t model_id); - private: /// /// @ingroup domi_ome @@ -281,6 +276,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { ge::Status DeleteModel(uint32_t id); void GenModelId(uint32_t *id); + ge::Status UpdateSessionId(std::shared_ptr &davinci_model, uint64_t session_id); std::map> model_map_; std::map> hybrid_model_map_; @@ -291,8 +287,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { std::mutex session_id_create_mutex_; uint64_t session_id_bias_; std::set sess_ids_; - - static DumpProperties dump_properties_; }; } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_utils.cc b/src/ge/graph/load/new_model_manager/model_utils.cc index bd684b9d..a807f2a3 100644 --- a/src/ge/graph/load/new_model_manager/model_utils.cc +++ b/src/ge/graph/load/new_model_manager/model_utils.cc @@ -31,7 +31,7 @@ namespace ge { /// -/// @ingroup ge +/// @ingroup domi_ome /// @brief Get input size. /// @return vector /// @@ -43,26 +43,22 @@ vector ModelUtils::GetInputSize(ConstOpDescPtr op_desc) { const vector v_is_input_const = op_desc->GetIsInputConst(); for (size_t i = 0; i < inputs_size; ++i) { - const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); - if (tensor_desc == nullptr) { - GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); - continue; - } - - int64_t tensor_size = 0; if ((i < v_is_input_const.size()) && v_is_input_const[i] && (op_type != NETOUTPUT)) { // TBE: add weights size to input - GE_CHK_STATUS(TensorUtils::GetSize(*tensor_desc, tensor_size)); + GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); + int64_t tensor_size = 0; + GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); if (tensor_size) { v_input_size.push_back(tensor_size); } continue; } + int64_t tensor_size = 0; GE_IF_BOOL_EXEC( - TensorUtils::GetSize(*tensor_desc, tensor_size) != GRAPH_SUCCESS, + TensorUtils::GetSize(op_desc->GetInputDesc(i), tensor_size) != GRAPH_SUCCESS, GELOGI("Get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); - continue); + continue;); v_input_size.push_back(tensor_size); } @@ -71,7 +67,7 @@ vector ModelUtils::GetInputSize(ConstOpDescPtr op_desc) { } /// -/// @ingroup ge +/// @ingroup domi_ome /// @brief Get output size. /// @return vector /// @@ -86,17 +82,11 @@ vector ModelUtils::GetOutputSize(ConstOpDescPtr op_desc) { return v_output_size;); for (size_t i = 0; i < outputs_size; ++i) { - const GeTensorDescPtr tensor_desc = op_desc->MutableOutputDesc(i); - if (tensor_desc == nullptr) { - GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); - continue; - } - int64_t tensor_size = 0; GE_IF_BOOL_EXEC( - TensorUtils::GetSize(*tensor_desc, tensor_size) != GRAPH_SUCCESS, + TensorUtils::GetSize(op_desc->GetOutputDesc(i), tensor_size) != GRAPH_SUCCESS, GELOGI("Get size from TensorDesc failed, op : %s, output index : %zu", op_desc->GetName().c_str(), i); - continue); + continue;); v_output_size.push_back(tensor_size); } @@ -105,7 +95,7 @@ vector ModelUtils::GetOutputSize(ConstOpDescPtr op_desc) { } /// -/// @ingroup ge +/// @ingroup domi_ome /// @brief Get workspace size. /// @return vector /// @@ -128,7 +118,7 @@ vector ModelUtils::GetWorkspaceSize(ConstOpDescPtr op_desc) { } /// -/// @ingroup ge +/// @ingroup domi_ome /// @brief Get weight size. /// @return vector /// @@ -152,14 +142,8 @@ vector ModelUtils::GetWeightSize(ConstOpDescPtr op_desc) { const vector v_is_input_const = op_desc->GetIsInputConst(); for (size_t i = 0; i < inputs_size; ++i) { if ((i < v_is_input_const.size()) && v_is_input_const[i]) { - const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); - if (tensor_desc == nullptr) { - GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); - continue; - } - int64_t tensor_size = 0; - (void)TensorUtils::GetSize(*tensor_desc, tensor_size); + (void)TensorUtils::GetSize(op_desc->GetInputDesc(i), tensor_size); v_weight_size.push_back(tensor_size); } } @@ -168,7 +152,7 @@ vector ModelUtils::GetWeightSize(ConstOpDescPtr op_desc) { } /// -/// @ingroup ge +/// @ingroup domi_ome /// @brief Get weights. /// @return vector /// @@ -192,14 +176,9 @@ vector ModelUtils::GetWeights(ConstOpDescPtr op_desc) { const vector v_is_input_const = op_desc->GetIsInputConst(); for (size_t i = 0; i < inputs_size; ++i) { if ((i < v_is_input_const.size()) && v_is_input_const[i]) { - const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); - if (tensor_desc == nullptr) { - GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); - continue; - } - ConstGeTensorPtr weight = nullptr; - if (AttrUtils::GetTensor(*tensor_desc, ATTR_NAME_WEIGHTS, weight)) { + GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); + if (AttrUtils::GetTensor(tensor_desc, ATTR_NAME_WEIGHTS, weight)) { v_weights.push_back(weight); } } @@ -209,7 +188,7 @@ vector ModelUtils::GetWeights(ConstOpDescPtr op_desc) { } /// -/// @ingroup ge +/// @ingroup domi_ome /// @brief Get AiCpuOp Input descriptor. /// @return vector<::tagCcAICPUTensor> /// @@ -226,25 +205,20 @@ vector<::tagCcAICPUTensor> ModelUtils::GetInputDescs(ConstOpDescPtr op_desc) { continue; } - const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); - if (tensor_desc == nullptr) { - GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); - continue; - } - uint32_t dim_cnt = 0; - GE_CHK_BOOL_EXEC_WARN(TensorUtils::GetRealDimCnt(*tensor_desc, dim_cnt) == GRAPH_SUCCESS, continue, + const auto &descriptor = op_desc->GetInputDesc(i); + GE_CHK_BOOL_EXEC_WARN(TensorUtils::GetRealDimCnt(descriptor, dim_cnt) == GRAPH_SUCCESS, continue, "Get dim_cnt failed"); opTensor_t tmp; - uint32_t tmp_fmt = tensor_desc->GetFormat(); + uint32_t tmp_fmt = descriptor.GetFormat(); tmp.format = tagOpTensorFormat(tmp_fmt); tmp.dim_cnt = static_cast(dim_cnt); - uint32_t tmp_type = tensor_desc->GetDataType(); + uint32_t tmp_type = descriptor.GetDataType(); tmp.data_type = tagOpDataType(tmp_type); for (int32_t j = 0; j < 4; j++) { // 4 dims - tmp.dim[j] = (j < tmp.dim_cnt ? tensor_desc->GetShape().GetDim(j) : 1); + tmp.dim[j] = (j < tmp.dim_cnt ? descriptor.GetShape().GetDim(j) : 1); } v_input_descs.push_back(tmp); @@ -254,7 +228,7 @@ vector<::tagCcAICPUTensor> ModelUtils::GetInputDescs(ConstOpDescPtr op_desc) { } /// -/// @ingroup ge +/// @ingroup domi_ome /// @brief Get AiCpuOp Output descriptor. /// @return vector<::tagCcAICPUTensor> /// @@ -266,25 +240,20 @@ vector<::tagCcAICPUTensor> ModelUtils::GetOutputDescs(ConstOpDescPtr op_desc) { // init op output opTensor_t struct const size_t output_num = op_desc->GetOutputsSize(); for (size_t i = 0; i < output_num; ++i) { - const GeTensorDescPtr tensor_desc = op_desc->MutableOutputDesc(i); - if (tensor_desc == nullptr) { - GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); - continue; - } - uint32_t dim_cnt = 0; - GE_CHK_BOOL_EXEC_WARN(TensorUtils::GetRealDimCnt(*tensor_desc, dim_cnt) == GRAPH_SUCCESS, continue, + const auto &descriptor = op_desc->GetOutputDesc(i); + GE_CHK_BOOL_EXEC_WARN(TensorUtils::GetRealDimCnt(descriptor, dim_cnt) == GRAPH_SUCCESS, continue, "Get dim_cnt failed"); opTensor_t tmp; - uint32_t tmp_fmt = tensor_desc->GetFormat(); + uint32_t tmp_fmt = descriptor.GetFormat(); tmp.format = tagOpTensorFormat(tmp_fmt); tmp.dim_cnt = static_cast(dim_cnt); - uint32_t tmp_type = tensor_desc->GetDataType(); + uint32_t tmp_type = descriptor.GetDataType(); tmp.data_type = tagOpDataType(tmp_type); for (int32_t j = 0; j < 4; j++) { // 4 dims - tmp.dim[j] = (j < tmp.dim_cnt ? tensor_desc->GetShape().GetDim(j) : 1); + tmp.dim[j] = (j < tmp.dim_cnt ? descriptor.GetShape().GetDim(j) : 1); } v_output_descs.push_back(tmp); @@ -294,14 +263,44 @@ vector<::tagCcAICPUTensor> ModelUtils::GetOutputDescs(ConstOpDescPtr op_desc) { } /// -/// @ingroup ge +/// @ingroup domi_ome /// @brief Get input data address. /// @return vector /// -vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc) { +vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, + bool need_convert) { vector v_input_data_addr; // init as:buf_base + op_def_->input(i)); GE_CHECK_NOTNULL_EXEC(op_desc, return v_input_data_addr); uint64_t session_id = model_param.session_id; + uint8_t *mem_base = model_param.mem_base; + uint8_t *var_base = model_param.var_base; + uint8_t *weight_base = model_param.weight_base; + const uint64_t logic_mem_base = 0; + uint64_t logic_weight_base = 0; + uint64_t logic_var_base = model_param.logic_var_base; + uint64_t mem_size = model_param.mem_size; + uint64_t weight_size = model_param.weight_size; + uint64_t var_size = model_param.var_size; + + if (need_convert) { + Status status = ConvertVirtualAddressToPhysical(mem_base, mem_size, mem_base); + if (status != SUCCESS) { + GELOGE(RT_FAILED, "Convert virtual address to physical for mem_base failed."); + return v_input_data_addr; + } + + status = ConvertVirtualAddressToPhysical(weight_base, weight_size, weight_base); + if (status != SUCCESS) { + GELOGE(RT_FAILED, "Convert virtual address to physical for weight_base failed."); + return v_input_data_addr; + } + + status = ConvertVirtualAddressToPhysical(var_base, var_size, var_base); + if (status != SUCCESS) { + GELOGE(RT_FAILED, "Convert virtual address to physical for var_base failed."); + return v_input_data_addr; + } + } const size_t inputs_size = op_desc->GetInputsSize(); const vector v_input_offset = op_desc->GetInputOffset(); @@ -320,18 +319,13 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co for (size_t i = 0; i < inputs_size; ++i) { if ((i < v_is_input_const.size()) && v_is_input_const[i] && (op_type != NETOUTPUT)) { // TBE: add weights address to input - const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); - if (tensor_desc == nullptr) { - GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); - continue; - } - + GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); int64_t tensor_size = 0; - GE_CHK_STATUS(TensorUtils::GetSize(*tensor_desc, tensor_size)); + GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); if (tensor_size) { int64_t data_offset = 0; - GE_CHK_STATUS(TensorUtils::GetDataOffset(*tensor_desc, data_offset)); - uint8_t *weight_addr = model_param.weight_base + data_offset; + GE_CHK_STATUS(TensorUtils::GetDataOffset(tensor_desc, data_offset)); + uint8_t *weight_addr = static_cast(weight_base + data_offset - logic_weight_base); v_input_data_addr.push_back(weight_addr); GELOGI("[IMAS]GetInputDataAddrs graph_%u type[C] name[%s] input[%zu] memaddr[%p]", model_param.graph_id, op_desc->GetName().c_str(), i, weight_addr); @@ -346,13 +340,17 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co int64_t input_offset = v_input_offset[non_const_index]; non_const_index++; - GE_IF_BOOL_EXEC(model_param.var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(input_offset), - uint8_t *variable_addr = model_param.var_base + input_offset - model_param.logic_var_base; + GE_IF_BOOL_EXEC(var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(input_offset), + uint8_t *variable_addr = var_base + input_offset - logic_var_base; v_input_data_addr.push_back(variable_addr); GELOGI("[IMAS]GetInputDataAddrs graph_%u type[V] name[%s] input[%lu] memaddr[%p]", model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr); - continue); + continue;); + bool input_tensor = false; + GE_IF_BOOL_EXEC(TensorUtils::GetInputTensor(op_desc->GetOutputDesc(i), input_tensor) != GRAPH_SUCCESS, + GELOGW("get size from TensorDesc failed, op: %s, input index: %zu", op_desc->GetName().c_str(), i); + continue;); // feature maps uint8_t *mem_addr = nullptr; // fusion @@ -360,7 +358,7 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co mem_addr = reinterpret_cast(reinterpret_cast(input_offset)); v_input_data_addr.push_back(mem_addr); } else { - mem_addr = model_param.mem_base + input_offset; + mem_addr = static_cast(mem_base + input_offset - logic_mem_base); v_input_data_addr.push_back(mem_addr); } GELOGI("[IMAS]GetInputDataAddrs graph_%u type[F] name[%s] input[%zu] memaddr[%p]", model_param.graph_id, @@ -371,20 +369,41 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co } /// -/// @ingroup ge +/// @ingroup domi_ome /// @brief Get output data address. /// @return vector /// -vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc) { +vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, + bool need_convert) { vector v_output_data_addr; // init as:buf_base + op_def_->output(i) GE_CHECK_NOTNULL_EXEC(op_desc, return v_output_data_addr); uint64_t session_id = model_param.session_id; + uint8_t *mem_base = model_param.mem_base; + uint8_t *var_base = model_param.var_base; + const uint64_t logic_mem_base = 0; + uint64_t logic_var_base = model_param.logic_var_base; + uint64_t mem_size = model_param.mem_size; + uint64_t var_size = model_param.var_size; + + if (need_convert) { + Status status = ConvertVirtualAddressToPhysical(mem_base, mem_size, mem_base); + if (status != SUCCESS) { + GELOGE(RT_FAILED, "Convert virtual address to physical for mem_base failed."); + return v_output_data_addr; + } + + status = ConvertVirtualAddressToPhysical(var_base, var_size, var_base); + if (status != SUCCESS) { + GELOGE(RT_FAILED, "Convert virtual address to physical for var_base failed."); + return v_output_data_addr; + } + } const size_t outputs_size = op_desc->GetOutputsSize(); const vector v_output_offset = op_desc->GetOutputOffset(); GE_IF_BOOL_EXEC(v_output_offset.size() != outputs_size, GELOGW("Output param invalid: output_offset=%zu, outputs=%zu.", v_output_offset.size(), outputs_size); - return v_output_data_addr); + return v_output_data_addr;); vector v_memory_type; bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, v_memory_type); if (has_mem_type_attr && (v_memory_type.size() != outputs_size)) { @@ -394,12 +413,12 @@ vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C return v_output_data_addr; } for (size_t i = 0; i < outputs_size; ++i) { - GE_IF_BOOL_EXEC(model_param.var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(v_output_offset[i]), - uint8_t *variable_addr = model_param.var_base + v_output_offset[i] - model_param.logic_var_base; + GE_IF_BOOL_EXEC(var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(v_output_offset[i]), + uint8_t *variable_addr = static_cast(var_base + v_output_offset[i] - logic_var_base); v_output_data_addr.push_back(variable_addr); GELOGI("[IMAS]GetOutputDataAddrs graph_%u type[V] name[%s] output[%zu] memaddr[%p]", model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr); - continue); + continue;); // feature maps uint8_t *mem_addr = nullptr; // fusion @@ -407,7 +426,7 @@ vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C mem_addr = reinterpret_cast(reinterpret_cast(v_output_offset[i])); v_output_data_addr.push_back(mem_addr); } else { - mem_addr = static_cast(model_param.mem_base + v_output_offset[i]); + mem_addr = static_cast(mem_base + v_output_offset[i] - logic_mem_base); v_output_data_addr.push_back(mem_addr); } GELOGI("[IMAS]GetOutputDataAddrs graph_%u type[F] name[%s] output[%zu] memaddr[%p]", model_param.graph_id, @@ -417,13 +436,24 @@ vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C } /// -/// @ingroup ge +/// @ingroup domi_ome /// @brief Get workspace data address. /// @return vector /// -vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc) { +vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, + bool need_convert) { vector v_workspace_data_addr; GE_CHECK_NOTNULL_EXEC(op_desc, return v_workspace_data_addr); + uint8_t *mem_base = model_param.mem_base; + uint64_t mem_size = model_param.mem_size; + + if (need_convert) { + Status status = ConvertVirtualAddressToPhysical(mem_base, mem_size, mem_base); + if (status != SUCCESS) { + GELOGE(RT_FAILED, "Convert virtual address to physical for mem_base failed."); + return v_workspace_data_addr; + } + } const vector v_workspace_offset = op_desc->GetWorkspace(); const vector v_workspace_bytes = op_desc->GetWorkspaceBytes(); @@ -436,13 +466,13 @@ vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, TVM_ATTR_NAME_WORKSPACE_TYPE, v_memory_type); for (size_t i = 0; i < v_workspace_bytes.size(); ++i) { if (has_mem_type_attr && v_memory_type[i] == RT_MEMORY_L1) { - v_workspace_data_addr.push_back(reinterpret_cast(reinterpret_cast(v_workspace_offset[i]))); + v_workspace_data_addr.push_back(reinterpret_cast(v_workspace_offset[i])); GELOGI("Fusion: op: %s, GetWorkspaceDataAddrs mem_addr[workspace index %zu]:%p", op_desc->GetName().c_str(), i, reinterpret_cast(reinterpret_cast(v_workspace_offset[i]))); } else { int64_t workspace_offset = v_workspace_offset[i]; int64_t workspace_bytes = v_workspace_bytes[i]; - uint8_t *mem_addr = workspace_bytes == 0 ? nullptr : model_param.mem_base + workspace_offset; + uint8_t *mem_addr = workspace_bytes == 0 ? nullptr : mem_base + workspace_offset; v_workspace_data_addr.push_back(mem_addr); GELOGI("[IMAS]GetWorkspaceDataAddrs graph_%u type[F] name[%s] workspace[%zu] offset[%ld] bytes[%ld] memaddr[%p]", model_param.graph_id, op_desc->GetName().c_str(), i, workspace_offset, workspace_bytes, mem_addr); @@ -452,32 +482,21 @@ vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param return v_workspace_data_addr; } -/// -/// @ingroup ge -/// @brief Get runtime memory address. -/// @return Status -/// -Status ModelUtils::GetRtAddress(const RuntimeParam ¶m, uintptr_t logic_addr, uint8_t *&mem_addr) { - uint8_t *runtime_base_addr = nullptr; - if ((param.logic_mem_base <= logic_addr) && (logic_addr < param.logic_mem_base + param.mem_size)) { - runtime_base_addr = param.mem_base - param.logic_mem_base; - GELOGI("The logic addr:0x%lx is data address, base:0x%lx, size:%lu", logic_addr, param.logic_mem_base, - param.mem_size); - } else if ((param.logic_weight_base <= logic_addr) && (logic_addr < param.logic_weight_base + param.weight_size)) { - runtime_base_addr = param.weight_base - param.logic_weight_base; - GELOGI("The logic addr:0x%lx is weight address, base:0x%lx, size:%lu", logic_addr, param.logic_weight_base, - param.weight_size); - } else if ((param.logic_var_base <= logic_addr) && (logic_addr < param.logic_var_base + param.var_size)) { - runtime_base_addr = param.var_base - param.logic_var_base; - GELOGI("The logic addr:0x%lx is variable address, base:0x%lx, size:%lu", logic_addr, param.logic_var_base, - param.var_size); - } else if (logic_addr != 0) { - mem_addr = nullptr; - GELOGE(PARAM_INVALID, "The logic addr:0x%lx is abnormal", logic_addr); - return PARAM_INVALID; +Status ModelUtils::ConvertVirtualAddressToPhysical(uint8_t *virtual_address, uint64_t size, + uint8_t *&physical_address) { + // Indicates whether use physical address. + const char *use_physical_address = std::getenv("GE_USE_PHYSICAL_ADDRESS"); + if (use_physical_address == nullptr || virtual_address == 0 || size == 0) { + return SUCCESS; + } + + rtError_t ret = rtKernelConfigTransArg(virtual_address, size, 0, reinterpret_cast(&physical_address)); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtKernelConfigTransArg failed, ret: 0x%X", ret); + return RT_FAILED; } - mem_addr = runtime_base_addr + logic_addr; + GELOGD("virtual_address=%p, physical_address=%p", virtual_address, physical_address); return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_utils.h b/src/ge/graph/load/new_model_manager/model_utils.h index 8474a987..d6afd5c8 100644 --- a/src/ge/graph/load/new_model_manager/model_utils.h +++ b/src/ge/graph/load/new_model_manager/model_utils.h @@ -34,79 +34,78 @@ class ModelUtils { ~ModelUtils() = default; /// - /// @ingroup ge + /// @ingroup domi_ome /// @brief Get input size. /// @return vector /// static vector GetInputSize(ConstOpDescPtr op_desc); /// - /// @ingroup ge + /// @ingroup domi_ome /// @brief Get output size. /// @return vector /// static vector GetOutputSize(ConstOpDescPtr op_desc); /// - /// @ingroup ge + /// @ingroup domi_ome /// @brief Get workspace size. /// @return vector /// static vector GetWorkspaceSize(ConstOpDescPtr op_desc); /// - /// @ingroup ge + /// @ingroup domi_ome /// @brief Get weight size. /// @return vector /// static vector GetWeightSize(ConstOpDescPtr op_desc); /// - /// @ingroup ge + /// @ingroup domi_ome /// @brief Get weights. /// @return vector /// static vector GetWeights(ConstOpDescPtr op_desc); /// - /// @ingroup ge + /// @ingroup domi_ome /// @brief Get AiCpuOp Input descriptor. /// @return vector<::tagCcAICPUTensor> /// static vector<::tagCcAICPUTensor> GetInputDescs(ConstOpDescPtr op_desc); /// - /// @ingroup ge + /// @ingroup domi_ome /// @brief Get AiCpuOp Output descriptor. /// @return vector<::tagCcAICPUTensor> /// static vector<::tagCcAICPUTensor> GetOutputDescs(ConstOpDescPtr op_desc); /// - /// @ingroup ge + /// @ingroup domi_ome /// @brief Get input data address. /// @return vector /// - static vector GetInputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc); + static vector GetInputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, + bool need_convert = true); /// - /// @ingroup ge + /// @ingroup domi_ome /// @brief Get output data address. /// @return vector /// - static vector GetOutputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc); + static vector GetOutputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, + bool need_convert = true); /// - /// @ingroup ge + /// @ingroup domi_ome /// @brief Get workspace data address. /// @return vector /// - static vector GetWorkspaceDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc); + static vector GetWorkspaceDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, + bool need_convert = true); - /// - /// @ingroup ge - /// @brief Get memory runtime base. - /// @return Status - /// - static Status GetRtAddress(const RuntimeParam &model_param, uintptr_t logic_addr, uint8_t *&mem_addr); + static ge::Status ConvertVirtualAddressToPhysical(uint8_t *virtual_address, uint64_t size, + uint8_t *&physical_address); }; } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc index 920b52e6..077ae827 100644 --- a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc @@ -45,7 +45,7 @@ Status EndGraphTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin Status EndGraphTaskInfo::Distribute() { GELOGI("EndGraphTaskInfo Distribute Start."); GE_CHECK_NOTNULL(davinci_model_); - auto all_dump_model = davinci_model_->GetDumpProperties().GetAllDumpModel(); + auto all_dump_model = PropertiesManager::Instance().GetAllDumpModel(); if (all_dump_model.find(ge::DUMP_ALL_MODEL) != all_dump_model.end() || all_dump_model.find(davinci_model_->Name()) != all_dump_model.end() || all_dump_model.find(davinci_model_->OmName()) != all_dump_model.end()) { @@ -80,4 +80,5 @@ Status EndGraphTaskInfo::Distribute() { } REGISTER_TASK_INFO(RT_MODEL_TASK_MODEL_END_GRAPH, EndGraphTaskInfo); + } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h index 82e228e6..49bef082 100644 --- a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h @@ -22,7 +22,7 @@ namespace ge { class EndGraphTaskInfo : public TaskInfo { public: - EndGraphTaskInfo() {} + EndGraphTaskInfo() : model_(0) {} ~EndGraphTaskInfo() override { model_ = nullptr; } @@ -35,10 +35,10 @@ class EndGraphTaskInfo : public TaskInfo { uint32_t GetStreamId() override { return stream_id_; } private: - rtModel_t model_{nullptr}; - DavinciModel *davinci_model_{nullptr}; - uint32_t task_id_{0}; - uint32_t stream_id_{0}; + rtModel_t model_; + DavinciModel *davinci_model_; + uint32_t task_id_; + uint32_t stream_id_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_END_GRAPH_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc index 2a79997f..0ee9727a 100644 --- a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc @@ -42,7 +42,6 @@ HcclTaskInfo::~HcclTaskInfo() { davinci_model_ = nullptr; ops_kernel_store_ = nullptr; max_node_of_hccl_stream_ = 0; - args_ = nullptr; } Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { GELOGI("HcclTaskInfo Init Start."); @@ -61,61 +60,54 @@ Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_m GELOGI("HcclTaskInfo Init, op_index is: %u", op_index); // Get HCCL op - op_desc_ = davinci_model->GetOpByIndex(op_index); - GE_CHECK_NOTNULL(op_desc_); + OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); + GE_CHECK_NOTNULL(op_desc); // Create the kernel hccl infos - CreateKernelHcclInfo(op_desc_); + CreateKernelHcclInfo(op_desc); // Initialize the hccl_type of all kernel hccl info HcomOmeUtil::GetHcclType(task_def, kernel_hccl_infos_); // Only in Horovod scenario should get the inputName and GeShape - ret = HcomOmeUtil::GetHorovodInputs(op_desc_, kernel_hccl_infos_); + ret = HcomOmeUtil::GetHorovodInputs(op_desc, kernel_hccl_infos_); if (ret != SUCCESS) { GELOGE(FAILED, "davinci_model: GetHorovodInputs fail! domi error: %u", ret); return FAILED; } - Status dmrt = HcomOmeUtil::GetHcclDataType(op_desc_, kernel_hccl_infos_); + Status dmrt = HcomOmeUtil::GetHcclDataType(op_desc, kernel_hccl_infos_); if (dmrt != SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomDataType fail! domi error: %u", dmrt); return FAILED; } - dmrt = HcomOmeUtil::GetHcclCount(op_desc_, kernel_hccl_infos_); + dmrt = HcomOmeUtil::GetHcclCount(op_desc, kernel_hccl_infos_); if (dmrt != SUCCESS) { GELOGE(FAILED, "davinci_model: GetHcomCount fail! domi error: %u", dmrt); return FAILED; } // Only HCOMBROADCAST and HVDCALLBACKBROADCAST need to get the rootId - dmrt = HcomOmeUtil::GetAllRootId(op_desc_, kernel_hccl_infos_); + dmrt = HcomOmeUtil::GetAllRootId(op_desc, kernel_hccl_infos_); if (dmrt != SUCCESS) { GELOGE(FAILED, "davinci_model: Get rootId fail! domi error: %u", dmrt); return FAILED; } - - // GE's new process: hccl declares the number of streams required, creates a stream by GE, and sends it to hccl - ret = SetFollowStream(op_desc_, davinci_model); - if (ret != SUCCESS) { - GELOGE(ret, "SetStream Fail."); - return ret; - } - - if (davinci_model_->IsKnownNode()) { - args_ = davinci_model_->GetCurrentArgsAddr(args_offset_); - GELOGI("Known node %s args addr %p, offset %u.", op_desc_->GetName().c_str(), args_, args_offset_); - } - - ret = SetAddrs(op_desc_, kernel_hccl_infos_); + ret = SetAddrs(op_desc, kernel_hccl_infos_); if (ret != SUCCESS) { GELOGE(ret, "Setaddrs Fail."); return ret; } // GE's new process: hccl declares the need for Workspace size, and GE allocates Workspace - ret = SetWorkspace(op_desc_, kernel_hccl_infos_); + ret = SetWorkspace(op_desc, kernel_hccl_infos_); if (ret != SUCCESS) { GELOGE(ret, "SetWorkspace Fail."); return ret; } + // GE's new process: hccl declares the number of streams required, creates a stream by GE, and sends it to hccl + ret = SetFollowStream(op_desc, davinci_model); + if (ret != SUCCESS) { + GELOGE(ret, "SetStream Fail."); + return ret; + } GELOGI("HcclTaskInfo Init Success"); return SUCCESS; @@ -217,83 +209,40 @@ Status HcclTaskInfo::Distribute() { GELOGI("HcclTaskInfo Distribute Success."); return SUCCESS; } - -Status HcclTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { - GE_CHECK_NOTNULL(davinci_model); - auto hccl_def = task_def.kernel_hccl(); - uint32_t op_index = hccl_def.op_index(); - GELOGI("HcclTaskInfo Init, op_index is: %u", op_index); - // Get HCCL op - auto op_desc = davinci_model->GetOpByIndex(op_index); - GE_CHECK_NOTNULL(op_desc); - GELOGI("Calc opType[%s] args size. Node name is [%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - // Only need the number of addr to allocate args memory - auto input_size = op_desc->GetInputsSize(); - auto output_size = op_desc->GetOutputsSize(); - auto workspace_size = op_desc->GetWorkspaceBytes().size(); - uint32_t args_size = sizeof(void *) * (input_size + output_size + workspace_size); - args_offset_ = davinci_model->GetTotalArgsSize(); - davinci_model->SetTotalArgsSize(args_size); - GELOGI("Calculate hccl task args , args_size %u, args_offset %u", args_size, args_offset_); - return SUCCESS; -} - -Status HcclTaskInfo::UpdateArgs() { - GELOGI("HcclTaskInfo::UpdateArgs in."); - const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); - input_data_addrs_ = ModelUtils::GetInputDataAddrs(rts_param, op_desc_); - output_data_addrs_ = ModelUtils::GetOutputDataAddrs(rts_param, op_desc_); - workspace_data_addrs_ = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc_); - - vector io_addrs; - io_addrs.insert(io_addrs.end(), input_data_addrs_.begin(), input_data_addrs_.end()); - io_addrs.insert(io_addrs.end(), output_data_addrs_.begin(), output_data_addrs_.end()); - io_addrs.insert(io_addrs.end(), workspace_data_addrs_.begin(), workspace_data_addrs_.end()); - - GE_CHK_STATUS_RET(davinci_model_->UpdateKnownZeroCopyAddr(io_addrs, args_offset_), - "update known node %s zero copy addr failed.", op_desc_->GetName().c_str()); - - GELOGI("HcclTaskInfo::UpdateArgs success."); - return SUCCESS; -} - Status HcclTaskInfo::SetAddrs(const std::shared_ptr &op_desc, std::vector &kernel_hccl_infos) { GE_CHECK_NOTNULL(op_desc); - GE_CHK_STATUS_RET(HcomOmeUtil::CheckKernelHcclInfo(op_desc, kernel_hccl_infos), - "HcomOmeUtil:: the number of GETaskKernelHcclInfo is invalid."); + if (HcomOmeUtil::CheckKernelHcclInfo(op_desc, kernel_hccl_infos) != SUCCESS) { + GELOGE(PARAM_INVALID, "HcomOmeUtil:: the number of GETaskKernelHcclInfo is invalid."); + return PARAM_INVALID; + } GELOGI("Set hccl task input output address, node[%s}, type[%s] kernel_hccl_infos.size[%zu].", op_desc->GetName().c_str(), op_desc->GetType().c_str(), kernel_hccl_infos.size()); if (op_desc->GetType() == HVDWAIT) { return SUCCESS; } - + domi::Status dmrt; hcclRedOp_t op_type = HCCL_REP_OP_SUM; GE_CHECK_NOTNULL(davinci_model_); GELOGI("Calc opType[%s] input address before. Node name[%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - if (!davinci_model_->IsKnownNode()) { - input_data_addrs_ = ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); - output_data_addrs_ = ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); - } - void *input_data_addr = nullptr; - void *output_data_addr = nullptr; + auto input_data_addr_list = ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + + auto output_data_addr_list = ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); // initialize every kernel_hccl_info inputDataAddr for (size_t i = 0; i < kernel_hccl_infos.size(); i++) { std::string hccl_type = kernel_hccl_infos[i].hccl_type; - if (davinci_model_->IsKnownNode()) { - input_data_addr = reinterpret_cast(reinterpret_cast(args_) + i); - output_data_addr = reinterpret_cast(reinterpret_cast(args_) + op_desc->GetInputsSize() + i); - GELOGI("Hccl task info known input addr %p, output addr %p.", input_data_addr, output_data_addr); - } else { - input_data_addr = input_data_addrs_.empty() ? nullptr : input_data_addrs_[i]; - output_data_addr = output_data_addrs_.empty() ? nullptr : output_data_addrs_[i]; - } + void *input_data_addr = input_data_addr_list.empty() ? nullptr : input_data_addr_list[i]; kernel_hccl_infos[i].inputDataAddr = input_data_addr; + + void *output_data_addr = output_data_addr_list.empty() ? nullptr : output_data_addr_list[i]; if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER) { kernel_hccl_infos[i].outputDataAddr = output_data_addr; } else if (hccl_type == HCOMALLREDUCE || hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE) { - GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), - "davinci_model: GetHcomOperationType fail!"); + dmrt = HcomOmeUtil::GetHcclOperationType(op_desc, op_type); + if (dmrt != SUCCESS) { + GELOGE(FAILED, "davinci_model: GetHcomOperationType fail! domi error: %u", dmrt); + return FAILED; + } kernel_hccl_infos[i].outputDataAddr = output_data_addr; kernel_hccl_infos[i].opType = op_type; } @@ -361,7 +310,6 @@ void HcclTaskInfo::CreateKernelHcclInfo(const ge::ConstOpDescPtr &op_desc) { Status HcclTaskInfo::SetWorkspace(const std::shared_ptr &op_desc, std::vector &kernel_hccl_infos) { GE_CHECK_NOTNULL(op_desc); - GE_CHECK_NOTNULL(davinci_model_); GELOGI("SetWorkspace Node[%s] opType[%s] set workspace.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); uint64_t workspace_mem_size = 0; void *workspace_addr = nullptr; @@ -371,12 +319,11 @@ Status HcclTaskInfo::SetWorkspace(const std::shared_ptr &op_desc, GELOGI("hccl need workSpaceMemSize=%lu", workspace_mem_size_tmp); if (workspace_mem_size_tmp != 0) { workspace_mem_size = workspace_mem_size_tmp; - if (davinci_model_->IsKnownNode()) { - workspace_addr = reinterpret_cast(reinterpret_cast(args_) + op_desc->GetInputsSize() + - op_desc->GetOutputsSize()); - } else { - workspace_data_addrs_ = ModelUtils::GetWorkspaceDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); - workspace_addr = workspace_data_addrs_.empty() ? nullptr : workspace_data_addrs_[0]; + vector workspace_data_addrs = + ModelUtils::GetWorkspaceDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + if (!workspace_data_addrs.empty()) { + GELOGI("Get workSpaceAddr"); + workspace_addr = workspace_data_addrs[0]; } } } diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h index cc3109f4..bb0a88de 100644 --- a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h @@ -34,10 +34,7 @@ class HcclTaskInfo : public TaskInfo { hccl_stream_list_(), ops_kernel_store_(nullptr), private_def_(nullptr), - private_def_len_(0), - op_desc_(nullptr), - args_(nullptr), - args_offset_(0) {} + private_def_len_(0) {} ~HcclTaskInfo() override; @@ -47,10 +44,6 @@ class HcclTaskInfo : public TaskInfo { uint32_t GetTaskID() override { return id_; } - Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; - - Status UpdateArgs() override; - private: ge::Status SetAddrs(const std::string &hccl_type, const std::shared_ptr &op); @@ -79,12 +72,6 @@ class HcclTaskInfo : public TaskInfo { static std::mutex hccl_follow_stream_mutex_; static uint32_t max_node_of_hccl_stream_; vector kernel_hccl_infos_; - vector input_data_addrs_; - vector output_data_addrs_; - vector workspace_data_addrs_; - OpDescPtr op_desc_; - void *args_; - uint32_t args_offset_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_HCCL_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc index a241e129..79971529 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc @@ -79,9 +79,6 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin return FAILED;) } - GELOGI("Node[%s] type[%s] kernel_ext_info size=%zu, ext_info_addr_=%p", op_desc_->GetName().c_str(), - op_desc_->GetType().c_str(), ext_info.size(), ext_info_addr_); - // 2.1 get loop cond variable for tensor array write uint64_t step_id_addr = 0; OpDescPtr step_id_node = davinci_model_->GetVariableOp(NODE_NAME_GLOBAL_STEP); @@ -100,11 +97,6 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuKernel(session_id, davinci_model->Id(), kernel_id) != SUCCESS, GELOGE(FAILED, "CreateAicpuKernel error."); return FAILED;) - // 2.3 Create session - GE_CHECK_NOTNULL(ModelManager::GetInstance()); - GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuSession(session_id) != SUCCESS, - GELOGE(FAILED, "CreateAicpuSession error. session id: %lu", session_id); - return FAILED;) kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); if (davinci_model_->IsKnownNode()) { @@ -161,8 +153,8 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy to input_output_addr_ error: 0x%X", rt_ret); return FAILED;) - if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), - op_desc->GetName())) { + if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), + op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; dump_args_ = input_output_addr_; } @@ -175,7 +167,12 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoLen = ext_info.size(); fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = reinterpret_cast(ext_info_addr_); - // 4. Return result + // 4. Create session + GE_CHECK_NOTNULL(ModelManager::GetInstance()); + GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuSession(session_id) != SUCCESS, + GELOGE(FAILED, "CreateAicpuSession error. session id: %lu", session_id); + return FAILED;) + // 5. Return result rtError_t rt_ret = rtMalloc(&kernel_buf_, sizeof(STR_FWK_OP_KERNEL), RT_MEMORY_HBM); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc error: 0x%X", rt_ret); return FAILED;) @@ -183,7 +180,12 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_DEVICE); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) - davinci_model_->SetZeroCopyAddr(op_desc, io_addrs, io_addrs.data(), input_output_addr_, addrs_size, 0); + vector virtual_io_addrs; // use virtual address for zero copy key. + const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); + const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); + virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); + virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); + davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, io_addrs.data(), input_output_addr_, addrs_size, 0); GELOGI("KernelExTaskInfo Init Success. session id: %lu", session_id); return SUCCESS; @@ -205,55 +207,19 @@ Status KernelExTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciMod uint32_t mem_size = sizeof(uint64_t) * mem_length; davinci_model->SetTotalArgsSize(mem_size); GELOGI("kernel task name %s, args_size %u, args_offset %u", op_desc->GetName().c_str(), mem_size, args_offset_); - - // alloc fixed addr - string peer_input_name; - if (AttrUtils::GetStr(op_desc, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name) && !peer_input_name.empty()) { - uint32_t output_index = davinci_model->GetFixedAddrOutputIndex(peer_input_name); - if (output_index > outputs_size) { - GELOGE(FAILED, "The output size[%zu] and output index[%u] are inconsistent.", outputs_size, output_index); - return FAILED; - } - fixed_addr_offset_ = davinci_model->GetFixedAddrsSize(peer_input_name); - auto tensor_desc = op_desc->GetOutputDesc(output_index); - int64_t tensor_size = 0; - GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); - davinci_model->SetTotalFixedAddrsSize(peer_input_name, tensor_size); - GELOGI("Calculate stream switch task args , tensor size is %ld, fixed addr offset %ld", tensor_size, - fixed_addr_offset_); - } return SUCCESS; } Status KernelExTaskInfo::UpdateArgs() { GELOGI("KernelExTaskInfo::UpdateArgs in."); const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); + vector io_addrs; vector input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc_); vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc_); - vector io_addrs; - if (!op_desc_->HasAttr(ATTR_DYNAMIC_SHAPE_FIXED_ADDR)) { - io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); - io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); - } else { - string peer_input_name; - if (AttrUtils::GetStr(op_desc_, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name)) { - uint32_t output_index = davinci_model_->GetFixedAddrOutputIndex(peer_input_name); - if (output_index > output_data_addrs.size()) { - GELOGE(FAILED, "The output data addr size[%zu] and output index[%u] are inconsistent.", - output_data_addrs.size(), output_index); - return FAILED; - } - io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); - for (size_t i = 0; i < output_data_addrs.size(); ++i) { - if (i == output_index) { - void *fixed_addr = davinci_model_->GetCurrentFixedAddr(fixed_addr_offset_); - io_addrs.emplace_back(fixed_addr); - continue; - } - io_addrs.emplace_back(output_data_addrs[i]); - } - } - } + + io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); + GE_CHK_STATUS_RET(davinci_model_->UpdateKnownZeroCopyAddr(io_addrs, args_offset_), "update known node %s zero copy addr failed.", op_desc_->GetName().c_str()); @@ -265,7 +231,7 @@ Status KernelExTaskInfo::CopyTaskInfo(const domi::KernelExDef &kernel_def, const const OpDescPtr &op_desc) { // Userspace copy need virtual address. const vector workspace_data_sizes = ModelUtils::GetWorkspaceSize(op_desc); - const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); + const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc, false); if (workspace_data_addrs.empty() || workspace_data_sizes.empty()) { GELOGE(FAILED, "Node:%s invalid workspace, addrs is %zu, size is %zu.", op_desc->GetName().c_str(), workspace_data_addrs.size(), workspace_data_sizes.size()); diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h index b26a95ac..ff8f3119 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h @@ -54,7 +54,6 @@ class KernelExTaskInfo : public TaskInfo { auto ret = reinterpret_cast(dump_args_); return ret; } - bool CallSaveDumpInfo() override { return true; }; private: Status CopyTaskInfo(const domi::KernelExDef &kernel_def, const RuntimeParam &rts_param, const OpDescPtr &op_desc); @@ -70,7 +69,6 @@ class KernelExTaskInfo : public TaskInfo { void *dump_args_; OpDescPtr op_desc_ = nullptr; uint32_t args_offset_ = 0; - int64_t fixed_addr_offset_ = 0; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_KERNEL_EX_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc index 12fe0206..7ef65555 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc @@ -47,16 +47,16 @@ const uint32_t kAddrLen = sizeof(void *); namespace ge { KernelTaskInfo::SuperKernelTaskInfo KernelTaskInfo::skt_info_ = { - 0, 0, 0, 0, nullptr, nullptr, {}, {}, {}, {}, {}, RT_KERNEL_DEFAULT, kInvalidGroupKey, 0, nullptr}; + 0, 0, 0, 0, nullptr, nullptr, {}, {}, RT_KERNEL_DEFAULT, kInvalidGroupKey, 0, nullptr}; Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { if (davinci_model == nullptr) { - GELOGE(PARAM_INVALID, "davinci model is null!"); + GELOGE(PARAM_INVALID, "davinci_model is null!"); return PARAM_INVALID; } davinci_model_ = davinci_model; is_l1_fusion_enable_ = davinci_model_->GetL1FusionEnableOption(); - GELOGD("KernelTaskInfo init start, ge.enableL1Fusion in davinci model is %d.", is_l1_fusion_enable_); + GELOGD("KernelTaskInfo Init Start, ge.enableL1Fusion in davinci model is %d.", is_l1_fusion_enable_); Status ret = SetStream(task_def.stream_id(), davinci_model_->GetStreamList()); if (ret != SUCCESS) { @@ -73,7 +73,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci // get opdesc op_desc_ = davinci_model_->GetOpByIndex(context.op_index()); if (op_desc_ == nullptr) { - GELOGE(INTERNAL_ERROR, "Get op desc failed, index is out of range!"); + GELOGE(INTERNAL_ERROR, "Get op_desc failed, index is out of range!"); return INTERNAL_ERROR; } (void)AttrUtils::GetBool(*op_desc_, ATTR_N_BATCH_SPILT, is_n_batch_spilt_); @@ -138,21 +138,14 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci ret = InitCceTask(kernel_def); } - GELOGD("KernelTaskInfo init finish, result=%u.", ret); + GELOGD("KernelTaskInfo Init finish, result=%u.", ret); return ret; } Status KernelTaskInfo::SaveSKTDumpInfo() { GE_CHECK_NOTNULL(davinci_model_); - if (skt_dump_flag_ == RT_KERNEL_DEFAULT) { - GELOGD("no need save skt dump info"); - return SUCCESS; - } - // all op in super kernel share one taskid and streamid - for (size_t i = 0; i < skt_info_.op_desc_list.size(); i++) { - davinci_model_->SaveDumpTask(skt_info_.last_task_id, skt_info_.last_stream_id, skt_info_.op_desc_list[i], - skt_info_.dump_args_list[i]); - } + davinci_model_->SaveDumpTask(skt_info_.last_task_id, skt_info_.last_stream_id, skt_info_.last_op, + skt_info_.last_dump_args); return SUCCESS; } @@ -194,9 +187,6 @@ Status KernelTaskInfo::SKTFinalize() { GELOGI("SuperKernel Distribute [skt_id:%u]", skt_id_); skt_info_.kernel_list.clear(); skt_info_.arg_list.clear(); - skt_info_.dump_flag_list.clear(); - skt_info_.op_desc_list.clear(); - skt_info_.dump_args_list.clear(); skt_info_.last_stream = nullptr; skt_info_.last_block_dim = 0; skt_info_.last_sm_desc = sm_desc_; @@ -207,15 +197,6 @@ Status KernelTaskInfo::SKTFinalize() { return SUCCESS; } -uint32_t KernelTaskInfo::GetDumpFlag() { - for (auto flag : skt_info_.dump_flag_list) { - if (flag == RT_KERNEL_DUMPFLAG) { - return RT_KERNEL_DUMPFLAG; - } - } - return RT_KERNEL_DEFAULT; -} - Status KernelTaskInfo::SuperKernelLaunch() { if (skt_info_.kernel_list.empty()) { GELOGI("SuperKernelLaunch: Skt_kernel_list has no task, just return"); @@ -225,7 +206,7 @@ Status KernelTaskInfo::SuperKernelLaunch() { auto &skt_kernel_list = skt_info_.kernel_list; auto &skt_arg_list = skt_info_.arg_list; GELOGI("SuperKernelLaunch: Skt_kernel_list size[%d] skt_arg_list[%d]", skt_kernel_list.size(), skt_arg_list.size()); - if (skt_kernel_list.size() == kSKTSingleSize && skt_arg_list.size() == kSKTSingleSize) { + if (skt_kernel_list.size() == kSKTSingleSize) { rt_ret = rtKernelLaunchWithFlag(skt_info_.kernel_list[0], static_cast(skt_info_.last_block_dim), skt_info_.arg_list[0], skt_info_.last_args_size, static_cast(skt_info_.last_sm_desc), skt_info_.last_stream, @@ -234,7 +215,6 @@ Status KernelTaskInfo::SuperKernelLaunch() { GELOGE(RT_FAILED, "SuperKernelLaunch: Call rt api failed, ret: 0x%X", rt_ret); return RT_FAILED; } - call_save_dump_ = true; GE_CHK_STATUS_RET(SKTFinalize(), "Skt finalize failed"); return SUCCESS; } @@ -246,22 +226,18 @@ Status KernelTaskInfo::SuperKernelLaunch() { return RT_FAILED; } // Call the fuse API - std::unique_ptr superKernel = nullptr; + skt::SuperKernel *superKernel = nullptr; if (factory->FuseKernels(skt_kernel_list, skt_arg_list, skt_info_.last_block_dim, superKernel) != SUCCESS) { GELOGE(RT_FAILED, "SuperKernelLaunch: fuse call failed"); return RT_FAILED; } // Launch a super kernel - skt_dump_flag_ = GetDumpFlag(); - if (superKernel->Launch(skt_info_.last_stream, skt_dump_flag_) != SUCCESS) { + if (superKernel->Launch(skt_info_.last_stream, RT_KERNEL_DUMPFLAG) != SUCCESS) { GELOGE(RT_FAILED, "SuperKernelLaunch: launch failed"); return RT_FAILED; } GELOGI("SuperKernelLaunch: success[skt_kernel_list size[%zu] skt_arg_list[%zu]]", skt_kernel_list.size(), skt_arg_list.size()); - // record skt addr for release - superkernel_dev_nav_table_ = superKernel->GetNavTablePtr(); - superkernel_device_args_addr_ = superKernel->GetDeviceArgsPtr(); GE_CHK_STATUS_RET(SKTFinalize(), "Skt finalize failed"); return SUCCESS; } @@ -274,9 +250,6 @@ Status KernelTaskInfo::SaveSuperKernelInfo() { skt_info_.last_args_size = args_size_; skt_info_.last_sm_desc = sm_desc_; skt_info_.last_dump_flag = dump_flag_; - skt_info_.dump_flag_list.push_back(dump_flag_); - skt_info_.op_desc_list.push_back(op_desc_); - skt_info_.dump_args_list.push_back(reinterpret_cast(dump_args_)); skt_info_.last_group_key = group_key_; skt_info_.last_dump_args = reinterpret_cast(dump_args_); skt_info_.last_op = op_desc_; @@ -355,7 +328,6 @@ Status KernelTaskInfo::SuperKernelDistribute() { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return FAILED; } - call_save_dump_ = true; UpdateTaskId(); GELOGI("Current Common Task Distribute [taskid:%u]", task_id_); } else { @@ -384,7 +356,6 @@ Status KernelTaskInfo::Distribute() { rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(so_name_.c_str()), reinterpret_cast(kernel_name_.c_str()), 1, args_, args_size_, nullptr, stream_, dump_flag_); - call_save_dump_ = true; } else { /* default: not skt launch */ GELOGI( @@ -398,7 +369,6 @@ Status KernelTaskInfo::Distribute() { // call rtKernelLaunch for current task rt_ret = rtKernelLaunchWithFlag(stub_func_, block_dim_, args_, args_size_, static_cast(sm_desc_), stream_, dump_flag_); - call_save_dump_ = true; } } if (rt_ret != RT_ERROR_NONE) { @@ -422,31 +392,9 @@ Status KernelTaskInfo::UpdateArgs() { vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc_); vector io_addrs; - if (!op_desc_->HasAttr(ATTR_DYNAMIC_SHAPE_FIXED_ADDR)) { - io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); - io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); - io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); - } else { - string peer_input_name; - if (AttrUtils::GetStr(op_desc_, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name)) { - uint32_t output_index = davinci_model_->GetFixedAddrOutputIndex(peer_input_name); - if (output_index > output_data_addrs.size()) { - GELOGE(FAILED, "The output data addr size[%zu] and output index[%u] are inconsistent.", - output_data_addrs.size(), output_index); - return FAILED; - } - io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); - for (size_t i = 0; i < output_data_addrs.size(); ++i) { - if (i == output_index) { - void *fixed_addr = davinci_model_->GetCurrentFixedAddr(fixed_addr_offset_); - io_addrs.emplace_back(fixed_addr); - continue; - } - io_addrs.emplace_back(output_data_addrs[i]); - } - io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); - } - } + io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); + io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); GE_CHK_STATUS_RET(davinci_model_->UpdateKnownZeroCopyAddr(io_addrs, args_offset_), "update known node %s zero copy addr failed.", op_desc_->GetName().c_str()); @@ -460,8 +408,6 @@ Status KernelTaskInfo::Release() { return SUCCESS; } FreeRtMem(&args_); - FreeRtMem(&superkernel_device_args_addr_); - FreeRtMem(&superkernel_dev_nav_table_); FreeRtMem(&flowtable_); FreeRtMem(&custom_info_.input_descs); FreeRtMem(&custom_info_.input_addrs); @@ -526,29 +472,6 @@ Status KernelTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel args_offset_ = davinci_model->GetTotalArgsSize(); davinci_model->SetTotalArgsSize(args_size); GELOGI("kernel task name , args_size %u, args_offset %u", args_size, args_offset_); - - // get opcontext stored in model - const domi::KernelContext &context = kernel_def.context(); - // get opdesc - op_desc_ = davinci_model->GetOpByIndex(context.op_index()); - GE_CHECK_NOTNULL(op_desc_); - // alloc fixed addr - string peer_input_name; - if (AttrUtils::GetStr(op_desc_, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name) && !peer_input_name.empty()) { - uint32_t output_index = davinci_model->GetFixedAddrOutputIndex(peer_input_name); - if (output_index > op_desc_->GetOutputsSize()) { - GELOGE(FAILED, "The output size[%zu] and output index[%u] are inconsistent.", op_desc_->GetOutputsSize(), - output_index); - return FAILED; - } - fixed_addr_offset_ = davinci_model->GetFixedAddrsSize(peer_input_name); - auto tensor_desc = op_desc_->GetOutputDesc(output_index); - int64_t tensor_size = 0; - GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); - davinci_model->SetTotalFixedAddrsSize(peer_input_name, tensor_size); - GELOGI("Calculate stream switch task args , tensor size is %ld, fixed addr offset %ld", tensor_size, - fixed_addr_offset_); - } return SUCCESS; } @@ -626,8 +549,8 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne return FAILED; } - if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), - op_desc->GetName())) { + if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), + op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; dump_args_ = static_cast(args_) + offset; } @@ -638,8 +561,10 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne } vector virtual_io_addrs; // use virtual address for zero copy key. - virtual_io_addrs.insert(virtual_io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); - virtual_io_addrs.insert(virtual_io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); + const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); + const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); + virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); + virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, args_info.data(), args_, args_size_, offset); GELOGD("Do InitTVMTask end"); @@ -677,6 +602,7 @@ Status KernelTaskInfo::InitAICPUCustomTask(uint32_t op_index, const domi::Kernel const std::vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); Status ret = StoreInputOutputTensor(input_data_addrs, output_data_addrs, ModelUtils::GetInputDescs(op_desc), ModelUtils::GetOutputDescs(op_desc)); + if (ret != SUCCESS) { GELOGE(ret, "StoreInputOutputTensor Failed"); return ret; @@ -741,9 +667,11 @@ Status KernelTaskInfo::InitAICPUCustomTask(uint32_t op_index, const domi::Kernel return RT_FAILED; } - davinci_model_->SetZeroCopyAddr(op_desc, input_data_addrs, input_data_addrs.data(), custom_info_.input_addrs, - input_data_addrs.size() * kAddrLen, 0); - davinci_model_->SetZeroCopyAddr(op_desc, output_data_addrs, output_data_addrs.data(), custom_info_.output_addrs, + const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); + const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); + davinci_model_->SetZeroCopyAddr(op_desc, virtual_in_addrs, input_data_addrs.data(), custom_info_.input_addrs, + virtual_in_addrs.size() * kAddrLen, 0); + davinci_model_->SetZeroCopyAddr(op_desc, virtual_out_addrs, output_data_addrs.data(), custom_info_.output_addrs, output_data_addrs.size() * kAddrLen, 0); return SUCCESS; } @@ -873,9 +801,6 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k GELOGE(init_ret, "Init aicpu task ext info failed, ext_info size=%zu", ext_info.size()); return init_ret; } - GELOGI("Node[%s] type[%s] kernel_ext_info size=%zu, aicpu_ext_info_addr_=%p", op_desc_->GetName().c_str(), - op_desc_->GetType().c_str(), ext_info.size(), aicpu_ext_info_addr_); - aicpu_param_head->extInfoAddr = reinterpret_cast(aicpu_ext_info_addr_); aicpu_param_head->extInfoLength = reinterpret_cast(ext_info.size()); @@ -894,13 +819,19 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k return RT_FAILED; } - if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), - op_desc->GetName())) { + if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), + op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; dump_args_ = static_cast(args_) + sizeof(aicpu::AicpuParamHead); } - davinci_model_->SetZeroCopyAddr(op_desc, io_addrs, args_addr.get(), args_, args_size_, sizeof(aicpu::AicpuParamHead)); + vector virtual_io_addrs; // use virtual address for zero copy key. + const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); + const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); + virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); + virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); + davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, args_addr.get(), args_, args_size_, + sizeof(aicpu::AicpuParamHead)); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h index 04cd6312..41ed5728 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h @@ -61,8 +61,6 @@ class KernelTaskInfo : public TaskInfo { sm_desc_ = nullptr; flowtable_ = nullptr; args_ = nullptr; - superkernel_device_args_addr_ = nullptr; - superkernel_dev_nav_table_ = nullptr; } Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; @@ -90,8 +88,6 @@ class KernelTaskInfo : public TaskInfo { uint32_t GetSktTaskID() override { return skt_id_; } - bool CallSaveDumpInfo() override { return call_save_dump_; }; - cce::ccOpContext ctx_; FusionOpInfo fusion_op_info_; @@ -134,7 +130,6 @@ class KernelTaskInfo : public TaskInfo { void UpdateSKTTaskId(); Status SKTFinalize(); Status SuperKernelLaunch(); - uint32_t GetDumpFlag(); Status SaveSuperKernelInfo(); bool IsMarkedLastNode(); bool IsMarkedFirstNode(); @@ -158,8 +153,6 @@ class KernelTaskInfo : public TaskInfo { OpDescPtr op_desc_; DavinciModel *davinci_model_; uint32_t args_offset_ = 0; - int64_t fixed_addr_offset_ = 0; - bool call_save_dump_ = false; // aicpu ext_info device mem void *aicpu_ext_info_addr_ = nullptr; @@ -171,9 +164,6 @@ class KernelTaskInfo : public TaskInfo { bool is_n_batch_spilt_; int64_t group_key_; bool has_group_key_; - uint32_t skt_dump_flag_ = RT_KERNEL_DEFAULT; - void *superkernel_device_args_addr_ = nullptr; - void *superkernel_dev_nav_table_ = nullptr; struct AICPUCustomInfo { void *input_descs = nullptr; @@ -193,9 +183,6 @@ class KernelTaskInfo : public TaskInfo { void *last_sm_desc; std::vector kernel_list; std::vector arg_list; - std::vector dump_flag_list; - std::vector op_desc_list; - std::vector dump_args_list; uint32_t last_dump_flag; int64_t last_group_key; uintptr_t last_dump_args; diff --git a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc index 162cf00d..818307eb 100644 --- a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc @@ -16,8 +16,8 @@ #include "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h" -#include "graph/debug/ge_attr_define.h" #include "graph/load/new_model_manager/davinci_model.h" +#include "graph/debug/ge_attr_define.h" namespace ge { constexpr uint8_t kLabelSwitchIndexNum = 1; @@ -59,13 +59,7 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo op_desc->GetName().c_str(), input_data_addr.size(), kLabelSwitchIndexNum); return INTERNAL_ERROR; } - - if (davinci_model->IsKnownNode()) { - index_value_ = davinci_model->GetCurrentFixedAddr(fixed_addr_offset_); - } else { - index_value_ = input_data_addr[0]; - } - + index_value_ = input_data_addr[0]; davinci_model->DisableZeroCopy(index_value_); std::vector label_idx_list; @@ -130,28 +124,5 @@ Status LabelSwitchByIndexTaskInfo::Distribute() { return SUCCESS; } -Status LabelSwitchByIndexTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { - GE_CHECK_NOTNULL(davinci_model); - auto label_switch = task_def.label_switch_by_index(); - uint32_t op_index = label_switch.op_index(); - GELOGI("Begin to calculate args, op_index is: %u", op_index); - auto op_desc = davinci_model->GetOpByIndex(op_index); - GE_CHECK_NOTNULL(op_desc); - GELOGI("Calc opType[%s] args size. Node name is [%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - if (op_desc->GetInputsSize() != kLabelSwitchIndexNum) { - GELOGE(FAILED, "Label switch op only have one data input. Now input size is %zu", op_desc->GetInputsSize()); - return FAILED; - } - string input_tensor_name = op_desc->GetInputNameByIndex(0); - fixed_addr_offset_ = davinci_model->GetFixedAddrsSize(input_tensor_name); - auto tensor_desc = op_desc->GetInputDesc(0); - int64_t tensor_size = 0; - GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); - davinci_model->SetTotalFixedAddrsSize(input_tensor_name, tensor_size); - GELOGI("Calculate stream switchn task args , tensor_size %ld, fixed_addr_offset %ld", tensor_size, - fixed_addr_offset_); - return SUCCESS; -} - REGISTER_TASK_INFO(RT_MODEL_TASK_STREAM_LABEL_SWITCH_BY_INDEX, LabelSwitchByIndexTaskInfo); } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h index 4cb39c95..1a644736 100644 --- a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h @@ -22,8 +22,7 @@ namespace ge { class LabelSwitchByIndexTaskInfo : public TaskInfo { public: - LabelSwitchByIndexTaskInfo() - : index_value_(nullptr), branch_max_(0), args_(nullptr), args_size_(0), fixed_addr_offset_(0) {} + LabelSwitchByIndexTaskInfo() : index_value_(nullptr), branch_max_(0), args_(nullptr), args_size_(0) {} ~LabelSwitchByIndexTaskInfo() override; @@ -31,15 +30,13 @@ class LabelSwitchByIndexTaskInfo : public TaskInfo { Status Distribute() override; - Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; - private: void *index_value_; // switch index input. uint32_t branch_max_; // max branch count. void *args_; // label info memory. uint32_t args_size_; // label info length. + std::vector label_list_; - int64_t fixed_addr_offset_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ \ No newline at end of file diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc index af32b44f..e9d99189 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc @@ -21,9 +21,9 @@ namespace ge { Status MemcpyAddrAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { - GELOGI("MemcpyAddrAsyncTaskInfo Init Start"); + GELOGI("MemcpyAddrAsyncTaskInfo Init Start."); if (davinci_model == nullptr) { - GELOGE(PARAM_INVALID, "davinci_model is null"); + GELOGE(PARAM_INVALID, "davinci_model is null!"); return PARAM_INVALID; } @@ -32,27 +32,45 @@ Status MemcpyAddrAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel return ret; } - const auto &memcpy_async = task_def.memcpy_async(); - OpDescPtr op_desc = davinci_model->GetOpByIndex(memcpy_async.op_index()); + auto memcpy_async_def = task_def.memcpy_async(); + uint32_t op_index = memcpy_async_def.op_index(); + OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); if (op_desc == nullptr) { - GELOGE(INTERNAL_ERROR, "Task op index:%u out of range", memcpy_async.op_index()); + GELOGE(INTERNAL_ERROR, "Init MemcpyAddrAsyncTaskInfo error, index is out of range!"); return INTERNAL_ERROR; } - ret = ModelUtils::GetRtAddress(davinci_model->GetRuntimeParam(), memcpy_async.src(), src_); + uint64_t logic_dst = memcpy_async_def.dst(); + uint64_t logic_src = memcpy_async_def.src(); + + dst_max_ = memcpy_async_def.dst_max(); + + uint64_t update_base_addr = 0; + ret = GetUpdateBaseAddr(davinci_model, logic_src, update_base_addr); if (ret != SUCCESS) { return ret; } + src_ = reinterpret_cast(update_base_addr + logic_src); + if (src_ == nullptr) { + GELOGE(PARAM_INVALID, "src_ is null!"); + return PARAM_INVALID; + } - ret = ModelUtils::GetRtAddress(davinci_model->GetRuntimeParam(), memcpy_async.dst(), dst_); - if (ret != SUCCESS) { - return ret; + uint64_t mem_base = reinterpret_cast(davinci_model->MemBase()); + uint64_t logic_mem_base = davinci_model->GetRtBaseAddr(); + dst_ = reinterpret_cast(reinterpret_cast(mem_base + (logic_dst - logic_mem_base))); + if (dst_ == nullptr) { + GELOGE(PARAM_INVALID, "dst_ is null!"); + return PARAM_INVALID; } vector io_addrs; io_addrs.emplace_back(src_); io_addrs.emplace_back(dst_); + count_ = memcpy_async_def.count(); + kind_ = memcpy_async_def.kind(); + // malloc args memory size_t args_size = sizeof(void *) * io_addrs.size(); rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); @@ -70,18 +88,20 @@ Status MemcpyAddrAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel return RT_FAILED; } - count_ = memcpy_async.count(); - kind_ = memcpy_async.kind(); - dst_max_ = memcpy_async.dst_max(); - GELOGI("InitMemcpyAddrAsyncTaskInfo, logic[0x%lx, 0x%lx], src:%p, dst:%p, max:%lu, count:%lu, args:%p, size:%zu", - memcpy_async.src(), memcpy_async.dst(), src_, dst_, dst_max_, count_, args_, args_size); + // Just dest addr need zero copy. + davinci_model->SetZeroCopyAddr(op_desc, {dst_}, io_addrs.data(), args_, args_size, sizeof(void *)); + + GELOGI("InitMemcpyAddrAsyncTaskInfo, logic_src:%p, logic_dst:%p, src:%p, dst:%p, src_args:%p, dst_args:%p", + reinterpret_cast(reinterpret_cast(logic_src)), + reinterpret_cast(reinterpret_cast(logic_dst)), src_, dst_, args_, + reinterpret_cast(reinterpret_cast(args_) + args_size)); - davinci_model->SetZeroCopyAddr(op_desc, io_addrs, io_addrs.data(), args_, args_size, 0); return SUCCESS; } Status MemcpyAddrAsyncTaskInfo::Distribute() { - GELOGI("MemcpyAddrAsyncTaskInfo Distribute Start, dst_max:%lu, count:%lu, kind:%u", dst_max_, count_, kind_); + GELOGI("MemcpyAddrAsyncTaskInfo Distribute Start."); + GELOGI("Distribute MemcpyAddrAsync, dst_max:%lu, count:%lu, kind:%u.", dst_max_, count_, kind_); rtError_t rt_ret = rtMemcpyAsync(reinterpret_cast(reinterpret_cast(args_) + sizeof(void *)), dst_max_, args_, count_, static_cast(kind_), stream_); @@ -93,5 +113,39 @@ Status MemcpyAddrAsyncTaskInfo::Distribute() { return SUCCESS; } +Status MemcpyAddrAsyncTaskInfo::GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, + uint64_t &base_addr) { + GE_CHECK_NOTNULL(davinci_model); + uint64_t data_base_addr = + static_cast(reinterpret_cast(davinci_model->MemBase())) - davinci_model->GetRtBaseAddr(); + uint64_t weight_base_addr = static_cast(reinterpret_cast(davinci_model->WeightsMemBase())) - + davinci_model->GetRtWeightAddr(); + uint64_t var_base_addr = + static_cast(reinterpret_cast(davinci_model->VarMemBase())) - davinci_model->GetRtVarAddr(); + + uint64_t data_base_addr_start = davinci_model->GetRtBaseAddr(); + uint64_t data_base_addr_end = davinci_model->GetRtBaseAddr() + davinci_model->TotalMemSize(); + uint64_t wight_base_addr_start = davinci_model->GetRtWeightAddr(); + uint64_t wight_base_addr_end = davinci_model->GetRtWeightAddr() + davinci_model->TotalWeightsMemSize(); + uint64_t varible_base_addr_start = davinci_model->GetRtVarAddr(); + uint64_t varible_base_addr_end = davinci_model->GetRtVarAddr() + davinci_model->TotalVarMemSize(); + + if ((data_base_addr_start <= update_addr) && (update_addr <= data_base_addr_end)) { + base_addr = data_base_addr; + GELOGI("The update_addr is data address."); + } else if ((wight_base_addr_start <= update_addr) && (update_addr <= wight_base_addr_end)) { + base_addr = weight_base_addr; + GELOGI("The update_addr is weight address."); + } else if ((varible_base_addr_start <= update_addr) && (update_addr <= varible_base_addr_end)) { + base_addr = var_base_addr; + GELOGI("The update_addr is variable address."); + } else if (update_addr != 0) { + base_addr = 0; + GELOGE(PARAM_INVALID, "The update_addr is abnormal."); + return PARAM_INVALID; + } + return SUCCESS; +} + REGISTER_TASK_INFO(RT_MODEL_TASK_MEMCPY_ADDR_ASYNC, MemcpyAddrAsyncTaskInfo); } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h index f8bf8a90..9252e43a 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h @@ -16,7 +16,6 @@ #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ - #include "graph/load/new_model_manager/task_info/task_info.h" namespace ge { @@ -33,8 +32,9 @@ class MemcpyAddrAsyncTaskInfo : public TaskInfo { if (ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); } - args_ = nullptr; } + + args_ = nullptr; } Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; @@ -42,9 +42,11 @@ class MemcpyAddrAsyncTaskInfo : public TaskInfo { Status Distribute() override; private: - uint8_t *dst_; + Status GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, uint64_t &base_addr); + + void *dst_; uint64_t dst_max_; - uint8_t *src_; + void *src_; void *args_; uint64_t count_; uint32_t kind_; diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc index c2b56436..82eabe69 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc @@ -21,9 +21,9 @@ namespace ge { Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { - GELOGI("MemcpyAsyncTaskInfo Init Start"); + GELOGI("MemcpyAsyncTaskInfo Init Start."); if (davinci_model == nullptr) { - GELOGE(PARAM_INVALID, "davinci_model is null"); + GELOGE(PARAM_INVALID, "davinci_model is null!"); return PARAM_INVALID; } @@ -32,38 +32,35 @@ Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da return ret; } - memcpy_async = task_def.memcpy_async(); - count_ = memcpy_async.count(); - kind_ = memcpy_async.kind(); - dst_max_ = memcpy_async.dst_max(); - if (davinci_model->IsKnownNode()) { - src_ = reinterpret_cast(davinci_model_->GetCurrentArgsAddr(args_offset_)); - dst_ = reinterpret_cast(reinterpret_cast(src_) + sizeof(void *)); - // for zero copy - kind_ = RT_MEMCPY_ADDR_DEVICE_TO_DEVICE; - GELOGI("MemcpyAsyncTaskInfo src_ %p, dst_ %p, args_offset %u.", src_, dst_, args_offset_); - return SUCCESS; - } - ret = ModelUtils::GetRtAddress(davinci_model->GetRuntimeParam(), memcpy_async.src(), src_); - if (ret != SUCCESS) { - return ret; - } + auto memcpy_async_def = task_def.memcpy_async(); + uint64_t logic_dst = memcpy_async_def.dst(); + uint64_t logic_src = memcpy_async_def.src(); + + dst_max_ = memcpy_async_def.dst_max(); - ret = ModelUtils::GetRtAddress(davinci_model->GetRuntimeParam(), memcpy_async.dst(), dst_); + uint64_t update_base_addr = 0; + ret = GetUpdateBaseAddr(davinci_model, logic_src, update_base_addr); if (ret != SUCCESS) { return ret; } + src_ = reinterpret_cast(update_base_addr + logic_src); + davinci_model->DisableZeroCopy(src_); - GELOGI("MemcpyAsyncTaskInfo Init Success, logic[0x%lx, 0x%lx], src:%p, dst:%p, max:%lu, count:%lu", - memcpy_async.src(), memcpy_async.dst(), src_, dst_, dst_max_, count_); + uint64_t mem_base = reinterpret_cast(davinci_model->MemBase()); + uint64_t logic_mem_base = davinci_model->GetRtBaseAddr(); + dst_ = reinterpret_cast(mem_base + (logic_dst - logic_mem_base)); + + count_ = memcpy_async_def.count(); + kind_ = memcpy_async_def.kind(); + GELOGI("MemcpyAsyncTaskInfo Init Success, logic_src:%p, logic_dst:%p, src:%p, dst:%p", + reinterpret_cast(reinterpret_cast(logic_src)), + reinterpret_cast(reinterpret_cast(logic_dst)), src_, dst_); - davinci_model->DisableZeroCopy(src_); - davinci_model->DisableZeroCopy(dst_); return SUCCESS; } Status MemcpyAsyncTaskInfo::Distribute() { - GELOGI("MemcpyAsyncTaskInfo Distribute Start. dst_max:%lu, count:%lu, kind:%u", dst_max_, count_, kind_); + GELOGI("MemcpyAsyncTaskInfo Distribute Start. dst_max:%lu, count:%lu, kind:%u.", dst_max_, count_, kind_); rtError_t rt_ret = rtMemcpyAsync(dst_, dst_max_, src_, count_, static_cast(kind_), stream_); if (rt_ret != RT_ERROR_NONE) { @@ -71,41 +68,40 @@ Status MemcpyAsyncTaskInfo::Distribute() { return RT_FAILED; } - GELOGI("MemcpyAsyncTaskInfo Distribute Success"); - return SUCCESS; -} - -Status MemcpyAsyncTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { - // the num of src and dst size is 2 - uint32_t args_size = sizeof(void *) * 2; - args_offset_ = davinci_model->GetTotalArgsSize(); - davinci_model->SetTotalArgsSize(args_size); - davinci_model_ = davinci_model; - GELOGI("MemcpyAsyncTaskInfo kernel args_size %u, args_offset %u", args_size, args_offset_); + GELOGI("MemcpyAsyncTaskInfo Distribute Success."); return SUCCESS; } -Status MemcpyAsyncTaskInfo::UpdateArgs() { - GELOGI("MemcpyAsyncTaskInfo::UpdateArgs in."); - GE_CHECK_NOTNULL(davinci_model_); - Status ret = ModelUtils::GetRtAddress(davinci_model_->GetRuntimeParam(), memcpy_async.src(), src_); - if (ret != SUCCESS) { - return ret; - } - - ret = ModelUtils::GetRtAddress(davinci_model_->GetRuntimeParam(), memcpy_async.dst(), dst_); - if (ret != SUCCESS) { - return ret; +Status MemcpyAsyncTaskInfo::GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, uint64_t &base_addr) { + GE_CHECK_NOTNULL(davinci_model); + uint64_t data_base_addr = + reinterpret_cast(reinterpret_cast(davinci_model->MemBase())) - davinci_model->GetRtBaseAddr(); + uint64_t weight_base_addr = reinterpret_cast(reinterpret_cast(davinci_model->WeightsMemBase())) - + davinci_model->GetRtWeightAddr(); + uint64_t var_base_addr = reinterpret_cast(reinterpret_cast(davinci_model->VarMemBase())) - + davinci_model->GetRtVarAddr(); + + uint64_t data_base_addr_start = davinci_model->GetRtBaseAddr(); + uint64_t data_base_addr_end = davinci_model->GetRtBaseAddr() + davinci_model->TotalMemSize(); + uint64_t wight_base_addr_start = davinci_model->GetRtWeightAddr(); + uint64_t wight_base_addr_end = davinci_model->GetRtWeightAddr() + davinci_model->TotalWeightsMemSize(); + uint64_t varible_base_addr_start = davinci_model->GetRtVarAddr(); + uint64_t varible_base_addr_end = davinci_model->GetRtVarAddr() + davinci_model->TotalVarMemSize(); + + if ((data_base_addr_start <= update_addr) && (update_addr <= data_base_addr_end)) { + base_addr = data_base_addr; + GELOGI("The update_addr is data address."); + } else if ((wight_base_addr_start <= update_addr) && (update_addr <= wight_base_addr_end)) { + base_addr = weight_base_addr; + GELOGI("The update_addr is weight address."); + } else if ((varible_base_addr_start <= update_addr) && (update_addr <= varible_base_addr_end)) { + base_addr = var_base_addr; + GELOGI("The update_addr is variable address."); + } else if (update_addr != 0) { + base_addr = 0; + GELOGE(PARAM_INVALID, "The update_addr is abnormal."); + return PARAM_INVALID; } - - vector io_addrs; - io_addrs.emplace_back(reinterpret_cast(src_)); - io_addrs.emplace_back(reinterpret_cast(dst_)); - - GE_CHK_STATUS_RET(davinci_model_->UpdateKnownZeroCopyAddr(io_addrs, args_offset_), - "update memcpyasync in known node zero copy addr failed."); - - GELOGI("MemcpyAsyncTaskInfo::UpdateArgs success."); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h index c3daa862..02872f34 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h @@ -16,7 +16,6 @@ #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ASYNC_TASK_INFO_H_ #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ASYNC_TASK_INFO_H_ - #include "graph/load/new_model_manager/task_info/task_info.h" namespace ge { @@ -33,19 +32,14 @@ class MemcpyAsyncTaskInfo : public TaskInfo { Status Distribute() override; - Status UpdateArgs() override; - - Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; - private: - uint8_t *dst_; + Status GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, uint64_t &base_addr); + + void *dst_; uint64_t dst_max_; - uint8_t *src_; + void *src_; uint64_t count_; uint32_t kind_; - DavinciModel *davinci_model_ = nullptr; - uint32_t args_offset_ = 0; - domi::MemcpyAsyncDef memcpy_async; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ASYNC_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc index 0ebaf573..a1d2f143 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc @@ -42,11 +42,16 @@ Status StreamSwitchTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *d auto stream_switch_def = task_def.stream_switch(); uint32_t op_index = stream_switch_def.op_index(); + // get StreamSwitch op OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); GE_CHECK_NOTNULL(op_desc); auto input_data_addr = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); - SetInputAndValuePtr(davinci_model, input_data_addr); + if (!input_data_addr.empty() && input_data_addr.size() >= STREAM_SWITCH_INPUT_NUM) { + input_ptr_ = input_data_addr[0]; + value_ptr_ = input_data_addr[1]; + } + uint32_t cond = 0; if (!AttrUtils::GetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, cond)) { GELOGE(INTERNAL_ERROR, "StreamSwitchOp get attr STREAM_SWITCH_COND fail."); @@ -110,42 +115,6 @@ Status StreamSwitchTaskInfo::Distribute() { GELOGI("StreamSwitchTaskInfo Distribute Success. cond:%d, stream:%p, datatype:%d.", cond_, true_stream_, data_type_); return SUCCESS; } -Status StreamSwitchTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { - GE_CHECK_NOTNULL(davinci_model); - auto stream_switch_def = task_def.stream_switch(); - uint32_t op_index = stream_switch_def.op_index(); - GELOGI("Begin to calculate args, op_index is: %u", op_index); - auto op_desc = davinci_model->GetOpByIndex(op_index); - GE_CHECK_NOTNULL(op_desc); - GELOGI("Calc opType[%s] args size. Node name is [%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - if (op_desc->GetInputsSize() != STREAM_SWITCH_INPUT_NUM) { - GELOGE(FAILED, "Stream switch op only have one data input. Now input size is %zu", op_desc->GetInputsSize()); - return FAILED; - } - for (uint32_t i = 0; i < STREAM_SWITCH_INPUT_NUM; ++i) { - string input_tensor_name = op_desc->GetInputNameByIndex(i); - int64_t fixed_addr_offset = davinci_model->GetFixedAddrsSize(input_tensor_name); - fixed_addr_offset_.emplace_back(fixed_addr_offset); - auto tensor_desc = op_desc->GetInputDesc(i); - int64_t tensor_size = 0; - GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); - davinci_model->SetTotalFixedAddrsSize(input_tensor_name, tensor_size); - GELOGI("Calculate stream switch task args , tensor size is %ld, fixed addr[%u] offset %ld", tensor_size, i, - fixed_addr_offset); - } - return SUCCESS; -} -void StreamSwitchTaskInfo::SetInputAndValuePtr(DavinciModel *davinci_model, const vector &input_data_addrs) { - if (davinci_model->IsKnownNode() && fixed_addr_offset_.size() == STREAM_SWITCH_INPUT_NUM) { - input_ptr_ = davinci_model->GetCurrentFixedAddr(fixed_addr_offset_[0]); - value_ptr_ = davinci_model->GetCurrentFixedAddr(fixed_addr_offset_[1]); - } else { - if (!input_data_addrs.empty() && input_data_addrs.size() >= STREAM_SWITCH_INPUT_NUM) { - input_ptr_ = input_data_addrs[0]; - value_ptr_ = input_data_addrs[1]; - } - } -} REGISTER_TASK_INFO(RT_MODEL_TASK_STREAM_SWITCH, StreamSwitchTaskInfo); } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h index e6e8339a..07509c7c 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h @@ -39,18 +39,13 @@ class StreamSwitchTaskInfo : public TaskInfo { Status Distribute() override; - Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; - private: - void SetInputAndValuePtr(DavinciModel *davinci_model, const vector &input_data_addrs); void *input_ptr_; rtCondition_t cond_; void *value_ptr_; rtStream_t true_stream_; uint32_t true_stream_id_; rtSwitchDataType_t data_type_; - static const uint32_t kInputNum = 2; - vector fixed_addr_offset_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_STREAM_SWITCH_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc index 01371af7..29b107bd 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc @@ -24,15 +24,18 @@ namespace { const uint32_t kDynamicBtachParamNum = 1; const uint32_t kDynamicResolutionParamNum = 2; -const uint8_t kStreamSwitchnInputNum = 1; } // namespace namespace ge { Status StreamSwitchNTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { GELOGI("StreamSwitchNTaskInfo Init Start."); - GE_CHECK_NOTNULL(davinci_model); + if (davinci_model == nullptr) { + GELOGE(PARAM_INVALID, "davinci_model is null!"); + return PARAM_INVALID; + } - if (SetStream(task_def.stream_id(), davinci_model->GetStreamList()) != SUCCESS) { + Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); + if (ret != SUCCESS) { return FAILED; } @@ -72,16 +75,14 @@ Status StreamSwitchNTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel * GELOGE(FAILED, "Get true stream ptr of switchN op failed."); return FAILED; } - if (davinci_model->IsKnownNode()) { - input_ptr_ = davinci_model->GetCurrentFixedAddr(args_offset_); - } else { - auto input_data_addr = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); - if (input_data_addr.empty()) { - GELOGE(FAILED, "Input data addr is nullptr."); - return FAILED; - } - input_ptr_ = input_data_addr[0]; + + // set input_ptr_ + auto input_data_addr = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); + if (input_data_addr.empty()) { + GELOGE(FAILED, "Input data addr is nullptr."); + return FAILED; } + input_ptr_ = input_data_addr[0]; davinci_model->DisableZeroCopy(input_ptr_); GELOGI("StreamSwitchNTaskInfo Init Success, inputSize:%u, elementSize:%d, trueStreamID:%ld.", input_size_, element_size_, op_desc->GetStreamId()); @@ -139,26 +140,5 @@ Status StreamSwitchNTaskInfo::GetTrueStreamPtr(const OpDescPtr &op_desc, Davinci return SUCCESS; } -Status StreamSwitchNTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { - GE_CHECK_NOTNULL(davinci_model); - auto stream_switchn_def = task_def.stream_switch_n(); - uint32_t op_index = stream_switchn_def.op_index(); - GELOGI("Begin to calculate args, op_index is: %u", op_index); - auto op_desc = davinci_model->GetOpByIndex(op_index); - GE_CHECK_NOTNULL(op_desc); - GELOGI("Calc opType[%s] args size. Node name is [%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - if (op_desc->GetInputsSize() != kStreamSwitchnInputNum) { - GELOGE(FAILED, "Stream switchn op only have one data input. Now input size is %zu", op_desc->GetInputsSize()); - return FAILED; - } - string input_tensor_name = op_desc->GetInputNameByIndex(0); - args_offset_ = davinci_model->GetFixedAddrsSize(input_tensor_name); - auto tensor_desc = op_desc->GetInputDesc(0); - int64_t tensor_size = 0; - GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); - davinci_model->SetTotalFixedAddrsSize(input_tensor_name, tensor_size); - GELOGI("Calculate stream switchn task args , tensor_size %ld, args_offset %ld", tensor_size, args_offset_); - return SUCCESS; -} REGISTER_TASK_INFO(RT_MODEL_TASK_STREAM_SWITCH_N, StreamSwitchNTaskInfo); } // namespace ge \ No newline at end of file diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h index 1a96243a..d1002da7 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h @@ -29,8 +29,7 @@ class StreamSwitchNTaskInfo : public TaskInfo { value_ptr_(nullptr), true_stream_ptr_(nullptr), element_size_(0), - data_type_(RT_SWITCH_INT64), - args_offset_(0) {} + data_type_(RT_SWITCH_INT64) {} ~StreamSwitchNTaskInfo() override {} @@ -38,8 +37,6 @@ class StreamSwitchNTaskInfo : public TaskInfo { Status Distribute() override; - Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; - private: Status GetTrueStreamPtr(const OpDescPtr &op_desc, DavinciModel *davinci_model); void *input_ptr_; @@ -50,7 +47,6 @@ class StreamSwitchNTaskInfo : public TaskInfo { rtSwitchDataType_t data_type_; vector true_stream_list_; vector value_list_; - int64_t args_offset_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_STREAM_SWITCHN_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h index b7e76af0..1c31acd1 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h @@ -34,13 +34,22 @@ class SuperKernel { public: SuperKernel(const void *stub, void *ptr, uint64_t sz, uint32_t dim) : func_stub_(stub), dev_nav_table_(ptr), nav_table_size_(sz), block_dim_(dim) {} - ~SuperKernel() = default; + ~SuperKernel() { + // free memory when all releasing + if (device_args_addr_ != nullptr) { + GE_CHK_RT(rtFree(device_args_addr_)); + GELOGI("SKT: super_kernel args addr free."); + } + if (dev_nav_table_ != nullptr) { + GE_CHK_RT(rtFree(dev_nav_table_)); + GELOGI("SKT: super_kernel args addr free."); + } + } Status Launch(rtStream_t stream, uint32_t dump_flag); const void *GetFuncStub() const { return func_stub_; } + const void *GetNavTablePtr() const { return dev_nav_table_; } uint64_t GetNavTableSize() const { return nav_table_size_; } uint32_t GetBlockDim() const { return block_dim_; } - void *GetNavTablePtr() const { return dev_nav_table_; } - void *GetDeviceArgsPtr() const { return device_args_addr_; } }; } // namespace skt } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc index 397c7d98..d2ad474a 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc @@ -42,10 +42,21 @@ Status SuperKernelFactory::Init() { rt_ret = rtGetAddrByFun(this->func_stub_, &this->func_ptr_); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); return FAILED;) - GELOGD( - "SKT: fuseKernels super_kernel_template subFunc %p, device func " - "address %p", - this->func_stub_, this->func_ptr_); + if (this->use_physical_address_ != nullptr) { + void *skt_func = nullptr; + rt_ret = rtKernelConfigTransArg(this->func_ptr_, sizeof(uint64_t), 0, &skt_func); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); + return FAILED;) + GELOGD( + "SKT: fuseKernels super_kernel_template subFunc %p, device func " + "address %p, device physic PC %p", + this->func_stub_, this->func_ptr_, skt_func); + } else { + GELOGD( + "SKT: fuseKernels super_kernel_template subFunc %p, device func " + "address %p", + this->func_stub_, this->func_ptr_); + } } is_init_ = true; @@ -60,8 +71,7 @@ Status SuperKernelFactory::Uninitialize() { } Status SuperKernelFactory::FuseKernels(const std::vector &stub_func_list, - const std::vector &args_addr_list, uint32_t block_dim, - std::unique_ptr &h) { + const std::vector &args_addr_list, uint32_t block_dim, SuperKernel *&h) { // Iterate through the ops to be fused // Each subkernel to be fused contains 2 fields: fn address offset, args // address. @@ -91,28 +101,70 @@ Status SuperKernelFactory::FuseKernels(const std::vector &stub_func_list rtError_t rt_ret; void *hbm_nav_table_addr = nullptr; - for (unsigned i = 0; i < stub_func_list.size(); i++) { - void *sub_device_func = nullptr; - rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); - return FAILED;) - GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); - // store two uint64_t address - // address divided by 4 because of 32bits encoding, call offset will *4 when calculating - nav_table[i * 2] = reinterpret_cast(reinterpret_cast(sub_device_func)) / 4; - GELOGD("SKT: CALL offet %lu", nav_table[i * 2]); - nav_table[i * 2 + 1] = reinterpret_cast(reinterpret_cast(args_addr_list[i])); - GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]); + if (this->use_physical_address_ != nullptr) { + for (unsigned i = 0; i < stub_func_list.size(); i++) { + void *sub_device_func = nullptr; + rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); + return FAILED;) + void *sub_device_func_pys = nullptr; + void *args_addr_pys = nullptr; + rt_ret = rtKernelConfigTransArg(sub_device_func, sizeof(uint64_t), 0, &sub_device_func_pys); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); + return FAILED;) + rt_ret = rtKernelConfigTransArg(args_addr_list[i], sizeof(uint64_t), 0, &args_addr_pys); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); + return FAILED;) + GELOGD( + "SKT: fuseKernels subFunc %p, device func address %p, device " + "physic func address %p", + stub_func_list[i], sub_device_func, sub_device_func_pys); + // store two uint64_t address + // address divided by 4 because of 32bits encoding, call offset will *4 when calculating + nav_table[i * 2] = reinterpret_cast(reinterpret_cast(sub_device_func_pys)) / 4; + GELOGD("SKT: CALL offset %lu", nav_table[i * 2]); + nav_table[i * 2 + 1] = reinterpret_cast(reinterpret_cast(args_addr_pys)); + + GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]); + } + + void *hbm_nav_table_addr_pys = nullptr; + rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) + rt_ret = + rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); + GE_CHK_RT(rtFree(hbm_nav_table_addr)); return FAILED;) + rt_ret = rtKernelConfigTransArg(hbm_nav_table_addr, sizeof(uint64_t), 0, &hbm_nav_table_addr_pys); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); + GE_CHK_RT(rtFree(hbm_nav_table_addr)); return FAILED;) + + GELOGD("SKT: hbm_nav_table_addr %p, hbm_nav_table_addr_pys %p", hbm_nav_table_addr, hbm_nav_table_addr_pys); + // Create the necessary metadata for the super kernel + h = new SuperKernel(this->func_stub_, hbm_nav_table_addr_pys, nav_table_size, block_dim); + } else { + for (unsigned i = 0; i < stub_func_list.size(); i++) { + void *sub_device_func = nullptr; + rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); + return FAILED;) + GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); + // store two uint64_t address + // address divided by 4 because of 32bits encoding, call offset will *4 when calculating + nav_table[i * 2] = reinterpret_cast(reinterpret_cast(sub_device_func)) / 4; + GELOGD("SKT: CALL offet %lu", nav_table[i * 2]); + nav_table[i * 2 + 1] = reinterpret_cast(reinterpret_cast(args_addr_list[i])); + GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]); + } + rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) + rt_ret = + rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); + GE_CHK_RT(rtFree(hbm_nav_table_addr)); return FAILED;) + // Create the necessary metadata for the super kernel + h = new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim); } - rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) - rt_ret = - rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); - GE_CHK_RT(rtFree(hbm_nav_table_addr)); return FAILED;) - // Create the necessary metadata for the super kernel - h = - std::unique_ptr(new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim)); return SUCCESS; } } // namespace skt diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h index 7db44eec..d8b7ff26 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h @@ -29,6 +29,7 @@ class SuperKernelFactory { void *func_ptr_ = nullptr; void *handle_ = nullptr; std::string sk_stub_name_ = "_Z21super_kernel_templatePmm"; + const char *use_physical_address_ = getenv("GE_USE_PHYSICAL_ADDRESS"); bool is_init_ = false; SuperKernelFactory(){}; ~SuperKernelFactory() { @@ -47,7 +48,7 @@ class SuperKernelFactory { Status Init(); Status Uninitialize(); Status FuseKernels(const std::vector &stub_func_list, const std::vector &args_addr_list, - uint32_t block_dim, std::unique_ptr &h); + uint32_t block_dim, SuperKernel *&h); }; } // namespace skt } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/task_info.h b/src/ge/graph/load/new_model_manager/task_info/task_info.h index f69511e6..5d2c89eb 100644 --- a/src/ge/graph/load/new_model_manager/task_info/task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/task_info.h @@ -72,8 +72,6 @@ class TaskInfo { virtual uint32_t GetTaskID() { return 0xFFFFFFFF; } - virtual bool CallSaveDumpInfo() { return false; } - virtual uint32_t GetStreamId() { return 0xFFFFFFFF; } virtual uintptr_t GetDumpArgs() { return 0; } diff --git a/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h b/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h index 5b220960..b6954016 100644 --- a/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h +++ b/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h @@ -86,5 +86,5 @@ class TaskInfoFactory { return ptr; \ } \ TaskInfoFactory::Registerar g_##type##_Task_Info_Creator(type, Creator_##type##_Task_Info); -} // namespace ge +}; // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_TASK_INFO_FACTORY_H_ diff --git a/src/ge/graph/load/new_model_manager/zero_copy_task.cc b/src/ge/graph/load/new_model_manager/zero_copy_task.cc index be75322d..42734a87 100644 --- a/src/ge/graph/load/new_model_manager/zero_copy_task.cc +++ b/src/ge/graph/load/new_model_manager/zero_copy_task.cc @@ -129,6 +129,12 @@ Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, const DataBuffer &data, } auto dst_addr = static_cast(data.data); + auto dst_size = static_cast(data.length); + if (ModelUtils::ConvertVirtualAddressToPhysical(dst_addr, dst_size, dst_addr) != SUCCESS) { + GELOGE(FAILED, "[ZCPY] Convert virtual address to physical for dst_addr failed."); + return FAILED; + } + GELOGI("[ZCPY] %s update task, args: %p, size: %zu, offset: %zu, addr: 0x%lx, length: %u", name_.c_str(), args_addr_, args_size_, offset, addr, data.length); *(uintptr_t *)(args_info + offset) = reinterpret_cast(dst_addr); diff --git a/src/ge/graph/load/output/output.cc b/src/ge/graph/load/output/output.cc new file mode 100644 index 00000000..d922ce7c --- /dev/null +++ b/src/ge/graph/load/output/output.cc @@ -0,0 +1,175 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/load/output/output.h" + +#include + +#include "common/properties_manager.h" +#include "graph/load/new_model_manager/davinci_model.h" +#include "graph/manager/graph_var_manager.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/utils/tensor_utils.h" + +namespace ge { +Output::Output(const OpDescPtr &op_desc, DavinciModel *model) + : base_(nullptr), + var_base_(nullptr), + logic_base_(0), + logic_var_base_(0), + model_(model), + op_desc_(op_desc), + input_num_(0) {} + +Output::~Output() { + var_base_ = nullptr; + base_ = nullptr; + model_ = nullptr; +} + +/// +/// @ingroup domi +/// @brief Initialize input/output params +/// @return Status +/// +Status Output::Init() { + if (op_desc_ == nullptr || model_ == nullptr) { + GELOGE(INTERNAL_ERROR, "The op_desc_ or model_ is nullptr."); + return INTERNAL_ERROR; + } + + base_ = model_->MemBase(); + var_base_ = model_->VarMemBase(); + logic_base_ = model_->GetRtBaseAddr(); + logic_var_base_ = model_->GetRtVarAddr(); + + input_num_ = op_desc_->GetInputsSize(); + v_input_size_.clear(); + v_input_data_addr_.clear(); + + auto input_vector = op_desc_->GetInputOffset(); + if (input_num_ != input_vector.size()) { + GELOGE(INTERNAL_ERROR, "input desc size: %zu != input offset size: %zu.", input_num_, input_vector.size()); + return INTERNAL_ERROR; + } + + for (size_t i = 0; i < input_num_; i++) { + int64_t tensor_size = 0; + auto input_desc = op_desc_->GetInputDescPtr(i); + GE_CHECK_NOTNULL(input_desc); + Status ret = TensorUtils::GetSize(*input_desc, tensor_size); + if (ret != GRAPH_SUCCESS) { + GELOGE(ret, "Get size from TensorDesc failed, op : %s, input index : %zu", op_desc_->GetName().c_str(), i); + return ret; + } + v_input_size_.push_back(tensor_size); + + if (VarManager::Instance(model_->SessionId())->IsVarAddr(input_vector[i])) { + v_input_data_addr_.push_back(static_cast(var_base_ + input_vector[i] - logic_var_base_)); + } else { + v_input_data_addr_.push_back(static_cast(base_ + input_vector[i])); + } + } + + GELOGI("Init output:%lu, %lu, %lu", input_num_, v_input_size_.size(), v_input_data_addr_.size()); + + return SUCCESS; +} + +/// +/// @ingroup domi +/// @brief Copy Op Output to user space. +/// @brief when model running, Add one DataOp as input node, Add one Output Op as output node. +/// @return Status +/// +Status Output::CopyResult(OutputData &rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share) { + uint32_t data_count = 0; + if (input_num_ > rslt.blobs.size() - data_begin) { + GELOGE(FAILED, "Tensor num %zu, data_buf num: %zu.", input_num_, rslt.blobs.size() - data_begin); + return FAILED; + } else if (input_num_ < rslt.blobs.size() - data_begin) { + GELOGW("Tensor num %zu, data_buf num: %zu.", input_num_, rslt.blobs.size() - data_begin); + } + + for (size_t i = 0; i < input_num_; i++) { + DataBuffer data_buf = rslt.blobs[data_begin + data_count]; + Status ret = SetDataBuf(data_buf, data_count, i, support_mem_share); + if (ret != SUCCESS) { + GELOGE(ret, "Copy data to host error. index: %zu", i); + return ret; + } + data_index = data_begin + data_count; + } + + return SUCCESS; +} + +Status Output::SetDataBuf(DataBuffer &data_buf, uint32_t &data_count, size_t i, bool support_mem_share) { + if (data_buf.length == 0) { + ++data_count; + GELOGD("Length of data_buffer is zero, No need to copy. output op : %s, output tensor index : %zu!", + op_desc_->GetName().c_str(), i); + return SUCCESS; + } + + auto tensor_desc = op_desc_->GetInputDescPtr(static_cast(i)); + if (tensor_desc == nullptr) { + GELOGE(FAILED, "tensor_desc is null"); + return FAILED; + } + + if (data_buf.isDataSupportMemShare && support_mem_share) { + GELOGI("No need to copy input data, user's output data buffer can be shared."); + } else { + // Copy result to Databuf + int64_t size = v_input_size_[i]; + GELOGI("Tensor data size before: %ld", size); + + graphStatus graph_status = TensorUtils::GetTensorSizeInBytes(*tensor_desc, size); + if (graph_status != ge::GRAPH_SUCCESS) { + GELOGE(graph_status, "GetTensorSizeInBytes failed!"); + return FAILED; + } + + if (data_buf.length < size) { + GELOGE(FAILED, "Tensor data size: %ld data_buf length: %ld", size, data_buf.length); + return FAILED; + } else if (data_buf.length > size) { + GELOGW("Tensor data size: %ld data_buf length: %ld", size, data_buf.length); + } + + rtError_t rt_ret = rtMemcpy(data_buf.data, size, v_input_data_addr_[i], size, RT_MEMCPY_DEVICE_TO_HOST); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(rt_ret, "rtmemcpy error"); + return FAILED; + } + GELOGI("Tensor data size: %ld data_buf length: %ld", size, data_buf.length); + } + + ++data_count; + GELOGD("Successfully copy the output tensor memory to buffer, output op : %s, output tensor index : %zu!", + op_desc_->GetName().c_str(), i); + + return SUCCESS; +} + +void Output::GetOutputData(vector &v_data_addr, vector &v_data_size) { + for (size_t i = 0; i < input_num_; ++i) { + v_data_addr.push_back(v_input_data_addr_[i]); + v_data_size.push_back(v_input_size_[i]); + } +} +} // namespace ge diff --git a/src/ge/graph/load/output/output.h b/src/ge/graph/load/output/output.h new file mode 100644 index 00000000..d93b8de9 --- /dev/null +++ b/src/ge/graph/load/output/output.h @@ -0,0 +1,94 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_LOAD_OUTPUT_OUTPUT_H_ +#define GE_GRAPH_LOAD_OUTPUT_OUTPUT_H_ + +#include +#include + +#include "common/debug/log.h" +#include "common/op/attr_value_util.h" +#include "common/op/ge_op_utils.h" +#include "common/types.h" +#include "common/util.h" +#include "common/ge_types.h" +#include "graph/load/new_model_manager/davinci_model.h" +#include "graph/op_desc.h" +#include "graph/debug/ge_attr_define.h" + +namespace ge { +using std::string; +using std::vector; + +// The base class for all op +class Output { + public: + Output(const OpDescPtr &op_desc, DavinciModel *model); + virtual ~Output(); + + /// + /// @ingroup domi + /// @brief Initialize input/output params + /// @return Status + /// + virtual Status Init(); + + /// + /// @ingroup domi + /// @brief Copy Op Output to user space. + /// @brief when model running, Add one DataOp as input node, Add one Output Op as output node. + /// @return Status + /// + virtual Status CopyResult(OutputData &rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share); + + /// + /// @ingroup domi + /// @brief Trans Output data to fp16 + /// @return Status + /// + Status SetDataBuf(DataBuffer &data_buf, uint32_t &data_count, size_t i, bool support_mem_share); + + /// + /// @ingroup domi + /// @brief Get Output data and size. + /// @return void + /// + void GetOutputData(vector &v_data_addr, vector &v_data_size); + + // Copy assignment operator and copy constructor are deleted + Output &operator=(const Output &output) = delete; + Output(const Output &output) = delete; + + protected: + // Model's base address + uint8_t *base_; + uint8_t *var_base_; + uint64_t logic_base_; + uint64_t logic_var_base_; + // The DavinciModel which ops belong to + DavinciModel *model_; + + ConstOpDescPtr op_desc_; + + // Input descriptions + size_t input_num_; + vector v_input_data_addr_; // init as:buf_base + op_def_->input(i)); + vector v_input_size_; +}; +} // namespace ge + +#endif // GE_GRAPH_LOAD_OUTPUT_OUTPUT_H_ diff --git a/src/ge/graph/manager/graph_caching_allocator.cc b/src/ge/graph/manager/graph_caching_allocator.cc index cbeafa3f..5df6769b 100644 --- a/src/ge/graph/manager/graph_caching_allocator.cc +++ b/src/ge/graph/manager/graph_caching_allocator.cc @@ -34,6 +34,9 @@ const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, 26 * kGByteSize}; static bool BlockComparator(const Block *left, const Block *right) { + if (left->device_id != right->device_id) { + return left->device_id < right->device_id; + } if (left->size != right->size) { return left->size < right->size; } @@ -264,20 +267,20 @@ Status CachingAllocator::TryExtendCache(size_t size, uint32_t device_id) { return ge::FAILED; } } - if (AddToBlockBin(memory_addr, memory_size, device_id) != ge::SUCCESS) { + if (AddToBlockBin(memory_addr, memory_size) != ge::SUCCESS) { (void)memory_allocator_->FreeMemory(memory_addr); return ge::FAILED; } return ge::SUCCESS; } -Status CachingAllocator::AddToBlockBin(uint8_t *ptr, size_t size, uint32_t device_id) { +Status CachingAllocator::AddToBlockBin(uint8_t *ptr, size_t size) { BlockBin *bin = GetBlockBin(size); if (bin == nullptr) { GELOGE(ge::FAILED, "Get block bin failed size = %zu", size); return ge::FAILED; } - Block *block = new (std::nothrow) Block(device_id, size, bin, nullptr); + Block *block = new (std::nothrow) Block(0, size, bin, nullptr); if (block == nullptr) { GELOGE(ge::FAILED, "Alloc block failed size = %zu", size); return ge::FAILED; @@ -336,4 +339,5 @@ void CachingAllocator::FreeBlockBins() { } } } + } // namespace ge diff --git a/src/ge/graph/manager/graph_caching_allocator.h b/src/ge/graph/manager/graph_caching_allocator.h index 94a5066a..75864ce7 100644 --- a/src/ge/graph/manager/graph_caching_allocator.h +++ b/src/ge/graph/manager/graph_caching_allocator.h @@ -32,6 +32,7 @@ #include "runtime/mem.h" namespace ge { + constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold constexpr size_t kKByteSize = 1024; @@ -68,10 +69,6 @@ class CachingAllocator { public: explicit CachingAllocator(rtMemType_t memory_type); - CachingAllocator(const CachingAllocator &) = delete; - - CachingAllocator &operator=(const CachingAllocator &) = delete; - virtual ~CachingAllocator() = default; /// @@ -140,10 +137,9 @@ class CachingAllocator { /// @brief add memory to right bin based on size /// @param [in] memory ptr /// @param [in] memory size - /// @param [in] device_id device id /// @return Status result of function /// - Status AddToBlockBin(uint8_t *ptr, size_t size, uint32_t device_id); + Status AddToBlockBin(uint8_t *ptr, size_t size); /// /// @ingroup ge_graph @@ -210,5 +206,7 @@ class CachingAllocator { // block bins by different block size BlockBin *free_block_bins_[kNumBins]; }; -} // namespace ge + +}; // namespace ge + #endif // GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_ diff --git a/src/ge/graph/manager/graph_manager.cc b/src/ge/graph/manager/graph_manager.cc index bfd09c72..dd4855b6 100644 --- a/src/ge/graph/manager/graph_manager.cc +++ b/src/ge/graph/manager/graph_manager.cc @@ -57,6 +57,7 @@ #include "graph/passes/flow_ctrl_pass.h" #include "graph/passes/hccl_group_pass.h" #include "graph/passes/hccl_memcpy_pass.h" +#include "graph/passes/identify_reference_pass.h" #include "graph/passes/identity_pass.h" #include "graph/passes/iterator_op_pass.h" #include "graph/passes/link_gen_mask_nodes_pass.h" @@ -73,9 +74,7 @@ #include "graph/passes/switch_data_edges_bypass.h" #include "graph/passes/switch_dead_branch_elimination.h" #include "graph/passes/switch_logic_remove_pass.h" -#include "graph/passes/merge_to_stream_merge_pass.h" -#include "graph/passes/switch_to_stream_switch_pass.h" -#include "graph/passes/attach_stream_label_pass.h" +#include "graph/passes/switch_op_pass.h" #include "graph/passes/transop_breadth_fusion_pass.h" #include "graph/passes/transop_depth_fusion_pass.h" #include "graph/passes/transop_nearby_allreduce_fusion_pass.h" @@ -84,7 +83,6 @@ #include "graph/passes/transpose_transdata_pass.h" #include "graph/passes/variable_op_pass.h" #include "graph/passes/variable_prepare_op_pass.h" -#include "graph/passes/ref_identity_delete_op_pass.h" #include "graph/passes/variable_ref_delete_op_pass.h" #include "graph/passes/variable_ref_useless_control_out_delete_pass.h" #include "graph/utils/tensor_adapter.h" @@ -349,13 +347,12 @@ Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_gr return SUCCESS; } -#define GM_RUN_AND_DUMP_PERF(name, func, ...) \ +#define GM_RUN_AND_DUMP(name, func, ...) \ do { \ - GE_RUN_PERF(GraphManager, func, __VA_ARGS__); \ + GE_RUN(GraphManager, func, __VA_ARGS__); \ GE_DUMP(compute_graph, "PreRunAfter" name); \ GELOGI("Run %s on graph %s(%u) success.", name, compute_graph->GetName().c_str(), graph_node->GetGraphId()); \ } while (0) - Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector &inputs, GeRootModelPtr &ge_root_model, uint64_t session_id) { GE_CHECK_NOTNULL(graph_node); @@ -368,30 +365,30 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vectorGetName().c_str()); GE_DUMP(compute_graph, "PreRunBegin"); - GM_RUN_AND_DUMP_PERF("OptimizeGraphPrepare", graph_optimize_.OptimizeOriginalGraphForQuantize, compute_graph); - GM_RUN_AND_DUMP_PERF("HandleSummaryOp", graph_optimize_.HandleSummaryOp, compute_graph); - GM_RUN_AND_DUMP_PERF("Prepare", graph_preparer_.PrepareDynShape, graph_node->GetGraph(), inputs, compute_graph, - session_id); - GM_RUN_AND_DUMP_PERF("OptimizeOriginalGraph", graph_optimize_.OptimizeOriginalGraph, compute_graph); + GM_RUN_AND_DUMP("OptimizeGraphPrepare", graph_optimize_.OptimizeOriginalGraphForQuantize, compute_graph); + GM_RUN_AND_DUMP("HandleSummaryOp", graph_optimize_.HandleSummaryOp, compute_graph); + GM_RUN_AND_DUMP("Prepare", graph_preparer_.PrepareDynShape, graph_node->GetGraph(), inputs, compute_graph, + session_id); + GM_RUN_AND_DUMP("OptimizeOriginalGraph", graph_optimize_.OptimizeOriginalGraph, compute_graph); - GM_RUN_AND_DUMP_PERF("PrepareRunningFormatRefiner", graph_preparer_.PrepareRunningFormatRefiner); - GM_RUN_AND_DUMP_PERF("RefineRunningFormat", graph_optimize_.OptimizeOriginalGraphJudgeInsert, compute_graph); + GM_RUN_AND_DUMP("PrepareRunningFormatRefiner", graph_preparer_.PrepareRunningFormatRefiner); + GM_RUN_AND_DUMP("RefineRunningFormat", graph_optimize_.OptimizeOriginalGraphJudgeInsert, compute_graph); GE_RUN(GraphManager, graph_preparer_.RecordAIPPInfo, compute_graph); if (IsTailingOptimization()) { - GM_RUN_AND_DUMP_PERF("OptimizeSwitchOp", graph_preparer_.SwitchOpOptimize, compute_graph); + GM_RUN_AND_DUMP("OptimizeSwitchOp", graph_preparer_.SwitchOpOptimize, compute_graph); } - GM_RUN_AND_DUMP_PERF("Optimize1", OptimizeStage1, compute_graph); - GM_RUN_AND_DUMP_PERF("InferShape2", compute_graph->InferShapeInNeed); + GM_RUN_AND_DUMP("Optimize1", OptimizeStage1, compute_graph); + GM_RUN_AND_DUMP("InferShape2", compute_graph->InferShapeInNeed); const char *unknown_shape_skip = std::getenv("EXPERIMENTAL_DYNAMIC_PARTITION"); if (unknown_shape_skip != nullptr) { PassManager graph_pass; GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::CtrlEdgeTransferPass", new (std::nothrow) CtrlEdgeTransferPass)) GE_CHK_STATUS_RET(graph_pass.Run(compute_graph)); } - GE_CHK_STATUS_RET(graph_optimize_.IdentifyReference(compute_graph), "Identify reference failed."); - GM_RUN_AND_DUMP_PERF("OptimizeSubgraph", OptimizeSubgraph, graph_node, compute_graph, session_id); - GM_RUN_AND_DUMP_PERF("Optimize2", OptimizeStage2, compute_graph); - GM_RUN_AND_DUMP_PERF("Build", Build, graph_node, compute_graph, ge_root_model, session_id); + + GM_RUN_AND_DUMP("OptimizeSubgraph", OptimizeSubgraph, graph_node, compute_graph, session_id); + GM_RUN_AND_DUMP("Optimize2", OptimizeStage2, compute_graph); + GM_RUN_AND_DUMP("Build", Build, graph_node, compute_graph, ge_root_model, session_id); // when set incre build, save om model and var manager GeModelPtr ge_model = nullptr; @@ -400,7 +397,7 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vectorSetRunFlag(false); @@ -637,7 +634,7 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vectorgraph_run_async_listener_); Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, graph_node->graph_run_async_listener_); - GE_TIMESTAMP_EVENT_END(LoadGraph, "GraphManager::LoadGraphAsync"); + GE_TIMESTAMP_END(LoadGraph, "GraphManager::LoadGraphAsync"); if (ret != SUCCESS) { GELOGE(ret, "[LoadGraphAsync] LoadGraphAsync Failed"); graph_node->SetRunFlag(false); @@ -2331,21 +2309,21 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra GELOGE(FAILED, "failed get dynamic shape partitioned flag on partitioned graph."); return FAILED; } - GE_TIMESTAMP_EVENT_END(GraphPartitionDynamicShape, "OptimizeSubgraph::GraphPartitionDynamicShape"); + GE_TIMESTAMP_END(GraphPartitionDynamicShape, "OptimizeSubgraph::GraphPartitionDynamicShape"); GE_TIMESTAMP_START(GraphPartition); ret = graph_partitioner_.Partition(compute_graph, GraphPartitioner::kPartitioning); if (ret != SUCCESS) { GELOGE(ret, "Graph partition Failed"); return ret; } - GE_TIMESTAMP_EVENT_END(GraphPartition, "OptimizeSubgraph::Partition1"); + GE_TIMESTAMP_END(GraphPartition, "OptimizeSubgraph::Partition1"); GE_TIMESTAMP_START(SetSubgraph); ret = SetSubgraph(session_id, compute_graph); if (ret != SUCCESS) { GELOGE(ret, "Graph set subgraph Failed"); return ret; } - GE_TIMESTAMP_EVENT_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph"); + GE_TIMESTAMP_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph"); ComputeGraphPtr merged_compute_graph = nullptr; std::vector merged_sub_graph_list; @@ -2364,7 +2342,7 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra sub_graph->SetSessionID(session_id); sub_graph->SetGraphID(graph_node->GetGraphId()); } - GE_TIMESTAMP_EVENT_END(MergeSubgraph, "OptimizeSubgraph::MergeSubGraph"); + GE_TIMESTAMP_END(MergeSubgraph, "OptimizeSubgraph::MergeSubGraph"); GE_DUMP(merged_compute_graph, "mergedComputeGraph"); compute_graph = merged_compute_graph; if (!AttrUtils::SetBool(*compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, dynamic_shape_partitioned)) { @@ -2390,7 +2368,8 @@ Status GraphManager::Build(const GraphNodePtr &graph_node, ComputeGraphPtr &comp } bool is_always_dump = false; - if (!PropertiesManager::Instance().GetDumpProperties(session_id).GetDumpPath().empty()) { + PropertiesManager &properties_manager = PropertiesManager::Instance(); + if (!properties_manager.GetDumpOutputPath().empty()) { is_always_dump = true; } diff --git a/src/ge/graph/manager/graph_manager.h b/src/ge/graph/manager/graph_manager.h index fd9542e8..8ab28316 100644 --- a/src/ge/graph/manager/graph_manager.h +++ b/src/ge/graph/manager/graph_manager.h @@ -327,6 +327,6 @@ class GraphManager { std::mutex run_mutex_; }; -} // namespace ge +}; // namespace ge #endif // GE_GRAPH_MANAGER_GRAPH_MANAGER_H_ diff --git a/src/ge/graph/manager/graph_mem_allocator.h b/src/ge/graph/manager/graph_mem_allocator.h index e4eeded3..7bf82897 100644 --- a/src/ge/graph/manager/graph_mem_allocator.h +++ b/src/ge/graph/manager/graph_mem_allocator.h @@ -190,6 +190,6 @@ class MemManager { std::map caching_allocator_map_; std::recursive_mutex allocator_mutex_; }; -} // namespace ge +}; // namespace ge #endif // GE_GRAPH_MANAGER_GRAPH_MEM_ALLOCATOR_H_ diff --git a/src/ge/graph/manager/graph_var_manager.cc b/src/ge/graph/manager/graph_var_manager.cc index 7ca0224b..2982eb89 100644 --- a/src/ge/graph/manager/graph_var_manager.cc +++ b/src/ge/graph/manager/graph_var_manager.cc @@ -91,7 +91,7 @@ ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTen std::string var_key = VarKey(var_name, tensor_desc); GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str()); if (var_addr_mgr_map_.count(var_key) == 0) { - uint64_t logic_address = VarManager::Instance(session_id_)->GetVarMemLogicBase() + + uint64_t logic_address = VarManager::Instance(0)->GetVarMemLogicBase() + reinterpret_cast(reinterpret_cast(address)); GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(), TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(), @@ -274,7 +274,7 @@ MemResource::MemResource() : total_size_(0), var_mem_size_(0) {} Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &mem_offset) { size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize; uint64_t real_size = size; - total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize(); + total_size_ = VarManager::Instance(0)->GetVarMemMaxSize(); if (total_size_ < var_mem_size_) { GELOGE(PARAM_INVALID, "total_size_: %lu is smaller than var_mem_size_: %lu", total_size_, var_mem_size_); return PARAM_INVALID; @@ -684,8 +684,7 @@ uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_ty if (mem_base == nullptr) { return nullptr; } - uint8_t *mem_addr = - logic_addr + reinterpret_cast(mem_base) - VarManager::Instance(session_id_)->GetVarMemLogicBase(); + uint8_t *mem_addr = logic_addr + reinterpret_cast(mem_base) - VarManager::Instance(0)->GetVarMemLogicBase(); return mem_addr; } diff --git a/src/ge/graph/manager/graph_var_manager.h b/src/ge/graph/manager/graph_var_manager.h index 2142d906..be839eee 100644 --- a/src/ge/graph/manager/graph_var_manager.h +++ b/src/ge/graph/manager/graph_var_manager.h @@ -309,5 +309,5 @@ class VarManagerPool { std::mutex var_manager_mutex_; map var_manager_map_; }; -} // namespace ge +}; // namespace ge #endif // GE_GRAPH_MANAGER_GRAPH_VAR_MANAGER_H_ diff --git a/src/ge/graph/manager/model_manager/event_manager.h b/src/ge/graph/manager/model_manager/event_manager.h index a20afead..bdf0633a 100644 --- a/src/ge/graph/manager/model_manager/event_manager.h +++ b/src/ge/graph/manager/model_manager/event_manager.h @@ -92,6 +92,6 @@ class EventManager { std::vector event_list_; bool inited_; uint32_t current_idx_; -}; // EventManager -} // namespace ge +}; // EventManager +}; // namespace ge #endif // GE_GRAPH_MANAGER_MODEL_MANAGER_EVENT_MANAGER_H_ diff --git a/src/ge/graph/manager/trans_var_data_utils.cc b/src/ge/graph/manager/trans_var_data_utils.cc index 3f346c91..e8444c53 100644 --- a/src/ge/graph/manager/trans_var_data_utils.cc +++ b/src/ge/graph/manager/trans_var_data_utils.cc @@ -397,11 +397,10 @@ Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeT uint8_t *src_addr = nullptr; GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr)); - uint8_t *mem_addr = - src_addr - - static_cast(reinterpret_cast(VarManager::Instance(session_id)->GetVarMemLogicBase())) + - static_cast( - reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); + uint8_t *mem_addr = src_addr - + static_cast(reinterpret_cast(VarManager::Instance(0)->GetVarMemLogicBase())) + + static_cast( + reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); GE_CHK_RT_RET(rtMallocHost(reinterpret_cast(host_addr), src_tensor_size)); GE_CHK_RT_RET(rtMemcpy(*host_addr, src_tensor_size, mem_addr, src_tensor_size, RT_MEMCPY_DEVICE_TO_HOST)); @@ -414,11 +413,10 @@ Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8 const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { uint8_t *dst_addr = nullptr; GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, dst_tensor_desc, &dst_addr)); - uint8_t *mem_addr = - dst_addr - - static_cast(reinterpret_cast(VarManager::Instance(session_id)->GetVarMemLogicBase())) + - static_cast( - reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); + uint8_t *mem_addr = dst_addr - + static_cast(reinterpret_cast(VarManager::Instance(0)->GetVarMemLogicBase())) + + static_cast( + reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); GE_CHK_RT_RET(rtMemcpy(mem_addr, addr_size, host_addr, addr_size, RT_MEMCPY_HOST_TO_DEVICE)); GELOGI("SyncTensorToDevice var_name %s, addr_size %u", var_name.c_str(), addr_size); diff --git a/src/ge/graph/manager/util/hcom_util.cc b/src/ge/graph/manager/util/hcom_util.cc index 5f31c982..4f6fe591 100644 --- a/src/ge/graph/manager/util/hcom_util.cc +++ b/src/ge/graph/manager/util/hcom_util.cc @@ -24,6 +24,7 @@ #include "graph/utils/type_utils.h" namespace ge { + Status HcomOmeUtil::GetHcclDataType(const ge::ConstOpDescPtr &op_desc, std::vector &kernel_hccl_infos) { GE_CHECK_NOTNULL(op_desc); @@ -100,12 +101,6 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(i)); GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetInputDescPtr(i), input_size), "get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); - // dynamic shape hccl op get size from output tensor desc - if (op_desc->HasAttr(ATTR_NAME_IS_UNKNOWN_SHAPE)) { - GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(i)); - GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetOutputDescPtr(i), input_size), - "get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); - } GE_IF_BOOL_EXEC( op_desc->GetType() == HCOMREDUCESCATTER, int32_t rank_size = 0; @@ -119,8 +114,6 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType total_size = total_size + block_size; continue;); int64_t shape_size = op_desc->GetInputDescPtr(i)->GetShape().GetShapeSize(); - GELOGD("hcom util node %s inputsize %ld, shapesize %ld, datasize %d.", op_desc->GetName().c_str(), input_size, - shape_size, size); GE_CHK_STATUS_RET(ge::CheckInt64Int32MulOverflow(shape_size, size), "Product of shape size and size beyond INT64_MAX"); GE_IF_BOOL_EXEC(is_allgather, block_size = shape_size * size;); diff --git a/src/ge/graph/manager/util/hcom_util.h b/src/ge/graph/manager/util/hcom_util.h index e31e3ef0..40aac3e5 100644 --- a/src/ge/graph/manager/util/hcom_util.h +++ b/src/ge/graph/manager/util/hcom_util.h @@ -144,6 +144,8 @@ class HcomOmeUtil { /// static Status GetHorovodInputs(const ge::ConstOpDescPtr &op_desc, std::vector &kernel_hccl_infos); + + private: /// /// @ingroup domi_ome /// @brief GetHcomCount @@ -152,8 +154,6 @@ class HcomOmeUtil { /// static Status GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType_t data_type, bool is_allgather, int &count); - - private: /// /// @ingroup domi_ome /// @brief GetHorovodCount diff --git a/src/ge/graph/manager/util/rt_context_util.cc b/src/ge/graph/manager/util/rt_context_util.cc index e6344539..05120f6a 100644 --- a/src/ge/graph/manager/util/rt_context_util.cc +++ b/src/ge/graph/manager/util/rt_context_util.cc @@ -19,30 +19,13 @@ #include "framework/common/debug/ge_log.h" namespace ge { -void RtContextUtil::AddRtContext(uint64_t session_id, rtContext_t context) { - std::lock_guard lock(ctx_mutex_); - rt_contexts_[session_id].emplace_back(context); -} - -void RtContextUtil::DestroyRtContexts(uint64_t session_id) { - std::lock_guard lock(ctx_mutex_); - auto &contexts = rt_contexts_[session_id]; - DestroyRtContexts(session_id, contexts); -} - -void RtContextUtil::DestroyAllRtContexts() { - std::lock_guard lock(ctx_mutex_); - for (auto &ctx_pair : rt_contexts_) { - DestroyRtContexts(ctx_pair.first, ctx_pair.second); - } - rt_contexts_.clear(); -} +void RtContextUtil::AddrtContext(rtContext_t context) { rtContexts_.emplace_back(context); } -void RtContextUtil::DestroyRtContexts(uint64_t session_id, std::vector &contexts) { - GELOGI("Runtime context handle number of session %lu is %zu.", session_id, contexts.size()); - for (auto &rtContext : contexts) { +void RtContextUtil::DestroyrtContexts() { + GELOGI("The size of runtime context handle is %zu.", rtContexts_.size()); + for (auto &rtContext : rtContexts_) { (void)rtCtxDestroy(rtContext); } - contexts.clear(); + rtContexts_.clear(); } } // namespace ge diff --git a/src/ge/graph/manager/util/rt_context_util.h b/src/ge/graph/manager/util/rt_context_util.h index 58cc0803..93db9882 100644 --- a/src/ge/graph/manager/util/rt_context_util.h +++ b/src/ge/graph/manager/util/rt_context_util.h @@ -18,8 +18,6 @@ #define GE_GRAPH_MANAGER_UTIL_RT_CONTEXT_UTIL_H_ #include -#include -#include #include "runtime/context.h" @@ -31,14 +29,13 @@ class RtContextUtil { return instance; } - void AddRtContext(uint64_t session_id, rtContext_t context); + void AddrtContext(rtContext_t context); const rtContext_t GetNormalModeContext() const { return before_prerun_ctx_; } void SetNormalModeContext(rtContext_t context) { before_prerun_ctx_ = context; } - void DestroyRtContexts(uint64_t session_id); - void DestroyAllRtContexts(); + void DestroyrtContexts(); RtContextUtil &operator=(const RtContextUtil &) = delete; RtContextUtil(const RtContextUtil &RtContextUtil) = delete; @@ -47,12 +44,8 @@ class RtContextUtil { RtContextUtil() = default; ~RtContextUtil() {} - void DestroyRtContexts(uint64_t session_id, std::vector &contexts); - - std::map> rt_contexts_; + std::vector rtContexts_; rtContext_t before_prerun_ctx_ = nullptr; - - std::mutex ctx_mutex_; }; } // namespace ge diff --git a/src/ge/graph/optimize/graph_optimize.cc b/src/ge/graph/optimize/graph_optimize.cc index 09acae33..b42c2e01 100644 --- a/src/ge/graph/optimize/graph_optimize.cc +++ b/src/ge/graph/optimize/graph_optimize.cc @@ -299,36 +299,4 @@ void GraphOptimize::TranFrameOp(ComputeGraphPtr &compute_graph) { } } } - -Status GraphOptimize::IdentifyReference(ComputeGraphPtr &compute_graph) { - for (auto &node : compute_graph->GetAllNodes()) { - GE_CHECK_NOTNULL(node); - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - auto input_name_index = op_desc->GetAllInputName(); - bool is_ref = false; - for (const auto &name_index : input_name_index) { - const int out_index = op_desc->GetOutputIndexByName(name_index.first); - if (out_index != -1) { - auto input_desc = op_desc->GetInputDesc(name_index.second); - input_desc.SetRefPortByIndex({name_index.second}); - op_desc->UpdateInputDesc(name_index.second, input_desc); - GELOGI("SetRefPort: set op[%s] input desc[%u-%s] ref.", op_desc->GetName().c_str(), name_index.second, - name_index.first.c_str()); - auto output_desc = op_desc->GetOutputDesc(static_cast(out_index)); - output_desc.SetRefPortByIndex({name_index.second}); - op_desc->UpdateOutputDesc(static_cast(out_index), output_desc); - GELOGI("SetRefPort: set op[%s] output desc[%u-%s] ref.", op_desc->GetName().c_str(), out_index, - name_index.first.c_str()); - is_ref = true; - } - } - if (is_ref) { - AttrUtils::SetBool(op_desc, ATTR_NAME_REFERENCE, is_ref); - GELOGI("param [node] %s is reference node, set attribute %s to be true.", node->GetName().c_str(), - ATTR_NAME_REFERENCE.c_str()); - } - } - return SUCCESS; -} } // namespace ge diff --git a/src/ge/graph/optimize/graph_optimize.h b/src/ge/graph/optimize/graph_optimize.h index 9741814d..72709932 100644 --- a/src/ge/graph/optimize/graph_optimize.h +++ b/src/ge/graph/optimize/graph_optimize.h @@ -67,9 +67,6 @@ class GraphOptimize { // handle summary node before preRun graph Status HandleSummaryOp(ComputeGraphPtr &compute_graph); - // Identify reference node before optimize subgraph - Status IdentifyReference(ComputeGraphPtr &compute_graph); - void TranFrameOp(ComputeGraphPtr &compute_graph); private: @@ -88,5 +85,5 @@ class GraphOptimize { std::map> summary_output_indexes_ = {}; std::string func_bin_path_; }; -} // namespace ge +}; // namespace ge #endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZE_H_ diff --git a/src/ge/graph/optimize/summary_optimize.cc b/src/ge/graph/optimize/summary_optimize.cc index a8325da3..8b38d602 100644 --- a/src/ge/graph/optimize/summary_optimize.cc +++ b/src/ge/graph/optimize/summary_optimize.cc @@ -80,8 +80,7 @@ Status GraphOptimize::HandleSummaryOp(ComputeGraphPtr &compute_graph) { del_nodes.emplace_back(node_ptr); } } - GE_IF_BOOL_EXEC(!summary_output_indexes.empty(), - summary_output_indexes_.insert({compute_graph->GetGraphID(), summary_output_indexes})); + summary_output_indexes_.insert({compute_graph->GetGraphID(), summary_output_indexes}); // add output nodes for summary std::vector> out_nodes_info; diff --git a/src/ge/graph/partition/dynamic_shape_partition.cc b/src/ge/graph/partition/dynamic_shape_partition.cc index 324129c4..6a396eef 100644 --- a/src/ge/graph/partition/dynamic_shape_partition.cc +++ b/src/ge/graph/partition/dynamic_shape_partition.cc @@ -62,16 +62,15 @@ Status DynamicShapePartitioner::Partition() { } GELOGD("Start dynamic shape partition graph %s.", root_graph_->GetName().c_str()); - REQUIRE_SUCCESS(MarkUnknownShapeNodes(), "Failed mark unknown shape nodes, root grah name:%s.", - root_graph_->GetName().c_str()); + REQUIRE_SUCCESS(MarkUnknownShapeNodes(), "Failed mark unknown shape nodes."); if (unknown_shape_nodes_.empty()) { GELOGD("Skip dynamic shape partition of graph %s as all nodes are known shape.", root_graph_->GetName().c_str()); REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, false), - "Failed set dynamic shape partitioned flag on root graph %s.", root_graph_->GetName().c_str()); + "Failed set dynamic shape partitioned flag on root graph."); return SUCCESS; } REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, true), - "Failed set dynamic shape partitioned flag on root graph %s.", root_graph_->GetName().c_str()); + "Failed set dynamic shape partitioned flag on root graph."); DumpGraph("_Before_DSP"); auto status = PartitionImpl(); @@ -108,21 +107,21 @@ void DynamicShapePartitioner::PruneUniqueClusters() { } Status DynamicShapePartitioner::BuildPartitionFrame() { - for (const auto &cluster : unique_clusters_) { + for (auto cluster : unique_clusters_) { REQUIRE_SUCCESS(cluster->BuildFrame(), "Failed build frame of cluster[%lu].", cluster->Id()); } return SUCCESS; } Status DynamicShapePartitioner::CombinePartitionFrame() { - for (const auto &cluster : unique_clusters_) { + for (auto cluster : unique_clusters_) { REQUIRE_SUCCESS(cluster->CombinePartitionFrame(), "Failed combine frame of cluster[%lu].", cluster->Id()); } return SUCCESS; } Status DynamicShapePartitioner::BuildPartitionSubgraph() { - for (const auto &cluster : unique_clusters_) { + for (auto cluster : unique_clusters_) { REQUIRE_SUCCESS(cluster->BuildPartitionSubgraph(), "Failed build subgraph of cluster[%lu].", cluster->Id()); } return SUCCESS; @@ -135,10 +134,10 @@ std::string DynamicShapePartitioner::DebugString() const { size_t netoutput = 0; std::stringstream ss; ss << "All unknown shape nodes:" << std::endl; - for (const auto &node : unknown_shape_nodes_) { + for (auto node : unknown_shape_nodes_) { ss << " [" << node->GetName() << "](" << node->GetType() << ")" << std::endl; } - for (const auto &cluster : unique_clusters_) { + for (auto cluster : unique_clusters_) { if (cluster->IsUnknownShape()) { unknown++; } else if (cluster->IsKnownShape()) { @@ -151,7 +150,7 @@ std::string DynamicShapePartitioner::DebugString() const { } ss << "All clusters:" << unique_clusters_.size() << ", data:" << data << ", known:" << known << ", unknown:" << unknown << ", netoutput:" << netoutput << std::endl; - for (const auto &cluster : unique_clusters_) { + for (auto cluster : unique_clusters_) { ss << " " << cluster->DebugString() << std::endl; } return ss.str(); @@ -159,13 +158,13 @@ std::string DynamicShapePartitioner::DebugString() const { void DynamicShapePartitioner::DumpGraph(const std::string &suffix) { GraphUtils::DumpGEGraphToOnnx(*root_graph_, root_graph_->GetName() + suffix); - for (const auto &sub_graph : root_graph_->GetAllSubgraphs()) { + for (auto sub_graph : root_graph_->GetAllSubgraphs()) { GraphUtils::DumpGEGraphToOnnx(*sub_graph, sub_graph->GetName() + suffix); } } void DynamicShapePartitioner::ClearResource() { - for (const auto &cluster : unique_clusters_) { + for (auto cluster : unique_clusters_) { cluster->Clear(); } node_2_cluster_.clear(); @@ -176,7 +175,8 @@ void DynamicShapePartitioner::ClearResource() { } Status DynamicShapePartitioner::MarkUnknownShapeNodes() { - for (auto &node : root_graph_->GetDirectNode()) { + auto graph = root_graph_; + for (auto &node : graph->GetDirectNode()) { REQUIRE_SUCCESS(CollectSpreadUnknownShapeNodes(node), "Failed collect spread unknown shape nodes %s.", node->GetName().c_str()); } @@ -186,7 +186,7 @@ Status DynamicShapePartitioner::MarkUnknownShapeNodes() { Status DynamicShapePartitioner::InitClusters() { auto graph = root_graph_; size_t rank = 0; - for (const auto &node : graph->GetDirectNode()) { + for (const auto node : graph->GetDirectNode()) { Cluster::Type type = Cluster::DATA; if (node->GetType() == DATA) { type = Cluster::DATA; @@ -208,7 +208,7 @@ Status DynamicShapePartitioner::InitClusters() { cluster->AddInput(node_2_cluster_[parent]); } } - for (const auto &node : graph->GetDirectNode()) { + for (const auto node : graph->GetDirectNode()) { GELOGD("Make cluster for node %s : %s.", node->GetName().c_str(), node_2_cluster_[node]->DebugString().c_str()); } return SUCCESS; @@ -220,8 +220,8 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { std::queue ready_clusters; std::unordered_map cluster_pending_count; std::unordered_set seen_clusters; - for (auto &iter : node_2_cluster_) { - auto cluster = iter.second; + for (auto iter = node_2_cluster_.begin(); iter != node_2_cluster_.end(); iter++) { + auto cluster = iter->second; if (seen_clusters.count(cluster) != 0) { continue; } @@ -242,7 +242,7 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { if (cluster->IsKnownShape()) { ordered_cluster_.push_back(cluster); } - for (const auto &out_cluster : cluster->Outputs()) { + for (auto out_cluster : cluster->Outputs()) { if (cluster_pending_count[out_cluster] > 0 && --cluster_pending_count[out_cluster] == 0) { ready_clusters.push(out_cluster); } @@ -273,16 +273,16 @@ static std::string ToString(const std::vector &clusters) { Status DynamicShapePartitioner::MergeClusters() { // Merge unknown shape clusters - for (const auto &cluster : ordered_cluster_) { - for (const auto &in_cluster : cluster->Inputs()) { + for (auto cluster : ordered_cluster_) { + for (auto in_cluster : cluster->Inputs()) { if (!in_cluster->IsUnknownShape()) { continue; } auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), ToString(merged_clusters).c_str()); - for (const auto &merged_cluster : merged_clusters) { - for (const auto &node : merged_cluster->Nodes()) { + for (auto merged_cluster : merged_clusters) { + for (auto node : merged_cluster->Nodes()) { node_2_cluster_[node] = cluster; } } @@ -291,7 +291,7 @@ Status DynamicShapePartitioner::MergeClusters() { REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters."); // Merge known shape clusters - for (const auto &cluster : ordered_cluster_) { + for (auto cluster : ordered_cluster_) { if (cluster->IsRefVariable() && cluster->Inputs().size() == 1) { auto in_cluster = *(cluster->Inputs().begin()); in_cluster->Merge(cluster); @@ -299,13 +299,13 @@ Status DynamicShapePartitioner::MergeClusters() { continue; } - for (const auto &in_cluster : cluster->Inputs()) { + for (auto in_cluster : cluster->Inputs()) { if (!in_cluster->IsKnownShape()) { continue; } if (cluster->TryMerge(in_cluster)) { GELOGD("Success merge known shape cluster from %lu to %lu.", in_cluster->Id(), cluster->Id()); - for (const auto &node : in_cluster->Nodes()) { + for (auto node : in_cluster->Nodes()) { node_2_cluster_[node] = cluster; } } @@ -333,7 +333,7 @@ Status DynamicShapePartitioner::CollectSpreadUnknownShapeNodes(NodePtr node) { if (IsUnknownShapeTensor(out_tensor)) { GELOGD("Collect node %s as unknown as output %lu is unknown.", node->GetName().c_str(), anchor_index); is_unknown = true; - auto anchor = node->GetOutDataAnchor(static_cast(anchor_index)); + auto anchor = node->GetOutDataAnchor(anchor_index); for (const auto peer_anchor : anchor->GetPeerInDataAnchors()) { if (peer_anchor != nullptr) { GELOGD("Collect node %s as has unknown input from %s:%lu.", peer_anchor->GetOwnerNode()->GetName().c_str(), @@ -349,7 +349,7 @@ Status DynamicShapePartitioner::CollectSpreadUnknownShapeNodes(NodePtr node) { if (IsUnknownShapeTensor(in_tensor)) { GELOGD("Collect node %s as unknown as input %lu is unknown.", node->GetName().c_str(), anchor_index); is_unknown = true; - auto anchor = node->GetInDataAnchor(static_cast(anchor_index)); + auto anchor = node->GetInDataAnchor(anchor_index); const auto peer_anchor = anchor->GetPeerOutAnchor(); if (peer_anchor != nullptr) { GELOGD("Collect node %s as has unknown output to %s:%lu.", peer_anchor->GetOwnerNode()->GetName().c_str(), @@ -453,15 +453,15 @@ std::string Cluster::DebugString() const { } ss << "[" << id_ << "](size:" << nodes_.size() << ")"; ss << "(" << min_ << "," << max_ << ")("; - for (const auto &cluster : in_clusters_) { + for (auto cluster : in_clusters_) { ss << cluster->id_ << ","; } ss << ")->("; - for (const auto &cluster : out_clusters_) { + for (auto cluster : out_clusters_) { ss << cluster->id_ << ","; } ss << ")|"; - for (const auto &node : nodes_) { + for (auto node : nodes_) { ss << (node->GetName() + "|"); } return ss.str(); @@ -507,12 +507,12 @@ void Cluster::Merge(ClusterPtr other) { in_clusters_.erase(other); out_clusters_.erase(other); auto in_clusters = other->in_clusters_; - for (const auto &cluster : in_clusters) { + for (auto cluster : in_clusters) { cluster->RemoveOutput(other); cluster->AddOutput(shared_from_this()); } auto out_clusters = other->out_clusters_; - for (const auto &cluster : out_clusters) { + for (auto cluster : out_clusters) { cluster->RemoveInput(other); cluster->AddInput(shared_from_this()); } @@ -529,7 +529,7 @@ bool Cluster::TryMerge(ClusterPtr other) { while (!forward_reached.empty()) { auto current_cluster = forward_reached.front(); forward_reached.pop(); - for (const auto &cluster : current_cluster->out_clusters_) { + for (auto cluster : current_cluster->out_clusters_) { if (cluster->max_ == max_ && current_cluster != other) { return false; } else if (cluster->min_ < max_) { @@ -557,7 +557,7 @@ std::vector Cluster::MergeAllPathFrom(ClusterPtr other) { while (!forward_reached_queue.empty()) { auto current_cluster = forward_reached_queue.front(); forward_reached_queue.pop(); - for (const auto &cluster : current_cluster->out_clusters_) { + for (auto cluster : current_cluster->out_clusters_) { if (cluster->min_ < max_ && cluster->max_ != max_ && forward_reached_clusters.count(cluster) == 0) { forward_reached_clusters.insert(cluster); forward_reached_queue.push(cluster); @@ -567,7 +567,7 @@ std::vector Cluster::MergeAllPathFrom(ClusterPtr other) { while (!backward_reached_queue.empty()) { auto current_cluster = backward_reached_queue.front(); backward_reached_queue.pop(); - for (const auto &cluster : current_cluster->in_clusters_) { + for (auto cluster : current_cluster->in_clusters_) { if (cluster->max_ > other->min_ && cluster->max_ != other->max_ && backward_reached_clusters.count(cluster) == 0) { backward_reached_clusters.insert(cluster); @@ -578,7 +578,7 @@ std::vector Cluster::MergeAllPathFrom(ClusterPtr other) { } } } - for (const auto &cluster : path_clusters) { + for (auto cluster : path_clusters) { Merge(cluster); } return path_clusters; @@ -598,11 +598,11 @@ void Cluster::AddFrameOutput(OutDataAnchorPtr anchor) { }; InDataAnchorPtr Cluster::GetFrameInDataAnchor(InDataAnchorPtr anchor) { - return partition_node_->GetInDataAnchor(static_cast(inputs_index_[anchor])); + return partition_node_->GetInDataAnchor(inputs_index_[anchor]); }; OutDataAnchorPtr Cluster::GetFrameOutDataAnchor(OutDataAnchorPtr anchor) { - return partition_node_->GetOutDataAnchor(static_cast(outputs_index_[anchor])); + return partition_node_->GetOutDataAnchor(outputs_index_[anchor]); }; InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_->GetInControlAnchor(); }; @@ -616,25 +616,22 @@ Status Cluster::BuildFrame() { auto node = nodes_.front(); auto in_control_anchor = node->GetInControlAnchor(); if (in_control_anchor != nullptr) { - for (const auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { + for (auto peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { auto src_cluster = partitioner_->node_2_cluster_[peer_out_control_anchor->GetOwnerNode()]; if (src_cluster->id_ != id_) { - REQUIRE_GRAPH_SUCCESS( - GraphUtils::RemoveEdge(peer_out_control_anchor, in_control_anchor), - "Failed remove edge from node %s index %d to node %s index %d.", - peer_out_control_anchor->GetOwnerNode()->GetName().c_str(), AnchorUtils::GetIdx(peer_out_control_anchor), - in_control_anchor->GetOwnerNode()->GetName().c_str(), AnchorUtils::GetIdx(in_control_anchor)); + auto src_cluster = partitioner_->node_2_cluster_[peer_out_control_anchor->GetOwnerNode()]; + GraphUtils::RemoveEdge(peer_out_control_anchor, in_control_anchor); control_inputs_.insert(src_cluster); src_cluster->control_outputs_.insert(peer_out_control_anchor); } } } if (IsData()) { - for (const auto &anchor : node->GetAllOutDataAnchors()) { + for (auto anchor : node->GetAllOutDataAnchors()) { AddFrameOutput(anchor); } } else { - for (const auto &anchor : node->GetAllInDataAnchors()) { + for (auto anchor : node->GetAllInDataAnchors()) { AddFrameInput(anchor); } } @@ -663,7 +660,7 @@ Status Cluster::BuildPartitionFrame() { "Failed set shape flag."); REQUIRE_GRAPH_SUCCESS(GraphUtils::RemoveJustNode(graph, node), "Failed remove root graph node."); REQUIRE_GRAPH_SUCCESS(node->SetOwnerComputeGraph(subgraph_), "Failed set owner graph."); - for (const auto &anchor : node->GetAllInDataAnchors()) { + for (auto anchor : node->GetAllInDataAnchors()) { auto peer_out_anchor = anchor->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { continue; // Skip overhang input. @@ -677,7 +674,7 @@ Status Cluster::BuildPartitionFrame() { } auto in_control_anchor = node->GetInControlAnchor(); if (in_control_anchor != nullptr) { - for (const auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { + for (auto peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { if (peer_out_control_anchor == nullptr) { continue; } @@ -692,9 +689,9 @@ Status Cluster::BuildPartitionFrame() { } } } - for (const auto &anchor : node->GetAllOutDataAnchors()) { + for (auto anchor : node->GetAllOutDataAnchors()) { auto peer_in_anchors = anchor->GetPeerInDataAnchors(); - for (const auto &peer_in_anchor : peer_in_anchors) { + for (auto peer_in_anchor : peer_in_anchors) { auto src_cluster = partitioner_->node_2_cluster_[peer_in_anchor->GetOwnerNode()]; if (src_cluster->id_ != id_) { AddFrameOutput(anchor); @@ -720,7 +717,7 @@ Status Cluster::BuildPartitionFrame() { } Status Cluster::CombinePartitionFrame() { - for (const auto &anchor : inputs_) { + for (auto anchor : inputs_) { auto peer_out_anchor = anchor->GetPeerOutAnchor(); auto src_cluster = partitioner_->node_2_cluster_[peer_out_anchor->GetOwnerNode()]; auto src_anchor = src_cluster->GetFrameOutDataAnchor(peer_out_anchor); @@ -732,7 +729,7 @@ Status Cluster::CombinePartitionFrame() { src_anchor->GetOwnerNode()->GetName().c_str(), src_anchor->GetIdx(), dst_anchor->GetOwnerNode()->GetName().c_str(), dst_anchor->GetIdx()); } - for (const auto &src_cluster : control_inputs_) { + for (auto src_cluster : control_inputs_) { auto src_anchor = src_cluster->GetFrameOutControlAnchor(); auto dst_anchor = GetFrameInControlAnchor(); REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(src_anchor, dst_anchor), "Failed add edge from %s:%d to %s:%d.", @@ -777,8 +774,8 @@ Status Cluster::BuildPartitionSubgraph() { REQUIRE_NOT_NULL(net_output_node, "Failed add netoutput node to subgraph."); REQUIRE_GRAPH_SUCCESS(net_output_node->SetOwnerComputeGraph(subgraph_), "Failed set owner graph of netoutput node."); parent_node_index = 0; - for (const auto &anchor : outputs_) { - auto output_desc = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(static_cast(anchor->GetIdx())); + for (auto anchor : outputs_) { + auto output_desc = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(anchor->GetIdx()); REQUIRE(AttrUtils::SetInt(output_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index), "Failed set parent_node_index on subgraph netoutput's input."); REQUIRE_GRAPH_SUCCESS(net_output_op->UpdateInputDesc(parent_node_index, output_desc), @@ -789,7 +786,7 @@ Status Cluster::BuildPartitionSubgraph() { anchor->GetIdx()); parent_node_index++; } - for (const auto &anchor : control_outputs_) { + for (auto anchor : control_outputs_) { REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(anchor, net_output_node->GetInControlAnchor()), "Faile add control edge from %s:%d to netoutput node.", anchor->GetOwnerNode()->GetName().c_str(), anchor->GetIdx()); diff --git a/src/ge/graph/partition/engine_place.cc b/src/ge/graph/partition/engine_place.cc index 2d1a7f13..74da0326 100644 --- a/src/ge/graph/partition/engine_place.cc +++ b/src/ge/graph/partition/engine_place.cc @@ -38,7 +38,6 @@ Status EnginePlacer::Run() { return FAILED; } // Assign engine for each node in the graph - instance_ptr->DNNEngineManagerObj().InitPerformanceStaistic(); for (const auto &node_ptr : compute_graph_->GetDirectNode()) { GE_CHECK_NOTNULL(node_ptr); GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); @@ -61,15 +60,12 @@ Status EnginePlacer::Run() { return FAILED; } } - for (auto &it : instance_ptr->DNNEngineManagerObj().GetCheckSupportCost()) { - GEEVENT("The time cost of %s::CheckSupported is [%lu] micro second.", it.first.c_str(), it.second); - } GELOGI("Engine placer ends."); return SUCCESS; } Status EnginePlacer::AssignEngineAndLog(ge::ConstNodePtr node_ptr, const std::string &engine_name) { - if ((node_ptr == nullptr) || (node_ptr->GetOpDesc() == nullptr)) { + if (node_ptr == nullptr || node_ptr->GetOpDesc() == nullptr) { GELOGE(FAILED, "node_ptr is null."); return FAILED; } diff --git a/src/ge/graph/partition/graph_partition.cc b/src/ge/graph/partition/graph_partition.cc index 907d672d..50cd7e81 100644 --- a/src/ge/graph/partition/graph_partition.cc +++ b/src/ge/graph/partition/graph_partition.cc @@ -25,7 +25,6 @@ #include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/manager/graph_manager_utils.h" -#include "graph/common/ge_call_wrapper.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" @@ -232,33 +231,33 @@ Status ge::GraphPartitioner::MergeSubGraph(ge::ComputeGraphPtr &output_merged_co ComputeGraphPtr new_sub_graph = MakeShared(original_compute_graph->GetName()); GE_CHECK_NOTNULL(new_sub_graph); output_merged_compute_graph = new_sub_graph; - GE_TIMESTAMP_START(MergeSubGraphRemoveNode); + GE_TIMESTAMP_START(MergeGraphRemoveNode); if (RemoveNodeAndEdgeBetweenEndPld(output_merged_compute_graph, sub_graph_list) != ge::SUCCESS) { GELOGE(GE_GRAPH_PARAM_NULLPTR, "[GraphPartitioner]: merging sub-graphs failed"); return FAILED; } - GE_TIMESTAMP_END(MergeSubGraphRemoveNode, "GraphPartitioner::MergeGraphRemoveNodeAndEdge"); - GE_TIMESTAMP_START(MergeSubGraphTopologicalSorting); + GE_TIMESTAMP_END(MergeGraphRemoveNode, "GraphPartitioner::MergeGraphRemoveNodeAndEdge"); + GE_TIMESTAMP_START(MergeGraphTopologicalSorting); Status ret = output_merged_compute_graph->TopologicalSorting(); if (ret != SUCCESS) { GELOGE(GE_GRAPH_TOPO_SORT_FAILED, "[GraphPartitioner]: output_merged_compute_graph->TopologicalSorting failed"); return FAILED; } - GE_TIMESTAMP_END(MergeSubGraphTopologicalSorting, "GraphPartitioner::MergeGraphTopologicalSorting"); + GE_TIMESTAMP_END(MergeGraphTopologicalSorting, "GraphPartitioner::MergeGraphTopologicalSorting"); // flush all nodes' engine of merged graph - GE_TIMESTAMP_START(MergeSubGraphEnginePlacerRun); + GE_TIMESTAMP_START(MergeGraphEnginePlacerRun); graph_info_.engine_placer_.SetComputeGraph(output_merged_compute_graph); if (graph_info_.engine_placer_.Run() != SUCCESS) { GELOGE(GE_GRAPH_INIT_FAILED, "[GraphPartitioner]: engine_placer run failed"); return FAILED; } - GE_TIMESTAMP_END(MergeSubGraphEnginePlacerRun, "GraphPartitioner::MergeGraphEnginePlacerRun"); + GE_TIMESTAMP_END(MergeGraphEnginePlacerRun, "GraphPartitioner::MergeGraphEnginePlacerRun"); GELOGI("Graph merge ends."); return SUCCESS; } Status ge::GraphPartitioner::UpdatePldOpDesc(const NodePtr &dst_node, int input_index, OpDescPtr &pld_op_desc) { - if ((dst_node == nullptr) || (pld_op_desc == nullptr) || (dst_node->GetOpDesc() == nullptr)) { + if (dst_node == nullptr || pld_op_desc == nullptr || dst_node->GetOpDesc() == nullptr) { GELOGE(FAILED, "parameter ptr is null."); return FAILED; } @@ -276,7 +275,7 @@ Status ge::GraphPartitioner::UpdatePldOpDesc(const NodePtr &dst_node, int input_ } Status ge::GraphPartitioner::UpdateEndOpDesc(const NodePtr &src_node, int output_index, OpDescPtr &end_op_desc) { - if ((src_node == nullptr) || (end_op_desc == nullptr) || (src_node->GetOpDesc() == nullptr)) { + if (src_node == nullptr || end_op_desc == nullptr || src_node->GetOpDesc() == nullptr) { GELOGE(FAILED, "parameter ptr is null."); return FAILED; } @@ -297,9 +296,9 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr const AnchorPtr &peer_in_anchor, const ge::ComputeGraphPtr &pld_graph, const ge::ComputeGraphPtr &end_graph) { + GE_CHECK_NOTNULL(out_anchor); GE_CHECK_NOTNULL(peer_in_anchor); GE_CHECK_NOTNULL(pld_graph); - GE_CHECK_NOTNULL(out_anchor); GE_CHECK_NOTNULL(end_graph); const auto &src_node = out_anchor->GetOwnerNode(); const auto &dst_node = peer_in_anchor->GetOwnerNode(); @@ -314,7 +313,6 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr GELOGW("SetInt peerIndex failed");) GE_IF_BOOL_EXEC(!AttrUtils::SetStr(end_op_desc, "parentOpType", dst_node->GetType()), GELOGW("SetStr parentOpType failed");) - GE_IF_BOOL_EXEC(!end_op_desc->SetExtAttr("parentNode", dst_node), GELOGW("SetEndExtAttr parentNode failed");) // replace input_desc of end with owner node's desc int output_index = ge::AnchorUtils::GetIdx(out_anchor); bool is_need_update_desc = (output_index >= 0) && (graph_info_.mode_ == kPartitioning); @@ -363,7 +361,6 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr GELOGW("SetStr parentId failed");) GE_IF_BOOL_EXEC(!AttrUtils::SetInt(pld_op_desc, "anchorIndex", AnchorUtils::GetIdx(out_anchor)), GELOGW("SetInt anchorIndex failed");) - GE_IF_BOOL_EXEC(!pld_op_desc->SetExtAttr("parentNode", src_node), GELOGW("SetPldExtAttr parentNode failed");) // do not care over flow graph_info_.num_of_pld_end_++; // replace output_desc of pld with input node's output desc @@ -398,14 +395,14 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr return FAILED; } graph_info_.index_2_end_[graph_info_.num_of_pld_end_] = new_end_node; - graph_info_.pld_2_end_[new_pld_node] = new_end_node; graph_info_.end_2_pld_[new_end_node] = new_pld_node; + graph_info_.pld_2_end_[new_pld_node] = new_end_node; return SUCCESS; } Status ge::GraphPartitioner::LinkInput2EndRemoveOrginalLink(ge::NodePtr input_node, ge::ComputeGraphPtr src_graph, ge::ComputeGraphPtr dst_graph) { - if ((input_node == nullptr) || (src_graph == nullptr) || (dst_graph == nullptr)) { + if (input_node == nullptr || src_graph == nullptr || dst_graph == nullptr) { GELOGE(FAILED, "parameter ptr is null."); return FAILED; } @@ -445,7 +442,7 @@ Status ge::GraphPartitioner::LinkInput2EndRemoveOrginalLink(ge::NodePtr input_no Status ge::GraphPartitioner::PutInputNodesInSubGraph(const ge::ComputeGraphPtr &src_graph, const ge::ComputeGraphPtr &dst_graph) { - if ((src_graph == nullptr) || (dst_graph == nullptr)) { + if (src_graph == nullptr || dst_graph == nullptr) { GELOGE(FAILED, "parameter ptr is null."); return FAILED; } @@ -852,34 +849,34 @@ Status ge::GraphPartitioner::PartitionSubGraph(ge::ComputeGraphPtr compute_graph GELOGE(GE_GRAPH_TOPO_SORT_FAILED, "[GraphPartitioner]: subGraphPtr->TopologicalSorting failed"); return FAILED; } - GE_TIMESTAMP_START(PartitionSubGraphInitialize); + GE_TIMESTAMP_START(GraphPartitionInitialize); if (Initialize(compute_graph) != SUCCESS) { GELOGE(GE_GRAPH_INIT_FAILED, "[GraphPartitioner]: initialize failed"); return FAILED; } - GE_TIMESTAMP_END(PartitionSubGraphInitialize, "GraphPartitioner::PartitionInitialize"); - GE_TIMESTAMP_START(PartitionSubGraphMarkClusters); + GE_TIMESTAMP_END(GraphPartitionInitialize, "GraphPartitioner::PartitionInitialize"); + GE_TIMESTAMP_START(GraphPartitionMarkClusters); MarkClusters(); - GE_TIMESTAMP_END(PartitionSubGraphMarkClusters, "GraphPartitioner::PartitionMarkClusters"); - GE_TIMESTAMP_START(PartitionSubGraphSplitSubGraphs); + GE_TIMESTAMP_END(GraphPartitionMarkClusters, "GraphPartitioner::PartitionMarkClusters"); + GE_TIMESTAMP_START(GraphPartitionSplitSubGraphs); if (SplitSubGraphs(compute_graph) != SUCCESS) { GELOGE(FAILED, "[GraphPartitioner]: SplitSubGraphs failed"); return FAILED; } - GE_TIMESTAMP_END(PartitionSubGraphSplitSubGraphs, "GraphPartitioner::PartitionSplitSubGraphs"); - GE_TIMESTAMP_START(PartitionSubGraphSortSubGraphs); + GE_TIMESTAMP_END(GraphPartitionSplitSubGraphs, "GraphPartitioner::PartitionSplitSubGraphs"); + GE_TIMESTAMP_START(GraphPartitionSortSubGraphs); if (SortSubGraphs(compute_graph) != ge::SUCCESS) { GELOGE(GE_GRAPH_TOPO_SORT_FAILED, "Graph Partition SortSubGraphs failed."); return ge::FAILED; } - GE_TIMESTAMP_END(PartitionSubGraphSortSubGraphs, "GraphPartitioner::PartitionSortSubGraphs"); - GE_TIMESTAMP_START(PartitionSubGraphAddPartitionsToGraphNode); + GE_TIMESTAMP_END(GraphPartitionSortSubGraphs, "GraphPartitioner::PartitionSortSubGraphs"); + GE_TIMESTAMP_START(GraphPartitionAddPartitionsToGraphNode); vector output_subgraphs; if (AddPartitionsToGraphNode(output_subgraphs, compute_graph) != ge::SUCCESS) { GELOGE(GE_GRAPH_EMPTY_PARTITION, "Graph Partition AddPartitionsToGraphNode failed."); return ge::FAILED; } - GE_TIMESTAMP_END(PartitionSubGraphAddPartitionsToGraphNode, "GraphPartitioner::PartitionAddPartitionsToGraphNode"); + GE_TIMESTAMP_END(GraphPartitionAddPartitionsToGraphNode, "GraphPartitioner::PartitionAddPartitionsToGraphNode"); GELOGI("Graph Partition ends. Adding partitions to SubGraphInfo, got %zu sub graphs", output_subgraphs.size()); graph_info_.mode_ = kMerging; // do not care over flow @@ -926,7 +923,7 @@ Status ge::GraphPartitioner::AddPlaceHolderEnd(const AnchorPtr &out_anchor, cons Status ge::GraphPartitioner::SortSubGraphs(const ge::ComputeGraphPtr &compute_graph) { uint32_t rank = kRankOne; // rank 0 for data graph ComputeGraphPtr new_input_nodes_sub_graph = MakeShared("inputNodeGraph"); - if ((new_input_nodes_sub_graph == nullptr) || (compute_graph == nullptr)) { + if (new_input_nodes_sub_graph == nullptr || compute_graph == nullptr) { GELOGE(FAILED, "[GraphPartitioner]: new_input_nodes_sub_graph or compute_graph is null."); return FAILED; } @@ -968,7 +965,7 @@ Status ge::GraphPartitioner::SortSubGraphs(const ge::ComputeGraphPtr &compute_gr } AnchorPtr ge::GraphPartitioner::GetEndInAnchor(const AnchorPtr &src_anchor, const NodePtr &end_node) { - if ((src_anchor == nullptr) || (end_node == nullptr)) { + if (src_anchor == nullptr || end_node == nullptr) { GELOGE(FAILED, "parameter ptr is null."); return nullptr; } @@ -982,7 +979,7 @@ AnchorPtr ge::GraphPartitioner::GetEndInAnchor(const AnchorPtr &src_anchor, cons } AnchorPtr ge::GraphPartitioner::GetPldOutAnchor(const NodePtr &pld_node, const AnchorPtr &dst_anchor) { - if ((pld_node == nullptr) || (dst_anchor == nullptr)) { + if (pld_node == nullptr || dst_anchor == nullptr) { GELOGE(FAILED, "parameter ptr is null."); return nullptr; } @@ -995,16 +992,16 @@ AnchorPtr ge::GraphPartitioner::GetPldOutAnchor(const NodePtr &pld_node, const A return pld_out_anchor; } -void ge::GraphPartitioner::AddEndPldInformationToSubGraphInfo(ge::SubGraphInfoPtr &subgraph_info) { - if (subgraph_info == nullptr) { +void ge::GraphPartitioner::AddEndPldInformationToSubGraphInfo(ge::SubGraphInfoPtr &sub_graph_info) { + if (sub_graph_info == nullptr) { GELOGE(FAILED, "parameter ptr is null."); return; } - auto subgraph = subgraph_info->GetSubGraph(); - GE_CHECK_NOTNULL_JUST_RETURN(subgraph); + auto sub_graph = sub_graph_info->GetSubGraph(); + GE_CHECK_NOTNULL_JUST_RETURN(sub_graph); NodetoNodeMap end_map; NodetoNodeMap pld_map; - for (const auto &node : subgraph->GetDirectNode()) { + for (const auto &node : sub_graph->GetDirectNode()) { if (node->GetType() == kEndType) { end_map[node] = graph_info_.end_2_pld_.at(node); } @@ -1012,8 +1009,8 @@ void ge::GraphPartitioner::AddEndPldInformationToSubGraphInfo(ge::SubGraphInfoPt pld_map[node] = graph_info_.pld_2_end_.at(node); } } - subgraph_info->SetEnd2PldMap(end_map); - subgraph_info->SetPld2EndMap(pld_map); + sub_graph_info->SetEnd2PldMap(end_map); + sub_graph_info->SetPld2EndMap(pld_map); } const Graph2SubGraphInfoList &ge::GraphPartitioner::GetSubGraphMap() { return graph_2_subgraph_list_; } diff --git a/src/ge/graph/passes/atomic_addr_clean_pass.cc b/src/ge/graph/passes/atomic_addr_clean_pass.cc index ae69fd93..7d9b8dec 100644 --- a/src/ge/graph/passes/atomic_addr_clean_pass.cc +++ b/src/ge/graph/passes/atomic_addr_clean_pass.cc @@ -22,12 +22,16 @@ #include #include +#include "framework/common/debug/ge_log.h" #include "common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/node_utils.h" #include "init/gelib.h" +namespace { +bool is_loop_graph = false; +} namespace ge { namespace { bool GraphShouldBeSkip(const ge::ComputeGraphPtr &graph) { @@ -40,6 +44,7 @@ bool GraphShouldBeSkip(const ge::ComputeGraphPtr &graph) { } // namespace Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { + GE_TIMESTAMP_START(AtomicAddrCleanPass); if (graph == nullptr) { GELOGE(PARAM_INVALID, "param [graph] must not be null."); return PARAM_INVALID; @@ -66,10 +71,10 @@ Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { } atomic_node_vec.push_back(node); } - if (!is_loop_graph_ && node->GetType() == LOOPCOND) { + if (!is_loop_graph && node->GetType() == LOOPCOND) { // there is loop in this graph GELOGD("There is no loop node. It will insert clean node follow atomic node."); - is_loop_graph_ = true; + is_loop_graph = true; } } if (atomic_node_vec.empty()) { @@ -78,7 +83,7 @@ Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { } // 2.Insert clean node and link to atomic node Status ret; - if (is_loop_graph_) { + if (is_loop_graph) { ret = HandleLoopGraph(graph, atomic_node_vec); if (ret != SUCCESS) { return ret; @@ -90,6 +95,7 @@ Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { } } GELOGD("AtomicAddrCleanPass end."); + GE_TIMESTAMP_END(AtomicAddrCleanPass, "GraphManager::AtomicAddrCleanPass"); return SUCCESS; } @@ -166,14 +172,12 @@ NodePtr AtomicAddrCleanPass::InsertAtomicAddrCleanNode(ComputeGraphPtr &graph) { if (!session_graph_id.empty()) { (void)AttrUtils::SetStr(op_desc, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); } - string node_name = op_desc->GetName(); // Only flush subgraph name - if (graph->GetParentGraph() != nullptr) { - node_name = graph->GetName() + "_" + node_name; - } + string node_name = (graph->GetParentGraph() != nullptr) + ? (graph->GetName() + "_" + op_desc->GetName() + session_graph_id) + : (op_desc->GetName() + session_graph_id); - string name = node_name + session_graph_id; - op_desc->SetName(name); + op_desc->SetName(node_name); GELOGI("Create cleanAddr op:%s.", op_desc->GetName().c_str()); // To avoid same name between graphs, set session graph id to this node NodePtr clean_addr_node = graph->AddNodeFront(op_desc); @@ -199,7 +203,7 @@ Status AtomicAddrCleanPass::LinkToAtomicNode(const NodePtr &atomic_node, NodePtr } GELOGD("Graph add cleanAddrNode op out ctrl edge, dst node: %s.", atomic_node->GetName().c_str()); std::string stream_label; - if (is_loop_graph_ && AttrUtils::GetStr(atomic_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { + if (is_loop_graph && AttrUtils::GetStr(atomic_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { if (!AttrUtils::SetStr(atomic_clean_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { GELOGW("LinkToAtomicNode: SetStr failed"); return INTERNAL_ERROR; @@ -258,7 +262,7 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { return true; } /// -/// @brief Clear Status, used for subgraph pass +/// @brief Clear Status, uesd for subgraph pass /// @return SUCCESS /// Status AtomicAddrCleanPass::ClearStatus() { diff --git a/src/ge/graph/passes/atomic_addr_clean_pass.h b/src/ge/graph/passes/atomic_addr_clean_pass.h index 3640beef..d2d8f2ce 100644 --- a/src/ge/graph/passes/atomic_addr_clean_pass.h +++ b/src/ge/graph/passes/atomic_addr_clean_pass.h @@ -75,7 +75,6 @@ class AtomicAddrCleanPass : public GraphPass { bool IsAtomicOp(const NodePtr &node); vector hcom_node_vec_; - bool is_loop_graph_ = false; }; } // namespace ge diff --git a/src/ge/graph/passes/attach_stream_label_pass.cc b/src/ge/graph/passes/attach_stream_label_pass.cc deleted file mode 100644 index 0c342d8c..00000000 --- a/src/ge/graph/passes/attach_stream_label_pass.cc +++ /dev/null @@ -1,319 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/attach_stream_label_pass.h" -#include "ge/ge_api_types.h" -#include "graph/common/omg_util.h" - -namespace ge { -Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { - GELOGD("AttachStreamLabelPass Enter."); - - FindNodes(graph); - for (const auto &node : need_label_nodes_) { - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (!op_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { - GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); - } - } - GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed."); - - GELOGD("AttachStreamLabelPass Leave."); - return SUCCESS; -} - -/// -/// @brief Clear Status, used for subgraph pass -/// @return -/// -Status AttachStreamLabelPass::ClearStatus() { - stream_switch_nodes_.clear(); - need_label_nodes_.clear(); - enter_nodes_.clear(); - branch_head_nodes_.clear(); - return SUCCESS; -} - -/// -/// @brief Find StreamSwitch / StreamMerge / Enter node -/// @param [in] graph -/// @return void -/// -void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { - for (const NodePtr &node : graph->GetDirectNode()) { - const std::string &type = node->GetType(); - if (type == STREAMSWITCH) { - stream_switch_nodes_.emplace_back(node); - } else if (type == STREAMMERGE) { - if ((node->GetOpDesc() != nullptr) && !node->GetOpDesc()->HasAttr(ATTR_NAME_NEXT_ITERATION)) { - need_label_nodes_.emplace_back(node); - } - } else if ((type == ENTER) || (type == REFENTER)) { - enter_nodes_.emplace_back(node); - } - } - - for (const auto &node : stream_switch_nodes_) { - for (const auto &out_ctrl_node : node->GetOutControlNodes()) { - MarkHeadNodes(out_ctrl_node, node); - } - need_label_nodes_.emplace_back(node); - } -} - -/// -/// @brief Mark node as head_node of stream_switch -/// @param [in] node -/// @param [in] stream_switch -/// @return void -/// -void AttachStreamLabelPass::MarkHeadNodes(const NodePtr &node, const NodePtr &stream_switch) { - static const std::set bypass_type_set = {IDENTITY, IDENTITYN, CAST, TRANSDATA, - TRANSPOSE, TRANSPOSED, RESHAPE}; - std::stack nodes; - nodes.push(node); - std::set visited; - while (!nodes.empty()) { - NodePtr cur_node = nodes.top(); - nodes.pop(); - if (visited.count(cur_node) > 0) { - continue; - } - GELOGD("branch_head_node %s of stream_switch %s.", cur_node->GetName().c_str(), stream_switch->GetName().c_str()); - branch_head_nodes_[cur_node] = stream_switch; - if (bypass_type_set.count(cur_node->GetType()) > 0) { - for (const auto &out_node : cur_node->GetOutAllNodes()) { - nodes.push(out_node); - } - } - visited.insert(cur_node); - } -} - -/// -/// @brief update cond branch -/// @param [in] node -/// @return Status -/// -Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { - std::string stream_label; - std::unordered_set branch_nodes; - std::unordered_set visited; - std::stack nodes; - nodes.push(node); - - static const std::set end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; - bool merge_flag = false; - bool exit_flag = false; - bool net_output_flag = false; - while (!nodes.empty()) { - NodePtr cur_node = nodes.top(); - nodes.pop(); - if (visited.count(cur_node) > 0) { - continue; - } - if (AttachFlag(cur_node, stream_label, merge_flag, exit_flag, net_output_flag) != SUCCESS) { - GELOGE(FAILED, "Attach flag for node %s failed.", cur_node->GetName().c_str()); - return FAILED; - } - - const std::string &type = cur_node->GetType(); - for (const auto &out_node : cur_node->GetOutAllNodes()) { - const std::string &out_type = out_node->GetType(); - bool stop_flag = (end_type_set.count(out_type) > 0) || - ((branch_head_nodes_.count(out_node) > 0) && (branch_head_nodes_[out_node] != node)) || - (((type == ENTER) || (type == REFENTER)) && (out_type != STREAMACTIVE)); - if (!stop_flag) { - nodes.push(out_node); - GELOGD("Insert branch node %s.", out_node->GetName().c_str()); - branch_nodes.insert(out_node); - } - } - visited.insert(cur_node); - } - - if (node->GetType() == STREAMSWITCH) { - GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); - } - - bool attach_flag = (merge_flag || exit_flag) && net_output_flag; - if (attach_flag) { - GELOGI("No need to keep on attaching label."); - return SUCCESS; - } - - for (const NodePtr &tmp_node : branch_nodes) { - GELOGD("Attach label %s to node: %s.", stream_label.c_str(), tmp_node->GetName().c_str()); - GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "Set stream label failed."); - } - - return SUCCESS; -} - -/// -/// @brief attach flag -/// @param [in] node -/// @param [out] stream_label -/// @param [out] merge_flag -/// @param [out] exit_flag -/// @param [out] net_output_flag -/// @return Status -/// -Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &stream_label, bool &merge_flag, - bool &exit_flag, bool &net_output_flag) { - const std::string &type = node->GetType(); - if (type == STREAMSWITCH) { - if (node->GetInDataNodes().empty()) { - GELOGE(INTERNAL_ERROR, "node %s has no input_data_node.", node->GetName().c_str()); - return INTERNAL_ERROR; - } - stream_label = node->GetInDataNodes().at(0)->GetName(); - GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); - bool value = false; - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, - "StreamSwitch get attr TRUE_BRANCH_STREAM failed."); - stream_label += (value ? "_t" : "_f"); - } else if (type == STREAMMERGE) { - stream_label = node->GetName(); - GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); - merge_flag = true; - } else if ((type == EXIT) || (type == REFEXIT)) { - GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); - exit_flag = true; - } else if (type == NETOUTPUT) { - net_output_flag = true; - } - - return SUCCESS; -} - -/// -/// @brief Update stream_label start with enter nodes -/// @return Status -/// -Status AttachStreamLabelPass::UpdateEnterNode() { - std::unordered_map> enter_active_map; - for (const auto &enter_node : enter_nodes_) { - for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { - if (out_ctrl_node->GetType() != STREAMACTIVE) { - continue; - } - auto iter = enter_active_map.find(out_ctrl_node); - if (iter == enter_active_map.end()) { - enter_active_map[out_ctrl_node] = {enter_node}; - } else { - iter->second.emplace_back(enter_node); - } - } - } - - for (const auto &pair : enter_active_map) { - if (SetEnterLabel(pair.second, pair.first) != SUCCESS) { - GELOGE(FAILED, "Set stream_label for enter_nodes failed."); - return FAILED; - } - - NodePtr active_node = pair.first; - GE_CHECK_NOTNULL(active_node); - std::vector active_label_list; - if (!AttrUtils::GetListStr(active_node->GetOpDesc(), ATTR_NAME_ACTIVE_LABEL_LIST, active_label_list) || - (active_label_list.size() != 1) || active_label_list[0].empty()) { - GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ACTIVE_LABEL_LIST failed, node: %s.", active_node->GetName().c_str()); - return INTERNAL_ERROR; - } - - std::stack enter_nodes; - for (const auto &enter_node : pair.second) { - enter_nodes.emplace(enter_node); - } - if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) { - GELOGE(FAILED, "Update stream_label for loop_branch failed."); - return FAILED; - } - } - - return SUCCESS; -} - -/// -/// @brief Set stream_label for enter_nodes -/// @param [in] enter_nodes -/// @param [in] active_node -/// @return Status -/// -Status AttachStreamLabelPass::SetEnterLabel(const std::vector &enter_nodes, const NodePtr &active_node) { - std::string stream_label; - GE_CHECK_NOTNULL(active_node); - (void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); - - bool same_flag = true; - for (const auto &enter_node : enter_nodes) { - std::string tmp_label; - (void)AttrUtils::GetStr(enter_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, tmp_label); - if (tmp_label.empty() || (stream_label == tmp_label)) { - continue; - } - same_flag = false; - break; - } - - if (stream_label.empty()) { - if (same_flag) { - stream_label = active_node->GetName(); - } else { - GELOGW("stream_label of enter_active is empty while stream_label of some enter_node is not."); - return SUCCESS; - } - } - - for (const auto &enter_node : enter_nodes) { - GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); - } - GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "Set stream label failed."); - return SUCCESS; -} - -/// -/// @brief Update stream_label for loop_branch -/// @param [in] enter_nodes -/// @param [in] stream_label -/// @return Status -/// -Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack &enter_nodes, - const std::string &stream_label) { - std::stack nodes(enter_nodes); - NodePtr cur_node = nullptr; - while (!nodes.empty()) { - cur_node = nodes.top(); - nodes.pop(); - for (const NodePtr &out_node : cur_node->GetOutAllNodes()) { - OpDescPtr out_desc = out_node->GetOpDesc(); - GE_CHECK_NOTNULL(out_desc); - std::string out_type = out_desc->GetType(); - if (out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER)) { - continue; - } - GELOGD("Attach label %s to node: %s.", stream_label.c_str(), out_node->GetName().c_str()); - GE_CHK_STATUS_RET(SetStreamLabel(out_node, stream_label), "Set stream label failed."); - nodes.push(out_node); - } - } - return SUCCESS; -} -} // namespace ge diff --git a/src/ge/graph/passes/attach_stream_label_pass.h b/src/ge/graph/passes/attach_stream_label_pass.h deleted file mode 100644 index 743ce36e..00000000 --- a/src/ge/graph/passes/attach_stream_label_pass.h +++ /dev/null @@ -1,97 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_PASSES_ATTACH_STREAM_LABEL_PASS_H_ -#define GE_GRAPH_PASSES_ATTACH_STREAM_LABEL_PASS_H_ - -#include -#include "inc/graph_pass.h" - -namespace ge { -class AttachStreamLabelPass : public GraphPass { - public: - Status Run(ComputeGraphPtr graph); - - /// - /// @brief Clear Status, used for subgraph pass - /// @return - /// - Status ClearStatus() override; - - private: - /// - /// @brief Find StreamSwitch / StreamMerge / Enter node - /// @param [in] graph - /// @return void - /// - void FindNodes(const ComputeGraphPtr &graph); - - /// - /// @brief Mark node as head_node of stream_switch - /// @param [in] node - /// @param [in] stream_switch - /// @return void - /// - void MarkHeadNodes(const NodePtr &node, const NodePtr &stream_switch); - - /// - /// @brief update cond branch - /// @param [in] node - /// @return Status - /// - Status UpdateCondBranch(const NodePtr &node); - - /// - /// @brief attach flag - /// @param [in] node - /// @param [out] stream_label - /// @param [out] merge_flag - /// @param [out] exit_flag - /// @param [out] net_output_flag - /// @return Status - /// - static Status AttachFlag(const NodePtr &node, std::string &stream_label, bool &merge_flag, bool &exit_flag, - bool &net_output_flag); - - /// - /// @brief Update stream_label for loop_branch - /// @param [in] enter_nodes - /// @param [in] stream_label - /// @return Status - /// - static Status UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label); - - /// - /// @brief Update stream_label start with enter nodes - /// @return Status - /// - Status UpdateEnterNode(); - - /// - /// @brief Set stream_label for enter_nodes - /// @param [in] enter_nodes - /// @param [in] active_node - /// @return Status - /// - static Status SetEnterLabel(const std::vector &enter_nodes, const NodePtr &active_node); - - std::vector stream_switch_nodes_; - std::vector need_label_nodes_; - std::vector enter_nodes_; - std::unordered_map branch_head_nodes_; -}; -} // namespace ge -#endif // GE_GRAPH_PASSES_ATTACH_STREAM_LABEL_PASS_H_ diff --git a/src/ge/graph/passes/cast_remove_pass.cc b/src/ge/graph/passes/cast_remove_pass.cc index f7ff941c..d18c4b4e 100644 --- a/src/ge/graph/passes/cast_remove_pass.cc +++ b/src/ge/graph/passes/cast_remove_pass.cc @@ -69,6 +69,7 @@ bool CastRemovePass::HasSameDataType(OpDescPtr &begin_op_desc, OpDescPtr &end_op auto begin_out_desc = begin_op_desc->MutableOutputDesc(0); DataType begin_out_datatype = begin_out_desc->GetDataType(); + if (begin_out_datatype == end_out_datatype && (begin_out_datatype == DT_FLOAT16 || begin_out_datatype == DT_FLOAT)) { type = begin_out_datatype; return true; diff --git a/src/ge/graph/passes/common_subexpression_elimination_pass.cc b/src/ge/graph/passes/common_subexpression_elimination_pass.cc index 18f2e857..a52535c1 100644 --- a/src/ge/graph/passes/common_subexpression_elimination_pass.cc +++ b/src/ge/graph/passes/common_subexpression_elimination_pass.cc @@ -83,7 +83,6 @@ Status CommonSubexpressionEliminationPass::Run(ComputeGraphPtr graph) { continue; } auto key = GetCseKey(node); - GELOGD("The node %s cse key %s", node->GetName().c_str(), key.c_str()); auto iter = keys_to_node.find(key); if (iter == keys_to_node.end()) { keys_to_node[key] = node; diff --git a/src/ge/graph/passes/compile_nodes_pass.cc b/src/ge/graph/passes/compile_nodes_pass.cc index 330569a2..def7655e 100644 --- a/src/ge/graph/passes/compile_nodes_pass.cc +++ b/src/ge/graph/passes/compile_nodes_pass.cc @@ -23,7 +23,6 @@ #include "common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" -#include "graph/common/ge_call_wrapper.h" #include "graph/op_desc.h" using domi::ImplyType; @@ -79,7 +78,7 @@ graphStatus CompileNodesPass::Run(ComputeGraphPtr graph) { return result; } GELOGI("[CompileNodesPass]: Optimize success."); - GE_TIMESTAMP_EVENT_END(CompileNodesPass, "OptimizeStage2::ControlAttrOptimize::CompileNodesPass"); + GE_TIMESTAMP_END(CompileNodesPass, "GraphManager::CompileNodesPass"); return GRAPH_SUCCESS; } @@ -102,6 +101,7 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: } } OpsKernelInfoStorePtr kernel_info = instance->OpsKernelManagerObj().GetOpsKernelInfoStore(kernel_lib_name); + if (kernel_info == nullptr) { GELOGE(ge::GE_GRAPH_PARAM_NULLPTR, "Get op %s ops kernel info store failed", node->GetName().c_str()); return ge::GE_GRAPH_PARAM_NULLPTR; diff --git a/src/ge/graph/passes/cond_pass.cc b/src/ge/graph/passes/cond_pass.cc index 2f3f9333..651cf98b 100644 --- a/src/ge/graph/passes/cond_pass.cc +++ b/src/ge/graph/passes/cond_pass.cc @@ -226,7 +226,7 @@ Status CondPass::HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnc return FAILED; } - if (GraphUtils::InsertNodeAfter(out_anchor, {in_anchor}, cast_node) != GRAPH_SUCCESS) { + if (GraphUtils::InsertNodeBefore(out_anchor, {in_anchor}, cast_node) != GRAPH_SUCCESS) { GELOGE(FAILED, "Insert Cast node %s between %s->%s failed.", cast_node->GetName().c_str(), out_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); return FAILED; @@ -271,7 +271,7 @@ Status CondPass::InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr } AddRePassNode(new_node); - if (GraphUtils::InsertNodeAfter(out_anchor, {in_anchor}, new_node) != GRAPH_SUCCESS) { + if (GraphUtils::InsertNodeBefore(out_anchor, {in_anchor}, new_node) != GRAPH_SUCCESS) { GELOGE(FAILED, "Insert %s node %s between %s->%s failed.", type.c_str(), new_node->GetName().c_str(), out_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); return FAILED; diff --git a/src/ge/graph/passes/cond_remove_pass.cc b/src/ge/graph/passes/cond_remove_pass.cc index 1650be92..8bc34fbc 100644 --- a/src/ge/graph/passes/cond_remove_pass.cc +++ b/src/ge/graph/passes/cond_remove_pass.cc @@ -225,40 +225,41 @@ bool CondRemovePass::CheckIfCondConstInput(const OutDataAnchorPtr &cond_out_anch Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, const ComputeGraphPtr &save_branch) { // Add compute graph to new node - const auto &input_desc_size = node->GetOpDesc()->GetInputsSize(); - const auto &output_desc_size = node->GetOpDesc()->GetOutputsSize(); + const auto &input_anchors = node->GetAllInAnchors(); + const auto &output_anchors = node->GetAllOutAnchors(); // Create subgraph opdesc & node auto partitioncall_opdesc = - CreateSubgraphOpDesc(save_branch->GetName(), input_desc_size - kConditionIndexNum, output_desc_size); + CreateSubgraphOpDesc(save_branch->GetName(), input_anchors.size() - kConditionIndexNum, output_anchors.size()); auto partitioncall_node = node->GetOwnerComputeGraph()->AddNode(partitioncall_opdesc); // Link node's peerout anchors to new node's inanchors - for (const auto &input_anchor : node->GetAllInAnchors()) { + for (const auto &input_anchor : input_anchors) { for (const auto &peerout_anchor : input_anchor->GetPeerAnchors()) { if (GraphUtils::AddEdge(peerout_anchor, partitioncall_node->GetInAnchor( input_anchor->GetIdx() - kConditionIndexNum)) != ge::GRAPH_SUCCESS) { GELOGE(FAILED, "Add edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%d, output num:%d", peerout_anchor->GetOwnerNode()->GetName().c_str(), peerout_anchor->GetIdx(), - partitioncall_node->GetName().c_str(), input_anchor->GetIdx(), input_desc_size, output_desc_size); + partitioncall_node->GetName().c_str(), input_anchor->GetIdx(), input_anchors.size(), + output_anchors.size()); return FAILED; } } } // Remove If / Case anchor and peer in anchor // Link new node's out anchors to node's peer inanchors - for (const auto &output_anchor : node->GetAllOutAnchors()) { + for (const auto &output_anchor : output_anchors) { for (const auto &peerin_anchor : output_anchor->GetPeerAnchors()) { if (GraphUtils::RemoveEdge(node->GetOutAnchor(output_anchor->GetIdx()), peerin_anchor) != ge::GRAPH_SUCCESS) { GELOGE(FAILED, "Remove edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%d, output num:%d", node->GetName().c_str(), output_anchor->GetIdx(), peerin_anchor->GetOwnerNode()->GetName().c_str(), - peerin_anchor->GetIdx(), input_desc_size, output_desc_size); + peerin_anchor->GetIdx(), input_anchors.size(), output_anchors.size()); return FAILED; } if (GraphUtils::AddEdge(partitioncall_node->GetOutAnchor(output_anchor->GetIdx()), peerin_anchor) != ge::GRAPH_SUCCESS) { GELOGE(FAILED, "Add edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%d, output num:%d", partitioncall_node->GetName().c_str(), output_anchor->GetIdx(), - peerin_anchor->GetOwnerNode()->GetName().c_str(), peerin_anchor->GetIdx(), input_desc_size, - output_desc_size); + peerin_anchor->GetOwnerNode()->GetName().c_str(), peerin_anchor->GetIdx(), input_anchors.size(), + output_anchors.size()); return FAILED; } } diff --git a/src/ge/graph/passes/constant_folding_pass.cc b/src/ge/graph/passes/constant_folding_pass.cc index 80bf7867..3ac7feb6 100644 --- a/src/ge/graph/passes/constant_folding_pass.cc +++ b/src/ge/graph/passes/constant_folding_pass.cc @@ -29,18 +29,6 @@ #include "inc/kernel.h" namespace ge { -const int64_t kStartCallNum = 1; - -const std::unordered_map> - &ConstantFoldingPass::GetGeConstantFoldingPerfStatistic() const { - return statistic_of_ge_constant_folding_; -} - -const std::unordered_map> - &ConstantFoldingPass::GetOpConstantFoldingPerfStatistic() const { - return statistic_of_op_constant_folding_; -} - Status ConstantFoldingPass::Run(ge::NodePtr &node) { GE_CHECK_NOTNULL(node); GELOGD("Begin to run constant folding on node %s", node->GetName().c_str()); @@ -62,8 +50,6 @@ Status ConstantFoldingPass::Run(ge::NodePtr &node) { auto inputs = OpDescUtils::GetInputData(input_nodes); vector outputs; - // Statistic of ge constant folding kernel - uint64_t start_time = GetCurrentTimestap(); auto ret = RunOpKernel(node, inputs, outputs); if (ret != SUCCESS) { auto op_kernel = folding_pass::GetKernelByType(node); @@ -73,18 +59,7 @@ Status ConstantFoldingPass::Run(ge::NodePtr &node) { return SUCCESS; } - // Statistic of op and fe constant folding kernel - start_time = GetCurrentTimestap(); ret = op_kernel->Compute(node_desc, inputs, outputs); - uint64_t cost_time = GetCurrentTimestap() - start_time; - if (statistic_of_ge_constant_folding_.find(node->GetType()) != statistic_of_ge_constant_folding_.end()) { - uint64_t &cnt = statistic_of_ge_constant_folding_[node->GetType()].first; - uint64_t &cur_cost_time = statistic_of_ge_constant_folding_[node->GetType()].second; - cnt++; - cur_cost_time += cost_time; - } else { - statistic_of_ge_constant_folding_[node->GetType()] = std::pair(kStartCallNum, cost_time); - } if (ret != SUCCESS) { if (ret == NOT_CHANGED) { GELOGD("Node %s type %s, compute terminates and exits the constant folding.", node->GetName().c_str(), @@ -95,16 +70,6 @@ Status ConstantFoldingPass::Run(ge::NodePtr &node) { return ret; } GELOGI("Node %s type %s, constant folding compute success.", node->GetName().c_str(), node->GetType().c_str()); - } else { - if (statistic_of_op_constant_folding_.find(node->GetType()) != statistic_of_op_constant_folding_.end()) { - uint64_t &cnt = statistic_of_op_constant_folding_[node->GetType()].first; - uint64_t &cost_time = statistic_of_op_constant_folding_[node->GetType()].second; - cnt++; - cost_time += GetCurrentTimestap() - start_time; - } else { - statistic_of_op_constant_folding_[node->GetType()] = - std::pair(kStartCallNum, GetCurrentTimestap() - start_time); - } } if (outputs.empty()) { diff --git a/src/ge/graph/passes/constant_folding_pass.h b/src/ge/graph/passes/constant_folding_pass.h index 683b66f1..1dcbcdc3 100644 --- a/src/ge/graph/passes/constant_folding_pass.h +++ b/src/ge/graph/passes/constant_folding_pass.h @@ -26,12 +26,6 @@ namespace ge { class ConstantFoldingPass : public FoldingPass { public: Status Run(ge::NodePtr &node) override; - const std::unordered_map> &GetGeConstantFoldingPerfStatistic() const; - const std::unordered_map> &GetOpConstantFoldingPerfStatistic() const; - - private: - std::unordered_map> statistic_of_op_constant_folding_; - std::unordered_map> statistic_of_ge_constant_folding_; }; } // namespace ge diff --git a/src/ge/graph/passes/control_trigger_pass.cc b/src/ge/graph/passes/control_trigger_pass.cc index 0c00d553..77fcbd69 100644 --- a/src/ge/graph/passes/control_trigger_pass.cc +++ b/src/ge/graph/passes/control_trigger_pass.cc @@ -15,9 +15,16 @@ */ #include "graph/passes/control_trigger_pass.h" + #include + #include "common/ge/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" #include "graph/common/omg_util.h" +#include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" namespace ge { @@ -437,7 +444,7 @@ Status ControlTriggerPass::FindPredInput(const NodePtr &switch_node) { return SUCCESS; } /// -/// @brief Clear Status, used for subgraph pass +/// @brief Clear Status, uesd for subgraph pass /// @return SUCCESS /// Status ControlTriggerPass::ClearStatus() { diff --git a/src/ge/graph/passes/hccl_memcpy_pass.cc b/src/ge/graph/passes/hccl_memcpy_pass.cc index a9b3484b..5325f56e 100644 --- a/src/ge/graph/passes/hccl_memcpy_pass.cc +++ b/src/ge/graph/passes/hccl_memcpy_pass.cc @@ -28,7 +28,6 @@ namespace { const int32_t kAnchorSize = 1; const int kAnchorNum = 0; -const char *const kInputMutable = "_input_mutable"; } // namespace namespace ge { Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { @@ -36,16 +35,7 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { for (const auto &node : graph->GetDirectNode()) { auto op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(op_desc == nullptr, continue); - - bool node_input_mutable = false; - if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { - continue; - } - - GE_IF_BOOL_EXEC(!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable), - GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); - return FAILED); - if (!node_input_mutable) { + if (!NeedInsertMemcpyOp(op_desc)) { continue; } @@ -63,7 +53,7 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { NodePtr src_node = src_out_anchor->GetOwnerNode(); std::string src_type = src_node->GetType(); bool check_src_type = (src_type == CONSTANTOP) || (src_type == DATA) || (src_type == CONSTANT); - if (check_src_type) { + if (check_src_type && node->GetType() == HCOMALLREDUCE) { Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); if (ret != SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to modify the connection."); @@ -145,6 +135,16 @@ std::string HcclMemcpyPass::CheckDuplicateName(const std::string &node_name) { return tmp_name; } +/// +/// @brief Check hcom op +/// @param [in] ge::ConstOpDescPtr op_desc +/// @return bool +/// +bool HcclMemcpyPass::NeedInsertMemcpyOp(const ge::ConstOpDescPtr &op_desc) const { + return (op_desc->GetType() == HCOMALLGATHER || op_desc->GetType() == HCOMALLREDUCE || + op_desc->GetType() == HCOMREDUCESCATTER); +} + /// /// @brief Modify edge connection /// @param [in] ComputeGraphPtr graph @@ -182,7 +182,7 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const return SUCCESS; } /// -/// @brief Clear Status, used for subgraph pass +/// @brief Clear Status, uesd for subgraph pass /// @return SUCCESS /// Status HcclMemcpyPass::ClearStatus() { diff --git a/src/ge/graph/passes/hccl_memcpy_pass.h b/src/ge/graph/passes/hccl_memcpy_pass.h index 13863bd6..9de96fbf 100644 --- a/src/ge/graph/passes/hccl_memcpy_pass.h +++ b/src/ge/graph/passes/hccl_memcpy_pass.h @@ -34,6 +34,8 @@ class HcclMemcpyPass : public GraphPass { std::string CheckDuplicateName(const std::string &node_name); + bool NeedInsertMemcpyOp(const ge::ConstOpDescPtr &op_desc) const; + Status ModifyEdgeConnection(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, const InDataAnchorPtr &hccl_in_anchor); diff --git a/src/ge/graph/passes/identify_reference_pass.cc b/src/ge/graph/passes/identify_reference_pass.cc new file mode 100644 index 00000000..92f7e7b6 --- /dev/null +++ b/src/ge/graph/passes/identify_reference_pass.cc @@ -0,0 +1,52 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/identify_reference_pass.h" + +#include +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_attr_define.h" + +namespace ge { +Status IdentifyReferencePass::Run(NodePtr &node) { + if (node == nullptr) { + GELOGE(PARAM_INVALID, "param [node] must not be null."); + return PARAM_INVALID; + } + auto op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(PARAM_INVALID, "OpDesc of param [node] must not be null."); + return PARAM_INVALID; + } + + auto input_names = op_desc->GetAllInputNames(); + auto outputs = op_desc->GetAllOutputName(); + for (auto &output : outputs) { + for (auto &input_name : input_names) { + if (input_name == output.first) { + bool is_ref = true; + if (AttrUtils::SetBool(op_desc, ATTR_NAME_REFERENCE, is_ref)) { + GELOGI("param [node] %s is reference node, set attribute %s to be true.", node->GetName().c_str(), + ATTR_NAME_REFERENCE.c_str()); + return SUCCESS; + } + } + } + } + + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/mark_same_addr_pass.h b/src/ge/graph/passes/identify_reference_pass.h similarity index 62% rename from src/ge/graph/passes/mark_same_addr_pass.h rename to src/ge/graph/passes/identify_reference_pass.h index ebfcf6b2..5f284b4c 100644 --- a/src/ge/graph/passes/mark_same_addr_pass.h +++ b/src/ge/graph/passes/identify_reference_pass.h @@ -14,19 +14,16 @@ * limitations under the License. */ -#include "graph/graph.h" -#include "inc/graph_pass.h" +#ifndef GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_ +#define GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_ -#ifndef GE_GRAPH_PASSES_MARK_SAME_ADDR_PASS_H_ -#define GE_GRAPH_PASSES_MARK_SAME_ADDR_PASS_H_ +#include "graph/passes/base_pass.h" namespace ge { -class MarkSameAddrPass : public GraphPass { +class IdentifyReferencePass : public BaseNodePass { public: - Status Run(ComputeGraphPtr graph); - - private: - bool IsNextNodeExpected(const ge::NodePtr &cur_node, const vector &next_nodes, int &out_anchor_idx); + Status Run(NodePtr &node) override; }; } // namespace ge -#endif // GE_GRAPH_PASSES_MARK_SAME_ADDR_PASS_H_ + +#endif // GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_ diff --git a/src/ge/graph/passes/infershape_pass.cc b/src/ge/graph/passes/infershape_pass.cc index 8b44d31b..18767cea 100644 --- a/src/ge/graph/passes/infershape_pass.cc +++ b/src/ge/graph/passes/infershape_pass.cc @@ -15,7 +15,7 @@ */ #include "graph/passes/infershape_pass.h" -#include "common/util/error_manager/error_manager.h" + #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/shape_refiner.h" @@ -24,8 +24,6 @@ namespace ge { Status InferShapePass::Run(NodePtr &node) { auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); if (ret != GRAPH_SUCCESS) { - ErrorManager::GetInstance().ATCReportErrMessage("E35003", {"opname", "err_msg"}, - {node->GetName(), "check your model!"}); GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str()); return GE_GRAPH_INFERSHAPE_FAILED; } diff --git a/src/ge/graph/passes/iterator_op_pass.cc b/src/ge/graph/passes/iterator_op_pass.cc index 1d11004d..e1d452b1 100644 --- a/src/ge/graph/passes/iterator_op_pass.cc +++ b/src/ge/graph/passes/iterator_op_pass.cc @@ -73,14 +73,14 @@ Status IteratorOpPass::Run(ge::ComputeGraphPtr graph) { GE_IF_BOOL_EXEC(status != SUCCESS, GELOGW("Fail to Get var_desc of NODE_NAME_FLOWCTRL_LOOP_PER_ITER failed."); continue); Status ret; - ret = SetRtContext(graph->GetSessionID(), rtContext_t(), RT_CTX_NORMAL_MODE); + ret = SetRtContext(rtContext_t(), RT_CTX_NORMAL_MODE); // EOS will not be considered if ret is not SUCCESS. - GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGW("Set rt context RT_CTX_NORMAL_MODE failed."); continue); + GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGW("Set rt context RT_CTX_GEN_MODE failed."); continue); status = GetVariableValue(graph->GetSessionID(), ge_tensor_desc, NODE_NAME_FLOWCTRL_LOOP_PER_ITER, &loop_per_iter); - ret = SetRtContext(graph->GetSessionID(), rtContext_t(), RT_CTX_GEN_MODE); + ret = SetRtContext(rtContext_t(), RT_CTX_GEN_MODE); // The following process will be affected if ret is not SUCCESS. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Set rt context RT_CTX_GEN_MODE failed."); return ret); @@ -108,7 +108,7 @@ Status IteratorOpPass::GetVariableValue(uint64_t session_id, const ge::GeTensorD // base_addr uint8_t *var_mem_base = VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM); GE_CHECK_NOTNULL(var_mem_base); - // offset + logic_base + // offset uint8_t *dev_ptr = nullptr; GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, &dev_ptr), "Get variable %s address failed.", var_name.c_str()); @@ -279,11 +279,11 @@ ge::OpDescPtr IteratorOpPass::CreateMemcpyAsyncOp(const ge::NodePtr &pre_node) { return op_desc; } -Status IteratorOpPass::SetRtContext(uint64_t session_id, rtContext_t rt_context, rtCtxMode_t mode) { +Status IteratorOpPass::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode) { GELOGI("set rt_context %d, device id:%u.", static_cast(mode), ge::GetContext().DeviceId()); GE_CHK_RT_RET(rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId())); GE_CHK_RT_RET(rtCtxSetCurrent(rt_context)); - RtContextUtil::GetInstance().AddRtContext(session_id, rt_context); + RtContextUtil::GetInstance().AddrtContext(rt_context); return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/passes/iterator_op_pass.h b/src/ge/graph/passes/iterator_op_pass.h index 78b951e6..e403020c 100644 --- a/src/ge/graph/passes/iterator_op_pass.h +++ b/src/ge/graph/passes/iterator_op_pass.h @@ -64,7 +64,7 @@ class IteratorOpPass : public GraphPass { /// ge::OpDescPtr CreateMemcpyAsyncOp(const ge::NodePtr &pre_node); - Status SetRtContext(uint64_t session_id, rtContext_t rt_context, rtCtxMode_t mode); + Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); }; } // namespace ge #endif // GE_GRAPH_PASSES_ITERATOR_OP_PASS_H_ diff --git a/src/ge/graph/passes/link_gen_mask_nodes_pass.cc b/src/ge/graph/passes/link_gen_mask_nodes_pass.cc index 63ca68a2..ff150a54 100644 --- a/src/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/src/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -97,16 +97,9 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vectorGetOpDesc() == nullptr) || (node->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { - continue; - } - auto in_data_nodes = node->GetInDataNodes(); if (in_data_nodes.size() > kGenMaskInputIndex) { NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); - if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { - continue; - } if (AreAllInputsConst(gen_mask) && nodes_set.count(gen_mask) == 0) { gen_mask_nodes.emplace_back(gen_mask); nodes_set.emplace(gen_mask); diff --git a/src/ge/graph/passes/mark_same_addr_pass.cc b/src/ge/graph/passes/mark_same_addr_pass.cc deleted file mode 100644 index 06d63393..00000000 --- a/src/ge/graph/passes/mark_same_addr_pass.cc +++ /dev/null @@ -1,81 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/mark_same_addr_pass.h" - -namespace ge { -bool MarkSameAddrPass::IsNextNodeExpected(const ge::NodePtr &cur_node, const vector &next_nodes, - int &out_anchor_idx) { - for (auto out_anchor : cur_node->GetAllOutDataAnchors()) { - if (out_anchor == nullptr) { - continue; - } - for (auto in_anchor : out_anchor->GetPeerInDataAnchors()) { - if (in_anchor == nullptr) { - continue; - } - auto dst_node = in_anchor->GetOwnerNode(); - if (dst_node == nullptr) { - continue; - } - if (std::count(next_nodes.begin(), next_nodes.end(), dst_node->GetType()) > 0) { - out_anchor_idx = out_anchor->GetIdx(); - GELOGD("Current node is %s, next node is %s.", cur_node->GetName().c_str(), dst_node->GetName().c_str()); - return true; - } - } - } - return false; -} - -Status MarkSameAddrPass::Run(ComputeGraphPtr graph) { - GELOGD("MarkSameAddrPass begin."); - GE_CHECK_NOTNULL(graph); - auto parent_node = graph->GetParentNode(); - if (parent_node == nullptr) { - return SUCCESS; - } - auto parent_op_desc = parent_node->GetOpDesc(); - GE_CHECK_NOTNULL(parent_op_desc); - if (!parent_op_desc->HasAttr(ATTR_NAME_IS_UNKNOWN_SHAPE)) { - GELOGD("Graph[%s] do not have unknown shape attr. Parent node is %s", graph->GetName().c_str(), - parent_op_desc->GetName().c_str()); - return SUCCESS; - } - - bool is_unknown_shape = false; - (void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape); - if (is_unknown_shape) { - GELOGD("Graph[%s] is unknown shape, do not need to set fixed addr attr. Parent node is %s", - graph->GetName().c_str(), parent_op_desc->GetName().c_str()); - return SUCCESS; - } - - int out_anchor_idx = 0; - for (const ge::NodePtr &node : graph->GetDirectNode()) { - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - vector next_nodes = {STREAMSWITCH, STREAMSWITCHN, LABELSWITCHBYINDEX}; - if (IsNextNodeExpected(node, next_nodes, out_anchor_idx)) { - string tensor_name = op_desc->GetOutputNameByIndex(out_anchor_idx); - (void)ge::AttrUtils::SetStr(node->GetOpDesc(), ATTR_DYNAMIC_SHAPE_FIXED_ADDR, tensor_name); - (void)ge::AttrUtils::SetInt(node->GetOpDesc(), ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX, out_anchor_idx); - } - } - GELOGD("MarkSameAddrPass end."); - return SUCCESS; -} -} // namespace ge diff --git a/src/ge/graph/passes/merge_to_stream_merge_pass.cc b/src/ge/graph/passes/merge_to_stream_merge_pass.cc deleted file mode 100644 index b785ddfa..00000000 --- a/src/ge/graph/passes/merge_to_stream_merge_pass.cc +++ /dev/null @@ -1,234 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/merge_to_stream_merge_pass.h" -#include "common/ge/ge_util.h" -#include "ge/ge_api_types.h" -#include "graph/common/omg_util.h" - -namespace ge { -Status MergeToStreamMergePass::Run(ComputeGraphPtr graph) { - GELOGD("MergeToStreamMergePass Enter"); - - bypass_nodes_.clear(); - for (const auto &node : graph->GetDirectNode()) { - if ((node->GetType() != MERGE) && (node->GetType() != REFMERGE)) { - continue; - } - - OpDescPtr merge_op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(merge_op_desc); - if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { - GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, node, true), "Merge add memcpy node failed."); - GE_CHK_STATUS_RET(SetStreamLabel(node, node->GetName()), "Set stream label failed"); - } else { - GE_CHK_STATUS_RET(ReplaceMergeNode(graph, node), "Add StreamMerge node failed."); - } - } - - for (const auto &node : bypass_nodes_) { - GE_CHK_BOOL_EXEC(GraphUtils::RemoveNodeWithoutRelink(graph, node) == GRAPH_SUCCESS, return FAILED, - "Remove merge node failed."); - } - - GELOGD("MergeToStreamMergePass Leave"); - return SUCCESS; -} - -/// -/// @brief Replace Merge Op -/// @param [in] graph -/// @param [in] merge_node -/// @return Status -/// -Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, const NodePtr &merge_node) { - OpDescPtr merge_op_desc = merge_node->GetOpDesc(); - GE_CHECK_NOTNULL(merge_op_desc); - - const std::string &node_name = merge_node->GetName(); - GELOGI("Create StreamMerge Op, name=%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, STREAMMERGE); - if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc failed, StreamMerge:%s.", node_name.c_str()); - return FAILED; - } - - for (const InDataAnchorPtr &in_anchor : merge_node->GetAllInDataAnchors()) { - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(merge_op_desc->GetInputDesc(in_anchor->GetIdx())) == GRAPH_SUCCESS, - return FAILED, "Create StreamMerge op: add input desc failed."); - } - - for (const OutDataAnchorPtr &out_anchor : merge_node->GetAllOutDataAnchors()) { - GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(merge_op_desc->GetOutputDesc(out_anchor->GetIdx())) == GRAPH_SUCCESS, - return FAILED, "Create StreamMerge op: add output desc failed."); - } - - NodePtr stream_merge = graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(stream_merge != nullptr, return FAILED, "Insert StreamMerge node failed."); - GE_CHK_STATUS_RET(MoveEdges(merge_node, stream_merge), "Move edges failed."); - bypass_nodes_.insert(merge_node); - - if (merge_op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { - std::string next_iteration_name; - GE_IF_BOOL_EXEC(!AttrUtils::GetStr(merge_op_desc, ATTR_NAME_NEXT_ITERATION, next_iteration_name), - GELOGE(INTERNAL_ERROR, "Get ATTR_NAME_NEXT_ITERATION failed"); - return INTERNAL_ERROR); - GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed"); - } - - return AddMemcpyAsyncNodes(graph, stream_merge, false); -} - -/// -/// @brief Add MemcpyAsync Op as StreamMerge in_node -/// @param [in] graph -/// @param [in] node -/// @param [in] multi_batch_flag -/// @return Status -/// -Status MergeToStreamMergePass::AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, const NodePtr &node, - bool multi_batch_flag) { - GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); - for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); - NodePtr in_node = peer_out_anchor->GetOwnerNode(); - const std::string &type = in_node->GetType(); - // For WhileLoop no need memcpy & active for merge. - GE_IF_BOOL_EXEC((type == ENTER) || (type == REFENTER) || (type == NEXTITERATION) || (type == REFNEXTITERATION), - continue); - - const std::string &memcpy_name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()); - NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, memcpy_name, peer_out_anchor, multi_batch_flag); - GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return FAILED, "Create MemcpyAsync node failed."); - GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "MemcpyAsync node remove edge failed."); - GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, memcpy_node->GetInDataAnchor(0)), - "MemcpyAsync node add edge failed."); - GE_CHK_STATUS(GraphUtils::AddEdge(memcpy_node->GetOutDataAnchor(0), in_data_anchor), - "MemcpyAsync node add edge failed."); - - NodePtr active_node = CreateActiveNode(graph, memcpy_node); - GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node failed."); - GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), node->GetInControlAnchor()), - "StreamActive add ctrl edge failed."); - if (SetActiveLabelList(active_node, {node->GetName()}) != SUCCESS) { - GELOGE(FAILED, "SetActiveLabelList for node %s failed.", active_node->GetName().c_str()); - return FAILED; - } - } - - return SUCCESS; -} - -/// -/// @brief Add MemcpyAsync Node -/// @param [in] graph -/// @param [in] name -/// @param [in] out_data_anchor -/// @param [in] multi_batch_flag -/// @return ge::NodePtr -/// -NodePtr MergeToStreamMergePass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, - const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag) { - GE_CHK_BOOL_EXEC(out_data_anchor != nullptr, return nullptr, "Param of input node is null."); - OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); - - const std::string &memcpy_type = multi_batch_flag ? MEMCPYADDRASYNC : MEMCPYASYNC; - const std::string &node_name = name + "_" + memcpy_type; - GELOGI("Create MemcpyAsync op:%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, memcpy_type); - if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc failed, MemcpyAsync:%s.", node_name.c_str()); - return nullptr; - } - - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, - return nullptr, "Create MemcpyAsync op: add input desc failed."); - GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, - return nullptr, "Create MemcpyAsync op: add output desc failed."); - - return graph->AddNode(op_desc); -} - -/// -/// @brief Create Active Op -/// @param [in] graph -/// @param [in] node -/// @return ge::NodePtr -/// -NodePtr MergeToStreamMergePass::CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node) { - const std::string &node_name = node->GetName() + "_" + STREAMACTIVE; - GELOGI("Create StreamActive op:%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, STREAMACTIVE); - if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc failed, StreamActive:%s.", node_name.c_str()); - return nullptr; - } - - NodePtr active_node = graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(active_node != nullptr, return nullptr, "Create StreamActive node failed."); - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(node->GetOutControlAnchor(), active_node->GetInControlAnchor()) != SUCCESS, - GELOGE(INTERNAL_ERROR, "add edge failed"); - return nullptr); - GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS, - GELOGE(INTERNAL_ERROR, "set switch branch node label failed"); - return nullptr); - - return active_node; -} - -/// -/// @brief move edges from old_node to new_node -/// @param [in] old_node -/// @param [in] new_node -/// @return Status -/// -Status MergeToStreamMergePass::MoveEdges(const NodePtr &old_node, const NodePtr &new_node) { - for (const InDataAnchorPtr &in_data_anchor : old_node->GetAllInDataAnchors()) { - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); - - GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "Merge remove in data edge failed."); - GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, new_node->GetInDataAnchor(in_data_anchor->GetIdx())), - "StreamMerge add in data edge failed."); - } - - for (const OutDataAnchorPtr &out_data_anchor : old_node->GetAllOutDataAnchors()) { - for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor), "Merge remove out data edge failed."); - GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutDataAnchor(out_data_anchor->GetIdx()), peer_in_anchor), - "StreamMerge add out data edge failed."); - } - } - - for (const NodePtr &in_ctrl_node : old_node->GetInControlNodes()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), old_node->GetInControlAnchor()), - "Merge remove in ctrl edge failed."); - GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), new_node->GetInControlAnchor()), - "StreamMerge add in ctrl edge failed."); - } - - for (const NodePtr &out_ctrl_node : old_node->GetOutControlNodes()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()), - "Merge remove out ctrl edge failed."); - GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()), - "StreamMerge add out ctrl edge failed."); - } - - return SUCCESS; -} -} // namespace ge diff --git a/src/ge/graph/passes/merge_to_stream_merge_pass.h b/src/ge/graph/passes/merge_to_stream_merge_pass.h deleted file mode 100644 index 9f713989..00000000 --- a/src/ge/graph/passes/merge_to_stream_merge_pass.h +++ /dev/null @@ -1,75 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_PASSES_MERGE_TO_STREAM_MERGE_PASS_H_ -#define GE_GRAPH_PASSES_MERGE_TO_STREAM_MERGE_PASS_H_ - -#include "inc/graph_pass.h" - -namespace ge { -class MergeToStreamMergePass : public GraphPass { - public: - Status Run(ComputeGraphPtr graph); - - private: - /// - /// @brief Replace Merge Op - /// @param [in] graph - /// @param [in] merge_node - /// @return Status - /// - Status ReplaceMergeNode(const ComputeGraphPtr &graph, const NodePtr &merge_node); - - /// - /// @brief Add MemcpyAsync Op as StreamMerge in_node - /// @param [in] graph - /// @param [in] node - /// @param [in] multi_batch_flag - /// @return Status - /// - Status AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, const NodePtr &node, bool multi_batch_flag); - - /// - /// @brief Add MemcpyAsync Node - /// @param [in] graph - /// @param [in] name - /// @param [in] out_data_anchor - /// @param [in] multi_batch_flag - /// @return ge::NodePtr - /// - NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, - const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); - - /// - /// @brief Create Active Op - /// @param [in] graph - /// @param [in] node - /// @return ge::NodePtr - /// - NodePtr CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node); - - /// - /// @brief move edges from old_node to new_node - /// @param [in] old_node - /// @param [in] new_node - /// @return Status - /// - Status MoveEdges(const NodePtr &old_node, const NodePtr &new_node); - - std::set bypass_nodes_; -}; -} // namespace ge -#endif // GE_GRAPH_PASSES_MERGE_TO_STREAM_MERGE_PASS_H_ diff --git a/src/ge/graph/passes/multi_batch_pass.cc b/src/ge/graph/passes/multi_batch_pass.cc index 26190168..bb0050be 100644 --- a/src/ge/graph/passes/multi_batch_pass.cc +++ b/src/ge/graph/passes/multi_batch_pass.cc @@ -32,7 +32,7 @@ namespace ge { Status MultiBatchPass::Run(ComputeGraphPtr graph) { GELOGD("MultiBatchPass Enter"); - + GE_CHECK_NOTNULL(graph); if (graph->GetParentGraph() != nullptr) { GELOGI("Subgraph %s skip the MultiBatchPass.", graph->GetName().c_str()); return SUCCESS; @@ -44,26 +44,26 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { return SUCCESS; } if (ret != SUCCESS) { - GELOGE(FAILED, "FindPredValue failed."); + GELOGE(FAILED, "FindPredValue fail."); return FAILED; } std::vector> batch_shape; if (!CheckSwitchN(batch_shape)) { - GELOGE(FAILED, "CheckSwitchN failed."); + GELOGE(FAILED, "CheckSwitchN fail."); return FAILED; } FindSwitchOutNodes(batch_shape.size()); if (ReplaceSwitchN(graph, pred_value, batch_shape) != SUCCESS) { - GELOGE(FAILED, "Replace SwitchN nodes failed."); + GELOGE(FAILED, "Replace SwitchN nodes fail."); return FAILED; } - for (const NodePtr &node : bypass_nodes_) { - if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Remove SwitchN nodes %s failed.", node->GetName().c_str()); + for (NodePtr &node : bypass_nodes_) { + if (graph->RemoveNode(node) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Remove SwitchN nodes %s fail.", node->GetName().c_str()); return FAILED; } } @@ -79,19 +79,19 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { /// @return Status /// Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value) { - for (const NodePtr &node : graph->GetDirectNode()) { + for (NodePtr &node : graph->GetDirectNode()) { if (node->GetType() != SWITCHN) { continue; } InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); if (in_data_anchor == nullptr) { - GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str()); + GELOGE(FAILED, "FindPredInput fail, in_data_anchor is null, node:%s.", node->GetName().c_str()); return FAILED; } OutDataAnchorPtr pred_input = in_data_anchor->GetPeerOutAnchor(); if (pred_input == nullptr) { - GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str()); + GELOGE(FAILED, "FindPredInput fail, pred_input is null, node:%s.", node->GetName().c_str()); return FAILED; } @@ -110,7 +110,7 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor } if (pred_value == nullptr) { - GELOGE(FAILED, "FindPredInput failed, pred_value is null."); + GELOGE(FAILED, "FindPredInput fail, pred_value is null."); return FAILED; } @@ -126,7 +126,7 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor bool MultiBatchPass::CheckSwitchN(std::vector> &batch_shape) { // Check if output_num of different SwitchN is same uint32_t batch_num = 0; - for (const NodePtr &node : switch_n_nodes_) { + for (NodePtr &node : switch_n_nodes_) { uint32_t tmp_num = node->GetAllOutDataAnchorsSize(); if (batch_num == 0) { batch_num = tmp_num; @@ -140,21 +140,21 @@ bool MultiBatchPass::CheckSwitchN(std::vector> &batch_shape std::vector> idx_batch_shape; for (uint32_t i = 0; i < batch_num; i++) { idx_batch_shape.clear(); - for (const NodePtr &node : switch_n_nodes_) { + for (NodePtr &node : switch_n_nodes_) { std::vector output_dims; OpDescPtr op_desc = node->GetOpDesc(); if (op_desc == nullptr) { - GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str()); + GELOGE(FAILED, "CheckDims fail, get op_desc fail, node: %s.", node->GetName().c_str()); return false; } if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) { - GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i); + GELOGE(FAILED, "CheckDims fail, get attr ATTR_NAME_SWITCHN_PRED_VALUE fail, batch_index=%u.", i); return false; } idx_batch_shape.emplace_back(output_dims); } if (!CheckDims(idx_batch_shape)) { - GELOGE(FAILED, "CheckDims failed, batch_index=%u.", i); + GELOGE(FAILED, "CheckDims fail, batch_index=%u.", i); return false; } @@ -187,11 +187,11 @@ void MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { std::vector output_nodes; for (uint32_t i = 0; i < batch_num; i++) { output_nodes.clear(); - for (const NodePtr &node : switch_n_nodes_) { + for (NodePtr &node : switch_n_nodes_) { // idx is promised to be valid OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i); GE_CHECK_NOTNULL_JUST_RETURN(out_data_anchor); - for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + for (InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { output_nodes.emplace_back(peer_in_anchor->GetOwnerNode()); } } @@ -208,33 +208,33 @@ void MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { /// @param [in] batch_shape /// @return Status /// -Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, +Status MultiBatchPass::ReplaceSwitchN(ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value, const std::vector> &batch_shape) { NodePtr pred_value_node = pred_value->GetOwnerNode(); // Create SwitchCase node - const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; + const std::string switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; NodePtr switch_case = CreateSwitchCaseNode(graph, switch_case_name, pred_value, batch_shape); if (switch_case == nullptr) { - GELOGE(FAILED, "CreateSwitchCaseNode %s failed.", switch_case_name.c_str()); + GELOGE(FAILED, "CreateSwitchCaseNode %s fail.", switch_case_name.c_str()); return FAILED; } - for (const NodePtr &switch_n_node : switch_n_nodes_) { + for (NodePtr &switch_n_node : switch_n_nodes_) { if (BypassSwitchN(switch_n_node, switch_case) != SUCCESS) { - GELOGE(FAILED, "Bypass SwitchN %s failed.", switch_case_name.c_str()); + GELOGE(FAILED, "Bypass SwitchN %s fail.", switch_case_name.c_str()); return FAILED; } } // Add switchCase input edge if (GraphUtils::AddEdge(pred_value, switch_case->GetInDataAnchor(0)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add SwitchCase in_data_edge failed, %s->%s.", pred_value_node->GetName().c_str(), + GELOGE(FAILED, "Add SwitchCase in_data_edge fail, %s->%s.", pred_value_node->GetName().c_str(), switch_case->GetName().c_str()); return FAILED; } if (AttachLabel(switch_case) != SUCCESS) { - GELOGE(FAILED, "AttachLabel failed."); + GELOGE(FAILED, "AttachLabel fail."); return FAILED; } @@ -248,7 +248,7 @@ Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDat /// bool MultiBatchPass::CheckDims(const std::vector> &output_shape) const { if (output_shape.empty()) { - GELOGE(FAILED, "CheckDims failed: output_shape is empty."); + GELOGE(FAILED, "CheckDims fail: output_shape is empty."); return false; } @@ -257,7 +257,7 @@ bool MultiBatchPass::CheckDims(const std::vector> &output_s for (size_t i = 1; i < num; i++) { size_t tmp_dim_num = output_shape[i].size(); if (dim_num != tmp_dim_num) { - GELOGE(FAILED, "CheckDims failed: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num); + GELOGE(FAILED, "CheckDims fail: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num); return false; } } @@ -271,7 +271,7 @@ bool MultiBatchPass::CheckDims(const std::vector> &output_s for (size_t j = 1; j < num; j++) { int64_t tmp_dim_value = output_shape[j][i]; if (dim_value != tmp_dim_value) { - GELOGE(FAILED, "CheckDims failed: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i, + GELOGE(FAILED, "CheckDims fail: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i, dim_value, j, tmp_dim_value); return false; } @@ -289,41 +289,41 @@ bool MultiBatchPass::CheckDims(const std::vector> &output_s /// @param [in] batch_shape /// @return ge::NodePtr /// -NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, +NodePtr MultiBatchPass::CreateSwitchCaseNode(ComputeGraphPtr &graph, const std::string &name, const OutDataAnchorPtr &pred_value, const std::vector> &batch_shape) { OpDescPtr op_desc = MakeShared(name, STREAMSWITCHN); if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "Create op_desc fail, StreamSwitchN:%s.", name.c_str()); return nullptr; } GELOGI("Create StreamSwitchN op:%s.", name.c_str()); OpDescPtr pred_desc = pred_value->GetOwnerNode()->GetOpDesc(); if (pred_desc == nullptr) { - GELOGE(FAILED, "Get pred_desc failed, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "Get pred_desc fail, StreamSwitchN:%s.", name.c_str()); return nullptr; } if (op_desc->AddInputDesc(pred_desc->GetOutputDesc(pred_value->GetIdx())) != GRAPH_SUCCESS) { - GELOGE(FAILED, "AddInputDesc failed, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "AddInputDesc fail, StreamSwitchN:%s.", name.c_str()); return nullptr; } NodePtr switch_case_node = graph->AddNode(op_desc); if (switch_case_node == nullptr) { - GELOGE(FAILED, "Create node failed, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "Create node fail, StreamSwitchN:%s.", name.c_str()); return nullptr; } uint32_t batch_num = static_cast(batch_shape.size()); if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) { - GELOGE(FAILED, "set attr ATTR_NAME_BATCH_NUM failed, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "set attr ATTR_NAME_BATCH_NUM fail, StreamSwitchN:%s.", name.c_str()); return nullptr; } for (uint32_t i = 0; i < batch_num; i++) { - const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); + const std::string attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shape[i])) { - GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE fail, StreamSwitchN:%s.", name.c_str()); return nullptr; } } @@ -337,43 +337,43 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const /// @param [in] switch_case /// @return Status /// -Status MultiBatchPass::BypassSwitchN(const NodePtr &switch_n_node, const NodePtr &switch_case) { +Status MultiBatchPass::BypassSwitchN(NodePtr &switch_n_node, NodePtr &switch_case) { InDataAnchorPtr in_data_anchor = switch_n_node->GetInDataAnchor(SWITCH_DATA_INPUT); if (in_data_anchor == nullptr) { - GELOGE(FAILED, "Check in_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str()); + GELOGE(FAILED, "Check in_data_anchor fail, SwitchN:%s.", switch_n_node->GetName().c_str()); return FAILED; } OutDataAnchorPtr peer_data_anchor = in_data_anchor->GetPeerOutAnchor(); if (peer_data_anchor == nullptr) { - GELOGE(FAILED, "Check peer_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str()); + GELOGE(FAILED, "Check peer_data_anchor fail, SwitchN:%s.", switch_n_node->GetName().c_str()); return FAILED; } NodePtr data_input = peer_data_anchor->GetOwnerNode(); // Remove SwitchN data input if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Remove SwitchN in_data_edge failed, %s->%s.", data_input->GetName().c_str(), + GELOGE(FAILED, "Remove SwitchN in_data_edge fail, %s->%s.", data_input->GetName().c_str(), switch_n_node->GetName().c_str()); return FAILED; } if (GraphUtils::AddEdge(data_input->GetOutControlAnchor(), switch_case->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add StreamSwitchN in_control_edge failed, %s->%s.", data_input->GetName().c_str(), + GELOGE(FAILED, "Add StreamSwitchN in_control_edge fail, %s->%s.", data_input->GetName().c_str(), switch_case->GetName().c_str()); return FAILED; } // Add SwitchCase control output - for (const OutDataAnchorPtr &out_data_anchor : switch_n_node->GetAllOutDataAnchors()) { - for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + for (OutDataAnchorPtr &out_data_anchor : switch_n_node->GetAllOutDataAnchors()) { + for (InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { NodePtr data_output = peer_in_anchor->GetOwnerNode(); if ((GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) || (GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor) != GRAPH_SUCCESS)) { - GELOGE(FAILED, "Bypass SwitchN data_edge failed, %s->%s->%s.", data_input->GetName().c_str(), + GELOGE(FAILED, "Bypass SwitchN data_edge fail, %s->%s->%s.", data_input->GetName().c_str(), switch_n_node->GetName().c_str(), data_output->GetName().c_str()); return FAILED; } if (GraphUtils::AddEdge(switch_case->GetOutControlAnchor(), data_output->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add SwitchCase out_control_edge failed, %s->%s.", switch_case->GetName().c_str(), + GELOGE(FAILED, "Add SwitchCase out_control_edge fail, %s->%s.", switch_case->GetName().c_str(), data_output->GetName().c_str()); return FAILED; } @@ -390,17 +390,17 @@ Status MultiBatchPass::BypassSwitchN(const NodePtr &switch_n_node, const NodePtr /// @param [in] switch_case_node /// @return Status /// -Status MultiBatchPass::AttachLabel(const NodePtr &switch_case_node) { +Status MultiBatchPass::AttachLabel(NodePtr &switch_case_node) { std::vector stream_label_list; for (uint32_t i = 0; i < static_cast(batch_head_nodes_.size()); i++) { if (AttachBatchLabel(i) != SUCCESS) { - GELOGE(FAILED, "AttachBatchLabel failed, batch_idx=%u", i); + GELOGE(FAILED, "AttachBatchLabel fail, batch_idx=%u", i); return FAILED; } - const std::string &stream_label = "stream_label_batch_" + std::to_string(i); + const std::string stream_label = "stream_label_batch_" + std::to_string(i); if (AttachStreamLabel(i, stream_label) != SUCCESS) { - GELOGE(FAILED, "AttachStreamLabel failed, stream_label=%s", stream_label.c_str()); + GELOGE(FAILED, "AttachStreamLabel fail, stream_label=%s", stream_label.c_str()); return FAILED; } stream_label_list.emplace_back(stream_label); @@ -416,11 +416,11 @@ Status MultiBatchPass::AttachLabel(const NodePtr &switch_case_node) { /// Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { std::stack nodes; - for (const auto &node : batch_head_nodes_[batch_idx]) { + for (auto &node : batch_head_nodes_[batch_idx]) { nodes.push(node); } - const std::string &batch_label = "Batch_" + std::to_string(batch_idx); + const std::string batch_label = "Batch_" + std::to_string(batch_idx); std::unordered_set handled_nodes; while (!nodes.empty()) { NodePtr cur_node = nodes.top(); @@ -434,7 +434,7 @@ Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { if (cur_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) { std::string tmp_label; if (!AttrUtils::GetStr(cur_desc, ATTR_NAME_BATCH_LABEL, tmp_label)) { - GELOGE(FAILED, "get attr ATTR_NAME_BATCH_LABEL failed, node: %s.", cur_desc->GetName().c_str()); + GELOGE(FAILED, "get attr ATTR_NAME_BATCH_LABEL fail, node: %s.", cur_desc->GetName().c_str()); return FAILED; } if (tmp_label != batch_label) { @@ -445,14 +445,14 @@ Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { } GELOGD("Attach batch_label %s to node %s.", batch_label.c_str(), cur_desc->GetName().c_str()); if (!AttrUtils::SetStr(cur_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { - GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", cur_desc->GetName().c_str()); + GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL fail, node:%s.", cur_desc->GetName().c_str()); return FAILED; } - for (const auto &out_node : cur_node->GetOutAllNodes()) { + for (auto &out_node : cur_node->GetOutAllNodes()) { OpDescPtr op_desc = out_node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - const std::string &type = op_desc->GetType(); + const std::string type = op_desc->GetType(); if ((type == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { continue; } @@ -476,7 +476,7 @@ Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { /// Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string &stream_label) { std::stack nodes; - for (const auto &node : batch_head_nodes_[batch_idx]) { + for (auto &node : batch_head_nodes_[batch_idx]) { nodes.push(node); } @@ -493,11 +493,11 @@ Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string & GELOGD("Attach stream_label %s to node %s.", stream_label.c_str(), cur_desc->GetName().c_str()); if (SetStreamLabel(cur_node, stream_label) != SUCCESS) { - GELOGE(FAILED, "Set stream_label failed, node:%s.", cur_node->GetName().c_str()); + GELOGE(FAILED, "SetStreamLabel fail, node:%s.", cur_node->GetName().c_str()); return FAILED; } - for (const auto &out_node : cur_node->GetOutAllNodes()) { + for (auto &out_node : cur_node->GetOutAllNodes()) { nodes.push(out_node); } diff --git a/src/ge/graph/passes/multi_batch_pass.h b/src/ge/graph/passes/multi_batch_pass.h index 2e83262c..6e3f5e46 100644 --- a/src/ge/graph/passes/multi_batch_pass.h +++ b/src/ge/graph/passes/multi_batch_pass.h @@ -31,15 +31,14 @@ class MultiBatchPass : public GraphPass { Status FindPredValue(const ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value); bool CheckSwitchN(std::vector> &batch_shape); void FindSwitchOutNodes(uint32_t batch_num); - Status ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, + Status ReplaceSwitchN(ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value, const std::vector> &batch_shape); bool CheckDims(const std::vector> &output_shape) const; - NodePtr CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, - const OutDataAnchorPtr &pred_value, + NodePtr CreateSwitchCaseNode(ComputeGraphPtr &graph, const std::string &name, const OutDataAnchorPtr &pred_value, const std::vector> &batch_shape); - Status BypassSwitchN(const NodePtr &switch_n_node, const NodePtr &switch_case_node); - Status AttachLabel(const NodePtr &switch_case_node); + Status BypassSwitchN(NodePtr &switch_n_node, NodePtr &switch_case_node); + Status AttachLabel(NodePtr &switch_case_node); Status AttachBatchLabel(uint32_t batch_idx); Status AttachStreamLabel(uint32_t batch_idx, const std::string &stream_label); diff --git a/src/ge/graph/passes/next_iteration_pass.cc b/src/ge/graph/passes/next_iteration_pass.cc index c664ac53..138ad86b 100644 --- a/src/ge/graph/passes/next_iteration_pass.cc +++ b/src/ge/graph/passes/next_iteration_pass.cc @@ -16,8 +16,19 @@ #include "graph/passes/next_iteration_pass.h" +#include +#include +#include +#include +#include + #include "common/ge/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" #include "graph/common/omg_util.h" +#include "graph/debug/ge_attr_define.h" namespace ge { Status NextIterationPass::Run(ComputeGraphPtr graph) { @@ -30,24 +41,24 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) { if ((type != ENTER) && (type != REFENTER)) { continue; } - if (GroupEnterNode(node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Group enter_node %s failed.", node->GetName().c_str()); + if (HandleEnterNode(node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "HandleEnterNode for node %s fail.", node->GetName().c_str()); return INTERNAL_ERROR; } } if (FindWhileGroups() != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Find while groups failed."); + GELOGE(INTERNAL_ERROR, "FindWhileGroups fail"); return INTERNAL_ERROR; } if (!VerifyWhileGroup()) { - GELOGE(INTERNAL_ERROR, "Verify while groups failed."); + GELOGE(INTERNAL_ERROR, "VerifyWhileGroup fail"); return INTERNAL_ERROR; } if (HandleWhileGroup(graph) != SUCCESS) { - GELOGE(FAILED, "Handle while groups failed."); + GELOGE(FAILED, "HandleWhileGroup fail"); return FAILED; } @@ -56,16 +67,16 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) { } /// -/// @brief Group Enter node +/// @brief Handle Enter node /// @param [in] enter_node /// @return Status /// -Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { +Status NextIterationPass::HandleEnterNode(const NodePtr &enter_node) { OpDescPtr enter_desc = enter_node->GetOpDesc(); GE_CHECK_NOTNULL(enter_desc); std::string frame_name; if (!ge::AttrUtils::GetStr(enter_desc, ENTER_ATTR_FRAME_NAME, frame_name) || frame_name.empty()) { - GELOGE(FAILED, "Get attr ENTER_ATTR_FRAME_NAME failed, node: %s", enter_desc->GetName().c_str()); + GELOGE(FAILED, "Get attr ENTER_ATTR_FRAME_NAME fail, node: %s", enter_desc->GetName().c_str()); return FAILED; } @@ -73,7 +84,7 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { if (iter == loop_group_map_.end()) { LoopCondGroupPtr loop_group = MakeShared(); if (loop_group == nullptr) { - GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); + GELOGE(FAILED, "MakeShared for LoopCondGroup fail."); return FAILED; } loop_group->enter_nodes.emplace_back(enter_node); @@ -90,30 +101,30 @@ Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { /// @return Status /// Status NextIterationPass::FindWhileGroups() { - for (const auto &loop_group_iter : loop_group_map_) { - const std::string &frame_name = loop_group_iter.first; - for (const auto &enter_node : loop_group_iter.second->enter_nodes) { - for (const auto &out_node : enter_node->GetOutAllNodes()) { - const std::string &type = out_node->GetType(); + for (auto &loop_group_iter : loop_group_map_) { + const std::string frame_name = loop_group_iter.first; + for (auto &enter_node : loop_group_iter.second->enter_nodes) { + for (auto &out_node : enter_node->GetOutAllNodes()) { + const std::string type = out_node->GetType(); if ((type != MERGE) && (type != REFMERGE)) { continue; } NodePtr next_node = nullptr; if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get NextIteration node failed, frame_name: %s.", frame_name.c_str()); + GELOGE(INTERNAL_ERROR, "Get NextIteration node fail, frame_name: %s.", frame_name.c_str()); return INTERNAL_ERROR; } NodePtr switch_node = nullptr; if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get Switch node failed, frame_name: %s.", frame_name.c_str()); + GELOGE(INTERNAL_ERROR, "Get Switch node fail, frame_name: %s.", frame_name.c_str()); return INTERNAL_ERROR; } NodePtr loop_cond = nullptr; if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str()); + GELOGE(INTERNAL_ERROR, "Get LoopCond node fail, frame_name: %s.", frame_name.c_str()); return INTERNAL_ERROR; } @@ -137,21 +148,21 @@ Status NextIterationPass::FindWhileGroups() { /// bool NextIterationPass::VerifyWhileGroup() { // map - for (const auto &loop_group_iter : loop_group_map_) { - const std::string &frame_name = loop_group_iter.first; + for (auto &loop_group_iter : loop_group_map_) { + const std::string frame_name = loop_group_iter.first; if (frame_name.empty()) { - GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty."); + GELOGE(INTERNAL_ERROR, "VerifyWhileGroup fail, frame_name is empty."); return false; } if (loop_group_iter.second->loop_cond == nullptr) { - GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); + GELOGE(INTERNAL_ERROR, "VerifyWhileGroup fail, LoopCond is null, frame_name: %s.", frame_name.c_str()); return false; } - for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) { + for (auto &pair_iter : loop_group_iter.second->merge_next_pairs) { if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { - GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", + GELOGE(INTERNAL_ERROR, "VerifyWhileGroup fail, merge_node/next_node is null, frame_name: %s.", frame_name.c_str()); return false; } @@ -167,51 +178,51 @@ bool NextIterationPass::VerifyWhileGroup() { /// @return Status /// Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { - for (const auto &loop_cond_iter : loop_group_map_) { - const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); - GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); + for (auto &loop_cond_iter : loop_group_map_) { + std::string cond_name = loop_cond_iter.second->loop_cond->GetName(); + GELOGI("HandleWhileGroup, LoopCond node: %s.", cond_name.c_str()); - // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge + // Create Active node, Enter->Active->Merge, NextItaration->Active->Merge NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); if ((enter_active == nullptr) || (next_active == nullptr)) { - GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); + GELOGE(INTERNAL_ERROR, "CreateActiveNode fail, cond_name: %s.", cond_name.c_str()); return INTERNAL_ERROR; } - for (const auto &enter_node : loop_cond_iter.second->enter_nodes) { + for (auto &enter_node : loop_cond_iter.second->enter_nodes) { // Enter --> Active if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge failed."); + GELOGE(INTERNAL_ERROR, "Add control edge fail"); return INTERNAL_ERROR; } } - for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { + for (auto &pair : loop_cond_iter.second->merge_next_pairs) { NodePtr merge_node = pair.first; NodePtr next_node = pair.second; // Active --> Merge if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge failed."); + GELOGE(INTERNAL_ERROR, "Add control edge fail"); return INTERNAL_ERROR; } // NextIteration --> Active if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge failed."); + GELOGE(INTERNAL_ERROR, "Add control edge fail"); return INTERNAL_ERROR; } // break link between NextIteration and Merge if (BreakNextIteration(next_node, merge_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); + GELOGE(INTERNAL_ERROR, "BreakNextIteration failed"); return INTERNAL_ERROR; } } if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { - GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); + GELOGE(INTERNAL_ERROR, "SetActiveLabelList failed"); return INTERNAL_ERROR; } } @@ -234,12 +245,12 @@ NodePtr NextIterationPass::CreateActiveNode(ComputeGraphPtr &graph, const std::s GELOGI("Create StreamActive op:%s.", op_desc->GetName().c_str()); NodePtr active_node = graph->AddNode(op_desc); if (active_node == nullptr) { - GELOGE(INTERNAL_ERROR, "Create node[%s] failed.", name.c_str()); + GELOGE(INTERNAL_ERROR, "Create node[%s] fail.", name.c_str()); return nullptr; } if (SetSwitchBranchNodeLabel(active_node, name) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Set attr SWITCH_BRANCH_NODE_LABEL for node: %s failed.", active_node->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "SetSwitchBranchNodeLabel for node: %s failed.", active_node->GetName().c_str()); return nullptr; } @@ -257,18 +268,18 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr & GELOGE(PARAM_INVALID, "merge node or next node is null."); return PARAM_INVALID; } - for (const auto &in_anchor : merge_node->GetAllInDataAnchors()) { + for (auto &in_anchor : merge_node->GetAllInDataAnchors()) { OutDataAnchorPtr out_anchor = in_anchor->GetPeerOutAnchor(); if ((out_anchor == nullptr) || (out_anchor->GetOwnerNode() != next_node)) { continue; } if (GraphUtils::RemoveEdge(out_anchor, in_anchor) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Remove data edge failed, %s->%s.", next_node->GetName().c_str(), + GELOGE(INTERNAL_ERROR, "Remove data edge fail, %s->%s.", next_node->GetName().c_str(), merge_node->GetName().c_str()); return INTERNAL_ERROR; } if (SetNextIteration(merge_node, next_node->GetName()) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "SetNextIteration for node %s fail.", merge_node->GetName().c_str()); return INTERNAL_ERROR; } } @@ -291,16 +302,16 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string } std::vector nodes; if (is_input) { - for (const auto &tmp_node : node->GetInDataNodes()) { + for (auto &tmp_node : node->GetInDataNodes()) { nodes.emplace_back(tmp_node); } } else { - for (const auto &tmp_node : node->GetOutDataNodes()) { + for (auto &tmp_node : node->GetOutDataNodes()) { nodes.emplace_back(tmp_node); } } - for (const auto &tmp_node : nodes) { + for (auto &tmp_node : nodes) { const std::string type = tmp_node->GetType(); if ((target_type == LOOPCOND) && (type == target_type)) { target_node = tmp_node; @@ -312,14 +323,13 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string } if (target_node == nullptr) { - GELOGE(INTERNAL_ERROR, "Find node %s failed.", target_type.c_str()); + GELOGE(INTERNAL_ERROR, "Find node %s fail", target_type.c_str()); return INTERNAL_ERROR; } return SUCCESS; } - /// -/// @brief Clear Status, used for subgraph pass +/// @brief Clear Status, uesd for subgraph pass /// @return SUCCESS /// Status NextIterationPass::ClearStatus() { diff --git a/src/ge/graph/passes/next_iteration_pass.h b/src/ge/graph/passes/next_iteration_pass.h index 4cdf4b51..4bbced4f 100644 --- a/src/ge/graph/passes/next_iteration_pass.h +++ b/src/ge/graph/passes/next_iteration_pass.h @@ -17,6 +17,12 @@ #ifndef GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ #define GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ +#include +#include +#include +#include +#include + #include "inc/graph_pass.h" struct LoopCondGroup { @@ -31,64 +37,15 @@ namespace ge { class NextIterationPass : public GraphPass { public: Status Run(ComputeGraphPtr graph); - - /// - /// @brief Clear Status, used for subgraph pass - /// @return SUCCESS - /// Status ClearStatus() override; private: - /// - /// @brief Group Enter node - /// @param [in] enter_node - /// @return Status - /// - Status GroupEnterNode(const NodePtr &enter_node); - - /// - /// @brief Find while groups - /// @return Status - /// + Status HandleEnterNode(const NodePtr &enter_node); Status FindWhileGroups(); - - /// - /// @brief Verify if valid - /// @return bool - /// bool VerifyWhileGroup(); - - /// - /// @brief Handle while group - /// @param [in] graph - /// @return Status - /// Status HandleWhileGroup(ComputeGraphPtr &graph); - - /// - /// @brief Create Active Node - /// @param [in] graph - /// @param [in] name - /// @return ge::NodePtr - /// NodePtr CreateActiveNode(ComputeGraphPtr &graph, const std::string &name); - - /// - /// @brief Break NextIteration Link & add name to merge attr - /// @param [in] next_node - /// @param [in] merge_node - /// @return Status - /// Status BreakNextIteration(const NodePtr &next_node, NodePtr &merge_node); - - /// - /// @brief find target node - /// @param [in] node - /// @param [in] target_type - /// @param [in] is_input - /// @param [out] target_node - /// @return Status - /// Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node); // map diff --git a/src/ge/graph/passes/pass_manager.cc b/src/ge/graph/passes/pass_manager.cc index 5be54f0a..eec33eef 100644 --- a/src/ge/graph/passes/pass_manager.cc +++ b/src/ge/graph/passes/pass_manager.cc @@ -19,7 +19,6 @@ #include "common/types.h" #include "common/util.h" #include "graph/utils/node_utils.h" -#include "graph/common/ge_call_wrapper.h" #include "omg/omg_inner_types.h" namespace ge { diff --git a/src/ge/graph/passes/permute_pass.cc b/src/ge/graph/passes/permute_pass.cc index 3c0dfd4e..f5fd9dc5 100644 --- a/src/ge/graph/passes/permute_pass.cc +++ b/src/ge/graph/passes/permute_pass.cc @@ -33,6 +33,7 @@ using domi::TENSORFLOW; namespace ge { Status PermutePass::Run(ComputeGraphPtr graph) { + GE_TIMESTAMP_START(PermutePass); GE_CHECK_NOTNULL(graph); std::vector isolate_nodes; for (NodePtr &node : graph->GetDirectNode()) { @@ -115,6 +116,8 @@ Status PermutePass::Run(ComputeGraphPtr graph) { GE_RETURN_WITH_LOG_IF_ERROR(graph->RemoveNode(node), "[%s]:remove permute node failed", node->GetOpDesc()->GetName().c_str()); }); + + GE_TIMESTAMP_END(PermutePass, "GraphManager::PermutePass"); return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/passes/print_op_pass.h b/src/ge/graph/passes/print_op_pass.h index 15b0badc..64bf6573 100644 --- a/src/ge/graph/passes/print_op_pass.h +++ b/src/ge/graph/passes/print_op_pass.h @@ -31,6 +31,6 @@ class PrintOpPass : public BaseNodePass { public: Status Run(ge::NodePtr &node) override; }; -} // namespace ge +}; // namespace ge #endif // GE_GRAPH_PASSES_PRINT_OP_PASS_H_ diff --git a/src/ge/graph/passes/ref_identity_delete_op_pass.cc b/src/ge/graph/passes/ref_identity_delete_op_pass.cc deleted file mode 100644 index 5bc0fad6..00000000 --- a/src/ge/graph/passes/ref_identity_delete_op_pass.cc +++ /dev/null @@ -1,225 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "ref_identity_delete_op_pass.h" -#include -#include -#include "graph/common/transop_util.h" - -namespace ge { -Status RefIdentityDeleteOpPass::Run(ComputeGraphPtr graph) { - GE_CHECK_NOTNULL(graph); - for (auto &node : graph->GetAllNodes()) { - if (node->GetType() != REFIDENTITY) { - continue; - } - int input_index = 0; - NodePtr ref_node = GetRefNode(node, input_index); - CHECK_FALSE_EXEC(GetRefNode(node, input_index) != nullptr, - GELOGE(FAILED, "Ref node of RefIdentity[%s] not found", node->GetName().c_str()); - return FAILED); - CHECK_FALSE_EXEC(DealNoOutputRef(ref_node, node, input_index, graph) == SUCCESS, - GELOGE(FAILED, "Ref identity [%s] delete failed", node->GetName().c_str()); - return FAILED); - } - return SUCCESS; -} - -NodePtr RefIdentityDeleteOpPass::GetRefNode(const NodePtr &node, int &input_index) { - OutDataAnchorPtr out_anchor = node->GetOutDataAnchor(0); - CHECK_FALSE_EXEC(out_anchor != nullptr, return nullptr); - for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { - CHECK_FALSE_EXEC(peer_in_anchor != nullptr, continue); - auto peer_node = peer_in_anchor->GetOwnerNode(); - CHECK_FALSE_EXEC(peer_node != nullptr, continue); - const auto &peer_op_desc = peer_node->GetOpDesc(); - CHECK_FALSE_EXEC(peer_op_desc != nullptr, return nullptr); - const auto &peer_input_desc = peer_op_desc->GetInputDescPtr(static_cast(peer_in_anchor->GetIdx())); - if (!peer_input_desc->GetRefPortIndex().empty()) { - input_index = peer_in_anchor->GetIdx(); - return peer_node; - } - } - return nullptr; -} - -Status RefIdentityDeleteOpPass::DealNoOutputRef(const NodePtr &node, const NodePtr &ref_identity, int input_index, - const ComputeGraphPtr &graph) { - NodePtr first_node = nullptr; - NodePtr variable_ref = GetVariableRef(node, ref_identity, first_node); - if (variable_ref == nullptr) { - GELOGE(FAILED, "[RefIdentityDeleteOpPass]Can not find variable ref for %s:%d", node->GetName().c_str(), - input_index); - return FAILED; - } - if (first_node->GetName() != variable_ref->GetName()) { - // Remove the control edge between ref node and variable ref - // Add a control edge between ref node and trans node - // +-----------+ +-----------+ - // +---------+RefIdentity| +-----------+RefIdentity| - // | +-----+-----+ | +-----+-----+ - // | | | | - // | v | v - // +-----v-----+ +----+----+ +-----v-----+ +----+----+ - // | TransNode | | RefNode | ==> | TransNode +<--C--+ RefNode | - // +-----+-----+ +----+----+ +-----+-----+ +---------+ - // | | | - // v C v - // +-----+-----+ | +-----+-----+ - // |VariableRef+<--------+ |VariableRef| - // +-----------+ +-----------+ - auto ret = ge::GraphUtils::AddEdge(node->GetOutControlAnchor(), first_node->GetInControlAnchor()); - if (ret != SUCCESS) { - GELOGE(FAILED, "Add control edge between ref node and trans node failed"); - return FAILED; - } - ret = ge::GraphUtils::RemoveEdge(node->GetOutControlAnchor(), variable_ref->GetInControlAnchor()); - if (ret != SUCCESS) { - GELOGE(FAILED, "Remove control edge between ref node and its peer node failed"); - return FAILED; - } - } else { - // +-----------+ +-----------+ - // +-----------+RefIdentity| +-----------+RefIdentity| - // | +-----+-----+ | +-----+-----+ - // | | | | - // | v | v - // +-----v-----+ +----+----+ +-----v-----+ +----+----+ - // |VariableRef+<--C--+ RefNode | ==> |VariableRef+<--C--+ RefNode | - // +-----+-----+ +----+----+ +-----------+ +----+----+ - // | | | - // | v v - // | +---+----+ +---+----+ - // +-----C------>+ | | | - // +--------+ +--------+ - auto ret = RemoveUselessControlEdge(node, variable_ref); - if (ret != SUCCESS) { - GELOGE(FAILED, "Remove useless control edge failed."); - return FAILED; - } - } - // remove ref identity - if (GraphUtils::IsolateNode(ref_identity, {0}) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Isolate removed node: %s, type: %s failed", ref_identity->GetName().c_str(), - variable_ref->GetType().c_str()); - return FAILED; - } - if (GraphUtils::RemoveNodeWithoutRelink(graph, ref_identity) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Remove node: %s, type: %s without relink failed", ref_identity->GetName().c_str(), - ref_identity->GetType().c_str()); - return FAILED; - } - return SUCCESS; -} - -ge::NodePtr RefIdentityDeleteOpPass::GetVariableRef(const NodePtr &ref, const NodePtr &ref_identity, - NodePtr &first_node) { - const auto &ref_identity_out_anchor = ref_identity->GetOutDataAnchor(0); - if (ref_identity_out_anchor == nullptr) { - return nullptr; - } - for (auto &peer_in_anchor : ref_identity_out_anchor->GetPeerInDataAnchors()) { - const auto &peer_node = peer_in_anchor->GetOwnerNode(); - if (peer_node == nullptr || peer_node->GetName() == ref->GetName()) { - continue; - } - // DFS to find variable ref node. - std::stack nodes_to_check; - nodes_to_check.push(peer_node); - GELOGI("[RefIdentityDeleteOpPass]Start to search variable ref node from %s.", peer_node->GetName().c_str()); - NodePtr cur_node = nullptr; - while (!nodes_to_check.empty()) { - cur_node = nodes_to_check.top(); - nodes_to_check.pop(); - const auto &type = cur_node->GetType(); - if (type == VARIABLE && CheckControlEdge(ref, cur_node)) { - // Target variable ref node found. - GELOGI("[RefIdentityDeleteOpPass]variable ref node[%s] found.", cur_node->GetName().c_str()); - first_node = peer_node; - return cur_node; - } - - int data_index = TransOpUtil::GetTransOpDataIndex(type); - if (data_index < 0) { - GELOGI("[RefIdentityDeleteOpPass]Find node[%s] that is not trans op[%s], stop to search its output.", - cur_node->GetName().c_str(), type.c_str()); - continue; - } - const auto &cur_out_anchor = cur_node->GetOutDataAnchor(0); - if (cur_out_anchor == nullptr) { - GELOGI("[RefIdentityDeleteOpPass]Get out anchor of [%s] failed, stop to search its output.", - cur_node->GetName().c_str()); - continue; - } - for (const auto &cur_peer_in_anchor : cur_out_anchor->GetPeerInDataAnchors()) { - const auto &cur_peer_node = cur_peer_in_anchor->GetOwnerNode(); - if (cur_peer_node == nullptr) { - continue; - } - nodes_to_check.push(cur_peer_node); - } - } - GELOGI("[RefIdentityDeleteOpPass]Can not find variable ref node from %s.", peer_node->GetName().c_str()); - } - GELOGI("[RefIdentityDeleteOpPass]Can not find variable ref node, return nullptr."); - return nullptr; -} - -bool RefIdentityDeleteOpPass::CheckControlEdge(const NodePtr &ref, const NodePtr &variable_ref) { - const auto &control_out_anchor = ref->GetOutControlAnchor(); - if (control_out_anchor == nullptr) { - return false; - } - const string &variable_ref_name = variable_ref->GetName(); - for (const auto &peer_in_control_anchor : control_out_anchor->GetPeerInControlAnchors()) { - const auto &node = peer_in_control_anchor->GetOwnerNode(); - if (node != nullptr && node->GetName() == variable_ref_name) { - return true; - } - } - return false; -} - -Status RefIdentityDeleteOpPass::RemoveUselessControlEdge(const NodePtr &ref, const NodePtr &variable_ref) { - map out_nodes_map; - for (const auto &out_anchor : ref->GetAllOutDataAnchors()) { - for (const auto &peer_in_anchor : out_anchor->GetPeerAnchors()) { - const auto &peer_node = peer_in_anchor->GetOwnerNode(); - if (peer_node == nullptr) { - continue; - } - out_nodes_map[peer_node->GetName()] = peer_node; - } - } - const auto &out_control_anchor = variable_ref->GetOutControlAnchor(); - GE_CHECK_NOTNULL(out_control_anchor); - for (const auto &peer_in_control_anchor : out_control_anchor->GetPeerInControlAnchors()) { - const auto &peer_node = peer_in_control_anchor->GetOwnerNode(); - if (peer_node == nullptr) { - continue; - } - if (out_nodes_map.find(peer_node->GetName()) != out_nodes_map.end()) { - auto ret = ge::GraphUtils::RemoveEdge(out_control_anchor, peer_in_control_anchor); - if (ret != SUCCESS) { - GELOGE(FAILED, "Remove control edge between variable ref node[%s] and ref node's peer node[%s] failed", - variable_ref->GetName().c_str(), peer_node->GetName().c_str()); - return FAILED; - } - } - } - return SUCCESS; -} -} // namespace ge diff --git a/src/ge/graph/passes/ref_identity_delete_op_pass.h b/src/ge/graph/passes/ref_identity_delete_op_pass.h deleted file mode 100644 index 3e42def4..00000000 --- a/src/ge/graph/passes/ref_identity_delete_op_pass.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_PASSES_REF_IDENTITY_DELETE_OP_PASS_H_ -#define GE_GRAPH_PASSES_REF_IDENTITY_DELETE_OP_PASS_H_ - -#include -#include -#include "framework/common/ge_inner_error_codes.h" -#include "inc/graph_pass.h" - -namespace ge { -class RefIdentityDeleteOpPass : public GraphPass { - public: - Status Run(ComputeGraphPtr graph); - - private: - Status DealNoOutputRef(const NodePtr &node, const NodePtr &ref_identity, int input_index, - const ComputeGraphPtr &graph); - NodePtr GetVariableRef(const NodePtr &ref, const NodePtr &ref_identity, NodePtr &first_node); - bool CheckControlEdge(const NodePtr &ref, const NodePtr &variable_ref); - Status RemoveUselessControlEdge(const NodePtr &ref, const NodePtr &variable_ref); - NodePtr GetRefNode(const NodePtr &node, int &input_index); -}; -} // namespace ge - -#endif // GE_GRAPH_PASSES_REF_IDENTITY_DELETE_OP_PASS_H_ diff --git a/src/ge/graph/passes/reshape_recovery_pass.cc b/src/ge/graph/passes/reshape_recovery_pass.cc index 07b08de9..787c8d83 100644 --- a/src/ge/graph/passes/reshape_recovery_pass.cc +++ b/src/ge/graph/passes/reshape_recovery_pass.cc @@ -30,10 +30,6 @@ NodePtr CreateReshape(const ConstGeTensorDescPtr &src, const ConstGeTensorDescPt if (ret != GRAPH_SUCCESS) { return nullptr; } - ret = reshape->AddInputDesc("shape", GeTensorDesc(GeShape(), Format(), DT_INT32)); - if (ret != GRAPH_SUCCESS) { - return nullptr; - } ret = reshape->AddOutputDesc("y", *dst); if (ret != GRAPH_SUCCESS) { return nullptr; @@ -53,10 +49,7 @@ Status InsertReshapeIfNeed(const NodePtr &node) { GE_CHECK_NOTNULL(dst_node); GE_CHECK_NOTNULL(dst_node->GetOpDesc()); auto dst_tensor = dst_node->GetOpDesc()->GetInputDescPtr(dst_anchor->GetIdx()); - bool is_need_insert_reshape = src_tensor->GetShape().GetDims() != UNKNOWN_RANK && - dst_tensor->GetShape().GetDims() != UNKNOWN_RANK && - src_tensor->GetShape().GetDims() != dst_tensor->GetShape().GetDims(); - if (is_need_insert_reshape) { + if (src_tensor->GetShape().GetDims() != dst_tensor->GetShape().GetDims()) { auto reshape = CreateReshape(src_tensor, dst_tensor, node->GetOwnerComputeGraph()); GE_CHECK_NOTNULL(reshape); auto ret = GraphUtils::InsertNodeBetweenDataAnchors(src_anchor, dst_anchor, reshape); diff --git a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc index d51f52e1..3b4e4c19 100644 --- a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc +++ b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc @@ -22,6 +22,7 @@ #include #include "common/ge_inner_error_codes.h" #include "common/types.h" +#include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" @@ -116,44 +117,20 @@ void SameTransdataBreadthFusionPass::InsertSameTransdataNodeIndex(int anchors_in same_transdata_nodes.push_back(anchors_index); } -std::set SameTransdataBreadthFusionPass::GetInControlIdentityNodes(const NodePtr &node, - int subgraph_index) { - std::set in_node_names; - for (const auto &in_node : node->GetInControlNodes()) { - if (in_node->GetType() == IDENTITY) { - in_node_names.insert(in_node->GetName()); - } - } - for (const auto &subgraph_node : before_transdata_nodes_[subgraph_index]) { - for (const auto &in_node : subgraph_node->GetInControlNodes()) { - if (in_node->GetType() == IDENTITY) { - in_node_names.insert(in_node->GetName()); - } - } - } - GELOGD("control in nodes for %s(%d): %zu", node->GetName().c_str(), subgraph_index, in_node_names.size()); - return in_node_names; -} - void SameTransdataBreadthFusionPass::GetSameTransdataNode(vector &same_transdata_nodes) { auto iter = all_transdata_nodes_.begin(); same_transdata_nodes.push_back(iter->first); - auto node_for_compare_in_anchor = iter->second; GE_CHECK_NOTNULL_JUST_RETURN(node_for_compare_in_anchor); auto node_for_compare = node_for_compare_in_anchor->GetOwnerNode(); - - // Get op-desc, input/output desc, in-control-edges-from-identity, as the compare-key auto op_desc_for_compare = node_for_compare->GetOpDesc(); GE_CHECK_NOTNULL_JUST_RETURN(op_desc_for_compare); string op_compare_stream_label; (void)AttrUtils::GetStr(op_desc_for_compare, ATTR_NAME_STREAM_LABEL, op_compare_stream_label); - auto op_compare_in_ctrl_nodes = GetInControlIdentityNodes(node_for_compare, iter->first); auto input_desc_for_compare = op_desc_for_compare->GetInputDescPtr(node_for_compare_in_anchor->GetIdx()); GE_CHECK_NOTNULL_JUST_RETURN(input_desc_for_compare); auto output_desc_for_compare = op_desc_for_compare->GetOutputDescPtr(0); GE_CHECK_NOTNULL_JUST_RETURN(output_desc_for_compare); - iter = all_transdata_nodes_.erase(iter); while (iter != all_transdata_nodes_.end()) { auto in_anchor = iter->second; @@ -172,14 +149,12 @@ void SameTransdataBreadthFusionPass::GetSameTransdataNode(vector &same_tran auto output_desc_tmp = op_desc_tmp->GetOutputDescPtr(0); string op_tmp_stream_label; (void)AttrUtils::GetStr(op_desc_tmp, ATTR_NAME_STREAM_LABEL, op_tmp_stream_label); - auto op_tmp_in_ctrl_nodes = GetInControlIdentityNodes(node_tmp, iter->first); GE_CHECK_NOTNULL_JUST_RETURN(input_desc_tmp); GE_CHECK_NOTNULL_JUST_RETURN(output_desc_tmp); if ((op_compare_stream_label == op_tmp_stream_label) && (input_desc_tmp->GetFormat() == input_desc_for_compare->GetFormat()) && - (output_desc_tmp->GetFormat() == output_desc_for_compare->GetFormat()) && - (op_compare_in_ctrl_nodes == op_tmp_in_ctrl_nodes)) { + (output_desc_tmp->GetFormat() == output_desc_for_compare->GetFormat())) { GELOGD("same transdata node:%s, src node:%s", node_tmp->GetName().c_str(), node_for_compare->GetName().c_str()); InsertSameTransdataNodeIndex(iter->first, same_transdata_nodes); iter = all_transdata_nodes_.erase(iter); @@ -364,13 +339,14 @@ graphStatus SameTransdataBreadthFusionPass::ReLinkTransdataControlOutput2PreNode } graphStatus SameTransdataBreadthFusionPass::Run(ComputeGraphPtr graph) { + GE_TIMESTAMP_START(SameTransdataBreadthFusionPass); GELOGI("[SameTransdataBreadthFusionPass]: optimize begin."); if (graph == nullptr) { return GRAPH_SUCCESS; } for (auto &node : graph->GetDirectNode()) { - if (IsTransOp(node) || node->GetOutDataNodesSize() <= 1) { + if (IsTransOp(node) || node->GetOutDataNodes().size() <= 1) { continue; } @@ -398,6 +374,7 @@ graphStatus SameTransdataBreadthFusionPass::Run(ComputeGraphPtr graph) { } GELOGI("[SameTransdataBreadthFusionPass]: Optimize success."); + GE_TIMESTAMP_END(SameTransdataBreadthFusionPass, "GraphManager::SameTransdataBreadthFusionPass"); return GRAPH_SUCCESS; } diff --git a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.h b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.h index a6a3bb26..f4b44a59 100644 --- a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.h +++ b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.h @@ -42,7 +42,7 @@ class SameTransdataBreadthFusionPass : public GraphPass { void GetSubGraphNodesInfo(); void EraseInvalidAnchorsPair(); - std::set GetInControlIdentityNodes(const NodePtr &node, int subgraph_index); + OpDescPtr GetCastOp(const GeTensorDesc &in_desc, const GeTensorDesc &out_desc); graphStatus AddCastNode(const ComputeGraphPtr &graph, int anchors_index, OutDataAnchorPtr &pre_out_anchor, diff --git a/src/ge/graph/passes/subgraph_pass.cc b/src/ge/graph/passes/subgraph_pass.cc index 80ce995a..d759aa12 100644 --- a/src/ge/graph/passes/subgraph_pass.cc +++ b/src/ge/graph/passes/subgraph_pass.cc @@ -15,6 +15,7 @@ */ #include "graph/passes/subgraph_pass.h" +#include #include "graph/utils/node_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" @@ -66,13 +67,13 @@ Status SubgraphPass::Run(ComputeGraphPtr graph) { /** * @ingroup ge - * @brief Check Subgraph Input node + * @brief Check Subgraph NetOutput node * @param [in] graph: ComputeGraph. - * @param [in] node: Data node in Subgraph. + * @param [in] node: NetOutput node in Subgraph. * @return: 0 for SUCCESS / others for FAILED */ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodePtr &node) { - GELOGD("Handle input_node %s for graph %s.", node->GetName().c_str(), graph->GetName().c_str()); + GELOGD("Hadle input_node %s for graph %s.", node->GetName().c_str(), graph->GetName().c_str()); // Data has and only has one output bool input_continues_required_flag = false; OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(0); @@ -85,7 +86,7 @@ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodeP // Data->InputContinuesRequiredOp in subgraph need memcpy. if (input_continues_required_flag) { GELOGD("Data %s output_node required continues input.", node->GetName().c_str()); - std::string name = node->GetName() + "_output_0_Memcpy"; + std::string name = node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); if (InsertMemcpyNode(graph, out_data_anchor, in_anchors, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy after %s failed.", node->GetName().c_str()); return FAILED; @@ -122,7 +123,7 @@ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodeP GE_CHECK_NOTNULL(peer_out_anchor); GELOGD("Constant input %s links to While %s.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), parent_node->GetName().c_str()); - std::string name = parent_node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; + std::string name = in_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); if (InsertMemcpyNode(parent_graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy between %s and %s failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), parent_node->GetName().c_str()); @@ -135,7 +136,7 @@ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodeP /** * @ingroup ge - * @brief Check Subgraph Output node + * @brief Check Subgraph NetOutput node * @param [in] graph: ComputeGraph. * @param [in] node: NetOutput node in Subgraph. * @return: 0 for SUCCESS / others for FAILED @@ -152,14 +153,14 @@ Status SubgraphPass::SubgraphOutputNode(const ComputeGraphPtr &graph, const Node // 1. Const->NetOutput in subgraph // 2. AtomicOp->NetOutput in subgraph // 3. OutputContinuesRequiredOp->NetOutput in subgraph - // 4. Data->NetOutput in subgraph but parent_node is not while + // 4. Data->NetOutput in subgraph but not while body std::string op_type; bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) || IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) || - ((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)); + ((in_node->GetType() == DATA) && !IsWhileBodyOutput(in_data_anchor)); if (insert_flag) { - GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); - std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; + GELOGI("Insert MemcpyAsync node between %s and %s.", node->GetName().c_str(), in_node->GetName().c_str()); + std::string name = in_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str()); return FAILED; @@ -185,8 +186,8 @@ Status SubgraphPass::WhileInputNodes(const ComputeGraphPtr &graph, const NodePtr GE_CHECK_NOTNULL(in_node); // Input->While and Input link to other nodes need insert memcpy if (peer_out_anchor->GetPeerInDataAnchors().size() > 1) { - GELOGD("Input %s of While %s links to other nodes.", in_node->GetName().c_str(), node->GetName().c_str()); - std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; + GELOGI("Input %s of While %s links to other nodes.", in_node->GetName().c_str(), node->GetName().c_str()); + std::string name = in_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str()); return FAILED; @@ -205,121 +206,231 @@ Status SubgraphPass::WhileInputNodes(const ComputeGraphPtr &graph, const NodePtr * @return: 0 for SUCCESS / others for FAILED */ Status SubgraphPass::WhileBodySubgraph(const ComputeGraphPtr &graph, const NodePtr &node) { - // index of body_subgraph is 1 - ComputeGraphPtr while_body = NodeUtils::GetSubgraph(*node, 1); + ComputeGraphPtr while_body = GetWhileBodySubgraph(graph, node); if (while_body == nullptr) { GELOGE(FAILED, "while_body of %s is NULL.", node->GetName().c_str()); return FAILED; } - std::vector data_nodes; - std::set bypass_index; - NodePtr output_node = nullptr; - for (const auto &n : while_body->GetDirectNode()) { - const std::string &type = n->GetType(); - if (type == DATA) { - if (CheckInsertInputMemcpy(n, bypass_index)) { - data_nodes.emplace_back(n); - } - } else if (type == NETOUTPUT) { - if (output_node == nullptr) { - output_node = n; - } else { - GELOGE(FAILED, "while_body %s exists multi NetOutput nodes.", while_body->GetName().c_str()); - return FAILED; - } - } - } + NodePtr output_node = while_body->FindFirstNodeMatchType(NETOUTPUT); if (output_node == nullptr) { - GELOGE(FAILED, "while_body %s has no output.", while_body->GetName().c_str()); + GELOGE(FAILED, "net_output_node not exist in graph %s.", while_body->GetName().c_str()); return FAILED; } + OpDescPtr output_desc = output_node->GetOpDesc(); + GE_CHECK_NOTNULL(output_desc); + std::unordered_map> node_to_attr_index; + for (const InDataAnchorPtr &in_data_anchor : output_node->GetAllInDataAnchors()) { + uint32_t index = 0; + if (!AttrUtils::GetInt(output_desc->GetInputDesc(in_data_anchor->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index)) { + GELOGE(FAILED, "Get attr PARENT_NODE_INDEX failed, node %s:%u.", output_node->GetName().c_str(), + in_data_anchor->GetIdx()); + return FAILED; + } + MarkOutputIndex(in_data_anchor->GetPeerOutAnchor(), index, node_to_attr_index); + } - if ((InsertInputMemcpy(while_body, data_nodes) != SUCCESS) || - (InsertOutputMemcpy(while_body, output_node, bypass_index) != SUCCESS)) { - GELOGE(FAILED, "Insert memcpy node in while_body %s failed.", while_body->GetName().c_str()); - return FAILED; + std::set data_nodes; + std::set netoutput_input_indexes; + GetExchangeInOut(node_to_attr_index, data_nodes, netoutput_input_indexes); + return InsertMemcpyInWhileBody(while_body, data_nodes, output_node, netoutput_input_indexes); +} + +/** + * @ingroup ge + * @brief Get body subgraph of While op + * @param [in] graph: ComputeGraph. + * @param [in] node: While node. + * @return: body subgraph + */ +ComputeGraphPtr SubgraphPass::GetWhileBodySubgraph(const ComputeGraphPtr &graph, const NodePtr &node) { + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(FAILED, "op_desc is NULL."); + return nullptr; } - return SUCCESS; + const std::vector &subgraph_instance_names = op_desc->GetSubgraphInstanceNames(); + std::string body_instance_name; + for (const std::string &instance_name : subgraph_instance_names) { + std::string subgraph_name; + if (op_desc->GetSubgraphNameByInstanceName(instance_name, subgraph_name) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Get subgraph_name by instance_name %s failed, node:%s.", instance_name.c_str(), + node->GetName().c_str()); + return nullptr; + } + if (subgraph_name == ATTR_NAME_WHILE_BODY) { + body_instance_name = instance_name; + break; + } + } + + ComputeGraphPtr root_graph = GraphUtils::FindRootGraph(graph); + if (root_graph == nullptr) { + GELOGE(FAILED, "root_graph is NULL."); + return nullptr; + } + + return root_graph->GetSubgraph(body_instance_name); } /** * @ingroup ge - * @brief Insert input memcpy node in while_body - * @param [in] graph: while_body - * @param [in] data_nodes: data_nodes - * @return: 0 for SUCCESS / others for FAILED + * @brief Mark output parent_node_index + * @param [in] peer_out_anchor: peer_out_anchor of NetOutput + * @param [in] index: parent_node_index of NetOutput + * @param [out] node_to_attr_index: key for node in subgraph, value for parent_node_index + * @return: void */ -Status SubgraphPass::InsertInputMemcpy(const ComputeGraphPtr &graph, const std::vector &data_nodes) { - if (data_nodes.empty()) { - GELOGD("No need to insert input memcpy node in while_body %s.", graph->GetName().c_str()); - return SUCCESS; +void SubgraphPass::MarkOutputIndex(const OutDataAnchorPtr &peer_out_anchor, uint32_t index, + std::unordered_map> &node_to_attr_index) { + if (peer_out_anchor == nullptr) { + return; + } + std::set visited_nodes; + std::stack nodes; + nodes.emplace(peer_out_anchor->GetOwnerNode()); + while (!nodes.empty()) { + NodePtr cur_node = nodes.top(); + nodes.pop(); + if (visited_nodes.count(cur_node) > 0) { + continue; + } + node_to_attr_index[cur_node].emplace_back(index); + for (const NodePtr &in_node : cur_node->GetInDataNodes()) { + nodes.emplace(in_node); + } + visited_nodes.emplace(cur_node); } +} + +/** + * @ingroup ge + * @brief Get data_nodes / input_indexes of netoutput if need insert memcpy + * @param [in] node_to_attr_index: key for node in subgraph, value for parent_node_index + * @param [out] data_nodes: data_nodes need insert memcpy + * @param [out] netoutput_input_indexes: input_indexes of netoutput need insert memcpy + * @return: void + */ +void SubgraphPass::GetExchangeInOut(const std::unordered_map> &node_to_attr_index, + std::set &data_nodes, std::set &netoutput_input_indexes) { + for (const auto &item : node_to_attr_index) { + NodePtr node = item.first; + uint32_t input_index = 0; + if ((node->GetType() != DATA) || !AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, input_index)) { + continue; + } + if (item.second.empty() || ((item.second.size() == 1) && (item.second[0] == input_index))) { + continue; + } + data_nodes.emplace(node); - std::string in_name = graph->GetName() + "_input_Memcpy"; - OpDescBuilder in_builder(in_name, MEMCPYASYNC); - for (size_t i = 0; i < data_nodes.size(); i++) { // Data node has and only has one output - in_builder.AddInput("x" + std::to_string(i), data_nodes[i]->GetOpDesc()->GetOutputDesc(0)) - .AddOutput("y" + std::to_string(i), data_nodes[i]->GetOpDesc()->GetOutputDesc(0)); + OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(0); + if (out_data_anchor == nullptr) { + continue; + } + for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + NodePtr out_node = peer_in_anchor->GetOwnerNode(); + if ((out_node->GetType() != NETOUTPUT) || (out_node->GetOpDesc() == nullptr)) { + continue; + } + uint32_t output_index = 0; + GeTensorDesc input_tensor = out_node->GetOpDesc()->GetInputDesc(peer_in_anchor->GetIdx()); + if (!AttrUtils::GetInt(input_tensor, ATTR_NAME_PARENT_NODE_INDEX, output_index)) { + continue; + } + if (input_index != output_index) { + netoutput_input_indexes.emplace(peer_in_anchor->GetIdx()); + } + } } - GELOGD("Insert memcpy after data_nodes of while_body %s.", graph->GetName().c_str()); - NodePtr in_memcpy = graph->AddNode(in_builder.Build()); - GE_CHECK_NOTNULL(in_memcpy); - for (size_t i = 0; i < data_nodes.size(); i++) { +} + +/** + * @ingroup ge + * @brief Insert memcpy node in while_body + * @param [in] graph: while_body + * @param [in] data_nodes: data_nodes need insert memcpy + * @param [in] output_node: NetOutput in while_body + * @param [in] netoutput_input_indexes: input_indexes of netoutput need insert memcpy + * @return: 0 for SUCCESS / others for FAILED + */ +Status SubgraphPass::InsertMemcpyInWhileBody(const ComputeGraphPtr &graph, const std::set &data_nodes, + const NodePtr &output_node, + const std::set &netoutput_input_indexes) { + for (const NodePtr &data_node : data_nodes) { // Data node has and only has one output - OutDataAnchorPtr out_data_anchor = data_nodes[i]->GetOutDataAnchor(0); + OutDataAnchorPtr out_data_anchor = data_node->GetOutDataAnchor(0); std::vector in_anchors; for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { in_anchors.emplace_back(peer_in_anchor); } - if (InsertNodeBetween(out_data_anchor, in_anchors, in_memcpy, i, i) != SUCCESS) { - GELOGE(FAILED, "Insert MemcpyAsync %s in while_body %s failed.", in_name.c_str(), graph->GetName().c_str()); + std::string name = data_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + GELOGD("Insert memcpy after while_body %s input_node %s.", graph->GetName().c_str(), data_node->GetName().c_str()); + if (InsertMemcpyNode(graph, out_data_anchor, in_anchors, name) != SUCCESS) { + GELOGE(FAILED, "Insert MemcpyAsync node %s after %s failed.", name.c_str(), data_node->GetName().c_str()); return FAILED; } } - return SUCCESS; + for (uint32_t index : netoutput_input_indexes) { + InDataAnchorPtr in_data_anchor = output_node->GetInDataAnchor(index); + GE_CHECK_NOTNULL(in_data_anchor); + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + std::string name = + peer_out_anchor->GetOwnerNode()->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + GELOGD("Insert memcpy after while_body %s output %u.", graph->GetName().c_str(), index); + if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { + GELOGE(FAILED, "Insert MemcpyAsync node %s after %s failed.", name.c_str(), + peer_out_anchor->GetOwnerNode()->GetName().c_str()); + return FAILED; + } + } + + std::set memcpy_nodes; + std::set loop_body_nodes; + for (const NodePtr &data_node : data_nodes) { + // data_node has only one output node + NodePtr memcpy_node = data_node->GetOutDataNodes().at(0); + GE_CHECK_NOTNULL(memcpy_node); + memcpy_nodes.emplace(memcpy_node); + for (const NodePtr &out_node : memcpy_node->GetOutDataNodes()) { + loop_body_nodes.insert(out_node); + } + } + return InsertNoOp(graph, memcpy_nodes, loop_body_nodes); } /** * @ingroup ge - * @brief Insert output memcpy node in while_body + * @brief Insert NoOp node between memcpy_nodes and loop_body_nodes * @param [in] graph: while_body - * @param [in] output_node: NetOutput - * @param [in] bypass_index + * @param [in] memcpy_nodes + * @param [in] loop_body_nodes * @return: 0 for SUCCESS / others for FAILED */ -Status SubgraphPass::InsertOutputMemcpy(const ComputeGraphPtr &graph, const NodePtr &output_node, - const std::set &bypass_index) { - if (output_node->GetAllInDataAnchorsSize() == bypass_index.size()) { - GELOGD("No need to insert output memcpy node in while_body %s, output_size=%zu, bypass_num=%zu.", - graph->GetName().c_str(), output_node->GetAllInDataAnchorsSize(), bypass_index.size()); +Status SubgraphPass::InsertNoOp(const ComputeGraphPtr &graph, const std::set &memcpy_nodes, + const std::set &loop_body_nodes) { + if (memcpy_nodes.empty() || loop_body_nodes.empty()) { return SUCCESS; } - std::string out_name = graph->GetName() + "_output_Memcpy"; - OpDescBuilder out_builder(out_name, MEMCPYASYNC); - for (size_t i = 0; i < output_node->GetAllInDataAnchorsSize(); i++) { - if (bypass_index.count(i) == 0) { - out_builder.AddInput("x" + std::to_string(i), output_node->GetOpDesc()->GetInputDesc(i)) - .AddOutput("y" + std::to_string(i), output_node->GetOpDesc()->GetInputDesc(i)); + OpDescBuilder noop_desc_builder("NoOp_for_Control", NOOP); + OpDescPtr noop_desc = noop_desc_builder.Build(); + NodePtr noop_node = graph->AddNode(noop_desc); + GE_CHECK_NOTNULL(noop_node); + for (const NodePtr &memcpy_node : memcpy_nodes) { + if (GraphUtils::AddEdge(memcpy_node->GetOutControlAnchor(), noop_node->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add ctrl edge %s->%s failed.", memcpy_node->GetName().c_str(), noop_node->GetName().c_str()); + return FAILED; } } - GELOGD("Insert memcpy before NetOutput of while_body %s.", graph->GetName().c_str()); - NodePtr out_memcpy = graph->AddNode(out_builder.Build()); - GE_CHECK_NOTNULL(out_memcpy); - size_t cnt = 0; - for (size_t i = 0; i < output_node->GetAllInDataAnchorsSize(); i++) { - if (bypass_index.count(i) == 0) { - InDataAnchorPtr in_data_anchor = output_node->GetInDataAnchor(i); - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - if (InsertNodeBetween(peer_out_anchor, {in_data_anchor}, out_memcpy, cnt, cnt) != SUCCESS) { - GELOGE(FAILED, "Insert MemcpyAsync %s in while_body %s failed.", out_name.c_str(), graph->GetName().c_str()); - return FAILED; - } - cnt++; + for (const NodePtr &loop_body_node : loop_body_nodes) { + if (GraphUtils::AddEdge(noop_node->GetOutControlAnchor(), loop_body_node->GetInControlAnchor()) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add ctrl edge %s->%s failed.", noop_node->GetName().c_str(), loop_body_node->GetName().c_str()); + return FAILED; } } @@ -328,39 +439,28 @@ Status SubgraphPass::InsertOutputMemcpy(const ComputeGraphPtr &graph, const Node /** * @ingroup ge - * @brief Check is data->netoutput without change in while body - * @param [in] node: data node - * @param [out] bypass_index - * @return: false for data->netoutput without change in while body / for true for others + * @brief Check is data->netoutput in while body + * @param [in] in_data_anchor + * @return: true for data->netoutput in while body / for false for others */ -bool SubgraphPass::CheckInsertInputMemcpy(const NodePtr &node, std::set &bypass_index) { - uint32_t input_index = 0; - if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, input_index)) { - return true; - } - - // Data node has and only has one output - OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(0); - if ((out_data_anchor == nullptr) || (out_data_anchor->GetPeerInDataAnchors().size() != 1)) { - return true; - } - InDataAnchorPtr peer_in_anchor = out_data_anchor->GetPeerInDataAnchors().at(0); - if (peer_in_anchor->GetOwnerNode()->GetType() != NETOUTPUT) { - return true; +bool SubgraphPass::IsWhileBodyOutput(const InDataAnchorPtr &in_data_anchor) { + // Check is subgraph + NodePtr parent_node = in_data_anchor->GetOwnerNode()->GetOwnerComputeGraph()->GetParentNode(); + if (parent_node == nullptr) { + return false; } - OpDescPtr op_desc = peer_in_anchor->GetOwnerNode()->GetOpDesc(); - uint32_t output_index = 0; - if ((op_desc == nullptr) || - !AttrUtils::GetInt(op_desc->GetInputDesc(peer_in_anchor->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, output_index)) { - return true; + // Check if parent_node is While + if (kWhileOpTypes.count(parent_node->GetType()) == 0) { + return false; } - if (input_index != output_index) { - return true; + // While cond / body + OpDescPtr op_desc = in_data_anchor->GetOwnerNode()->GetOpDesc(); + if (op_desc == nullptr) { + return false; } - bypass_index.insert(peer_in_anchor->GetIdx()); - return false; + return AttrUtils::HasAttr(op_desc->GetInputDesc(in_data_anchor->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX); } /** @@ -442,7 +542,7 @@ Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDat OpDescPtr op_desc = op_desc_builder.AddInput("x", in_node->GetOpDesc()->GetOutputDesc(0)) .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(0)) .Build(); - if (GraphUtils::InsertNodeAfter(out_anchor, in_anchors, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { + if (GraphUtils::InsertNodeBefore(out_anchor, in_anchors, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { GELOGE(FAILED, "Insert MemcpyAsync node %s after %s failed.", name.c_str(), in_node->GetName().c_str()); return FAILED; } @@ -450,33 +550,4 @@ Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDat return SUCCESS; } -/// -/// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst -/// @param [in] src -/// @param [in] dsts -/// @param [in] insert_node -/// @param [in] input_index -/// @param [in] output_index -/// @return Status -/// -Status SubgraphPass::InsertNodeBetween(const OutDataAnchorPtr &src, const std::vector &dsts, - const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { - if (GraphUtils::AddEdge(src, insert_node->GetInDataAnchor(input_index)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add data_edge %s:%d->%s:%u failed.", src->GetOwnerNode()->GetName().c_str(), src->GetIdx(), - insert_node->GetName().c_str(), input_index); - return FAILED; - } - for (const auto &dst : dsts) { - GELOGD("Insert node %s between %s->%s.", insert_node->GetName().c_str(), src->GetOwnerNode()->GetName().c_str(), - dst->GetOwnerNode()->GetName().c_str()); - if ((GraphUtils::RemoveEdge(src, dst) != GRAPH_SUCCESS) || - (GraphUtils::AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS)) { - GELOGE(FAILED, "Replace data_edge %s:%d->%s:%d by %s:%u->%s:%d failed.", src->GetOwnerNode()->GetName().c_str(), - src->GetIdx(), dst->GetOwnerNode()->GetName().c_str(), dst->GetIdx(), insert_node->GetName().c_str(), - output_index, dst->GetOwnerNode()->GetName().c_str(), dst->GetIdx()); - return FAILED; - } - } - return SUCCESS; -} } // namespace ge diff --git a/src/ge/graph/passes/subgraph_pass.h b/src/ge/graph/passes/subgraph_pass.h index 7ff2019f..2308b1bd 100644 --- a/src/ge/graph/passes/subgraph_pass.h +++ b/src/ge/graph/passes/subgraph_pass.h @@ -17,6 +17,12 @@ #ifndef GE_GRAPH_PASSES_SUBGRAPH_PASS_H_ #define GE_GRAPH_PASSES_SUBGRAPH_PASS_H_ +#include +#include +#include +#include + +#include "graph/types.h" #include "inc/graph_pass.h" namespace ge { @@ -69,32 +75,65 @@ class SubgraphPass : public GraphPass { /** * @ingroup ge - * @brief Insert input memcpy node in while_body + * @brief Get body subgraph of While op + * @param [in] graph: ComputeGraph. + * @param [in] node: While node. + * @return: body subgraph + */ + ComputeGraphPtr GetWhileBodySubgraph(const ComputeGraphPtr &graph, const NodePtr &node); + + /** + * @ingroup ge + * @brief Mark output parent_node_index + * @param [in] peer_out_anchor: peer_out_anchor of NetOutput + * @param [in] index: parent_node_index of NetOutput + * @param [out] node_to_attr_index: key for node in subgraph, value for parent_node_index + * @return: void + */ + void MarkOutputIndex(const OutDataAnchorPtr &peer_out_anchor, uint32_t index, + std::unordered_map> &node_to_attr_index); + + /** + * @ingroup ge + * @brief Get data_nodes / input_indexes of netoutput if need insert memcpy + * @param [in] node_to_attr_index: key for node in subgraph, value for parent_node_index + * @param [out] data_nodes: data_nodes need insert memcpy + * @param [out] netoutput_input_indexes: input_indexes of netoutput need insert memcpy + * @return: void + */ + void GetExchangeInOut(const std::unordered_map> &node_to_attr_index, + std::set &data_nodes, std::set &netoutput_input_indexes); + + /** + * @ingroup ge + * @brief Insert memcpy node in while_body * @param [in] graph: while_body - * @param [in] data_nodes: data_nodes + * @param [in] data_nodes: data_nodes need insert memcpy + * @param [in] output_node: NetOutput in while_body + * @param [in] netoutput_input_indexes: input_indexes of netoutput need insert memcpy * @return: 0 for SUCCESS / others for FAILED */ - Status InsertInputMemcpy(const ComputeGraphPtr &graph, const std::vector &data_nodes); + Status InsertMemcpyInWhileBody(const ComputeGraphPtr &graph, const std::set &data_nodes, + const NodePtr &output_node, const std::set &netoutput_input_indexes); /** * @ingroup ge - * @brief Insert output memcpy node in while_body + * @brief Insert NoOp node between memcpy_nodes and loop_body_nodes * @param [in] graph: while_body - * @param [in] output_node: NetOutput - * @param [in] bypass_index + * @param [in] memcpy_nodes + * @param [in] loop_body_nodes * @return: 0 for SUCCESS / others for FAILED */ - Status InsertOutputMemcpy(const ComputeGraphPtr &graph, const NodePtr &output_node, - const std::set &bypass_index); + Status InsertNoOp(const ComputeGraphPtr &graph, const std::set &memcpy_nodes, + const std::set &loop_body_nodes); /** * @ingroup ge - * @brief Check is data->netoutput without change in while body - * @param [in] node: data node - * @param [out] bypass_index - * @return: false for data->netoutput without change in while body / for true for others + * @brief Check is Data->NetOutput in while body + * @param [in] in_data_anchor + * @return: true for Data->NetOutput in while body / false for others */ - bool CheckInsertInputMemcpy(const NodePtr &node, std::set &bypass_index); + bool IsWhileBodyOutput(const InDataAnchorPtr &in_data_anchor); /** * @ingroup ge @@ -133,17 +172,8 @@ class SubgraphPass : public GraphPass { Status InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, const std::vector &in_anchors, const std::string &name); - /// - /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst - /// @param [in] src - /// @param [in] dsts - /// @param [in] insert_node - /// @param [in] input_index - /// @param [in] output_index - /// @return Status - /// - Status InsertNodeBetween(const OutDataAnchorPtr &src, const std::vector &dsts, - const NodePtr &insert_node, uint32_t input_index, uint32_t output_index); + // Append index for new memcpy node. + uint32_t memcpy_num_{0}; }; } // namespace ge #endif // GE_GRAPH_PASSES_SUBGRAPH_PASS_H_ diff --git a/src/ge/graph/passes/switch_op_pass.cc b/src/ge/graph/passes/switch_op_pass.cc new file mode 100644 index 00000000..ed3e9b36 --- /dev/null +++ b/src/ge/graph/passes/switch_op_pass.cc @@ -0,0 +1,1227 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/switch_op_pass.h" +#include +#include +#include +#include +#include +#include +#include +#include "common/ge/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" +#include "ge/ge_api_types.h" +#include "graph/common/omg_util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/ge_context.h" +#include "graph/utils/type_utils.h" + +namespace ge { +Status SwitchOpPass::Run(ComputeGraphPtr graph) { + GELOGD("SwitchOpPass Enter"); + GE_CHK_STATUS_RET(CheckCycleDependence(graph), "CheckCycleDependence fail."); + + for (auto &switch_node : switch_nodes_) { + GE_CHK_STATUS_RET(ReplaceSwitchNode(graph, switch_node), "Add StreamSwitch node fail."); + } + + for (auto &merge_node : merge_nodes_) { + OpDescPtr merge_op_desc = merge_node->GetOpDesc(); + GE_CHECK_NOTNULL(merge_op_desc); + if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { + GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, merge_node, true), "Merge add memcpy node fail."); + GE_CHK_STATUS_RET(SetStreamLabel(merge_node, merge_node->GetName()), "Set stream label failed"); + } else { + GE_CHK_STATUS_RET(ReplaceMergeNode(graph, merge_node), "Add StreamMerge node fail."); + } + } + + GE_CHK_STATUS_RET(CombineSwitchNode(graph), "Combine StreamSwitch nodes fail."); + + for (auto &node : bypass_nodes_) { + GE_CHK_BOOL_EXEC(graph->RemoveNode(node) == GRAPH_SUCCESS, return FAILED, "Remove switch node fail."); + } + + for (auto &node : stream_switch_nodes_) { + for (auto &out_ctrl_node : node->GetOutControlNodes()) { + MarkHeadNodes(out_ctrl_node, node); + } + } + + for (auto &node : need_label_nodes_) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (!op_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { + GE_CHK_STATUS_RET(UpdateCondBranch(node), "Set cond branch fail, start node:%s", node->GetName().c_str()); + } + } + + GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode fail."); + + GELOGD("SwitchOpPass Leave"); + return SUCCESS; +} + +/// +/// @brief Replace Switch Op +/// @param [in] graph +/// @param [in] switch_node +/// @return Status +/// +Status SwitchOpPass::ReplaceSwitchNode(ComputeGraphPtr &graph, NodePtr &switch_node) { + std::string type; + GE_CHK_STATUS_RET(GetOriginalType(switch_node, type), "Get node type fail."); + GE_CHK_BOOL_EXEC((type == SWITCH) || (type == REFSWITCH), return FAILED, "Type of input node is not switch."); + + OutDataAnchorPtr peer_data_anchor = nullptr; + OutDataAnchorPtr peer_cond_anchor = nullptr; + GE_CHK_BOOL_EXEC(BypassSwitchNode(switch_node, peer_data_anchor, peer_cond_anchor) == SUCCESS, return FAILED, + "Bypass switch node %s fail.", switch_node->GetName().c_str()); + GE_CHECK_NOTNULL(peer_data_anchor); + GE_CHECK_NOTNULL(peer_cond_anchor); + OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHECK_NOTNULL(cond_desc); + DataType cond_data_type = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx()).GetDataType(); + GE_CHK_BOOL_EXEC(cond_data_type == DT_BOOL, return FAILED, + "SwitchNode not support datatype %s, datatype of cond_input should be bool", + TypeUtils::DataTypeToSerialString(cond_data_type).c_str()); + + OpDescPtr switch_desc = switch_node->GetOpDesc(); + GE_CHECK_NOTNULL(switch_desc); + bool cyclic_flag = switch_desc->HasAttr(ATTR_NAME_CYCLIC_DEPENDENCE_FLAG); + + std::set out_node_list; + for (OutDataAnchorPtr &out_data_anchor : switch_node->GetAllOutDataAnchors()) { + bool true_branch_flag = (static_cast(out_data_anchor->GetIdx()) == SWITCH_TRUE_OUTPUT); + NodePtr stream_switch = nullptr; + out_node_list.clear(); + for (auto &peer_in_anchor : out_data_anchor->GetPeerAnchors()) { + GE_IF_BOOL_EXEC(stream_switch == nullptr, { + std::string suffix = (true_branch_flag ? "_t" : "_f"); + stream_switch = CreateStreamSwitchNode(graph, switch_node, suffix, peer_cond_anchor); + GE_CHK_BOOL_EXEC(stream_switch != nullptr, return FAILED, "Create stream_switch node fail."); + if (SetSwitchTrueBranchFlag(stream_switch, true_branch_flag) != SUCCESS) { + GELOGE(FAILED, "SetSwitchTrueBranchFlag for node %s fail.", stream_switch->GetName().c_str()); + return FAILED; + } + if (MarkBranchs(peer_cond_anchor, stream_switch, true_branch_flag) != SUCCESS) { + GELOGE(FAILED, "MarkBranchs for stream_switch %s fail.", stream_switch->GetName().c_str()); + return FAILED; + } + + if (!cyclic_flag) { + GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor->GetOwnerNode()->GetOutControlAnchor(), + stream_switch->GetInControlAnchor()), + "StreamSwitch node add ctl edge fail."); + } + }); + + GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor), "Remove Switch data output fail."); + + NodePtr out_node = peer_in_anchor->GetOwnerNode(); + GE_CHK_STATUS_RET(GetOriginalType(out_node, type), "Get node type fail."); + if ((type == MERGE) || (type == REFMERGE)) { + NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, peer_data_anchor, false); + GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return FAILED, "Create memcpy_async node fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, memcpy_node->GetInDataAnchor(0)), + "MemcpyAsync node add edge fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(memcpy_node->GetOutDataAnchor(0), peer_in_anchor), + "MemcpyAsync node add edge fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(stream_switch->GetOutControlAnchor(), memcpy_node->GetInControlAnchor()), + "MemcpyAsync node add ctl edge fail."); + out_node_list.insert(memcpy_node->GetName()); + } else { + GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor), "StreamSwitch node add edge fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(stream_switch->GetOutControlAnchor(), out_node->GetInControlAnchor()), + "StreamSwitch node add ctl edge fail."); + out_node_list.insert(out_node->GetName()); + } + } + GE_IF_BOOL_EXEC(stream_switch != nullptr, { + CopyControlEdges(switch_node, stream_switch, true); + switch_node_map_[stream_switch] = out_node_list; + if (SetOriginalNodeName(stream_switch, switch_node->GetName()) != SUCCESS) { + GELOGE(FAILED, "SetOriginalNodeName for node %s fail.", stream_switch->GetName().c_str()); + return FAILED; + } + }); + } + + RemoveControlEdges(switch_node); + (void)bypass_nodes_.insert(switch_node); + + return SUCCESS; +} + +/// +/// @brief Replace Merge Op +/// @param [in] graph +/// @param [in] merge_node +/// @return Status +/// +Status SwitchOpPass::ReplaceMergeNode(ComputeGraphPtr &graph, NodePtr &merge_node) { + std::string type; + GE_CHK_STATUS_RET(GetOriginalType(merge_node, type), "Get node type fail."); + GE_CHK_BOOL_EXEC((type == MERGE) || (type == REFMERGE), return FAILED, "Type of input node is not merge."); + + OpDescPtr merge_op_desc = merge_node->GetOpDesc(); + GE_CHECK_NOTNULL(merge_op_desc); + + const std::string node_name = merge_node->GetName(); + GELOGI("Create StreamMerge Op, name=%s.", node_name.c_str()); + OpDescPtr op_desc = MakeShared(node_name, STREAMMERGE); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc fail, StreamMerge:%s.", node_name.c_str()); + return FAILED; + } + + for (InDataAnchorPtr &in_anchor : merge_node->GetAllInDataAnchors()) { + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(merge_op_desc->GetInputDesc(in_anchor->GetIdx())) == GRAPH_SUCCESS, + return FAILED, "Create StreamMerge op: add input desc fail."); + } + + for (OutDataAnchorPtr &out_anchor : merge_node->GetAllOutDataAnchors()) { + GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(merge_op_desc->GetOutputDesc(out_anchor->GetIdx())) == GRAPH_SUCCESS, + return FAILED, "Create StreamMerge op: add output desc fail."); + } + + NodePtr stream_merge = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(stream_merge != nullptr, return FAILED, "Insert StreamMerge node fail."); + + for (InDataAnchorPtr &in_data_anchor : merge_node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "Remove Merge data input fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, stream_merge->GetInDataAnchor(in_data_anchor->GetIdx())), + "StreamMerge node add edge fail."); + } + + for (OutDataAnchorPtr &out_data_anchor : merge_node->GetAllOutDataAnchors()) { + for (InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor), "Remove Merge data output fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(stream_merge->GetOutDataAnchor(out_data_anchor->GetIdx()), peer_in_anchor), + "StreamMerge node add edge fail."); + } + } + + ReplaceControlEdges(merge_node, stream_merge); + + if (merge_op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { + std::string next_iteration_name; + GE_IF_BOOL_EXEC(!AttrUtils::GetStr(merge_op_desc, ATTR_NAME_NEXT_ITERATION, next_iteration_name), + GELOGE(INTERNAL_ERROR, "get ATTR_NAME_NEXT_ITERATION failed"); + return INTERNAL_ERROR); + + GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "set next iteration failed"); + } else { + need_label_nodes_.emplace_back(stream_merge); + } + + (void)bypass_nodes_.insert(merge_node); + + GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, stream_merge, false), "StreamMerge add memcpy node fail."); + + return SUCCESS; +} + +/// +/// @brief Create StreamSwitch Node +/// @param [in] graph +/// @param [in] switch_node +/// @param [in] suffix +/// @param [in] peer_cond_anchor +/// @return ge::NodePtr +/// +NodePtr SwitchOpPass::CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodePtr &switch_node, + const std::string &suffix, OutDataAnchorPtr &peer_cond_anchor) { + GE_CHK_BOOL_EXEC(switch_node != nullptr, return nullptr, "Param of merge node is null."); + OpDescPtr switch_op_desc = switch_node->GetOpDesc(); + GE_CHK_BOOL_EXEC(switch_op_desc != nullptr, return nullptr, "OpDesc of Switch node is invalid."); + GE_IF_BOOL_EXEC(switch_op_desc->GetInputsSize() != SWITCH_INPUT_NUM, { + GELOGE(FAILED, "Switch input param invalid, input_size=%lu, should be %u", switch_op_desc->GetInputsSize(), + SWITCH_INPUT_NUM); + return nullptr; + }); + + const std::string node_name = switch_node->GetName() + "_" + STREAMSWITCH + suffix; + GELOGI("Create StreamSwitch, name=%s.", node_name.c_str()); + OpDescPtr op_desc = MakeShared(node_name, STREAMSWITCH); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc fail, StreamSwitch:%s.", node_name.c_str()); + return nullptr; + } + // mark hccl group id + std::string hccl_group_id; + if (AttrUtils::GetStr(switch_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { + (void)AttrUtils::SetStr(op_desc, ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id); + GELOGI("Set attr ATTR_NAME_HCCL_FUSED_GROUP for Stream_Switch%s, value is %s.", node_name.c_str(), + hccl_group_id.c_str()); + } else { + GELOGI("Can not find attr ATTR_NAME_HCCL_FUSED_GROUP for node %s.", switch_node->GetName().c_str()); + } + + if (!AttrUtils::SetInt(op_desc, ATTR_NAME_SWITCH_DATA_TYPE, RT_SWITCH_INT32) || + !AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, (int64_t)RT_EQUAL)) { + GELOGE(INTERNAL_ERROR, "set int failed"); + return nullptr; + } + + // Already checked, first input is Variable will passed, second is condition will checked. + GeTensorDesc cond_input_desc = switch_op_desc->GetInputDesc(SWITCH_PRED_INPUT); + GeTensorDesc input_desc(GeShape(cond_input_desc.GetShape().GetDims()), cond_input_desc.GetFormat(), DT_INT32); + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS, return nullptr, + "Create StreamSwitch node: add input desc fail."); + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS, return nullptr, + "Create StreamSwitch node: add input desc fail."); + + NodePtr stream_switch = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(stream_switch != nullptr, return nullptr, "Insert StreamSwitch node fail."); + + GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), + "StreamSwitch node add cond edge fail."); + + return stream_switch; +} + +/// +/// @brief Add MemcpyAsync Node +/// @param [in] graph +/// @param [in] in_node +/// @param [in] multi_batch_flag +/// @return ge::NodePtr +/// +NodePtr SwitchOpPass::CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, + bool multi_batch_flag) { + GE_CHK_BOOL_EXEC(out_data_anchor != nullptr, return nullptr, "Param of input node is null."); + OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); + + std::string memcpy_type = multi_batch_flag ? MEMCPYADDRASYNC : MEMCPYASYNC; + std::string node_name = pre_op_desc->GetName() + "_" + memcpy_type; + node_name = CheckDuplicateName(node_name); + GELOGI("Create MemcpyAsync op:%s.", node_name.c_str()); + OpDescPtr op_desc = MakeShared(node_name, memcpy_type); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc fail, MemcpyAsync:%s.", node_name.c_str()); + return nullptr; + } + + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, + return nullptr, "Create MemcpyAsync op: add input desc fail."); + GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, + return nullptr, "Create MemcpyAsync op: add output desc fail."); + + NodePtr memcpy_node = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return nullptr, "Insert MemcpyAsync node fail."); + + return memcpy_node; +} + +/// +/// @brief Combine switch nodes link to same cond +/// @param [in] graph +/// @return Status +/// +Status SwitchOpPass::CombineSwitchNode(ComputeGraphPtr &graph) { + for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { + for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { + OutDataAnchorPtr peer_cond_anchor = iter->first; + GE_CHECK_NOTNULL(peer_cond_anchor); + std::list false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; + std::list true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; + std::set same_cond_switch; + same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); + same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); + + NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); + GELOGI("CombineSwitchNode: cond_node=%s", cond_node->GetName().c_str()); + + NodePtr cast_node = CreateCastOp(graph, peer_cond_anchor); + GE_CHK_BOOL_EXEC(cast_node != nullptr, return FAILED, "Create cast_node fail."); + + NodePtr active_node = CreateActiveNode(graph, cond_node); + GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutControlAnchor(), active_node->GetInControlAnchor()), + "StreamActive add ctl edge fail."); + if (SetActiveLabelList(active_node, {cast_node->GetName()}) != SUCCESS) { + GELOGE(FAILED, "SetActiveLabelList for node %s fail.", active_node->GetName().c_str()); + return FAILED; + } + + const std::string cond_group = cond_node->GetName(); + for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { + bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); + std::list &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); + GE_IF_BOOL_EXEC(switch_list.empty(), continue); + + // select first stream_switch + NodePtr stream_switch = switch_list.front(); + OpDescPtr switch_desc = stream_switch->GetOpDesc(); + GE_CHECK_NOTNULL(switch_desc); + std::string node_name = cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f"); + node_name = CheckDuplicateName(node_name); + switch_desc->SetName(node_name); + stream_switch_nodes_.emplace_back(stream_switch); + need_label_nodes_.emplace_back(stream_switch); + + // 0_input: original pred input, 1_input: constant node + GE_CHK_STATUS_RET(AddConstNode(graph, stream_switch), "Add const node fail"); + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), + "StreamSwitch remove data edge fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), + "Cast add data edge fail."); + + for (NodePtr &node : switch_list) { + GE_CHECK_NOTNULL(node); + GE_IF_BOOL_EXEC(node != stream_switch, { + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), + "StreamSwitch remove data edge fail."); + }); + GE_CHK_STATUS(ModifySwitchInCtlEdges(node, cast_node, same_cond_switch), "ModifySwitchInCtlEdges fail"); + GE_CHK_STATUS(ModifySwitchOutCtlEdges(node, stream_switch, active_node), "ModifySwitchOutCtlEdges fail"); + } + + GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), stream_switch->GetInControlAnchor()), + "StreamActive add ctl edge fail."); + } + } + } + return SUCCESS; +} + +/// +/// @brief Create Active Op +/// @param [in] graph +/// @param [in] cond_node +/// @return ge::NodePtr +/// +NodePtr SwitchOpPass::CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node) { + GE_CHK_BOOL_EXEC(node != nullptr, return nullptr, "Param of pre cond_node is null."); + std::string node_name = node->GetName() + "_" + STREAMACTIVE; + node_name = CheckDuplicateName(node_name); + GELOGI("Create StreamActive op:%s.", node_name.c_str()); + OpDescPtr op_desc = MakeShared(node_name, STREAMACTIVE); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc fail, StreamActive:%s.", node_name.c_str()); + return nullptr; + } + + NodePtr active_node = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(active_node != nullptr, return nullptr, "Create StreamActive node fail."); + + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(node->GetOutControlAnchor(), active_node->GetInControlAnchor()) != SUCCESS, + GELOGE(INTERNAL_ERROR, "add edge failed"); + return nullptr); + + GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS, + GELOGE(INTERNAL_ERROR, "set switch branch node label failed"); + return nullptr); + + return active_node; +} + +/// +/// @brief Add MemcpyAsync Op as StreamMerge in_node +/// @param [in] graph +/// @param [in] node +/// @param [in] multi_batch_flag +/// @return Status +/// +Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node, bool multi_batch_flag) { + GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); + for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + NodePtr in_node = peer_out_anchor->GetOwnerNode(); + + const std::string type = in_node->GetType(); + // For WhileLoop no need memcpy & active for merge. + GE_IF_BOOL_EXEC((type == ENTER) || (type == REFENTER) || (type == NEXTITERATION) || (type == REFNEXTITERATION), + continue); + + GE_IF_BOOL_EXEC(type != MEMCPYASYNC, { + in_node = CreateMemcpyAsyncNode(graph, peer_out_anchor, multi_batch_flag); + GE_CHK_BOOL_EXEC(in_node != nullptr, return FAILED, "Create MemcpyAsync node fail."); + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "MemcpyAsync node remove edge fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, in_node->GetInDataAnchor(0)), + "MemcpyAsync node add edge fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(in_node->GetOutDataAnchor(0), in_data_anchor), + "MemcpyAsync node add edge fail."); + }); + + NodePtr active_node = CreateActiveNode(graph, in_node); + GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), node->GetInControlAnchor()), + "StreamActive add ctl edge fail."); + if (SetActiveLabelList(active_node, {node->GetName()}) != SUCCESS) { + GELOGE(FAILED, "SetActiveLabelList for node %s fail.", active_node->GetName().c_str()); + return FAILED; + } + } + + return SUCCESS; +} + +/// +/// @brief Bypass Switch Node +/// @param [in] switch_node +/// @param [out] peer_data_anchor +/// @param [out] peer_cond_anchor +/// @return Status +/// +Status SwitchOpPass::BypassSwitchNode(NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, + OutDataAnchorPtr &peer_cond_anchor) { + GE_CHK_BOOL_EXEC(switch_node != nullptr, return FAILED, "Switch_node is null."); + for (uint32_t idx = 0; idx < SWITCH_INPUT_NUM; ++idx) { + InDataAnchorPtr in_data_anchor = switch_node->GetInDataAnchor(idx); + GE_CHK_BOOL_EXEC(in_data_anchor != nullptr, return FAILED, "node[%s]Check Switch input anchor fail.", + switch_node->GetName().c_str()); + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHK_BOOL_EXEC(peer_out_anchor != nullptr, return FAILED, "node[%s]Check Pre node output anchor fail.", + switch_node->GetName().c_str()); + // Remove Switch data input. + GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "remove edge failed"); + + if (idx == SWITCH_DATA_INPUT) { + peer_data_anchor = peer_out_anchor; + } else { + if (FindSwitchCondInput(false, peer_out_anchor) != SUCCESS) { + GELOGE(FAILED, "FindSwitchCondInput fail, switch=%s", switch_node->GetName().c_str()); + return FAILED; + } + peer_cond_anchor = peer_out_anchor; + } + } + + return SUCCESS; +} + +/// +/// @brief Find Switch cond input +/// @param [in] pass_switch_flag +/// @param [out] peer_cond_anchor +/// @return Status +/// +Status SwitchOpPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor) { + NodePtr tmp_node = nullptr; + string type; + bool need_pass_type = true; + while (need_pass_type) { + if (tmp_node == nullptr) { + GE_CHECK_NOTNULL(peer_cond_anchor); + tmp_node = peer_cond_anchor->GetOwnerNode(); + } else { + InDataAnchorPtr in_data_anchor = tmp_node->GetInDataAnchor(SWITCH_DATA_INPUT); + GE_CHECK_NOTNULL(in_data_anchor); + peer_cond_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_cond_anchor); + tmp_node = peer_cond_anchor->GetOwnerNode(); + } + + GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type fail"); + need_pass_type = (pass_switch_flag && ((type == SWITCH) || (type == REFSWITCH))); + } + + return SUCCESS; +} + +int64_t SwitchOpPass::GetGroupId(const NodePtr &node) { + string tailing_optimization_option; + bool is_tailing_optimization = false; + auto ret = GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option); + if (ret == GRAPH_SUCCESS) { + // "1" means it's True from frontend option + is_tailing_optimization = (tailing_optimization_option == "1"); + GELOGI("Option ge.exec.isTailingOptimization is %s", tailing_optimization_option.c_str()); + } + if (!is_tailing_optimization) { + return 0; + } + + string hccl_group_id; + if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { + GELOGI("Node is %s, can not find hccl group id", node->GetName().c_str()); + return 0; + } + auto key_index = hccl_group_id.find_last_of('_'); + auto key_num = hccl_group_id.substr(key_index + 1, hccl_group_id.length() - key_index); + GELOGI("Node is %s,Hccl group id is %s, key_num is %s", node->GetName().c_str(), hccl_group_id.c_str(), + key_num.c_str()); + int64_t num = atoi(key_num.c_str()); + if (num == 0) { + return 0; + } + GELOGI("Hccl group id is %s, group id is %ld", hccl_group_id.c_str(), num); + return num; +} + +/// +/// @brief Mark Switch Branch +/// @param [in] peer_cond_anchor +/// @param [in] stream_switch +/// @param [in] true_branch_flag +/// @return Status +/// +Status SwitchOpPass::MarkBranchs(OutDataAnchorPtr &peer_cond_anchor, NodePtr &stream_switch, bool true_branch_flag) { + uint32_t index = true_branch_flag ? SWITCH_TRUE_OUTPUT : SWITCH_FALSE_OUTPUT; + GE_CHECK_NOTNULL(stream_switch); + auto it = cond_node_map_.find(peer_cond_anchor); + if (it != cond_node_map_.end()) { + int64_t switch_group_id = GetGroupId(stream_switch); + auto switch_group_it = it->second.find(switch_group_id); + if (switch_group_it == it->second.end()) { + std::list false_node_list; + std::list true_node_list; + std::list &node_list = true_branch_flag ? true_node_list : false_node_list; + node_list.emplace_back(stream_switch); + std::vector> switch_list; + switch_list.emplace_back(false_node_list); + switch_list.emplace_back(true_node_list); + (void)it->second.emplace(switch_group_id, switch_list); + } else { + GE_IF_BOOL_EXEC(switch_group_it->second.size() != SWITCH_OUTPUT_NUM, { + GELOGE(INTERNAL_ERROR, "cond_node_map_ check size fail, node: %s", stream_switch->GetName().c_str()); + return FAILED; + }); + switch_group_it->second[index].emplace_back(stream_switch); + } + } else { + int64_t switch_group_id = GetGroupId(stream_switch); + map>> switch_group_map; + std::list false_node_list; + std::list true_node_list; + std::list &node_list = true_branch_flag ? true_node_list : false_node_list; + node_list.emplace_back(stream_switch); + std::vector> switch_list; + switch_list.emplace_back(false_node_list); + switch_list.emplace_back(true_node_list); + (void)switch_group_map.emplace(switch_group_id, switch_list); + auto result = cond_node_map_.insert( + std::pair>>>(peer_cond_anchor, switch_group_map)); + GE_IF_BOOL_EXEC(!result.second, { + GELOGE(INTERNAL_ERROR, "cond_node_map_ insert fail, node: %s", stream_switch->GetName().c_str()); + return FAILED; + }); + } + return SUCCESS; +} + +/// +/// @brief Create cast node +/// @param [in] graph +/// @param [in] peer_cond_anchor +/// @return NodePtr +/// +NodePtr SwitchOpPass::CreateCastOp(ComputeGraphPtr &graph, OutDataAnchorPtr &peer_cond_anchor) { + GE_CHK_BOOL_EXEC(peer_cond_anchor != nullptr, return nullptr, "Param of pre cond_node is null."); + OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHK_BOOL_EXEC(cond_desc != nullptr, return nullptr, "Get cond_desc fail."); + + std::string cast_name = cond_desc->GetName() + "_" + CAST; + cast_name = CheckDuplicateName(cast_name); + GELOGI("Create cast_node: %s, input datatype:DT_BOOL, out datatype:DT_INT32", cast_name.c_str()); + OpDescPtr cast_desc = MakeShared(cast_name, CAST); + if (cast_desc == nullptr) { + GELOGE(FAILED, "Create op_desc fail, Cast:%s.", cast_name.c_str()); + return nullptr; + } + if (!(AttrUtils::SetInt(cast_desc, CAST_ATTR_SRCT, (int64_t)DT_BOOL) && + AttrUtils::SetInt(cast_desc, CAST_ATTR_DSTT, (int64_t)DT_INT32) && + AttrUtils::SetInt(cast_desc, CAST_ATTR_DST_TYPE, (int64_t)DT_INT32) && + AttrUtils::SetBool(cast_desc, CAST_ATTR_TRUNCATE, false))) { + GELOGE(FAILED, "Set CAST_ATTR_SRCT or CAST_ATTR_DSTT or CAST_ATTR_DST_TYPE or CAST_ATTR_TRUNCATE fail, node: %s.", + cast_name.c_str()); + return nullptr; + } + GeTensorDesc tensor_desc = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx()); + tensor_desc.SetDataType(DT_BOOL); + GE_CHK_BOOL_EXEC(cast_desc->AddInputDesc(tensor_desc) == SUCCESS, return nullptr, "Cast_node add input desc fail."); + tensor_desc.SetDataType(DT_INT32); + GE_CHK_BOOL_EXEC(cast_desc->AddOutputDesc(tensor_desc) == SUCCESS, return nullptr, "Cast_node add output desc fail."); + + NodePtr cast_node = graph->AddNode(cast_desc); + GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node fail."); + + GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge fail."); + + return cast_node; +} + +/// +/// @brief Add const node as switch input1 +/// @param [in] graph +/// @param [in] stream_switch +/// @return Status +/// +Status SwitchOpPass::AddConstNode(ComputeGraphPtr &graph, NodePtr &stream_switch) { + GE_CHK_BOOL_EXEC(stream_switch != nullptr, return FAILED, "stream_switch is null."); + OpDescPtr op_desc = stream_switch->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + bool value = false; + GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, + "StreamSwitch get attr TRUE_BRANCH_STREAM fail."); + + const std::string const_node_name = op_desc->GetName() + "_Constant_" + (value ? "t" : "f"); + GELOGI("Create const op: %s", const_node_name.c_str()); + OpDescPtr const_op_desc = MakeShared(const_node_name, CONSTANT); + if (const_op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc fail, Constant:%s.", const_node_name.c_str()); + return FAILED; + } + + auto resize_value = (int32_t)value; + GeTensorDesc data_desc = op_desc->GetInputDesc(1); + GeTensorPtr const_value = + MakeShared(data_desc, reinterpret_cast(&resize_value), sizeof(int32_t)); + if (const_value == nullptr) { + GELOGE(FAILED, "Create tensor fail."); + return FAILED; + } + GE_CHK_BOOL_EXEC(AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value), return FAILED); + GE_CHK_BOOL_EXEC(const_op_desc->AddOutputDesc(data_desc) == GRAPH_SUCCESS, return FAILED, + "Create Const op: add output desc fail."); + + NodePtr const_node = graph->AddNode(const_op_desc); + GE_CHK_BOOL_EXEC(const_node != nullptr, return FAILED, "Insert Const node fail."); + GE_CHK_STATUS(GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(1)), + "StreamSwitch node add ctl edge fail."); + + return SUCCESS; +} + +/// +/// @brief update cond branch +/// @param [in] node +/// @return Status +/// +Status SwitchOpPass::UpdateCondBranch(NodePtr &node) { + std::string stream_label; + std::unordered_set branch_nodes; + std::unordered_set handled_set; + std::stack nodes; + nodes.push(node); + + static const std::set end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; + bool merge_flag = false; + bool exit_flag = false; + bool net_output_flag = false; + + while (!nodes.empty()) { + NodePtr cur_node = nodes.top(); + nodes.pop(); + if (handled_set.count(cur_node) > 0) { + continue; + } + GE_CHECK_NOTNULL(cur_node); + if (UpdateAttachFlag(cur_node, stream_label, merge_flag, exit_flag, net_output_flag) != SUCCESS) { + GELOGE(FAILED, "UpdateAttachFlag fail, cur_node: %s.", cur_node->GetName().c_str()); + return FAILED; + } + + const std::string type = cur_node->GetType(); + for (auto &out_node : cur_node->GetOutAllNodes()) { + const std::string out_type = out_node->GetType(); + bool stop_flag = (end_type_set.count(out_type) > 0) || + ((branch_head_nodes_.count(out_node) > 0) && (branch_head_nodes_[out_node] != node)) || + (((type == ENTER) || (type == REFENTER)) && (out_type != STREAMACTIVE)); + if (!stop_flag) { + nodes.push(out_node); + GELOGD("branch_nodes insert %s", out_node->GetName().c_str()); + branch_nodes.insert(out_node); + } + } + handled_set.insert(cur_node); + } + + if (node->GetType() == STREAMSWITCH) { + GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed"); + } + + bool attach_flag = (merge_flag || exit_flag) && net_output_flag; + if (attach_flag) { + GELOGI("No need to keep on attaching label."); + return SUCCESS; + } + + for (NodePtr tmp_node : branch_nodes) { + GELOGD("Attach label %s to node: %s", stream_label.c_str(), tmp_node->GetName().c_str()); + GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "set stream label failed"); + } + + return SUCCESS; +} + +/// +/// @brief update attach flag +/// @param [in] node +/// @param [out] stream_label +/// @param [out] merge_flag +/// @param [out] exit_flag +/// @param [out] net_output_flag +/// @return Status +/// +Status SwitchOpPass::UpdateAttachFlag(const NodePtr &node, std::string &stream_label, bool &merge_flag, bool &exit_flag, + bool &net_output_flag) { + const std::string type = node->GetType(); + if (type == STREAMSWITCH) { + if (node->GetInDataNodes().empty()) { + GELOGE(INTERNAL_ERROR, "cur_node %s has no input_data_node", node->GetName().c_str()); + return INTERNAL_ERROR; + } + stream_label = node->GetInDataNodes().at(0)->GetName(); + GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "set stream label failed"); + bool value = false; + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, + "StreamSwitch get attr TRUE_BRANCH_STREAM fail."); + stream_label += (value ? "_t" : "_f"); + } else if (type == STREAMMERGE) { + stream_label = node->GetName(); + GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "set stream label failed"); + merge_flag = true; + } else if ((type == EXIT) || (type == REFEXIT)) { + GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "set stream label failed"); + exit_flag = true; + } else if (type == NETOUTPUT) { + net_output_flag = true; + } + + return SUCCESS; +} + +/// +/// @brief update loop branch +/// @param [in] enter_nodes +/// @param [in] stream_label +/// @return Status +/// +Status SwitchOpPass::UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label) { + std::stack nodes(enter_nodes); + NodePtr cur_node = nullptr; + while (!nodes.empty()) { + cur_node = nodes.top(); + nodes.pop(); + for (NodePtr &out_node : cur_node->GetOutAllNodes()) { + OpDescPtr out_desc = out_node->GetOpDesc(); + GE_CHECK_NOTNULL(out_desc); + if (out_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { + continue; + } + GELOGD("Attach label %s to node: %s", stream_label.c_str(), out_node->GetName().c_str()); + GE_CHK_STATUS_RET(SetStreamLabel(out_node, stream_label), "set stream label failed"); + nodes.push(out_node); + } + } + + return SUCCESS; +} + +/// +/// @brief update enter nodes +/// @return Status +/// +Status SwitchOpPass::UpdateEnterNode() { + std::unordered_map> enter_active_map; + for (auto &enter_node : enter_nodes_) { + for (auto &out_ctrl_node : enter_node->GetOutControlNodes()) { + if (out_ctrl_node->GetType() != STREAMACTIVE) { + continue; + } + auto iter = enter_active_map.find(out_ctrl_node); + if (iter == enter_active_map.end()) { + enter_active_map[out_ctrl_node] = {enter_node}; + } else { + iter->second.emplace_back(enter_node); + } + } + } + + for (auto &pair : enter_active_map) { + std::string stream_label; + NodePtr active_node = pair.first; + GE_CHECK_NOTNULL(active_node); + OpDescPtr active_desc = active_node->GetOpDesc(); + GE_CHECK_NOTNULL(active_desc); + (void)AttrUtils::GetStr(active_desc, ATTR_NAME_STREAM_LABEL, stream_label); + if (stream_label.empty()) { + stream_label = active_desc->GetName(); + GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "set stream label failed"); + } + std::stack enter_nodes; + for (auto &enter_node : pair.second) { + GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "set stream label failed"); + enter_nodes.emplace(enter_node); + } + + std::vector active_label_list; + if (!AttrUtils::GetListStr(active_desc, ATTR_NAME_ACTIVE_LABEL_LIST, active_label_list) || + (active_label_list.size() != 1) || active_label_list[0].empty()) { + GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ACTIVE_LABEL_LIST fail, node: %s", active_desc->GetName().c_str()); + return INTERNAL_ERROR; + } + if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) { + GELOGE(FAILED, "UpdateLoopBranch fail."); + return FAILED; + } + } + + return SUCCESS; +} + +/// +/// @brief Check duplicate node_name +/// @param [in] node_name +/// @return std::string +/// +std::string SwitchOpPass::CheckDuplicateName(const std::string &node_name) { + std::string tmp_name = node_name; + auto iter = node_num_map_.find(tmp_name); + if (iter != node_num_map_.end()) { + tmp_name = tmp_name + "_" + std::to_string(iter->second); + (iter->second)++; + } else { + node_num_map_[tmp_name] = 1; + } + return tmp_name; +} + +/// +/// @brief Check cyclic dependence +/// @param [in] graph +/// @return Status +/// +Status SwitchOpPass::CheckCycleDependence(ComputeGraphPtr &graph) { + std::string type; + std::unordered_map> cond_switch_map; + for (NodePtr &node : graph->GetDirectNode()) { + GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type fail"); + if ((type == SWITCH) || (type == REFSWITCH)) { + InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); + GE_CHK_BOOL_EXEC(in_cond_anchor != nullptr, return INTERNAL_ERROR, "Check Switch in_cond_anchor fail."); + OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); + GE_CHK_BOOL_EXEC(peer_out_anchor != nullptr, return INTERNAL_ERROR, "Check Switch peer_out_anchor fail."); + if (FindSwitchCondInput(true, peer_out_anchor) != SUCCESS) { + GELOGE(FAILED, "FindSwitchCondInput fail, switch=%s", node->GetName().c_str()); + return FAILED; + } + + NodePtr cond_node = peer_out_anchor->GetOwnerNode(); + auto iter = cond_switch_map.find(cond_node); + if (iter == cond_switch_map.end()) { + cond_switch_map[cond_node] = {node}; + } else { + iter->second.emplace_back(node); + } + + switch_nodes_.emplace_back(node); + } else if ((type == MERGE) || (type == REFMERGE)) { + merge_nodes_.emplace_back(node); + } else if ((type == ENTER) || (type == REFENTER)) { + enter_nodes_.emplace_back(node); + } + } + + MarkCycleDependence(cond_switch_map); + + return SUCCESS; +} + +/// +/// @brief Mark cyclic dependence +/// @param [in] graph +/// @param [in] cond_switch_map +/// @return void +/// +void SwitchOpPass::MarkCycleDependence(const std::unordered_map> &cond_switch_map) { + std::stack out_nodes; + NodePtr tmp_node = nullptr; + std::unordered_set handled_set; + for (auto &iter : cond_switch_map) { + std::set switch_nodes(iter.second.begin(), iter.second.end()); + for (auto &switch_node : switch_nodes) { + GE_CHECK_NOTNULL_JUST_RETURN(switch_node); + GELOGD("CheckCycleDependence: cond_node=%s, switch=%s", iter.first->GetName().c_str(), + switch_node->GetName().c_str()); + for (const NodePtr &node : switch_node->GetOutAllNodes()) { + out_nodes.push(node); + } + } + handled_set.clear(); + while (!out_nodes.empty()) { + tmp_node = out_nodes.top(); + GE_CHECK_NOTNULL_JUST_RETURN(tmp_node); + out_nodes.pop(); + if (handled_set.count(tmp_node) > 0) { + continue; + } + GELOGD("CheckCycleDependence: tmp_node=%s", tmp_node->GetName().c_str()); + for (NodePtr &out_node : tmp_node->GetOutAllNodes()) { + if (switch_nodes.find(out_node) == switch_nodes.end()) { + out_nodes.push(out_node); + continue; + } + GE_IF_BOOL_EXEC(SetCyclicDependenceFlag(out_node) != SUCCESS, GELOGW("set cyclic dependence failed"); return ); + auto map_iter = switch_cyclic_map_.find(out_node); + if (map_iter == switch_cyclic_map_.end()) { + switch_cyclic_map_[out_node] = {tmp_node->GetName()}; + } else { + map_iter->second.insert(tmp_node->GetName()); + } + } + handled_set.insert(tmp_node); + } + } + + return; +} + +/// +/// @brief Modify in ctl edge for switch_node +/// @param [in] switch_node +/// @param [in] cast_node +/// @param [in] same_cond_switch +/// @return Status +/// +Status SwitchOpPass::ModifySwitchInCtlEdges(NodePtr &switch_node, NodePtr &cast_node, + const std::set &same_cond_switch) { + GE_CHECK_NOTNULL(switch_node); + GE_CHECK_NOTNULL(cast_node); + GELOGI("ModifySwitchInCtlEdges: switch_node=%s, active_node=%s", switch_node->GetName().c_str(), + cast_node->GetName().c_str()); + + std::string orig_switch_name = switch_node->GetName(); + OpDescPtr switch_desc = switch_node->GetOpDesc(); + GE_CHECK_NOTNULL(switch_desc); + if (!AttrUtils::GetStr(switch_desc, ATTR_NAME_ORIG_NODE_NAME, orig_switch_name) || orig_switch_name.empty()) { + GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ORIG_NODE_NAME fail, node: %s", switch_desc->GetName().c_str()); + return INTERNAL_ERROR; + } + + for (NodePtr &in_ctl_node : switch_node->GetInControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), + "Remove ctl edge fail."); + GE_IF_BOOL_EXEC(!in_ctl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), + "Add ctl edge fail."); + }); + + GE_IF_BOOL_EXEC(in_ctl_node->GetType() != STREAMSWITCH, continue); + if (same_cond_switch.count(in_ctl_node) > 0) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), + "Remove ctl edge fail."); + continue; + } + auto find_res1 = switch_node_map_.find(in_ctl_node); + GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { + GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctl_node->GetName().c_str()); + return INTERNAL_ERROR; + }); + auto find_res2 = find_res1->second.find(orig_switch_name); + auto find_res3 = find_res1->second.find(cast_node->GetName()); + GE_IF_BOOL_EXEC((find_res2 != find_res1->second.end()) && (find_res3 == find_res1->second.end()), { + find_res1->second.erase(find_res2); + find_res1->second.insert(cast_node->GetName()); + continue; + }); + } + + return SUCCESS; +} + +/// +/// @brief Modify out ctl edge for switch_node +/// @param [in] switch_node +/// @param [in] stream_switch +/// @param [in] active_node +/// @return Status +/// +Status SwitchOpPass::ModifySwitchOutCtlEdges(NodePtr &switch_node, NodePtr &stream_switch, NodePtr &active_node) { + GE_CHECK_NOTNULL(switch_node); + GE_CHECK_NOTNULL(stream_switch); + GE_CHECK_NOTNULL(active_node); + GELOGI("ModifySwitchOutCtlEdges: switch_node=%s, stream_switch=%s, active_node=%s", switch_node->GetName().c_str(), + stream_switch->GetName().c_str(), active_node->GetName().c_str()); + auto find_res = switch_node_map_.find(switch_node); + GE_IF_BOOL_EXEC(find_res == switch_node_map_.end(), { + GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", switch_node->GetName().c_str()); + return INTERNAL_ERROR; + }); + GE_IF_BOOL_EXEC(find_res->second.empty(), { + GELOGE(INTERNAL_ERROR, "true_nodes of StreamSwitch node %s is empty.", switch_node->GetName().c_str()); + return INTERNAL_ERROR; + }); + + for (NodePtr &node : switch_node->GetOutControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(switch_node->GetOutControlAnchor(), node->GetInControlAnchor()), + "Remove ctl edge fail."); + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + std::string orig_name = op_desc->GetName(); + GE_IF_BOOL_EXEC(op_desc->HasAttr(ATTR_NAME_ORIG_NODE_NAME), { + if (!AttrUtils::GetStr(op_desc, ATTR_NAME_ORIG_NODE_NAME, orig_name) || orig_name.empty()) { + GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ORIG_NODE_NAME fail, node: %s.", op_desc->GetName().c_str()); + return INTERNAL_ERROR; + } + }); + if (find_res->second.find(orig_name) == find_res->second.end()) { + auto active_out_control_anchor = active_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(active_out_control_anchor); + GE_IF_BOOL_EXEC(!active_out_control_anchor->IsLinkedWith(node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(active_out_control_anchor, node->GetInControlAnchor()), "Add ctl edge fail."); + }); + } else { + auto stream_switch_out_control_anchor = stream_switch->GetOutControlAnchor(); + GE_CHECK_NOTNULL(stream_switch_out_control_anchor); + GE_IF_BOOL_EXEC(!stream_switch_out_control_anchor->IsLinkedWith(node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(stream_switch_out_control_anchor, node->GetInControlAnchor()), + "Add ctl edge fail."); + }); + } + } + + GE_IF_BOOL_EXEC(switch_node != stream_switch, (void)bypass_nodes_.insert(switch_node)); + + return SUCCESS; +} + +/// +/// @brief Copy Control Edges +/// @param [in] old_node +/// @param [in] new_node +/// @param [in] input_check_flag +/// @return void +/// +void SwitchOpPass::CopyControlEdges(NodePtr &old_node, NodePtr &new_node, bool input_check_flag) { + GE_CHECK_NOTNULL_JUST_RETURN(old_node); + GE_CHECK_NOTNULL_JUST_RETURN(new_node); + GE_IF_BOOL_EXEC(old_node == new_node, return ); + auto iter = switch_cyclic_map_.find(old_node); + bool check_flag = input_check_flag && (iter != switch_cyclic_map_.end()); + for (NodePtr &node : old_node->GetInControlNodes()) { + if (check_flag && (iter->second.count(node->GetName()) > 0)) { + for (auto &out_node : old_node->GetOutAllNodes()) { + auto out_control_anchor = node->GetOutControlAnchor(); + GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor); + GE_IF_BOOL_EXEC(!out_control_anchor->IsLinkedWith(out_node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(out_control_anchor, out_node->GetInControlAnchor()), "Add ctl edge fail."); + }); + } + } else { + auto out_control_anchor = node->GetOutControlAnchor(); + GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor); + GE_IF_BOOL_EXEC(!out_control_anchor->IsLinkedWith(new_node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(out_control_anchor, new_node->GetInControlAnchor()), "Add in ctl edge fail."); + }); + } + } + + for (NodePtr &node : old_node->GetOutControlNodes()) { + GE_IF_BOOL_EXEC(!new_node->GetOutControlAnchor()->IsLinkedWith(node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), node->GetInControlAnchor()), + "Add out ctl edge fail."); + }); + } +} + +/// +/// @brief Remove Control Edges +/// @param [in] node +/// @return void +/// +void SwitchOpPass::RemoveControlEdges(NodePtr &node) { + GE_CHECK_NOTNULL_JUST_RETURN(node); + for (NodePtr &in_node : node->GetInControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_node->GetOutControlAnchor(), node->GetInControlAnchor()), + "Remove in ctl edge fail."); + } + + for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { + for (auto &in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, in_ctrl_anchor), "Remove in ctl edge fail."); + } + } + + auto out_control_anchor = node->GetOutControlAnchor(); + GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor); + for (auto &peer_anchor : out_control_anchor->GetPeerAnchors()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(out_control_anchor, peer_anchor), "Remove out ctl edge fail."); + } +} + +/// +/// @brief Replace Control Edges +/// @param [in] old_node +/// @param [in] new_node +/// @return void +/// +void SwitchOpPass::ReplaceControlEdges(NodePtr &old_node, NodePtr &new_node) { + GE_IF_BOOL_EXEC(old_node == new_node, return ); + CopyControlEdges(old_node, new_node); + RemoveControlEdges(old_node); +} + +/// +/// @brief Mark node as head_node of stream_switch +/// @param [in] node +/// @param [in] stream_switch +/// @return void +/// +void SwitchOpPass::MarkHeadNodes(const NodePtr &node, const NodePtr &stream_switch) { + std::stack nodes; + nodes.push(node); + std::set visited; + while (!nodes.empty()) { + NodePtr cur_node = nodes.top(); + nodes.pop(); + if (visited.count(cur_node) > 0) { + continue; + } + GELOGD("branch_head_node %s of stream_switch %s", cur_node->GetName().c_str(), stream_switch->GetName().c_str()); + branch_head_nodes_[cur_node] = stream_switch; + if ((cur_node->GetType() == IDENTITY) || (cur_node->GetType() == IDENTITYN)) { + for (auto &out_node : cur_node->GetOutAllNodes()) { + nodes.push(out_node); + } + } + visited.insert(cur_node); + } +} + +/// +/// @brief Clear Status, uesd for subgraph pass +/// @return +/// +Status SwitchOpPass::ClearStatus() { + switch_nodes_.clear(); + merge_nodes_.clear(); + enter_nodes_.clear(); + switch_cyclic_map_.clear(); + bypass_nodes_.clear(); + branch_head_nodes_.clear(); + stream_switch_nodes_.clear(); + need_label_nodes_.clear(); + cond_node_map_.clear(); + switch_node_map_.clear(); + node_num_map_.clear(); + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/switch_to_stream_switch_pass.h b/src/ge/graph/passes/switch_op_pass.h similarity index 61% rename from src/ge/graph/passes/switch_to_stream_switch_pass.h rename to src/ge/graph/passes/switch_op_pass.h index 15fe9dce..202b919c 100644 --- a/src/ge/graph/passes/switch_to_stream_switch_pass.h +++ b/src/ge/graph/passes/switch_op_pass.h @@ -14,9 +14,15 @@ * limitations under the License. */ -#ifndef GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_ -#define GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_ - +#ifndef GE_GRAPH_PASSES_SWITCH_OP_PASS_H_ +#define GE_GRAPH_PASSES_SWITCH_OP_PASS_H_ + +#include +#include +#include +#include +#include +#include #include "inc/graph_pass.h" namespace ge { @@ -85,158 +91,78 @@ namespace ge { +-----------+ +-----------+ +-----------+ +-----| Less |----+ +-----------+ */ -class SwitchToStreamSwitchPass : public GraphPass { +class SwitchOpPass : public GraphPass { public: Status Run(ComputeGraphPtr graph); - - /// - /// @brief Clear Status, used for subgraph pass - /// @return - /// Status ClearStatus() override; private: - /// - /// @brief Check cyclic dependence - /// @param [in] graph - /// @return Status - /// - Status CheckCycleDependence(const ComputeGraphPtr &graph); - - /// - /// @brief Mark cyclic dependence - /// @param [in] graph - /// @param [in] cond_switch_map - /// @return void - /// - void MarkCycleDependence(const std::unordered_map> &cond_switch_map); + Status ReplaceSwitchNode(ComputeGraphPtr &graph, NodePtr &switch_node); + + Status ReplaceMergeNode(ComputeGraphPtr &graph, NodePtr &merge_node); + + NodePtr CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix, + OutDataAnchorPtr &peer_cond_anchor); + + NodePtr CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); + + Status CombineSwitchNode(ComputeGraphPtr &graph); + + NodePtr CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node); + + Status AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &stream_merge_node, bool multi_batch_flag); + + Status BypassSwitchNode(NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, OutDataAnchorPtr &peer_cond_anchor); - /// - /// @brief Replace Switch Op - /// @param [in] graph - /// @param [in] switch_node - /// @return Status - /// - Status ReplaceSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node); - - /// - /// @brief Bypass Switch Node - /// @param [in] switch_node - /// @param [out] peer_data_anchor - /// @param [out] peer_cond_anchor - /// @return Status - /// - Status BypassSwitchNode(const NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, - OutDataAnchorPtr &peer_cond_anchor); - - /// - /// @brief Find Switch cond input - /// @param [in] pass_switch_flag - /// @param [out] peer_cond_anchor - /// @return Status - /// Status FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor); - /// - /// @brief Create StreamSwitch Node - /// @param [in] graph - /// @param [in] switch_node - /// @param [in] suffix - /// @param [in] peer_cond_anchor - /// @return ge::NodePtr - /// - NodePtr CreateStreamSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix, - const OutDataAnchorPtr &peer_cond_anchor); - - /// - /// @brief Mark Switch Branch - /// @param [in] peer_cond_anchor - /// @param [in] stream_switch - /// @param [in] true_branch_flag - /// @return Status - /// - Status MarkBranches(const OutDataAnchorPtr &peer_cond_anchor, const NodePtr &stream_switch_node, - bool true_branch_flag); - - /// - /// @brief Get group_id for switch_node - /// @param [in] node - /// @return group_id - /// - int64_t GetGroupId(const NodePtr &node); + Status MarkBranchs(OutDataAnchorPtr &peer_cond_anchor, NodePtr &stream_switch_node, bool true_branch_flag); + + NodePtr CreateCastOp(ComputeGraphPtr &graph, OutDataAnchorPtr &peer_cond_anchor); + + Status AddConstNode(ComputeGraphPtr &graph, NodePtr &stream_switch_node); + + Status UpdateCondBranch(NodePtr &node); + + Status UpdateAttachFlag(const NodePtr &node, std::string &stream_label, bool &merge_flag, bool &exit_flag, + bool &net_output_flag); + + Status UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label); + + Status UpdateEnterNode(); - /// - /// @brief Combine switch nodes link to same cond - /// @param [in] graph - /// @return Status - /// - Status CombineSwitchNode(const ComputeGraphPtr &graph); - - /// - /// @brief Create cast node - /// @param [in] graph - /// @param [in] peer_cond_anchor - /// @return NodePtr - /// - NodePtr CreateCastOp(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_cond_anchor); - - /// - /// @brief Create Active Op - /// @param [in] graph - /// @param [in] cond_node - /// @return ge::NodePtr - /// - NodePtr CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node); - - /// - /// @brief Add const node as switch input1 - /// @param [in] graph - /// @param [in] stream_switch - /// @return Status - /// - Status AddConstNode(const ComputeGraphPtr &graph, const NodePtr &stream_switch_node); - - /// - /// @brief Modify in ctl edge for switch_node - /// @param [in] switch_node - /// @param [in] cast_node - /// @param [in] same_cond_switch - /// @return Status - /// - Status ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node, - const std::set &same_cond_switch); - - /// - /// @brief Modify out ctl edge for switch_node - /// @param [in] switch_node - /// @param [in] stream_switch - /// @param [in] active_node - /// @return Status - /// - Status ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch, const NodePtr &active_node); - - /// - /// @brief Check duplicate node_name - /// @param [in] node_name - /// @return std::string - /// std::string CheckDuplicateName(const std::string &node_name); - /// - /// @brief Move Control Edges - /// @param [in] old_node - /// @param [in] new_node - /// @return void - /// - void MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node); + Status CheckCycleDependence(ComputeGraphPtr &graph); + + void MarkCycleDependence(const std::unordered_map> &cond_switch_map); + + Status ModifySwitchInCtlEdges(NodePtr &switch_node, NodePtr &cast_node, const std::set &same_cond_switch); + + Status ModifySwitchOutCtlEdges(NodePtr &switch_node, NodePtr &stream_switch, NodePtr &active_node); + + void CopyControlEdges(NodePtr &old_node, NodePtr &new_node, bool input_check_flag = false); + + void RemoveControlEdges(NodePtr &node); + + void ReplaceControlEdges(NodePtr &old_node, NodePtr &new_node); + + int64_t GetGroupId(const NodePtr &node); + + void MarkHeadNodes(const NodePtr &node, const NodePtr &stream_switch); std::vector switch_nodes_; + std::vector merge_nodes_; + std::vector enter_nodes_; std::unordered_map> switch_cyclic_map_; + std::set bypass_nodes_; + std::unordered_map branch_head_nodes_; std::vector stream_switch_nodes_; + std::vector need_label_nodes_; std::unordered_map>>> cond_node_map_; std::unordered_map> switch_node_map_; std::unordered_map node_num_map_; }; } // namespace ge -#endif // GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_ +#endif // GE_GRAPH_PASSES_SWITCH_OP_PASS_H_ diff --git a/src/ge/graph/passes/switch_to_stream_switch_pass.cc b/src/ge/graph/passes/switch_to_stream_switch_pass.cc deleted file mode 100644 index ef8879dd..00000000 --- a/src/ge/graph/passes/switch_to_stream_switch_pass.cc +++ /dev/null @@ -1,755 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/switch_to_stream_switch_pass.h" -#include -#include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/debug/log.h" -#include "framework/common/ge_inner_error_codes.h" -#include "framework/common/types.h" -#include "ge/ge_api_types.h" -#include "graph/common/omg_util.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/ge_context.h" -#include "graph/utils/type_utils.h" - -namespace ge { -Status SwitchToStreamSwitchPass::Run(ComputeGraphPtr graph) { - GELOGD("SwitchToStreamSwitchPass Enter"); - - GE_CHK_STATUS_RET(CheckCycleDependence(graph), "Check cyclic dependence failed."); - for (const auto &switch_node : switch_nodes_) { - GE_CHK_STATUS_RET(ReplaceSwitchNode(graph, switch_node), "Replace Switch by StreamSwitch failed."); - } - GE_CHK_STATUS_RET(CombineSwitchNode(graph), "Combine StreamSwitch nodes failed."); - - for (const auto &node : bypass_nodes_) { - GE_CHK_BOOL_EXEC(graph->IsolateNode(node) == GRAPH_SUCCESS, return FAILED, "Isolate node failed."); - GE_CHK_BOOL_EXEC(GraphUtils::RemoveNodeWithoutRelink(graph, node) == GRAPH_SUCCESS, return FAILED, - "Remove switch node failed."); - } - - GELOGD("SwitchToStreamSwitchPass Leave"); - return SUCCESS; -} - -/// -/// @brief Clear Status, used for subgraph pass -/// @return -/// -Status SwitchToStreamSwitchPass::ClearStatus() { - switch_nodes_.clear(); - switch_cyclic_map_.clear(); - bypass_nodes_.clear(); - stream_switch_nodes_.clear(); - cond_node_map_.clear(); - switch_node_map_.clear(); - node_num_map_.clear(); - return SUCCESS; -} - -/// -/// @brief Check cyclic dependence -/// @param [in] graph -/// @return Status -/// -Status SwitchToStreamSwitchPass::CheckCycleDependence(const ComputeGraphPtr &graph) { - std::string type; - std::unordered_map> cond_switch_map; - for (const NodePtr &node : graph->GetDirectNode()) { - GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed."); - if ((type == SWITCH) || (type == REFSWITCH)) { - InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); - GE_CHECK_NOTNULL(in_cond_anchor); - OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_anchor); - if (FindSwitchCondInput(true, peer_out_anchor) != SUCCESS) { - GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); - return FAILED; - } - - NodePtr cond_node = peer_out_anchor->GetOwnerNode(); - auto iter = cond_switch_map.find(cond_node); - if (iter == cond_switch_map.end()) { - cond_switch_map[cond_node] = {node}; - } else { - iter->second.emplace_back(node); - } - switch_nodes_.emplace_back(node); - } - } - - MarkCycleDependence(cond_switch_map); - return SUCCESS; -} - -/// -/// @brief Mark cyclic dependence -/// @param [in] graph -/// @param [in] cond_switch_map -/// @return void -/// -void SwitchToStreamSwitchPass::MarkCycleDependence( - const std::unordered_map> &cond_switch_map) { - std::stack out_nodes; - NodePtr tmp_node = nullptr; - std::unordered_set visited; - for (const auto &iter : cond_switch_map) { - std::set switch_nodes(iter.second.begin(), iter.second.end()); - for (const auto &switch_node : switch_nodes) { - GELOGD("MarkCycleDependence: cond_node=%s, switch=%s.", iter.first->GetName().c_str(), - switch_node->GetName().c_str()); - for (const auto &node : switch_node->GetOutAllNodes()) { - out_nodes.push(node); - } - } - visited.clear(); - while (!out_nodes.empty()) { - tmp_node = out_nodes.top(); - out_nodes.pop(); - if (visited.count(tmp_node) > 0) { - continue; - } - GELOGD("MarkCycleDependence: tmp_node=%s.", tmp_node->GetName().c_str()); - for (const NodePtr &out_node : tmp_node->GetOutAllNodes()) { - if (switch_nodes.find(out_node) == switch_nodes.end()) { - out_nodes.push(out_node); - continue; - } - GE_IF_BOOL_EXEC(SetCyclicDependenceFlag(out_node) != SUCCESS, GELOGW("set cyclic dependence attr failed."); - return ); - auto map_iter = switch_cyclic_map_.find(out_node); - if (map_iter == switch_cyclic_map_.end()) { - switch_cyclic_map_[out_node] = {tmp_node->GetName()}; - } else { - map_iter->second.insert(tmp_node->GetName()); - } - } - visited.insert(tmp_node); - } - } - - return; -} - -/// -/// @brief Replace Switch Op -/// @param [in] graph -/// @param [in] switch_node -/// @return Status -/// -Status SwitchToStreamSwitchPass::ReplaceSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node) { - OutDataAnchorPtr peer_data_anchor = nullptr; - OutDataAnchorPtr peer_cond_anchor = nullptr; - GE_CHK_BOOL_EXEC(BypassSwitchNode(switch_node, peer_data_anchor, peer_cond_anchor) == SUCCESS, return FAILED, - "Bypass switch node %s failed.", switch_node->GetName().c_str()); - GE_CHECK_NOTNULL(peer_data_anchor); - GE_CHECK_NOTNULL(peer_cond_anchor); - OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHECK_NOTNULL(cond_desc); - DataType cond_data_type = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx()).GetDataType(); - GE_CHK_BOOL_EXEC(cond_data_type == DT_BOOL, return FAILED, - "pred_input of Switch only support DT_BOOL data_type, but %s exactly.", - TypeUtils::DataTypeToSerialString(cond_data_type).c_str()); - - OpDescPtr switch_desc = switch_node->GetOpDesc(); - GE_CHECK_NOTNULL(switch_desc); - bool cyclic_flag = switch_desc->HasAttr(ATTR_NAME_CYCLIC_DEPENDENCE_FLAG); - std::set out_node_list; - for (const auto &out_data_anchor : switch_node->GetAllOutDataAnchors()) { - bool true_branch_flag = (static_cast(out_data_anchor->GetIdx()) == SWITCH_TRUE_OUTPUT); - NodePtr stream_switch = nullptr; - out_node_list.clear(); - for (const auto &peer_in_anchor : out_data_anchor->GetPeerAnchors()) { - GE_IF_BOOL_EXEC(stream_switch == nullptr, { - stream_switch = CreateStreamSwitchNode(graph, switch_node, true_branch_flag ? "_t" : "_f", peer_cond_anchor); - GE_CHK_BOOL_EXEC(stream_switch != nullptr, return FAILED, "Create stream_switch node failed."); - if (SetSwitchTrueBranchFlag(stream_switch, true_branch_flag) != SUCCESS) { - GELOGE(FAILED, "SetSwitchTrueBranchFlag for node %s failed.", stream_switch->GetName().c_str()); - return FAILED; - } - if (MarkBranches(peer_cond_anchor, stream_switch, true_branch_flag) != SUCCESS) { - GELOGE(FAILED, "Mark branches for stream_switch %s failed.", stream_switch->GetName().c_str()); - return FAILED; - } - - if (!cyclic_flag) { - GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor->GetOwnerNode()->GetOutControlAnchor(), - stream_switch->GetInControlAnchor()), - "StreamSwitch node add ctl edge failed."); - } - }); - - GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor), "Remove Switch data output failed."); - - NodePtr out_node = peer_in_anchor->GetOwnerNode(); - GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor), "StreamSwitch node add edge failed."); - GE_CHK_STATUS(GraphUtils::AddEdge(stream_switch->GetOutControlAnchor(), out_node->GetInControlAnchor()), - "StreamSwitch node add ctl edge failed."); - out_node_list.insert(out_node->GetName()); - } - - GE_IF_BOOL_EXEC(stream_switch != nullptr, { - MoveCtrlEdges(switch_node, stream_switch); - switch_node_map_[stream_switch] = out_node_list; - if (SetOriginalNodeName(stream_switch, switch_node->GetName()) != SUCCESS) { - GELOGE(FAILED, "SetOriginalNodeName for node %s failed.", stream_switch->GetName().c_str()); - return FAILED; - } - }); - } - - (void)bypass_nodes_.insert(switch_node); - return SUCCESS; -} - -/// -/// @brief Bypass Switch Node -/// @param [in] switch_node -/// @param [out] peer_data_anchor -/// @param [out] peer_cond_anchor -/// @return Status -/// -Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, - OutDataAnchorPtr &peer_cond_anchor) { - for (uint32_t idx = 0; idx < SWITCH_INPUT_NUM; ++idx) { - InDataAnchorPtr in_data_anchor = switch_node->GetInDataAnchor(idx); - GE_CHECK_NOTNULL(in_data_anchor); - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_anchor); - // Remove Switch data input. - if (GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Remove data edge %s->%s failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), - switch_node->GetName().c_str()); - return FAILED; - } - - if (idx == SWITCH_DATA_INPUT) { - peer_data_anchor = peer_out_anchor; - } else { - if (FindSwitchCondInput(false, peer_out_anchor) != SUCCESS) { - GELOGE(FAILED, "Find pred_input for switch_node %s failed.", switch_node->GetName().c_str()); - return FAILED; - } - peer_cond_anchor = peer_out_anchor; - } - } - - return SUCCESS; -} - -/// -/// @brief Find Switch cond input -/// @param [in] pass_switch_flag -/// @param [out] peer_cond_anchor -/// @return Status -/// -Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor) { - NodePtr tmp_node = nullptr; - string type; - bool need_pass_type = true; - while (need_pass_type) { - if (tmp_node == nullptr) { - tmp_node = peer_cond_anchor->GetOwnerNode(); - } else { - InDataAnchorPtr in_data_anchor = tmp_node->GetInDataAnchor(SWITCH_DATA_INPUT); - GE_CHECK_NOTNULL(in_data_anchor); - peer_cond_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_cond_anchor); - tmp_node = peer_cond_anchor->GetOwnerNode(); - } - - GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed."); - need_pass_type = (pass_switch_flag && ((type == SWITCH) || (type == REFSWITCH))); - } - - return SUCCESS; -} - -/// -/// @brief Create StreamSwitch Node -/// @param [in] graph -/// @param [in] switch_node -/// @param [in] suffix -/// @param [in] peer_cond_anchor -/// @return ge::NodePtr -/// -NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node, - const std::string &suffix, - const OutDataAnchorPtr &peer_cond_anchor) { - OpDescPtr switch_op_desc = switch_node->GetOpDesc(); - GE_CHK_BOOL_EXEC(switch_op_desc != nullptr, return nullptr, "OpDesc of Switch node is invalid."); - GE_IF_BOOL_EXEC(switch_op_desc->GetInputsSize() != SWITCH_INPUT_NUM, { - GELOGE(FAILED, "Switch input param invalid, input_size=%lu, should be %u.", switch_op_desc->GetInputsSize(), - SWITCH_INPUT_NUM); - return nullptr; - }); - - const std::string &node_name = switch_node->GetName() + "_" + STREAMSWITCH + suffix; - GELOGI("Create StreamSwitch, name=%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, STREAMSWITCH); - if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc failed, StreamSwitch:%s.", node_name.c_str()); - return nullptr; - } - - // mark hccl group id - std::string hccl_group_id; - if (AttrUtils::GetStr(switch_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { - (void)AttrUtils::SetStr(op_desc, ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id); - GELOGD("Set attr ATTR_NAME_HCCL_FUSED_GROUP for Stream_Switch %s, value is %s.", node_name.c_str(), - hccl_group_id.c_str()); - } - - if (!AttrUtils::SetInt(op_desc, ATTR_NAME_SWITCH_DATA_TYPE, RT_SWITCH_INT32) || - !AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, (int64_t)RT_EQUAL)) { - GELOGE(INTERNAL_ERROR, "set int failed"); - return nullptr; - } - - // Already checked, first input is Variable will passed, second is condition will checked. - GeTensorDesc cond_input_desc = switch_op_desc->GetInputDesc(SWITCH_PRED_INPUT); - GeTensorDesc input_desc(GeShape(cond_input_desc.GetShape().GetDims()), cond_input_desc.GetFormat(), DT_INT32); - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS, return nullptr, - "Create StreamSwitch node: add input desc failed."); - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS, return nullptr, - "Create StreamSwitch node: add input desc failed."); - - NodePtr stream_switch = graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(stream_switch != nullptr, return nullptr, "Insert StreamSwitch node failed."); - GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), - "StreamSwitch node add cond edge failed."); - - return stream_switch; -} - -/// -/// @brief Mark Switch Branch -/// @param [in] peer_cond_anchor -/// @param [in] stream_switch -/// @param [in] true_branch_flag -/// @return Status -/// -Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_anchor, const NodePtr &stream_switch, - bool true_branch_flag) { - uint32_t index = true_branch_flag ? SWITCH_TRUE_OUTPUT : SWITCH_FALSE_OUTPUT; - auto it = cond_node_map_.find(peer_cond_anchor); - if (it != cond_node_map_.end()) { - int64_t switch_group_id = GetGroupId(stream_switch); - auto switch_group_it = it->second.find(switch_group_id); - if (switch_group_it == it->second.end()) { - std::list false_node_list; - std::list true_node_list; - std::list &node_list = true_branch_flag ? true_node_list : false_node_list; - node_list.emplace_back(stream_switch); - std::vector> switch_list; - switch_list.emplace_back(false_node_list); - switch_list.emplace_back(true_node_list); - it->second[switch_group_id] = switch_list; - } else { - GE_IF_BOOL_EXEC(switch_group_it->second.size() != SWITCH_OUTPUT_NUM, { - GELOGE(INTERNAL_ERROR, "Check size failed, node: %s", stream_switch->GetName().c_str()); - return FAILED; - }); - switch_group_it->second[index].emplace_back(stream_switch); - } - } else { - int64_t switch_group_id = GetGroupId(stream_switch); - map>> switch_group_map; - std::list false_node_list; - std::list true_node_list; - std::list &node_list = true_branch_flag ? true_node_list : false_node_list; - node_list.emplace_back(stream_switch); - std::vector> switch_list; - switch_list.emplace_back(false_node_list); - switch_list.emplace_back(true_node_list); - switch_group_map[switch_group_id] = switch_list; - cond_node_map_[peer_cond_anchor] = switch_group_map; - } - return SUCCESS; -} - -/// -/// @brief Get group_id for switch_node -/// @param [in] node -/// @return group_id -/// -int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { - string tailing_optimization_option; - bool is_tailing_optimization = false; - if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) { - // "1" means it's True from frontend option - is_tailing_optimization = (tailing_optimization_option == "1"); - GELOGI("Option ge.exec.isTailingOptimization is %s", tailing_optimization_option.c_str()); - } - if (!is_tailing_optimization) { - return 0; - } - - string hccl_group_id; - if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { - GELOGI("Node %s can not find hccl group id.", node->GetName().c_str()); - return 0; - } - auto key_index = hccl_group_id.find_last_of('_'); - auto key_num = hccl_group_id.substr(key_index + 1, hccl_group_id.length() - key_index); - GELOGI("Node:%s, hccl_group_id=%s, key_num=%s", node->GetName().c_str(), hccl_group_id.c_str(), key_num.c_str()); - int64_t num = atoi(key_num.c_str()); - if (num == 0) { - return 0; - } - - GELOGI("Hccl_group_id is %s, group_id is %ld", hccl_group_id.c_str(), num); - return num; -} - -/// -/// @brief Combine switch nodes link to same cond -/// @param [in] graph -/// @return Status -/// -Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { - for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { - for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { - std::list false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; - std::list true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; - std::set same_cond_switch; - same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); - same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); - - OutDataAnchorPtr peer_cond_anchor = iter->first; - NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); - GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str()); - - NodePtr cast_node = CreateCastOp(graph, peer_cond_anchor); - GE_CHK_BOOL_EXEC(cast_node != nullptr, return FAILED, "Create cast_node failed."); - - NodePtr active_node = CreateActiveNode(graph, cond_node); - GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node failed."); - GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutControlAnchor(), active_node->GetInControlAnchor()), - "StreamActive add ctl edge failed."); - if (SetActiveLabelList(active_node, {cast_node->GetName()}) != SUCCESS) { - GELOGE(FAILED, "Set active_label_list attr for node %s failed.", active_node->GetName().c_str()); - return FAILED; - } - - const std::string &cond_group = cond_node->GetName(); - for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { - bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); - std::list &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); - GE_IF_BOOL_EXEC(switch_list.empty(), continue); - - // select first stream_switch - NodePtr stream_switch = switch_list.front(); - OpDescPtr switch_desc = stream_switch->GetOpDesc(); - GE_CHECK_NOTNULL(switch_desc); - switch_desc->SetName(CheckDuplicateName(cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f"))); - stream_switch_nodes_.emplace_back(stream_switch); - - // 0_input: original pred input, 1_input: constant node - GE_CHK_STATUS_RET(AddConstNode(graph, stream_switch), "Add const node failed."); - GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), - "StreamSwitch remove data edge failed."); - GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), - "Cast add data edge failed."); - - for (const NodePtr &node : switch_list) { - GE_IF_BOOL_EXEC(node != stream_switch, { - GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), - "StreamSwitch remove data edge failed."); - }); - GE_CHK_STATUS(ModifySwitchInCtlEdges(node, cast_node, same_cond_switch), "ModifySwitchInCtlEdges failed."); - GE_CHK_STATUS(ModifySwitchOutCtlEdges(node, stream_switch, active_node), "ModifySwitchOutCtlEdges failed."); - } - - GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), stream_switch->GetInControlAnchor()), - "StreamActive add ctl edge failed."); - } - } - } - return SUCCESS; -} - -/// -/// @brief Create Active Op -/// @param [in] graph -/// @param [in] cond_node -/// @return ge::NodePtr -/// -NodePtr SwitchToStreamSwitchPass::CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node) { - const std::string &node_name = CheckDuplicateName(node->GetName() + "_" + STREAMACTIVE); - GELOGI("Create StreamActive op:%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, STREAMACTIVE); - if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc failed, StreamActive:%s.", node_name.c_str()); - return nullptr; - } - - NodePtr active_node = graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(active_node != nullptr, return nullptr, "Create StreamActive node failed."); - - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(node->GetOutControlAnchor(), active_node->GetInControlAnchor()) != SUCCESS, - GELOGE(INTERNAL_ERROR, "add edge failed"); - return nullptr); - - GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS, - GELOGE(INTERNAL_ERROR, "set switch branch node label failed"); - return nullptr); - - return active_node; -} - -/// -/// @brief Create cast node -/// @param [in] graph -/// @param [in] peer_cond_anchor -/// @return NodePtr -/// -NodePtr SwitchToStreamSwitchPass::CreateCastOp(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_cond_anchor) { - OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHK_BOOL_EXEC(cond_desc != nullptr, return nullptr, "Get cond_desc failed."); - - const std::string &cast_name = CheckDuplicateName(cond_desc->GetName() + "_" + CAST); - GELOGI("Create cast_node: %s, input datatype:DT_BOOL, out datatype:DT_INT32", cast_name.c_str()); - OpDescPtr cast_desc = MakeShared(cast_name, CAST); - if (cast_desc == nullptr) { - GELOGE(FAILED, "Create op_desc failed, Cast:%s.", cast_name.c_str()); - return nullptr; - } - if (!(AttrUtils::SetInt(cast_desc, CAST_ATTR_SRCT, (int64_t)DT_BOOL) && - AttrUtils::SetInt(cast_desc, CAST_ATTR_DSTT, (int64_t)DT_INT32) && - AttrUtils::SetInt(cast_desc, CAST_ATTR_DST_TYPE, (int64_t)DT_INT32) && - AttrUtils::SetBool(cast_desc, CAST_ATTR_TRUNCATE, false))) { - GELOGE(FAILED, "Set CAST_ATTR_SRCT or CAST_ATTR_DSTT or CAST_ATTR_DST_TYPE or CAST_ATTR_TRUNCATE failed, node: %s.", - cast_name.c_str()); - return nullptr; - } - - GeTensorDesc tensor_desc = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx()); - tensor_desc.SetDataType(DT_BOOL); - GE_CHK_BOOL_EXEC(cast_desc->AddInputDesc(tensor_desc) == SUCCESS, return nullptr, "Cast_node add input desc failed."); - tensor_desc.SetDataType(DT_INT32); - GE_CHK_BOOL_EXEC(cast_desc->AddOutputDesc(tensor_desc) == SUCCESS, return nullptr, - "Cast_node add output desc failed."); - - NodePtr cast_node = graph->AddNode(cast_desc); - GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node failed."); - GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed."); - - return cast_node; -} - -/// -/// @brief Add const node as switch input1 -/// @param [in] graph -/// @param [in] stream_switch -/// @return Status -/// -Status SwitchToStreamSwitchPass::AddConstNode(const ComputeGraphPtr &graph, const NodePtr &stream_switch) { - OpDescPtr op_desc = stream_switch->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - bool value = false; - GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, - "StreamSwitch get attr TRUE_BRANCH_STREAM failed."); - - const std::string &const_node_name = op_desc->GetName() + "_Constant_" + (value ? "t" : "f"); - GELOGI("Create const op: %s", const_node_name.c_str()); - OpDescPtr const_op_desc = MakeShared(const_node_name, CONSTANT); - if (const_op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc failed, Constant:%s.", const_node_name.c_str()); - return FAILED; - } - - auto resize_value = (int32_t)value; - GeTensorDesc data_desc = op_desc->GetInputDesc(1); - GeTensorPtr const_value = - MakeShared(data_desc, reinterpret_cast(&resize_value), sizeof(int32_t)); - if (const_value == nullptr) { - GELOGE(FAILED, "Create tensor failed."); - return FAILED; - } - GE_CHK_BOOL_EXEC(AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value), return FAILED); - GE_CHK_BOOL_EXEC(const_op_desc->AddOutputDesc(data_desc) == GRAPH_SUCCESS, return FAILED, - "Create Const op: add output desc failed."); - - NodePtr const_node = graph->AddNode(const_op_desc); - GE_CHK_BOOL_EXEC(const_node != nullptr, return FAILED, "Insert Const node failed."); - GE_CHK_STATUS(GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(1)), - "StreamSwitch node add ctl edge failed."); - - return SUCCESS; -} - -/// -/// @brief Modify in ctl edge for switch_node -/// @param [in] switch_node -/// @param [in] cast_node -/// @param [in] same_cond_switch -/// @return Status -/// -Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node, - const std::set &same_cond_switch) { - GELOGI("ModifySwitchInCtlEdges: switch_node=%s, active_node=%s", switch_node->GetName().c_str(), - cast_node->GetName().c_str()); - std::string orig_switch_name = switch_node->GetName(); - OpDescPtr switch_desc = switch_node->GetOpDesc(); - GE_CHECK_NOTNULL(switch_desc); - if (!AttrUtils::GetStr(switch_desc, ATTR_NAME_ORIG_NODE_NAME, orig_switch_name) || orig_switch_name.empty()) { - GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ORIG_NODE_NAME failed, node: %s", switch_desc->GetName().c_str()); - return INTERNAL_ERROR; - } - - for (const NodePtr &in_ctl_node : switch_node->GetInControlNodes()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), - "Remove ctl edge failed."); - GE_IF_BOOL_EXEC(!in_ctl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), - "Add ctl edge failed."); - }); - - GE_IF_BOOL_EXEC(in_ctl_node->GetType() != STREAMSWITCH, continue); - if (same_cond_switch.count(in_ctl_node) > 0) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), - "Remove ctl edge failed."); - continue; - } - - auto find_res1 = switch_node_map_.find(in_ctl_node); - GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { - GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctl_node->GetName().c_str()); - return INTERNAL_ERROR; - }); - auto find_res2 = find_res1->second.find(orig_switch_name); - auto find_res3 = find_res1->second.find(cast_node->GetName()); - GE_IF_BOOL_EXEC((find_res2 != find_res1->second.end()) && (find_res3 == find_res1->second.end()), { - find_res1->second.erase(find_res2); - find_res1->second.insert(cast_node->GetName()); - continue; - }); - } - - return SUCCESS; -} - -/// -/// @brief Modify out ctl edge for switch_node -/// @param [in] switch_node -/// @param [in] stream_switch -/// @param [in] active_node -/// @return Status -/// -Status SwitchToStreamSwitchPass::ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch, - const NodePtr &active_node) { - GELOGI("ModifySwitchOutCtlEdges: switch_node=%s, stream_switch=%s, active_node=%s", switch_node->GetName().c_str(), - stream_switch->GetName().c_str(), active_node->GetName().c_str()); - auto find_res = switch_node_map_.find(switch_node); - GE_IF_BOOL_EXEC(find_res == switch_node_map_.end(), { - GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", switch_node->GetName().c_str()); - return INTERNAL_ERROR; - }); - GE_IF_BOOL_EXEC(find_res->second.empty(), { - GELOGE(INTERNAL_ERROR, "true_nodes of StreamSwitch node %s is empty.", switch_node->GetName().c_str()); - return INTERNAL_ERROR; - }); - - for (const NodePtr &node : switch_node->GetOutControlNodes()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(switch_node->GetOutControlAnchor(), node->GetInControlAnchor()), - "Remove ctl edge failed."); - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - std::string orig_name = op_desc->GetName(); - GE_IF_BOOL_EXEC(op_desc->HasAttr(ATTR_NAME_ORIG_NODE_NAME), { - if (!AttrUtils::GetStr(op_desc, ATTR_NAME_ORIG_NODE_NAME, orig_name) || orig_name.empty()) { - GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ORIG_NODE_NAME failed, node: %s.", op_desc->GetName().c_str()); - return INTERNAL_ERROR; - } - }); - if (find_res->second.find(orig_name) == find_res->second.end()) { - auto active_out_ctrl_anchor = active_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(active_out_ctrl_anchor); - GE_IF_BOOL_EXEC(!active_out_ctrl_anchor->IsLinkedWith(node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(active_out_ctrl_anchor, node->GetInControlAnchor()), "Add ctl edge failed."); - }); - } else { - auto switch_out_ctrl_anchor = stream_switch->GetOutControlAnchor(); - GE_CHECK_NOTNULL(switch_out_ctrl_anchor); - GE_IF_BOOL_EXEC(!switch_out_ctrl_anchor->IsLinkedWith(node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(switch_out_ctrl_anchor, node->GetInControlAnchor()), "Add ctl edge failed."); - }); - } - } - - GE_IF_BOOL_EXEC(switch_node != stream_switch, (void)bypass_nodes_.insert(switch_node)); - return SUCCESS; -} - -/// -/// @brief Check duplicate node_name -/// @param [in] node_name -/// @return std::string -/// -std::string SwitchToStreamSwitchPass::CheckDuplicateName(const std::string &node_name) { - std::string tmp_name = node_name; - auto iter = node_num_map_.find(tmp_name); - if (iter != node_num_map_.end()) { - tmp_name = tmp_name + "_" + std::to_string(iter->second); - (iter->second)++; - } else { - node_num_map_[tmp_name] = 1; - } - return tmp_name; -} - -/// -/// @brief Move Control Edges -/// @param [in] old_node -/// @param [in] new_node -/// @return void -/// -void SwitchToStreamSwitchPass::MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node) { - GE_IF_BOOL_EXEC(old_node == new_node, return ); - auto iter = switch_cyclic_map_.find(old_node); - bool check_flag = (iter != switch_cyclic_map_.end()); - for (const NodePtr &in_node : old_node->GetInControlNodes()) { - auto out_ctrl_anchor = in_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL_JUST_RETURN(out_ctrl_anchor); - if (check_flag && (iter->second.count(in_node->GetName()) > 0)) { - for (const auto &out_node : old_node->GetOutAllNodes()) { - GE_IF_BOOL_EXEC(!out_ctrl_anchor->IsLinkedWith(out_node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(out_ctrl_anchor, out_node->GetInControlAnchor()), - "Add in ctrl edge failed."); - }); - } - } else { - GE_IF_BOOL_EXEC(!out_ctrl_anchor->IsLinkedWith(new_node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(out_ctrl_anchor, new_node->GetInControlAnchor()), "Add in ctrl edge failed."); - }); - } - GE_CHK_STATUS(GraphUtils::RemoveEdge(out_ctrl_anchor, old_node->GetInControlAnchor()), - "Remove in ctrl edge failed."); - } - - for (const NodePtr &out_node : old_node->GetOutControlNodes()) { - GE_IF_BOOL_EXEC(!new_node->GetOutControlAnchor()->IsLinkedWith(out_node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_node->GetInControlAnchor()), - "Add out ctrl edge failed."); - }); - GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_node->GetInControlAnchor()), - "Remove out ctrl edge failed."); - } -} -} // namespace ge diff --git a/src/ge/graph/passes/transop_breadth_fusion_pass.cc b/src/ge/graph/passes/transop_breadth_fusion_pass.cc index d8df4a22..53f9e825 100644 --- a/src/ge/graph/passes/transop_breadth_fusion_pass.cc +++ b/src/ge/graph/passes/transop_breadth_fusion_pass.cc @@ -19,12 +19,14 @@ #include #include +#include "framework/common/debug/ge_log.h" #include "common/types.h" #include "graph/common/transop_util.h" #include "graph/utils/node_utils.h" namespace ge { Status TransOpBreadthFusionPass::Run(ge::ComputeGraphPtr graph) { + GE_TIMESTAMP_START(TransOpBreadthFusionPass); if (graph == nullptr) { return SUCCESS; } @@ -45,6 +47,7 @@ Status TransOpBreadthFusionPass::Run(ge::ComputeGraphPtr graph) { } } } + GE_TIMESTAMP_END(TransOpBreadthFusionPass, "GraphManager::TransOpBreadthFusionPass"); return SUCCESS; } diff --git a/src/ge/graph/passes/transop_depth_fusion_pass.cc b/src/ge/graph/passes/transop_depth_fusion_pass.cc index afeca3c4..c0c854b6 100644 --- a/src/ge/graph/passes/transop_depth_fusion_pass.cc +++ b/src/ge/graph/passes/transop_depth_fusion_pass.cc @@ -17,6 +17,7 @@ #include "graph/passes/transop_depth_fusion_pass.h" #include +#include "framework/common/debug/ge_log.h" #include "common/ge_inner_error_codes.h" #include "common/types.h" #include "graph/compute_graph.h" @@ -28,6 +29,7 @@ namespace ge { graphStatus TransOpDepthFusionPass::Run(ComputeGraphPtr graph) { + GE_TIMESTAMP_START(TransOpDepthFusionPass); GELOGI("[TransOpDepthFusionPass]: optimize in depth begin..."); if (graph == nullptr) { return GRAPH_SUCCESS; @@ -51,6 +53,7 @@ graphStatus TransOpDepthFusionPass::Run(ComputeGraphPtr graph) { } } GELOGI("[TransOpDepthFusionPass]: Optimize in depth success..."); + GE_TIMESTAMP_END(TransOpDepthFusionPass, "GraphManager::TransOpDepthFusionPass"); return GRAPH_SUCCESS; } diff --git a/src/ge/graph/passes/transop_symmetry_elimination_pass.cc b/src/ge/graph/passes/transop_symmetry_elimination_pass.cc index 2ff7cd82..38b6684b 100644 --- a/src/ge/graph/passes/transop_symmetry_elimination_pass.cc +++ b/src/ge/graph/passes/transop_symmetry_elimination_pass.cc @@ -24,6 +24,7 @@ namespace { const int kTransOpOutIndex = 0; static std::map precision_loss_transfer_map = {{ge::DT_FLOAT, ge::DT_BOOL}}; + } // namespace namespace ge { Status TransOpSymmetryEliminationPass::Run(NodePtr &node) { diff --git a/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc b/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc index 1d97d9a1..ba4cd031 100644 --- a/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc +++ b/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc @@ -22,6 +22,7 @@ #include "common/ge/ge_util.h" #include "common/ge_inner_error_codes.h" #include "common/types.h" +#include "framework/common/debug/ge_log.h" #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_tensor.h" @@ -732,6 +733,7 @@ void TransOpWithoutReshapeFusionPass::RemoveNousedNodes(const ComputeGraphPtr &g } graphStatus TransOpWithoutReshapeFusionPass::Run(ComputeGraphPtr graph) { + GE_TIMESTAMP_START(TransOpWithoutReshapeFusionPass); GELOGI("[TransOpWithoutReshapeFusionPass]: optimize begin."); if (graph == nullptr) { return GRAPH_SUCCESS; @@ -784,6 +786,7 @@ graphStatus TransOpWithoutReshapeFusionPass::Run(ComputeGraphPtr graph) { } } GELOGI("[TransOpWithoutReshapeFusionPass]: Optimize end."); + GE_TIMESTAMP_END(TransOpWithoutReshapeFusionPass, "GraphManager::TransOpWithoutReshapeFusionPass"); return GRAPH_SUCCESS; } diff --git a/src/ge/graph/passes/variable_op_pass.cc b/src/ge/graph/passes/variable_op_pass.cc index 8c34cd36..175a049a 100644 --- a/src/ge/graph/passes/variable_op_pass.cc +++ b/src/ge/graph/passes/variable_op_pass.cc @@ -20,6 +20,7 @@ #include "common/formats/formats.h" #include "common/formats/utils/formats_trans_utils.h" +#include "framework/common/debug/ge_log.h" #include "graph/ge_context.h" #include "graph/graph.h" #include "graph/manager/graph_var_manager.h" @@ -114,6 +115,7 @@ bool IsTransSupport(const TransNodeInfo &trans_info) { } // namespace Status VariableOpPass::Run(ge::ComputeGraphPtr graph) { + GE_TIMESTAMP_START(VariableOpPass); if (graph == nullptr) { GELOGE(INTERNAL_ERROR, "Failed to run variable op pass, null graph"); return INTERNAL_ERROR; @@ -188,15 +190,9 @@ Status VariableOpPass::Run(ge::ComputeGraphPtr graph) { if (UpdateIOFormatInfo(end_iter->output, node_set) != SUCCESS) { return GE_GRAPH_VARIABLE_OP_PASS_FAILED; } - - // renew var desc if the trans_road is all reshape or reformat - ret = RenewVarDesc(graph->GetSessionID(), node, fusion_road); - if (ret != SUCCESS) { - GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str()); - return FAILED; - } } + GE_TIMESTAMP_END(VariableOpPass, "GraphManager::VariableOpPass"); return SUCCESS; } @@ -608,28 +604,4 @@ Status VariableOpPass::RenewVarDesc(ge::ComputeGraphPtr &graph) { } return SUCCESS; } - -Status VariableOpPass::RenewVarDesc(uint64_t session_id, const NodePtr &node, const VarTransRoad &fusion_road) { - // renew var desc if the trans_road is all reshape or reformat - for (auto &road : fusion_road) { - if (road.node_type != RESHAPE && road.node_type != REFORMAT) { - return SUCCESS; - } - } - - if (!ge::VarManager::Instance(session_id)->IsVarExist(node->GetName())) { - GELOGD("var manager does not exist var node[%s]", node->GetName().c_str()); - return SUCCESS; - } - GELOGD("var manager exist var node[%s]", node->GetName().c_str()); - GE_CHECK_NOTNULL(node->GetOpDesc()); - Status ret = ge::VarManager::Instance(session_id)->RenewCurVarDesc(node->GetName(), node->GetOpDesc()); - if (ret != SUCCESS) { - GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str()); - return FAILED; - } - - return SUCCESS; -} - } // namespace ge diff --git a/src/ge/graph/passes/variable_op_pass.h b/src/ge/graph/passes/variable_op_pass.h index e17980e9..4e194a0c 100644 --- a/src/ge/graph/passes/variable_op_pass.h +++ b/src/ge/graph/passes/variable_op_pass.h @@ -66,7 +66,6 @@ class VariableOpPass : public GraphPass { Status UpdateIOFormatInfo(const GeTensorDesc &final_output, std::set &nodes); Status RenewVarDesc(ge::ComputeGraphPtr &graph); - Status RenewVarDesc(uint64_t session_id, const NodePtr &node, const VarTransRoad &fusion_road); std::map> var_and_var_ref_map_; diff --git a/src/ge/graph/passes/variable_prepare_op_pass.cc b/src/ge/graph/passes/variable_prepare_op_pass.cc index d93e1003..4db78a46 100644 --- a/src/ge/graph/passes/variable_prepare_op_pass.cc +++ b/src/ge/graph/passes/variable_prepare_op_pass.cc @@ -30,7 +30,6 @@ namespace ge { std::map> VariablePrepareOpPass::ref_node_without_prototype_map_{ {REFSWITCH, {{0, 0}, {0, 1}}}}; - Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); for (const auto &node : graph->GetDirectNode()) { @@ -63,6 +62,7 @@ Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { GELOGI("{ %d : %d }", index_iter->first, index_iter->second); } } + return SUCCESS; } @@ -73,13 +73,10 @@ Status VariablePrepareOpPass::DealVariableNode(NodePtr &var_node) { GE_CHECK_NOTNULL(dst_node); InDataAnchorPtr dst_in_data_anchor = dst_node_and_inanchor.second; GE_CHECK_NOTNULL(dst_in_data_anchor); - auto input_index = dst_in_data_anchor->GetIdx(); - int out_index = GetWritableNodeOutIndex(dst_node, input_index); + int out_index = GetWritableNodeOutIndex(dst_node, dst_in_data_anchor->GetIdx()); if (out_index >= 0) { - Status ret = DealWritableNode(dst_node, input_index, var_node); + Status ret = DealWritableNode(dst_node, var_node, out_index); if (ret != SUCCESS) { - GELOGE(FAILED, "Deal writable node[%s] failed, input index: %d, var: %s.", dst_node->GetName().c_str(), - input_index, var_node->GetName().c_str()); return FAILED; } } @@ -87,97 +84,84 @@ Status VariablePrepareOpPass::DealVariableNode(NodePtr &var_node) { return SUCCESS; } -Status VariablePrepareOpPass::DealWritableNode(const ge::NodePtr &writable_node, int input_index, - const ge::NodePtr &var_node) { - // Find the last ref node: - // If the ref input has corresponding output, add variable ref after it. - // If the ref input has no corresponding output, insert RefIdentity and variable ref before it. - // If ref node with control output was found while finding the last ref node, add variable ref after it. - std::stack> nodes_to_check; - nodes_to_check.push({writable_node, input_index}); - while (!nodes_to_check.empty()) { - auto node_index = nodes_to_check.top(); - nodes_to_check.pop(); - auto cur_node = node_index.first; - int cur_input_index = node_index.second; - // Collect ref node after cur node - const auto nodes_size = nodes_to_check.size(); - // Add peer ref output node of current node to stack - CHECK_FALSE_EXEC(GetPeerNodeOfRefInput(cur_node, cur_input_index, nodes_to_check) == SUCCESS, - GELOGE(FAILED, "GetPeerNodeOfRefInput for node[%s] failed.", cur_node->GetName().c_str()); - return FAILED); - auto output_index = GetWritableNodeOutIndex(cur_node, cur_input_index); - CHECK_FALSE_EXEC(output_index >= 0, - GELOGE(FAILED, "Get writable node[%s] ref input[%d]'s corresponding out index failed: %d.", - cur_node->GetName().c_str(), cur_input_index, output_index); - return FAILED); - if (nodes_size == nodes_to_check.size()) { - const auto &op_desc = cur_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - // No need to add variable_ref for frameworkop - if (op_desc->GetType() == FRAMEWORKOP) { - GELOGD("No need to add variable_ref for frameworkop"); - continue; - } - if (static_cast(output_index) < op_desc->GetOutputsSize()) { - // Add variable ref node after ref output for final ref node - CHECK_FALSE_EXEC(AddVariableRef(cur_node, var_node, output_index) == SUCCESS, - GELOGE(FAILED, "Add variable ref failed"); - return FAILED); - } else { - // Insert variable ref node before ref input without corresponding ref output - CHECK_FALSE_EXEC(InsertVariableRef(cur_node, cur_input_index, var_node) == SUCCESS, - GELOGE(FAILED, "Insert variable ref and ref identity failed"); - return FAILED); - } - continue; +Status VariablePrepareOpPass::DealWritableNode(ge::NodePtr &writable_node, ge::NodePtr &var_node, int out_index) { + GE_CHECK_NOTNULL(writable_node); + GE_CHECK_NOTNULL(var_node); + NodePtr final_writable_node = writable_node; + bool is_have_peer_node = false; + for (auto &dst_node_and_inanchor : writable_node->GetOutDataNodesAndAnchors()) { + NodePtr dst_node = dst_node_and_inanchor.first; + GE_CHECK_NOTNULL(dst_node); + InDataAnchorPtr dst_in_data_anchor = dst_node_and_inanchor.second; + GE_CHECK_NOTNULL(dst_in_data_anchor); + is_have_peer_node = true; + int current_out_index = GetWritableNodeOutIndex(dst_node, dst_in_data_anchor->GetIdx()); + if (current_out_index >= 0) { + final_writable_node = GetFinalWritableNode(dst_node, current_out_index); + out_index = current_out_index; } - if (HasControlOut(cur_node)) { - // Add variable ref node after ref output for ref node has control output. - CHECK_FALSE_EXEC(AddVariableRef(cur_node, var_node, output_index) == SUCCESS, - GELOGE(FAILED, "Add variable ref failed"); - return FAILED); + + GE_CHECK_NOTNULL(final_writable_node); + Status ret = AddVariableRef(final_writable_node, var_node, out_index); + if (ret != SUCCESS) { + GELOGE(FAILED, "add variable ref failed"); + return FAILED; + } + } + if (final_writable_node->GetName() == writable_node->GetName() && !is_have_peer_node) { + Status ret = AddVariableRef(final_writable_node, var_node, out_index); + if (ret != SUCCESS) { + return FAILED; } } return SUCCESS; } -Status VariablePrepareOpPass::GetPeerNodeOfRefInput(const ge::NodePtr &node, int input_index, - std::stack> &nodes) { - auto output_index = GetWritableNodeOutIndex(node, input_index); - if (output_index == -1) { - GELOGE(PARAM_INVALID, "Node[%s] is not a ref node.", node->GetName().c_str()); - return PARAM_INVALID; - } - const auto &op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (static_cast(output_index) == op_desc->GetOutputsSize()) { - return SUCCESS; - } - if (output_index >= static_cast(node->GetAllOutDataAnchorsSize())) { - GELOGW("Can not get %d th output anchor of %s", output_index, node->GetName().c_str()); - return SUCCESS; - } - const auto &out_anchor = node->GetOutDataAnchor(output_index); - GE_CHECK_NOTNULL(out_anchor); - for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { - auto peer_node = peer_in_anchor->GetOwnerNode(); - if (peer_node == nullptr) { - continue; +NodePtr VariablePrepareOpPass::GetFinalWritableNode(ge::NodePtr &writable_node, int &out_index) { + NodePtr current_node = writable_node; + std::unordered_set seen_node; + while (true) { + if (seen_node.count(current_node.get())) { + GELOGE(FAILED, "There is a ring structure in the graph"); + return nullptr; + } + seen_node.insert(current_node.get()); + OutDataAnchorPtr out_anchor = current_node->GetOutDataAnchor(out_index); + if (out_anchor == nullptr) { + GELOGE(FAILED, "Failed to get data anchor by index %d", out_index); + return nullptr; + } + bool found_writeable_node = false; + auto peer_in_anchors = out_anchor->GetPeerInDataAnchors(); + for (auto &peer_in_anchor : peer_in_anchors) { + if (peer_in_anchor == nullptr) { + GELOGE(FAILED, "peer in data anchor is nullptr, node %s:%s", current_node->GetType().c_str(), + current_node->GetName().c_str()); + continue; + } + + NodePtr peer_node = peer_in_anchor->GetOwnerNode(); + int current_out_index = GetWritableNodeOutIndex(peer_node, peer_in_anchor->GetIdx()); + if (current_out_index >= 0) { + current_node = peer_node; + out_index = current_out_index; + found_writeable_node = true; + break; + } } - const int peer_in_index = peer_in_anchor->GetIdx(); - if (GetWritableNodeOutIndex(peer_node, peer_in_index) != -1) { - nodes.push({peer_node, peer_in_index}); + if (!found_writeable_node) { + GELOGD("final writable node is %s", current_node->GetName().c_str()); + return current_node; } } - return SUCCESS; } -Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, const ge::NodePtr &var_node, int index) { +Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, ge::NodePtr &var_node, int index) { GE_CHECK_NOTNULL(final_writable_node); GE_CHECK_NOTNULL(var_node); - if (index >= static_cast(final_writable_node->GetAllOutDataAnchorsSize())) { - GELOGW("Can not get %d th output anchor of %s", index, final_writable_node->GetName().c_str()); + + if (final_writable_node->GetType() == FRAMEWORKOP) { + GELOGD("No need to add variable_ref for frameworkop"); return SUCCESS; } // Check for duplicate creation @@ -197,8 +181,7 @@ Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, c // creat variable_ref std::stringstream variable_ref_name; variable_ref_name << "_TO_" << final_writable_node->GetName() << "_REF_" << index; - NodePtr variable_ref_node = CreateVariableRef(var_node->GetName() + variable_ref_name.str(), var_node); - GE_CHECK_NOTNULL(variable_ref_node); + NodePtr variable_ref_node = CreatVariableRef(var_node->GetName() + variable_ref_name.str(), var_node); Status ret_check = CheckStreamLabel(variable_ref_node, final_writable_node); if (ret_check != SUCCESS) { GELOGE(FAILED, "check stream lable failed"); @@ -206,12 +189,23 @@ Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, c } GELOGI("Add variable_ref between [%s] and [%s]", var_node->GetName().c_str(), variable_ref_node->GetName().c_str()); - // add control anchor between variable_ref and final peer node + GE_CHECK_NOTNULL(variable_ref_node); + // add control anchor between variable_ref and final peer node // variable_ref_node need to execute before other nodes - CHECK_FALSE_EXEC(AddControlEdge(final_writable_node, variable_ref_node) == SUCCESS, - GELOGE(FAILED, "Add control edges between variable ref node and output nodes of ref node failed"); - return FAILED); - + auto final_writable_outAnchors = final_writable_node->GetAllOutAnchors(); + for (auto &final_writable_outAnchor : final_writable_outAnchors) { + GE_CHECK_NOTNULL(final_writable_outAnchor); + for (auto &final_writable_peerAnchor : final_writable_outAnchor->GetPeerAnchors()) { + GE_CHECK_NOTNULL(final_writable_peerAnchor); + NodePtr peer_node = final_writable_peerAnchor->GetOwnerNode(); + graphStatus ret = + ge::GraphUtils::AddEdge(variable_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "add control anchor between variable_ref and final_writable peer node failed"); + return FAILED; + } + } + } graphStatus ret = ge::GraphUtils::AddEdge(out_anchor, variable_ref_node->GetInDataAnchor(0)); if (ret != GRAPH_SUCCESS) { GELOGE(FAILED, "add data anchor between variable_ref and final_writable peer node failed"); @@ -220,110 +214,7 @@ Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, c return SUCCESS; } -Status VariablePrepareOpPass::InsertVariableRef(ge::NodePtr &node, int in_index, const ge::NodePtr &var_node) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(var_node); - // Check connection between two nodes - const auto in_anchor = node->GetInDataAnchor(in_index); - GE_CHECK_NOTNULL(in_anchor); - auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_anchor); - auto peer_in_node = peer_out_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(peer_in_node); - - // Create ref_identity - std::stringstream ref_identity_name; - ref_identity_name << "RefIdentity_" << peer_in_node->GetName() << "_" << peer_out_anchor->GetIdx() << "_TO_" - << node->GetName() << "_" << in_index; - NodePtr ref_identity_node = CreateRefIdentity(ref_identity_name.str(), node, static_cast(in_index)); - GE_CHECK_NOTNULL(ref_identity_node); - - // Create variable_ref - std::stringstream variable_ref_name; - variable_ref_name << "_TO_" << node->GetName() << "_REF_" << in_index; - NodePtr variable_ref_node = CreateVariableRef(var_node->GetName() + variable_ref_name.str(), var_node); - GE_CHECK_NOTNULL(variable_ref_node); - Status ret_check = CheckStreamLabel(variable_ref_node, node); - if (ret_check != SUCCESS) { - GELOGE(FAILED, "check stream lable failed"); - return FAILED; - } - - GELOGI("Insert variable_ref of [%s] between [%s] and [%s]", var_node->GetName().c_str(), - peer_in_node->GetName().c_str(), node->GetName().c_str()); - // add control anchor between variable_ref and node - // variable_ref_node need to execute before other nodes - CHECK_FALSE_EXEC(AddControlEdge(node, variable_ref_node) == SUCCESS, - GELOGE(FAILED, "Add control edges between variable ref node and output nodes of ref node failed"); - return FAILED); - - // Insert variable ref node between two nodes and remove the original edge. - CHECK_FALSE_EXEC(ge::GraphUtils::RemoveEdge(peer_out_anchor, in_anchor) == SUCCESS, - GELOGE(FAILED, "Remove edge between ref node and its peer node failed"); - return FAILED); - CHECK_FALSE_EXEC(ge::GraphUtils::AddEdge(peer_out_anchor, ref_identity_node->GetInDataAnchor(0)) == SUCCESS, - GELOGE(FAILED, "Add data edge between pre node and ref_identity failed"); - return FAILED); - CHECK_FALSE_EXEC(ge::GraphUtils::AddEdge(ref_identity_node->GetOutDataAnchor(0), in_anchor) == SUCCESS, - GELOGE(FAILED, "Add data edge between ref_identity and ref node failed"); - return FAILED); - - // Add edge from ref identity node to variable ref node. - CHECK_FALSE_EXEC( - ge::GraphUtils::AddEdge(ref_identity_node->GetOutDataAnchor(0), variable_ref_node->GetInDataAnchor(0)) == SUCCESS, - GELOGE(FAILED, "Add data edge between ref_identity and variable_ref failed"); - return FAILED); - CHECK_FALSE_EXEC( - ge::GraphUtils::AddEdge(node->GetOutControlAnchor(), variable_ref_node->GetInControlAnchor()) == SUCCESS, - GELOGE(FAILED, "Add control edge between ref_identity and variable_ref failed"); - return FAILED); - return SUCCESS; -} - -Status VariablePrepareOpPass::AddControlEdge(const ge::NodePtr &node, const ge::NodePtr &variable_ref_node) { - auto out_anchors = node->GetAllOutAnchors(); - for (auto &out_anchor : out_anchors) { - GE_CHECK_NOTNULL(out_anchor); - for (auto &peer_in_anchor : out_anchor->GetPeerAnchors()) { - GE_CHECK_NOTNULL(peer_in_anchor); - NodePtr peer_node = peer_in_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(peer_node); - CHECK_FALSE_EXEC( - ge::GraphUtils::AddEdge(variable_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()) == SUCCESS, - GELOGE(FAILED, "Add control edge between variable_ref and ref node's peer node failed"); - return FAILED); - } - } - return SUCCESS; -} - -ge::NodePtr VariablePrepareOpPass::CreateRefIdentity(const std::string &ref_identity_name, const ge::NodePtr &node, - uint32_t input_index) { - OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - GELOGE(FAILED, "opdesc is nullptr"); - return nullptr; - } - - OpDescPtr ref_identity_op_desc = MakeShared(ref_identity_name.c_str(), REFIDENTITY); - if (ref_identity_op_desc == nullptr) { - GELOGE(FAILED, "ref_identity op desc is nullptr"); - return nullptr; - } - - GE_IF_BOOL_EXEC(ref_identity_op_desc->AddOutputDesc(op_desc->GetInputDesc(input_index)) != SUCCESS, - GELOGW("add output desc edge failed"); - return nullptr); - GE_IF_BOOL_EXEC(ref_identity_op_desc->AddInputDesc(op_desc->GetInputDesc(input_index)) != SUCCESS, - GELOGW("add input desc edge failed"); - return nullptr); - NodePtr ref_identity_node = node->GetOwnerComputeGraph()->AddNode(ref_identity_op_desc); - GE_IF_BOOL_EXEC(ref_identity_node == nullptr, GELOGW("ref_identity_node is null"); return nullptr); - return ref_identity_node; -} - -ge::NodePtr VariablePrepareOpPass::CreateVariableRef(const std::string &variable_ref_name, - const ge::NodePtr &var_node) { +ge::NodePtr VariablePrepareOpPass::CreatVariableRef(const std::string &variable_ref_name, ge::NodePtr &var_node) { OpDescPtr var_op_desc = var_node->GetOpDesc(); if (var_op_desc == nullptr) { GELOGE(FAILED, "get var opdesc is nullptr"); @@ -359,6 +250,7 @@ int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int inpu } GELOGD("get writable node and input index %s:%d", node->GetName().c_str(), input_index); auto node_type = node->GetType(); + if (node_type == FRAMEWORKOP) { std::string original_type; GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, GELOGW("Get node original type fail")); @@ -374,17 +266,25 @@ void VariablePrepareOpPass::GenerateRefTypeAndInputOutputMap(const NodePtr &node GELOGW("op_desc in null, please check node:[%s]", node->GetName().c_str()); return; } - for (const auto &name_index : op_desc->GetAllInputName()) { - // Record the index of output with the same name as input, thinking of them as a pair of ref input and output. - const int out_index = op_desc->GetOutputIndexByName(name_index.first); - if (out_index != -1) { - ref_input_output_map_[node->GetType()][name_index.second] = out_index; + for (const auto &out_ancohor : node->GetAllOutDataAnchors()) { + int output_index = out_ancohor->GetIdx(); + string output_name = op_desc->GetOutputNameByIndex(output_index); + GELOGD("output name:[%s]", output_name.c_str()); + + int input_index = op_desc->GetInputIndexByName(output_name); + if (input_index == -1) { continue; } - // Record the ref input without corresponding output. - const auto &input_desc = op_desc->GetInputDesc(name_index.second); - if (!input_desc.GetRefPortIndex().empty()) { - ref_input_output_map_[node->GetType()][name_index.second] = static_cast(op_desc->GetOutputsSize()); + auto ref_type_and_input_output_iter = ref_input_output_map_.find(node->GetType()); + if (ref_type_and_input_output_iter != ref_input_output_map_.end()) { + auto &input_output_index_map = ref_type_and_input_output_iter->second; + if (input_output_index_map.find(input_index) == input_output_index_map.end()) { + input_output_index_map.emplace(input_index, output_index); + GELOGD("Add RefInputOutputMap %s:{ %d, %d }", node->GetType().c_str(), input_index, output_index); + } + } else { + ref_input_output_map_.insert({node->GetType(), {{input_index, output_index}}}); + GELOGD("Create RefInputOutputMap { %s:{ %d, %d } }", node->GetType().c_str(), input_index, output_index); } } } @@ -417,15 +317,4 @@ Status VariablePrepareOpPass::CheckStreamLabel(const ge::NodePtr &var_ref_node, } return SUCCESS; } - -bool VariablePrepareOpPass::HasControlOut(const ge::NodePtr &node) { - const auto &out_control_anchor = node->GetOutControlAnchor(); - for (const auto &peer_in_control_anchor : out_control_anchor->GetPeerInControlAnchors()) { - if (peer_in_control_anchor == nullptr || peer_in_control_anchor->GetOwnerNode() == nullptr) { - continue; - } - return true; - } - return false; -} } // namespace ge diff --git a/src/ge/graph/passes/variable_prepare_op_pass.h b/src/ge/graph/passes/variable_prepare_op_pass.h index f024a464..c8b9883e 100644 --- a/src/ge/graph/passes/variable_prepare_op_pass.h +++ b/src/ge/graph/passes/variable_prepare_op_pass.h @@ -18,7 +18,6 @@ #define GE_GRAPH_PASSES_VARIABLE_PREPARE_OP_PASS_H_ #include -#include #include #include "framework/common/ge_inner_error_codes.h" @@ -31,19 +30,15 @@ class VariablePrepareOpPass : public GraphPass { private: Status DealVariableNode(ge::NodePtr &node); - Status DealWritableNode(const ge::NodePtr &writable_node, int input_index, const ge::NodePtr &var_node); - Status GetPeerNodeOfRefInput(const ge::NodePtr &node, int input_index, std::stack> &nodes); - Status AddVariableRef(ge::NodePtr &node, const ge::NodePtr &var_node, int index); - Status InsertVariableRef(ge::NodePtr &node, int in_index, const ge::NodePtr &var_node); - Status AddControlEdge(const ge::NodePtr &node, const ge::NodePtr &variable_ref_node); - NodePtr CreateVariableRef(const std::string &variable_ref_name, const ge::NodePtr &var_node); - NodePtr CreateRefIdentity(const std::string &ref_identity_name, const ge::NodePtr &node, uint32_t input_index); + Status DealWritableNode(ge::NodePtr &writable_node, ge::NodePtr &var_node, int out_index); + NodePtr GetFinalWritableNode(ge::NodePtr &writable_node, int &out_index); + Status AddVariableRef(ge::NodePtr &node, ge::NodePtr &var_node, int index); + NodePtr CreatVariableRef(const std::string &variable_ref_name, ge::NodePtr &var_node); int GetWritableNodeOutIndex(const NodePtr &node, int input_index); void GenerateRefTypeAndInputOutputMap(const NodePtr &node); int FindRefOutIndex(const std::string &node_type, int input_index, const std::map> &ref_map); Status CheckStreamLabel(const ge::NodePtr &var_ref_node, const ge::NodePtr &final_writable_node); - bool HasControlOut(const ge::NodePtr &node); std::map> ref_input_output_map_; static std::map> ref_node_without_prototype_map_; diff --git a/src/ge/graph/passes/variable_ref_delete_op_pass.cc b/src/ge/graph/passes/variable_ref_delete_op_pass.cc index 32236814..cd5b9fe9 100644 --- a/src/ge/graph/passes/variable_ref_delete_op_pass.cc +++ b/src/ge/graph/passes/variable_ref_delete_op_pass.cc @@ -16,10 +16,18 @@ #include "graph/passes/variable_ref_delete_op_pass.h" #include +#include "framework/common/debug/ge_log.h" namespace ge { Status VariableRefDeleteOpPass::Run(ge::ComputeGraphPtr graph) { + GE_TIMESTAMP_START(VariableRefDeleteOpPass); GE_CHECK_NOTNULL(graph); + + for (auto &node : graph->GetDirectNode()) { + GELOGD("before VariableRefDeleteOpPass, graph has node: %s, and node name: %s", node->GetType().c_str(), + node->GetName().c_str()); + } + for (auto &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node->GetOpDesc()); std::string ref_var_src_var_name; @@ -34,6 +42,13 @@ Status VariableRefDeleteOpPass::Run(ge::ComputeGraphPtr graph) { return FAILED; } } + + for (auto &node : graph->GetDirectNode()) { + GELOGD("after VariableRefDeleteOpPass, graph has node: %s, and node name: %s", node->GetType().c_str(), + node->GetName().c_str()); + } + GE_TIMESTAMP_END(VariableRefDeleteOpPass, "GraphManager::VariableRefDeleteOpPass"); + return SUCCESS; } @@ -53,21 +68,21 @@ Status VariableRefDeleteOpPass::DealVariableRef(ge::ComputeGraphPtr &graph, ge:: // get previous node of variable_ref NodePtr peer_node = inAnchor0->GetPeerOutAnchor()->GetOwnerNode(); - // add attr [REF_VAR_SRC_VAR_NAME] to the previous op output desc of the variable_ref - auto op_desc = peer_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - auto out_desc = op_desc->GetOutputDesc(static_cast(index)); - bool is_set_str = ge::AttrUtils::SetStr(out_desc, REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); - (void)op_desc->UpdateOutputDesc(static_cast(index), out_desc); + // add attr [REF_VAR_SRC_VAR_NAME] to the previous node of the variable_ref + GE_CHECK_NOTNULL(peer_node->GetOpDesc()); + bool is_set_str = ge::AttrUtils::SetStr(peer_node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); + ge::NodePtr ref_var_src_var = GraphUtils::FindNodeFromAllNodes(graph, ref_var_src_var_name); if (ref_var_src_var == nullptr) { - GELOGE(FAILED, "Can not find source variable[%s] of variable ref[%s]", ref_var_src_var_name.c_str(), - variable_ref->GetName().c_str()); + GELOGE(FAILED, "get ref_var_src_var failed"); return FAILED; } - if (is_set_str) { - GELOGI("[%s-%d]: add attr [REF_VAR_SRC_VAR_NAME: %s ] ", peer_node->GetName().c_str(), index, - ref_var_src_var_name.c_str()); + + GE_CHECK_NOTNULL(ref_var_src_var->GetOpDesc()); + bool is_set_index = ge::AttrUtils::SetInt(ref_var_src_var->GetOpDesc(), REF_VAR_PRE_PEER_OUT_INDEX, index); + if (is_set_str && is_set_index) { + GELOGI("[%s]: add attr [REF_VAR_SRC_VAR_NAME: %s ] ", peer_node->GetName().c_str(), ref_var_src_var_name.c_str()); + GELOGI("[%s]: add attr [REF_VAR_PRE_PEER_OUT_INDEX: %d]", ref_var_src_var->GetName().c_str(), index); } // remove variable_ref diff --git a/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc b/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc index 1321cf20..bd153184 100644 --- a/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc +++ b/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc @@ -17,6 +17,7 @@ #include "variable_ref_useless_control_out_delete_pass.h" namespace ge { + Status VariableRefUselessControlOutDeletePass::Run(ge::ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); for (const auto &node : graph->GetDirectNode()) { diff --git a/src/ge/graph/preprocess/graph_preprocess.cc b/src/ge/graph/preprocess/graph_preprocess.cc index 94818698..9c82a06d 100644 --- a/src/ge/graph/preprocess/graph_preprocess.cc +++ b/src/ge/graph/preprocess/graph_preprocess.cc @@ -19,12 +19,9 @@ #include #include #include -#include "common/formats/format_transfers/format_transfer_fractal_nz.h" -#include "common/formats/format_transfers/format_transfer_fractal_z.h" #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_transpose.h" -#include "common/formats/utils/formats_trans_utils.h" #include "common/helper/model_helper.h" #include "common/math/math_util.h" #include "common/op/ge_op_utils.h" @@ -83,9 +80,7 @@ #include "graph/passes/switch_dead_branch_elimination.h" #include "graph/passes/switch_fusion_pass.h" #include "graph/passes/switch_logic_remove_pass.h" -#include "graph/passes/merge_to_stream_merge_pass.h" -#include "graph/passes/switch_to_stream_switch_pass.h" -#include "graph/passes/attach_stream_label_pass.h" +#include "graph/passes/switch_op_pass.h" #include "graph/passes/switch_split_pass.h" #include "graph/passes/unused_const_pass.h" #include "graph/passes/unused_op_remove_pass.h" @@ -101,6 +96,7 @@ #include "runtime/dev.h" #include "graph/passes/dimension_adjust_pass.h" +#include "graph/passes/identify_reference_pass.h" #include "graph/passes/link_gen_mask_nodes_pass.h" #include "graph/passes/permute_pass.h" #include "graph/passes/reshape_remove_pass.h" @@ -138,14 +134,14 @@ OpDescPtr CreateTensorShape(const GeTensorDesc &data_tensor) { auto dim_cnt = static_cast(dst_ge_shape.GetDimNum()); if (dim_cnt == 0) { // if the dim_cnt is 0, the tensor is a scalar tensor->MutableTensorDesc().SetShape(GeShape()); - int32_t dst_shape = 1; - if (tensor->SetData(reinterpret_cast(&dst_shape), sizeof(int32_t)) != GRAPH_SUCCESS) { + int64_t dst_shape = 1; + if (tensor->SetData(reinterpret_cast(&dst_shape), sizeof(int64_t)) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "tensor set data failed"); return nullptr; } } else { tensor->MutableTensorDesc().SetShape(GeShape(std::vector({dim_cnt}))); - unique_ptr dst_shape(new (std::nothrow) int32_t[dim_cnt]()); + unique_ptr dst_shape(new (std::nothrow) int64_t[dim_cnt]()); if (dst_shape == nullptr) { GELOGE(INTERNAL_ERROR, "Create unique ptr failed"); return nullptr; @@ -155,7 +151,7 @@ OpDescPtr CreateTensorShape(const GeTensorDesc &data_tensor) { } GE_IF_BOOL_EXEC( - tensor->SetData(reinterpret_cast(dst_shape.get()), dim_cnt * sizeof(int32_t)) != GRAPH_SUCCESS, + tensor->SetData(reinterpret_cast(dst_shape.get()), dim_cnt * sizeof(int64_t)) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "tensor set data failed"); return nullptr;) } @@ -652,39 +648,7 @@ Status ModifyFormatAndShapeForSingleTensor(const GeTensorDescPtr &input_output) input_output->SetShape(ge::GeShape(dst_shape_dims)); return SUCCESS; } -Status ModifyDataNetOutputFormatAndShape(OpDescPtr &op_desc, uint32_t index, Format storage_format, - vector &dst_shape_dims) { - GE_CHECK_NOTNULL(op_desc); - const GeTensorDescPtr &input = op_desc->MutableInputDesc(index); - GE_CHECK_NOTNULL(input); - ge::Format old_format = input->GetFormat(); - std::vector old_shape = input->GetShape().GetDims(); - - input->SetShape(ge::GeShape(dst_shape_dims)); - input->SetFormat(storage_format); - auto output = op_desc->MutableOutputDesc(index); - GE_CHECK_NOTNULL(output); - output->SetShape(ge::GeShape(dst_shape_dims)); - output->SetFormat(storage_format); - - int64_t size = 0; - graphStatus graph_status = TensorUtils::GetTensorMemorySizeInBytes(*output, size); - if (graph_status != ge::GRAPH_SUCCESS) { - GELOGE(graph_status, "GetTensorSizeInBytes failed!"); - return FAILED; - } - ge::TensorUtils::SetSize(*input, size); - ge::TensorUtils::SetSize(*output, size); - - GELOGI( - "Modify Data NetOutput format and shape success, node:%s, index:%d, old_shape:%s, old_Format:%s, " - "new_shape:%s, new_format:%s, new_size:%u", - op_desc->GetName().c_str(), index, formats::JoinToString(old_shape).c_str(), - ge::TypeUtils::FormatToSerialString(old_format).c_str(), formats::JoinToString(dst_shape_dims).c_str(), - ge::TypeUtils::FormatToSerialString(storage_format).c_str(), size); - return SUCCESS; -} Status ProcessInputNC1HWC0(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr &switchn_node) { GE_CHECK_NOTNULL(node_ptr); auto op_desc = node_ptr->GetOpDesc(); @@ -1090,6 +1054,7 @@ Status ProcessInputFP16DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodeP return SUCCESS; } input->SetDataType(DT_FLOAT16); + input->SetOriginDataType(DT_FLOAT16); int64_t input_shape_size = 0; int64_t output_shape_size = 0; ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size); @@ -1102,6 +1067,7 @@ Status ProcessInputFP16DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodeP const GeTensorDescPtr &output = op_desc->MutableOutputDesc(0); GE_CHECK_NOTNULL(output); output->SetDataType(DT_FLOAT16); + output->SetOriginDataType(DT_FLOAT16); ge::TensorUtils::SetSize(*output, output_shape_size); if (is_dynamic_batch) { GELOGI("The node [%s] dtype set fp16", switchn_node->GetName().c_str()); @@ -1110,10 +1076,12 @@ Status ProcessInputFP16DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodeP auto switchn_input = switchn_op_desc->MutableInputDesc(0); GE_CHECK_NOTNULL(switchn_input); switchn_input->SetDataType(DT_FLOAT16); + switchn_input->SetOriginDataType(DT_FLOAT16); for (uint32_t i = 0; i < switchn_node->GetAllOutDataAnchorsSize(); ++i) { const GeTensorDescPtr &switchn_output = switchn_op_desc->MutableOutputDesc(i); GE_CHECK_NOTNULL(switchn_output); switchn_output->SetDataType(DT_FLOAT16); + switchn_output->SetOriginDataType(DT_FLOAT16); } } return SUCCESS; @@ -1132,6 +1100,10 @@ Status ProcessInputNC1HWC0DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, No GELOGE(INTERNAL_ERROR, "The format [%s] is unsupported", TypeUtils::FormatToSerialString(old_format).c_str()); return FAILED; } + if (old_format == FORMAT_NC1HWC0) { + GELOGI("No need to transfer format"); + return SUCCESS; + } if (ModifyInputFormatAndShape(node_ptr) != SUCCESS) { GELOGE(INTERNAL_ERROR, "modify format and shape failed"); return FAILED; @@ -1167,7 +1139,7 @@ Status ProcessDataNodeDynShape(NodePtr &node_ptr) { } for (auto const &next_node : node_ptr->GetOutNodes()) { if (next_node->GetType() == AIPP) { - ErrorManager::GetInstance().ATCReportErrMessage("E10034", {"opname"}, {node_ptr->GetName()}); + ErrorManager::GetInstance().ATCReportErrMessage("E10049", {"opname"}, {node_ptr->GetName()}); GELOGE(INTERNAL_ERROR, "This input op [%s] is linked to aipp, can not be set to fp16, " "please check your atc parameter --insert_op_conf, --input_fp16_nodes.", @@ -1199,42 +1171,6 @@ Status ProcessDataNodeDynShape(NodePtr &node_ptr) { return SUCCESS; } -Status GetStorageFormatAndShape(OpDescPtr &op_desc, const GeTensorDescPtr &tensor_desc_ptr, Format &storage_format, - vector &dst_shape_dims) { - GE_CHECK_NOTNULL(op_desc); - GE_CHECK_NOTNULL(tensor_desc_ptr); - - storage_format = FORMAT_RESERVED; - int64_t format = FORMAT_RESERVED; - dst_shape_dims.clear(); - if (ge::AttrUtils::GetInt(*tensor_desc_ptr, ATTR_NAME_STORAGE_FORMAT, format)) { - storage_format = static_cast(format); - vector storage_shape; - if (ge::AttrUtils::GetListInt(*tensor_desc_ptr, ATTR_NAME_STORAGE_SHAPE, storage_shape)) { - for (auto dim : storage_shape) { - dst_shape_dims.push_back(static_cast(dim)); - } - GELOGI("Update node by storage format, node: [%s], storage_format: [%s], storage_shape:[%s]", - op_desc->GetName().c_str(), TypeUtils::FormatToSerialString(storage_format).c_str(), - formats::JoinToString(storage_shape).c_str()); - } else { - GELOGE(PARAM_INVALID, - "Update node by storage format failed, storage_shape not set. " - "node: [%s], storage_format [%s]", - op_desc->GetName().c_str(), TypeUtils::FormatToSerialString(storage_format).c_str()); - return FAILED; - } - - ge::Format old_format = tensor_desc_ptr->GetFormat(); - auto old_shape = tensor_desc_ptr->GetShape().GetDims(); - if (old_format == storage_format && old_shape == dst_shape_dims) { - GELOGI("Update node by storage format, not changed."); - storage_format = FORMAT_RESERVED; - return SUCCESS; - } - } - return SUCCESS; -} Status ProcessNetoutputNodeFp16Nc1hwc0DynShape(GeTensorDesc &src_desc, GeTensorDescPtr &net_output_input_desc, NodePtr &node) { bool is_dynamic = CheckOpType(node, MERGE); @@ -1244,16 +1180,24 @@ Status ProcessNetoutputNodeFp16Nc1hwc0DynShape(GeTensorDesc &src_desc, GeTensorD ge::Format src_format = src_desc.GetFormat(); net_output_input_desc->SetDataType(DT_FLOAT16); + net_output_input_desc->SetOriginDataType(DT_FLOAT16); if (is_dynamic) { auto merge_output = src_op_desc->MutableOutputDesc(0); GE_CHECK_NOTNULL(merge_output); merge_output->SetDataType(DT_FLOAT16); + merge_output->SetOriginDataType(DT_FLOAT16); for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { auto merge_input = src_op_desc->MutableInputDesc(i); GE_CHECK_NOTNULL(merge_input); merge_input->SetDataType(DT_FLOAT16); + merge_input->SetOriginDataType(DT_FLOAT16); } } + + if (src_format == FORMAT_NC1HWC0) { + GELOGI("Format is NC1HWC0, no need to transfer"); + return SUCCESS; + } std::vector dst_shape_dims; std::vector src_shape_dims = src_shape.GetDims(); if (TransferShape2NC1HWC0(src_format, src_shape_dims, DT_FLOAT16, FORMAT_NC1HWC0, dst_shape_dims) != SUCCESS) { @@ -1347,14 +1291,17 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node, std::string &output_type) { if (NeedUpdateOutputByOutputTypeParm(output_type, src_node, src_index, output_data_type)) { GELOGI("Enter into process output_type schedule"); net_output_input_desc->SetDataType(output_data_type); + net_output_input_desc->SetOriginDataType(output_data_type); if (is_dynamic) { auto merge_output = src_op_desc->MutableOutputDesc(0); GE_CHECK_NOTNULL(merge_output); merge_output->SetDataType(output_data_type); + merge_output->SetOriginDataType(output_data_type); for (uint32_t i = 0; i < src_node->GetAllInDataAnchorsSize(); ++i) { auto merge_input = src_op_desc->MutableInputDesc(i); GE_CHECK_NOTNULL(merge_input); merge_input->SetDataType(output_data_type); + merge_input->SetOriginDataType(output_data_type); } } continue; @@ -1390,6 +1337,7 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node, std::string &output_type) { } return SUCCESS; } + } // namespace GraphPrepare::GraphPrepare() : compute_graph_(nullptr) {} @@ -1483,8 +1431,6 @@ Status GraphPrepare::Init(const ge::Graph &graph, uint64_t session_id) { if (compute_graph_ != nullptr) { compute_graph_->SetSessionID(session_id); } - session_id_ = session_id; - Status ret = CheckGraph(); if (ret != SUCCESS) { GELOGE(ret, "RunGraph graph check fail, ret:%u", ret); @@ -1496,6 +1442,7 @@ Status GraphPrepare::Init(const ge::Graph &graph, uint64_t session_id) { GELOGE(ret, "RunGraph check ref op fail, ret:%u", ret); return ret; } + return SUCCESS; } @@ -1520,13 +1467,13 @@ Status GraphPrepare::CheckGraph() { } Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &input_name, - const std::set &ref_nodes) { + const std::unordered_set &ref_nodes) { // Acceptable input types should be ref node, variable or Switch operator, which is issued by ME for dynamic - // lossscale and would be optimized in SwitchToStreamSwitchPass. - // Since ME dont differentiate between RefSwitch and Switch, and only issue Switch. - static std::set acceptable_types = {ge::VARIABLE, ge::VARIABLEV2, ge::VARHANDLEOP, - ge::REFSWITCH, ge::REFMERGE, ge::REFENTER, - ge::REFNEXTITERATION, ge::REFEXIT, ge::SWITCH}; + // lossscale and would be optimized in SwitchOpPass. Since ME dont differentiate between RefSwitch and Switch, + // and only issue Switch. + static std::unordered_set acceptable_types = {ge::VARIABLE, ge::VARIABLEV2, ge::VARHANDLEOP, + ge::REFSWITCH, ge::REFMERGE, ge::REFENTER, + ge::REFNEXTITERATION, ge::REFEXIT, ge::SWITCH}; GE_CHECK_NOTNULL(node); const auto &op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -1552,6 +1499,7 @@ Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &i } } bool is_acceptable = (acceptable_types.find(input_type) != acceptable_types.end()); + if (!is_acceptable) { GELOGE(PARAM_INVALID, "The ref input of ref node %s[%s] must be ref node or variable, but %s[%s]isn't.", node->GetName().c_str(), node->GetType().c_str(), input_op_desc->GetName().c_str(), @@ -1564,7 +1512,7 @@ Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &i Status GraphPrepare::CheckRefOp() { GE_CHECK_NOTNULL(compute_graph_); - std::set ref_nodes; + std::unordered_set ref_nodes; for (const NodePtr &node : compute_graph_->GetDirectNode()) { if (node == nullptr) { GELOGE(PARAM_INVALID, "param [node] must not be null."); @@ -1576,15 +1524,20 @@ Status GraphPrepare::CheckRefOp() { return PARAM_INVALID; } - auto input_name_index = op_desc->GetAllInputName(); + auto input_names = op_desc->GetAllInputNames(); auto outputs = op_desc->GetAllOutputName(); - for (const auto &name_index : input_name_index) { - if (op_desc->GetOutputIndexByName(name_index.first) != -1) { - if (CheckRefInputNode(node, name_index.first, ref_nodes) != SUCCESS) { + std::unordered_set all_output_name; + + for (auto &output : outputs) { + all_output_name.insert(output.first); + } + for (const auto &input_name : input_names) { + if (all_output_name.find(input_name) != all_output_name.end()) { + if (CheckRefInputNode(node, input_name, ref_nodes) != SUCCESS) { GELOGE(PARAM_INVALID, "CheckRefInputNode failed."); return PARAM_INVALID; } - (void)ref_nodes.insert(node); // no need to check value + (void)ref_nodes.insert(node); } } } @@ -1595,7 +1548,7 @@ Status GraphPrepare::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode) { GELOGI("set rt_context %d, device id:%u.", static_cast(mode), ge::GetContext().DeviceId()); GE_CHK_RT_RET(rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId())); GE_CHK_RT_RET(rtCtxSetCurrent(rt_context)); - RtContextUtil::GetInstance().AddRtContext(session_id_, rt_context); + RtContextUtil::GetInstance().AddrtContext(rt_context); return SUCCESS; } @@ -1613,8 +1566,6 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) { int64_t tensor_size = 0; graphStatus graph_status = TensorUtils::GetTensorMemorySizeInBytes(output, tensor_size); if (graph_status != GRAPH_SUCCESS) { - ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, - {"GetTensorMemorySizeInBytes", "opname is " + node->GetName()}); GELOGE(graph_status, "GetTensorMemorySizeInBytes failed!"); return FAILED; } @@ -1648,16 +1599,12 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input) { GeTensorDesc desc(user_input[index].GetTensorDesc()); auto format = desc.GetFormat(); auto origin_format = desc.GetOriginFormat(); - // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. - bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op); + bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); + bool need_check_internal_format = (!options_.is_single_op) && is_internal; if (need_check_internal_format) { - bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); - if (is_internal) { - GELOGE(PARAM_INVALID, "Input format %s or origin_format %s is not support.", - TypeUtils::FormatToSerialString(format).c_str(), - TypeUtils::FormatToSerialString(origin_format).c_str()); - return FAILED; - } + GELOGE(PARAM_INVALID, "Input format %s or origin_format %s is not support.", + TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::FormatToSerialString(origin_format).c_str()); + return FAILED; } auto data_type = desc.GetDataType(); @@ -1676,8 +1623,7 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input) { GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "TensorUtils GetSize failed"); return FAILED); - bool size_check = (size != 0 && shape_size != size); - if (size_check) { + if ((size != 0) && (shape_size != size)) { GELOGE(PARAM_INVALID, "input data size =%ld, shape_size =%ld.", size, shape_size); return FAILED; } @@ -1825,55 +1771,6 @@ Status GraphPrepare::OptimizeAfterInfershapeByAtcParams() { return SUCCESS; } -Status GraphPrepare::UpdateDataNetOutputByStorageFormat() { - for (auto &node_ptr : compute_graph_->GetAllNodes()) { - GE_CHECK_NOTNULL(node_ptr); - if (node_ptr->GetType() == DATA) { - uint32_t index = 0; - auto op_desc = node_ptr->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - const GeTensorDescPtr input = op_desc->MutableInputDesc(index); - Format storage_format = FORMAT_RESERVED; - vector dst_shape_dims; - if (GetStorageFormatAndShape(op_desc, input, storage_format, dst_shape_dims) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get storage format for input failed"); - return FAILED; - } - - if (storage_format == FORMAT_RESERVED) { - continue; - } - - if (ModifyDataNetOutputFormatAndShape(op_desc, index, storage_format, dst_shape_dims) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Modify format and shape for inputfailed"); - return FAILED; - } - } - - if (node_ptr->GetType() == ge::NETOUTPUT) { - auto op_desc = node_ptr->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - for (uint32_t index = 0; index < op_desc->GetOutputsSize(); index++) { - const GeTensorDescPtr output = op_desc->MutableOutputDesc(index); - Format storage_format = FORMAT_RESERVED; - vector dst_shape_dims; - if (GetStorageFormatAndShape(op_desc, output, storage_format, dst_shape_dims) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get storage format from output failed"); - return FAILED; - } - if (storage_format == FORMAT_RESERVED) { - continue; - } - if (ModifyDataNetOutputFormatAndShape(op_desc, index, storage_format, dst_shape_dims) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Modify format and shape for output failed"); - return FAILED; - } - } - } - } - return SUCCESS; -} - void GraphPrepare::ProcessCCEFormat() { static const char *const parser_priority = std::getenv("PARSER_PRIORITY"); static const bool keep_cce = parser_priority != nullptr && string(parser_priority) == "cce"; @@ -2058,7 +1955,9 @@ Status GraphPrepare::PrepareDynShape(ConstGraphPtr graph, const std::vector(options_.framework_type); const Graph &const_graph = *graph; @@ -2090,6 +1989,7 @@ Status GraphPrepare::PrepareRunningFormatRefiner() { PassManager pass_manager; GE_CHK_STATUS_RET(pass_manager.AddPass("PrepareRunningFormatRefiner::VariablePrepareOpPass", new (std::nothrow) VariablePrepareOpPass)) + GE_CHK_STATUS_RET(pass_manager.AddPass("PrepareRunningFormatRefiner::SubgraphPass", new (std::nothrow) SubgraphPass)) GE_TIMESTAMP_START(pass_manager); auto ret = pass_manager.Run(compute_graph); GE_TIMESTAMP_END(pass_manager, "GraphPrepare::PrepareRunningFormatRefiner"); @@ -2153,6 +2053,10 @@ Status GraphPrepare::GenerateInfershapeGraph(ConstGraphPtr graph) { Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &user_input, ge::ComputeGraphPtr &compute_graph, VarAccelerateCtrl &var_acc_ctrl, uint64_t session_id) { + // train graph flag + if (options_.train_graph_flag) { + domi::GetContext().train_flag = true; + } domi::GetContext().type = static_cast(options_.framework_type); if (graph == nullptr) { @@ -2167,7 +2071,7 @@ Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &u } GraphOptimize graph_optimize; - if (!options_.train_graph_flag && !domi::GetContext().train_flag) { + if (!domi::GetContext().train_flag) { GE_DUMP(compute_graph_, "BeforeOriginalGraphForQuantize"); GE_TIMESTAMP_START(OptimizeOriginalGraphForQuantize); ret = graph_optimize.OptimizeOriginalGraphForQuantize(compute_graph_); @@ -2398,10 +2302,10 @@ Status GraphPrepare::PrepareOptimize() { GEPass ge_passes(compute_graph_); NamesToPass names_to_passes; EnterPass enter_pass; + PrintOpPass print_pass; names_to_passes.emplace_back("EnterPass", &enter_pass); CondPass cond_pass; names_to_passes.emplace_back("CondPass", &cond_pass); - PrintOpPass print_pass; if (options_.enable_print_op_pass) { names_to_passes.emplace_back("PrintOpPass", &print_pass); } @@ -2574,9 +2478,7 @@ Status GraphPrepare::OptimizeForPreprocess() { (void)graph_pass.AddPass("OptimizeForPreprocess::PrunePass", new PrunePass); (void)graph_pass.AddPass("OptimizeForPreprocess::NextIterationPass", new NextIterationPass); (void)graph_pass.AddPass("OptimizeForPreprocess::ControlTriggerPass", new ControlTriggerPass); - (void)graph_pass.AddPass("OptimizeForPreprocess::MergeToStreamMergePass", new MergeToStreamMergePass); - (void)graph_pass.AddPass("OptimizeForPreprocess::SwitchToStreamSwitchPass", new SwitchToStreamSwitchPass); - (void)graph_pass.AddPass("OptimizeForPreprocess::AttachStreamLabelPass", new AttachStreamLabelPass); + (void)graph_pass.AddPass("OptimizeForPreprocess::SwitchOpPass", new SwitchOpPass); (void)graph_pass.AddPass("OptimizeForPreprocess::HcclMemcpyPass", new HcclMemcpyPass); GE_IF_BOOL_EXEC(options_.train_graph_flag, (void)graph_pass.AddPass("OptimizeForPreprocess::FlowCtrlPass", new FlowCtrlPass);); @@ -2658,6 +2560,8 @@ Status GraphPrepare::NewOptimizeGraphBeforeSubGraph(VarAccelerateCtrl &var_acc_c GEPass ge_passes_for_shape(compute_graph_); NamesToPass names_to_passes_for_shape; + IdentifyReferencePass identify_reference_pass; + names_to_passes_for_shape.emplace_back("IdentifyReferencePass", &identify_reference_pass); CastRemovePass cast_remove_pass; names_to_passes_for_shape.emplace_back("CastRemovePass", &cast_remove_pass); TransposeTransDataPass transpose_transdata_pass; @@ -2789,12 +2693,6 @@ Status GraphPrepare::CheckAndUpdateInput(const std::vector &user_input return SUCCESS; } Status GraphPrepare::UpdateInputOutputByOptions() { - auto ret = UpdateDataNetOutputByStorageFormat(); - if (ret != SUCCESS) { - GELOGE(ret, "Update format acoording to storage format failed."); - return ret; - } - if (options_.train_graph_flag) { GELOGI("This is train mode, no need to do this schedule."); return SUCCESS; @@ -2838,21 +2736,6 @@ bool GraphPrepare::IsBroadCastOpData(const ge::NodePtr &var_node) { return false; } -bool GraphPrepare::IsTansDataOpData(const ge::NodePtr &var_node) { - for (auto &out_anchor : var_node->GetAllOutDataAnchors()) { - GE_RT_FALSE_CHECK_NOTNULL(out_anchor); - for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { - GE_RT_FALSE_CHECK_NOTNULL(in_anchor); - ge::NodePtr dst_node = in_anchor->GetOwnerNode(); - GE_RT_FALSE_CHECK_NOTNULL(dst_node); - if (dst_node->GetType() == TRANSDATA) { - return true; - } - } - } - return false; -} - bool GraphPrepare::ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, const map> &confirm_ops, ge::NodePtr &use_node) { GE_RT_FALSE_CHECK_NOTNULL(in_anchor); diff --git a/src/ge/graph/preprocess/graph_preprocess.h b/src/ge/graph/preprocess/graph_preprocess.h index bae2a885..b90caa86 100644 --- a/src/ge/graph/preprocess/graph_preprocess.h +++ b/src/ge/graph/preprocess/graph_preprocess.h @@ -59,7 +59,8 @@ class GraphPrepare { Status Init(const ge::Graph &graph, uint64_t session_id = 0); Status Preprocess(const std::vector &user_input); Status CheckGraph(); - Status CheckRefInputNode(const NodePtr &node, const std::string &input_name, const std::set &ref_nodes); + Status CheckRefInputNode(const NodePtr &node, const std::string &input_name, + const std::unordered_set &ref_nodes); Status CheckRefOp(); Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); Status AdjustDataOpOutput(const NodePtr &node); @@ -68,7 +69,6 @@ class GraphPrepare { Status CheckConstOp(); Status VerifyConstOp(const NodePtr &node); Status CheckUserInput(const std::vector &user_input); - Status UpdateDataNetOutputByStorageFormat(); Status OptimizeForPreprocess(); Status PrepareOptimize(); Status InferShapeForPreprocess(); @@ -88,8 +88,6 @@ class GraphPrepare { Status UpdateInputOutputByOptions(); bool IsBroadCastOpData(const ge::NodePtr &var_node); - bool IsTansDataOpData(const ge::NodePtr &var_node); - void AdjustBroadCastOpData(const ge::NodePtr &var_node); bool IsAssignOpData(const ge::NodePtr &var_node); @@ -106,7 +104,6 @@ class GraphPrepare { ge::ComputeGraphPtr compute_graph_; GraphManagerOptions options_; - uint64_t session_id_ = 0; }; } // namespace ge #endif // GE_GRAPH_PREPROCESS_GRAPH_PREPROCESS_H_ diff --git a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc index 8bb0c6c4..5fe19869 100644 --- a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc +++ b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc @@ -45,7 +45,7 @@ static void ConvertShape2Nhwc(Format &format, vector &shape_vec) { return; } if (format != FORMAT_NCHW) { - GELOGW("The format is not NCHW, current format is %s.", TypeUtils::FormatToSerialString(format).c_str()); + GELOGW("The format is not NCHW, current format is %s", TypeUtils::FormatToSerialString(format).c_str()); return; } vector shape_vec_tmp; @@ -245,6 +245,7 @@ Status InsertNewOpUtil::UpdatePrevNodeByAipp(NodePtr &node, std::set &s GELOGE(FAILED, "Can not get size from aipp [%s]", aipp_op_desc->GetName().c_str()); return FAILED; } + // Save the input size of aipp node, which will be used in dumping aipp node or fused aipp node (void)AttrUtils::SetInt(aipp_input, ATTR_NAME_INPUT_ORIGIN_SIZE, size); auto in_data_anchor = node->GetInDataAnchor(0); diff --git a/src/ge/graph/preprocess/multi_batch_copy_graph.cc b/src/ge/graph/preprocess/multi_batch_copy_graph.cc index d06a493d..fbe935ec 100644 --- a/src/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/src/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -126,12 +126,8 @@ Status CalcShape(const std::vector &batch_shape, GeShape &data_shape) { for (size_t i = 0; i < data_shape.GetDimNum(); ++i) { if (data_shape.GetDim(i) < 0) { if (batch_shape_index >= batch_shape.size()) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19012", {"function", "reason"}, - {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) + - " does not match the data shape " + data_shape.ToString()}); GELOGE(PARAM_INVALID, - "Failed to calc tensor shape, the batch shape count %zu, does not match the data shape %s", + "Failed to calc tensor shape, the batch shape count %zu, doees not match the data shape %s", batch_shape.size(), data_shape.ToString().c_str()); return PARAM_INVALID; } @@ -139,10 +135,6 @@ Status CalcShape(const std::vector &batch_shape, GeShape &data_shape) { } } if (batch_shape_index != batch_shape.size()) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E19012", {"function", "reason"}, - {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) + " does not match the data shape " + - data_shape.ToString()}); GELOGE(PARAM_INVALID, "Failed to calc tensor shape, the batch shape count %zu, does not match the data shape %s", batch_shape.size(), data_shape.ToString().c_str()); return PARAM_INVALID; @@ -207,7 +199,7 @@ Status CheckDataShape(const std::vector &nodes) { } } if (unknown_shape_count == 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10040"); + ErrorManager::GetInstance().ATCReportErrMessage("E10055"); GELOGE(PARAM_INVALID, "Need unknow shape data when user set --dynamic_batch_size or --dynamic_image_size, please check."); return PARAM_INVALID; @@ -287,8 +279,6 @@ Status MultiBatchGraphCopyer::CreateNewNodes() { case kNodeOutBatchBranch: ret = InsertMergeForEdgeNode(node); break; - case kNodeNotSupportNode: - break; default: GELOGE(INTERNAL_ERROR, "Unexpected status %d on node %s", static_cast(branch_status), node->GetName().c_str()); @@ -301,13 +291,7 @@ Status MultiBatchGraphCopyer::CreateNewNodes() { } return SUCCESS; } - NodeStatus MultiBatchGraphCopyer::GetNodeStatus(const NodePtr &node) { - // node with subgraph is not supported - if (!(node->GetOpDesc()->GetSubgraphInstanceNames().empty())) { - return kNodeNotSupportNode; - } - if (node->GetType() == NETOUTPUT) { return kNodeOutBatchBranch; } @@ -321,7 +305,6 @@ NodeStatus MultiBatchGraphCopyer::GetNodeStatus(const NodePtr &node) { } return kNodeOutBatchBranch; } - NodePtr MultiBatchGraphCopyer::InsertMergeNode(const NodePtr &node, int index) { if (index < 0) { // the merge node must has data inputs, if origin connection is a control @@ -494,7 +477,7 @@ Status MultiBatchGraphCopyer::CheckArguments() { return PARAM_INVALID; } if (shapes_.size() < kMinShapesCount) { - ErrorManager::GetInstance().ATCReportErrMessage("E10035", {"shapesize", "minshapesize"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10050", {"shapesize", "minshapesize"}, {std::to_string(shapes_.size()), std::to_string(kMinShapesCount)}); GELOGE(PARAM_INVALID, "Input parameter[--dynamic_batch_size or --dynamic_image_size]'s " @@ -503,7 +486,7 @@ Status MultiBatchGraphCopyer::CheckArguments() { return PARAM_INVALID; } if (shapes_.size() > kMaxShapesCount) { - ErrorManager::GetInstance().ATCReportErrMessage("E10036", {"shapesize", "maxshapesize"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10051", {"shapesize", "maxshapesize"}, {std::to_string(shapes_.size()), std::to_string(kMaxShapesCount)}); GELOGE(PARAM_INVALID, "Input parameter[--dynamic_batch_size or --dynamic_image_size]'s " @@ -515,7 +498,7 @@ Status MultiBatchGraphCopyer::CheckArguments() { size_t shape_size = shapes_.at(0).size(); for (auto &shape : shapes_) { if (shape_size != shape.size()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10037", {"shapesize1", "shapesize2"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10052", {"shapesize1", "shapesize2"}, {std::to_string(shape_size), std::to_string(shape.size())}); GELOGE(PARAM_INVALID, "Input parameter[--dynamic_batch_size or --dynamic_image_size]'s " @@ -525,7 +508,7 @@ Status MultiBatchGraphCopyer::CheckArguments() { } for (auto dim : shape) { if (dim <= 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10038", {"dim"}, {std::to_string(dim)}); + ErrorManager::GetInstance().ATCReportErrMessage("E10053", {"dim"}, {std::to_string(dim)}); GELOGE(PARAM_INVALID, "Invalid dim %ld, all dims must be greater than 0", dim); return PARAM_INVALID; } @@ -533,7 +516,7 @@ Status MultiBatchGraphCopyer::CheckArguments() { shapes_set.insert(shape); } if (shapes_set.size() != shapes_.size()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10039"); + ErrorManager::GetInstance().ATCReportErrMessage("E10054"); GELOGE(PARAM_INVALID, "Input parameter[--dynamic_batch_size or --dynamic_image_size] exist duplicate shapes, please check"); return PARAM_INVALID; diff --git a/src/ge/graph/preprocess/multi_batch_copy_graph.h b/src/ge/graph/preprocess/multi_batch_copy_graph.h index bf1d53b3..2500645f 100644 --- a/src/ge/graph/preprocess/multi_batch_copy_graph.h +++ b/src/ge/graph/preprocess/multi_batch_copy_graph.h @@ -33,7 +33,6 @@ enum NodeStatus { kNodeInBatchBranch, kNodeOutBatchBranch, kNodeStartNode, - kNodeNotSupportNode, }; class MultiBatchGraphCopyer { diff --git a/src/ge/host_kernels/identity_kernel.cc b/src/ge/host_kernels/identity_kernel.cc deleted file mode 100644 index 16bd3138..00000000 --- a/src/ge/host_kernels/identity_kernel.cc +++ /dev/null @@ -1,63 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "identity_kernel.h" -#include "inc/kernel_factory.h" - -namespace { -constexpr uint32_t kInputDescIndex = 0; -constexpr uint32_t kOutputDescIndex = 0; -} // namespace - -namespace ge { -Status IdentityKernel::Compute(const ge::OpDescPtr op_desc, const std::vector &input, - std::vector &v_output) { - if (op_desc == nullptr) { - GELOGE(PARAM_INVALID, "IdentityKernel op_desc is null."); - return NOT_CHANGED; - } - if (input.empty()) { - GELOGE(PARAM_INVALID, "Node [%s] inputs is empty.", op_desc->GetName().c_str()); - return NOT_CHANGED; - } - if (op_desc->GetOutputsSize() < 1) { - GELOGE(PARAM_INVALID, "Node [%s] output is empty.", op_desc->GetName().c_str()); - return NOT_CHANGED; - } - GELOGD("IdentityKernel in: node[%s]", op_desc->GetName().c_str()); - - auto out_tensor_desc = op_desc->GetOutputDesc(kOutputDescIndex); - GeTensorPtr output_ptr = MakeShared(out_tensor_desc); - if (output_ptr == nullptr) { - GELOGE(OUT_OF_MEMORY, "Node [%s] make shared failed.", op_desc->GetName().c_str()); - return OUT_OF_MEMORY; - } - auto input_tensor_ptr = input.at(kInputDescIndex); - if (input_tensor_ptr == nullptr) { - GELOGE(PARAM_INVALID, "Node [%s] get input failed.", op_desc->GetName().c_str()); - return NOT_CHANGED; - } - if (output_ptr->SetData(input_tensor_ptr->GetData()) != GRAPH_SUCCESS) { - GELOGW("Compute: SetData failed"); - return NOT_CHANGED; - } - v_output.emplace_back(output_ptr); - GELOGD("IdentityKernel success: node[%s]", op_desc->GetName().c_str()); - - return SUCCESS; -} -REGISTER_KERNEL(IDENTITY, IdentityKernel); -} // namespace ge diff --git a/src/ge/host_kernels/identity_kernel.h b/src/ge/host_kernels/identity_kernel.h deleted file mode 100644 index 2164d880..00000000 --- a/src/ge/host_kernels/identity_kernel.h +++ /dev/null @@ -1,31 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_PASSES_FOLDING_KERNEL_IDENTITY_KERNEL_H_ -#define GE_GRAPH_PASSES_FOLDING_KERNEL_IDENTITY_KERNEL_H_ - -#include "inc/kernel.h" -#include - -namespace ge { -class IdentityKernel : public Kernel { - public: - Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector &input, - std::vector &v_output) override; -}; -} // namespace ge - -#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_IDENTITY_KERNEL_H_ diff --git a/src/ge/inc/kernel_factory.h b/src/ge/inc/kernel_factory.h index 61455836..c0624e14 100644 --- a/src/ge/inc/kernel_factory.h +++ b/src/ge/inc/kernel_factory.h @@ -103,5 +103,5 @@ class KernelFactory { return ptr; \ } \ KernelFactory::Registerar g_##type##_Kernel_Creator(type, Creator_##type##_Kernel) -} // namespace ge +}; // end namespace ge #endif // GE_INC_KERNEL_FACTORY_H_ diff --git a/src/ge/init/gelib.cc b/src/ge/init/gelib.cc index f7740a3c..5fcb0cd7 100644 --- a/src/ge/init/gelib.cc +++ b/src/ge/init/gelib.cc @@ -37,7 +37,6 @@ #include "graph/load/new_model_manager/model_manager.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_var_manager.h" -#include "graph/common/ge_call_wrapper.h" #include "omm/csa_interact.h" #include "runtime/kernel.h" @@ -47,9 +46,6 @@ namespace ge { namespace { const int kDecimal = 10; const int kSocVersionLen = 50; -const uint32_t kAicoreOverflow = (0x1 << 0); -const uint32_t kAtomicOverflow = (0x1 << 1); -const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow); } // namespace static std::shared_ptr instancePtr_ = nullptr; @@ -79,7 +75,7 @@ Status GELib::Initialize(const map &options) { instancePtr_ = nullptr; return ret; } - GE_TIMESTAMP_EVENT_END(Init, "GELib::Initialize"); + GE_TIMESTAMP_END(Init, "GELib::Initialize"); return SUCCESS; } @@ -130,6 +126,16 @@ Status GELib::InnerInitialize(const map &options) { return initSmStatus; } + GELOGI("memoryMallocSize initial."); + GE_TIMESTAMP_START(SetMemoryMallocSize); + Status initMemStatus = VarManager::Instance(0)->SetMemoryMallocSize(options); + GE_TIMESTAMP_END(SetMemoryMallocSize, "InnerInitialize::SetMemoryMallocSize"); + if (initMemStatus != SUCCESS) { + GELOGE(initMemStatus, "failed to set malloc size"); + RollbackInit(); + return initMemStatus; + } + GELOGI("Start to initialize HostCpuEngine"); GE_TIMESTAMP_START(HostCpuEngineInitialize); Status initHostCpuEngineStatus = HostCpuEngine::GetInstance().Initialize(); @@ -154,6 +160,37 @@ Status GELib::SystemInitialize(const map &options) { } } + iter = options.find(HEAD_STREAM); + head_stream_ = (iter != options.end()) ? std::strtol(iter->second.c_str(), nullptr, kDecimal) : false; + + iter = options.find(OPTION_EXEC_ENABLE_DUMP); + if (iter != options.end()) { + int32_t enable_dump_flag = 1; + auto path_iter = options.find(OPTION_EXEC_DUMP_PATH); + if (iter->second == std::to_string(enable_dump_flag) && path_iter != options.end()) { + std::string dump_path = path_iter->second; + if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { + dump_path = dump_path + "/" + CurrentTimeInStr() + "/"; + } + + PropertiesManager::Instance().AddDumpPropertyValue(DUMP_ALL_MODEL, {}); + GELOGD("Get dump path %s successfully", dump_path.c_str()); + PropertiesManager::Instance().SetDumpOutputPath(dump_path); + } + auto step_iter = options.find(OPTION_EXEC_DUMP_STEP); + if (step_iter != options.end()) { + std::string dump_step = step_iter->second; + GELOGD("Get dump step %s successfully", dump_step.c_str()); + PropertiesManager::Instance().SetDumpStep(dump_step); + } + auto mode_iter = options.find(OPTION_EXEC_DUMP_MODE); + if (mode_iter != options.end()) { + std::string dump_mode = mode_iter->second; + GELOGD("Get dump mode %s successfully", dump_mode.c_str()); + PropertiesManager::Instance().SetDumpMode(dump_mode); + } + } + // In train and infer, profiling is always needed. InitOptions(options); InitProfiling(this->options_); diff --git a/src/ge/init/gelib.h b/src/ge/init/gelib.h index b5621dfd..0dfec391 100644 --- a/src/ge/init/gelib.h +++ b/src/ge/init/gelib.h @@ -62,6 +62,9 @@ class GELib { // get TrainMode flag bool isTrainMode() { return is_train_mode_; } + // add head stream to model + bool HeadStream() const { return head_stream_; } + // get incre build flag bool IsIncreBuild() const { return is_incre_build_; } @@ -83,8 +86,6 @@ class GELib { Status SetRTSocVersion(const map &options, map &new_options); void RollbackInit(); void InitOptions(const map &options); - void SetDumpModelOptions(const map &options); - void SetOpDebugOptions(const map &options); DNNEngineManager engineManager_; OpsKernelManager opsManager_; @@ -97,6 +98,7 @@ class GELib { bool is_shutdown = false; bool is_use_hcom = false; bool is_incre_build_ = false; + bool head_stream_ = false; std::string incre_build_cache_path_; }; } // namespace ge diff --git a/src/ge/ir_build/atc_ir_common.cc b/src/ge/ir_build/atc_ir_common.cc index 352e5dc2..12c85bc0 100644 --- a/src/ge/ir_build/atc_ir_common.cc +++ b/src/ge/ir_build/atc_ir_common.cc @@ -32,29 +32,8 @@ const int64_t kDynamicImageSizeNum = 2; // datatype/formats from user to GE, Unified to util interface file later const std::map kOutputTypeSupportDatatype = { {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; -const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; -const std::set kBufferOptimizeSupportOption = {"l1_optimize", "l2_optimize", "off_optimize", - "l1_and_l2_optimize"}; -// The function is incomplete. Currently, only l2_optimize, off_optimize is supported. -const char *const kBufferOptimizeSupport = "only support l2_optimize, off_optimize"; +const std::set kBufferOptimizeSupportOption = {"l2_optimize", "off_optimize"}; const std::string IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance"; -const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\""; -const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\""; -const char *const kSplitError1 = "size not equal to 2 split by \":\""; -const char *const kEmptyError = "can not be empty"; -const char *const kFloatNumError = "exist float number"; -const char *const kDigitError = "is not digit"; -const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; - -vector SplitInputShape(const std::string &input_shape) { - vector shape_pair_vec; - size_t pos = input_shape.rfind(":"); - if (pos != std::string::npos) { - shape_pair_vec.emplace_back(input_shape.substr(0, pos)); - shape_pair_vec.emplace_back(input_shape.substr(pos + 1, input_shape.size() - pos)); - } - return shape_pair_vec; -} } // namespace bool CheckDynamicBatchSizeInputShapeValid(unordered_map> shape_map, @@ -63,7 +42,7 @@ bool CheckDynamicBatchSizeInputShapeValid(unordered_map> for (auto iter = shape_map.begin(); iter != shape_map.end(); ++iter) { vector shape = iter->second; if (shape.size() < 1) { - ErrorManager::GetInstance().ATCReportErrMessage("E10012"); + ErrorManager::GetInstance().ATCReportErrMessage("E10017"); GELOGE(ge::PARAM_INVALID, "--input_shape's shape size can not be less than 1 when set --dynamic_batch_size."); return false; } @@ -82,14 +61,14 @@ bool CheckDynamicBatchSizeInputShapeValid(unordered_map> } if (size == 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10031"); + ErrorManager::GetInstance().ATCReportErrMessage("E10043"); GELOGE(ge::PARAM_INVALID, "At least one batch n must be equal to -1 when set --dynamic_batch_size."); return false; } for (char c : dynamic_batch_size) { if (!isdigit(c) && (c != ',') && (c != ' ')) { - ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"value"}, {dynamic_batch_size}); + ErrorManager::GetInstance().ATCReportErrMessage("E10047", {"value"}, {dynamic_batch_size}); GELOGE(ge::PARAM_INVALID, "Input parameter[--dynamic_batch_size]'s value[%s] is invalid.", dynamic_batch_size.c_str()); return false; @@ -190,7 +169,7 @@ Status CheckDynamicBatchSizeOrImageSizeParamValid(std::string &dynamic_batch_siz vector>> user_shape_map; is_dynamic_input = true; if (input_shape.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"input_shape"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"input_shape"}); GELOGE(ge::PARAM_INVALID, "The input_shape can not be empty in dynamic batchsize scenario."); return ge::PARAM_INVALID; } @@ -221,19 +200,21 @@ bool ParseInputShape(const string &input_shape, unordered_map shape_vec = StringUtils::Split(input_shape, ';'); const int DEFAULT_SHAPE_PAIR_SIZE = 2; for (const auto &shape : shape_vec) { - vector shape_pair_vec = SplitInputShape(shape); + vector shape_pair_vec = StringUtils::Split(shape, ':'); if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) { - ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, - {shape, kSplitError1, kInputShapeSample1}); - GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", - shape.c_str(), kSplitError1, kInputShapeSample1); + ErrorManager::GetInstance().ATCReportErrMessage("E10010", {"shape"}, {shape}); + GELOGW( + "Input parameter[--input_shape]’s shape is [%s], " + "correct sample is input_name1:n1,c1,h1,w1", + shape.c_str()); return false; } if (shape_pair_vec[1].empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, - {shape, kEmptyError, kInputShapeSample1}); - GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", - shape.c_str(), kEmptyError, kInputShapeSample1); + ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape"}, {shape}); + GELOGW( + "Input parameter[--input_shape]’s shape is [%s], can not empty, " + "correct sample is input_name1:n1,c1,h1,w1", + shape.c_str()); return false; } @@ -242,48 +223,34 @@ bool ParseInputShape(const string &input_shape, unordered_map caffe_support_input_format = {"NCHW", "ND"}; static std::set tf_support_input_format = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"}; static std::set onnx_support_input_format = {"NCHW", "ND"}; -static const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model"; -static const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model"; -static const char *const kONNXFormatSupport = "only support NCHW, ND in ONNX model"; static std::map input_format_str_to_geformat = { {"ND", domi::DOMI_TENSOR_ND}, {"NCHW", domi::DOMI_TENSOR_NCHW}, {"NHWC", domi::DOMI_TENSOR_NHWC}, diff --git a/src/ge/ir_build/ge_ir_build.cc b/src/ge/ir_build/ge_ir_build.cc index a64591da..0be75b51 100644 --- a/src/ge/ir_build/ge_ir_build.cc +++ b/src/ge/ir_build/ge_ir_build.cc @@ -296,6 +296,7 @@ graphStatus Impl::BuildModel(const Graph &graph, const std::map(model.data.get()), static_cast(model.length)); } + } // namespace ge diff --git a/src/ge/model/ge_model.h b/src/ge/model/ge_model.h index be4b65bc..6305211a 100644 --- a/src/ge/model/ge_model.h +++ b/src/ge/model/ge_model.h @@ -87,6 +87,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder uint8_t platform_type_ = {0}; uint32_t model_id_ = INVALID_MODEL_ID; }; -} // namespace ge +}; // namespace ge using GeModelPtr = std::shared_ptr; #endif // GE_MODEL_GE_MODEL_H_ diff --git a/src/ge/offline/main.cc b/src/ge/offline/main.cc index fad4134c..61a843c3 100644 --- a/src/ge/offline/main.cc +++ b/src/ge/offline/main.cc @@ -66,10 +66,6 @@ static bool is_dynamic_input = false; // 310 limited 8G size const char *const kGraphMemoryManagerMallocMaxSize = "8*1024*1024*1024"; -const char *const kModeSupport = - "only support 0(model to framework model), " - "1(framework model to json), 3(only pre-check), 5(pbtxt to json)"; -const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow)"; DEFINE_string(model, "", "The model file."); DEFINE_string(output, "", "The output file path&name."); @@ -142,6 +138,10 @@ DEFINE_string(optypelist_for_implmode, "", "Optional; Nodes need use implmode selected in op_select_implmode " "Format:\"node_name1,node_name2\""); +DEFINE_string(head_stream, "0", + "Optional; Is need head stream, default is not need." + "Format: \"0: no head stream; 1: add head stream;\""); + DEFINE_string(singleop, "", "Optional; If set, generate single op model with the given json file."); DEFINE_int32(disable_reuse_memory, 0, "Optional; If set to 1, disable reuse memory when generating if."); @@ -173,8 +173,7 @@ DEFINE_string(dynamic_image_size, "", DEFINE_string(enable_small_channel, "0", "Optional; If set to 1, small channel is enabled."); -DEFINE_string(enable_compress_weight, "false", - "Optional; enable compress weight. true: enable; false(default): disable"); +DEFINE_bool(enable_compress_weight, false, "Optional; enable compress weight. true: enable; false(default): disable"); DEFINE_string(compress_weight_conf, "", "Optional; the config file to compress weight"); @@ -184,10 +183,6 @@ DEFINE_string(log, "default", "Optional; generate atc log. Support debug, info, DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0."); -DEFINE_int32(op_debug_level, 0, - "Optional; configure debug level of compiler. 0(default): close debug;" - "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler"); - class GFlagUtils { public: /** @@ -207,10 +202,7 @@ class GFlagUtils { "arguments explain:\n" " --model Model file\n" " --singleop Single op definition file. atc will generate offline " - "model(s) for single op if --singleop is set. \n" - " Note: Only output, soc_verion, core_type, aicore_num, auto_tune_mode, precision_mode, " - "op_select_implmode, enable_small_channel, enable_compress_weight, compress_weight_conf " - "enable_single_stream and log are valid in this mode \n" + "model(s) for single op if --singleop is set.\n" " --weight Weight file. Required when framework is Caffe\n" " --framework Framework type(0:Caffe; 1:MindSpore; 3:Tensorflow)\n" " --output Output file path&name(needn't suffix, will add " @@ -240,7 +232,7 @@ class GFlagUtils { "\"check_result.json\"\n" " --disable_reuse_memory The switch of reuse memory. Default value is : 0." "0 means reuse memory, 1 means do not reuse memory.\n" - " --input_fp16_nodes Input node datatype is fp16. Separate multiple nodes with semicolons " + " --input_fp16_nodes Input node datatype is fp16 and format is NCHW. Separate multiple nodes with semicolons " "(;)." "Use double quotation marks (\") to enclose each argument." "E.g.: \"node_name1;node_name2\"\n" @@ -260,6 +252,7 @@ class GFlagUtils { " --optypelist_for_implmode Appoint which op to use op_select_implmode, used with op_select_implmode ." "Separate multiple nodes with commas (,). Use double quotation marks (\") to enclose each argument." "E.g.: \"node_name1,node_name2\"\n" + " --head_stream Add head stream. 0(default): disable; 1: enable\n" " --soc_version The soc version. E.g.: \"Ascend310\"\n" " --core_type Set core type AiCore or VectorCore. VectorCore: use vector core. " "Default value is: AiCore\n" @@ -287,7 +280,7 @@ class GFlagUtils { static Status CheckDumpInfershapeJsonFlags() { Status ret = CheckFrameWorkValid(FLAGS_framework, FLAGS_weight); GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "check custom aicpu run so failed!"); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"), + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "weight"), return domi::FAILED, "Input parameter[--weight]'s value[%s] is invalid!", FLAGS_weight.c_str()); return domi::SUCCESS; @@ -296,7 +289,7 @@ class GFlagUtils { static Status CheckFlags() { // No model file information passed in GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_model == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"model"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"model"}); return domi::PARAM_INVALID, "Input parameter[--model]'s value is empty!"); // check param disable_reuse_memory GE_CHK_BOOL_EXEC(ge::CheckDisableReuseMemoryParamValid(to_string(FLAGS_disable_reuse_memory)) == ge::SUCCESS, @@ -308,7 +301,7 @@ class GFlagUtils { return ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!"); // No output file information passed in GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_mode == GEN_OM_MODEL && FLAGS_output == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"output"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); return domi::PARAM_INVALID, "Input parameter[--output]'s value is empty!"); Status ret = CheckFrameWorkValid(FLAGS_framework, FLAGS_weight); @@ -327,16 +320,16 @@ class GFlagUtils { GELOGI("domi will run with encrypt!"); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_encrypt_key), return domi::FAILED, - "encrypt_key file not found!!"); + "encrypt_key file %s not found!!", FLAGS_encrypt_key.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_certificate), return domi::FAILED, - "certificate file not found!!"); + "certificate file %s not found!!", FLAGS_certificate.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_hardware_key), return domi::FAILED, - "hardware_key file not found!!"); + "hardware_key file %s not found!!", FLAGS_hardware_key.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_private_key), return domi::FAILED, - "private_key file not found!!"); + "private_key file %s not found!!", FLAGS_private_key.c_str()); } else { // No encryption GELOGI("domi will run without encrypt!"); } @@ -345,37 +338,43 @@ class GFlagUtils { /** * Check the validity of the I / O file path */ - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_model, "--model"), return domi::FAILED, + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_model, "model"), return domi::FAILED, "model file %s not found!!", FLAGS_model.c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"), + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "weight"), return domi::FAILED, "weight file %s not found!!", FLAGS_weight.c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_cal_conf != "" && !ge::CheckInputPathValid(FLAGS_cal_conf, "--cal_conf"), + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_cal_conf != "" && !ge::CheckInputPathValid(FLAGS_cal_conf, "cal_conf"), return domi::FAILED, "calibration config file %s not found!!", FLAGS_cal_conf.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_op_name_map != "" && !ge::CheckInputPathValid(FLAGS_op_name_map, "--op_name_map"), return domi::FAILED, + FLAGS_op_name_map != "" && !ge::CheckInputPathValid(FLAGS_op_name_map, "op_name_map"), return domi::FAILED, "op config file %s not found!!", FLAGS_op_name_map.c_str()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_head_stream != "" && FLAGS_head_stream != "0" && FLAGS_head_stream != "1", + ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"head_stream"}); + return domi::FAILED, "Input parameter[--head_stream] must be 0 or 1!!"); + GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(FLAGS_insert_op_conf)) == ge::SUCCESS, return ge::FAILED, "check insert op conf failed!"); GE_CHK_BOOL_EXEC( - ge::CheckCompressWeightParamValid(FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS, + ge::CheckCompressWeightParamValid(FLAGS_enable_compress_weight ? std::string("true") : std::string("false"), + FLAGS_compress_weight_conf) == ge::SUCCESS, return ge::FAILED, "check compress weight failed!"); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), return domi::FAILED, + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckOutputPathValid(FLAGS_check_report, "check_report"), return domi::FAILED, "check_report file %s not found!!", FLAGS_check_report.c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_mode == GEN_OM_MODEL && (!ge::CheckOutputPathValid(FLAGS_output, "--output") || - !CheckPathWithName(FLAGS_output)), - return domi::FAILED, "output path %s is not valid!!", FLAGS_output.c_str()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + FLAGS_mode == GEN_OM_MODEL && (!ge::CheckOutputPathValid(FLAGS_output) || !CheckPathWithName(FLAGS_output)), + return domi::FAILED, "output path %s is not valid!!", FLAGS_output.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( FLAGS_save_original_model != "" && FLAGS_save_original_model != "true" && FLAGS_save_original_model != "false", - ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"parameter", "value"}, {"save_original_model", FLAGS_save_original_model}); return domi::FAILED, "Input parameter[--save_original_model]'s value[%s] must be true or false.", FLAGS_save_original_model.c_str()); @@ -396,18 +395,18 @@ class GFlagUtils { static Status CheckConverJsonParamFlags() { // No model path passed in GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"om"}); return domi::PARAM_INVALID, "Input parameter[--om]'s value is empty!!"); // JSON path not passed in GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_json == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"json"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"json"}); return domi::PARAM_INVALID, "Input parameter[--json]'s value is empty!!"); // Check if the model path is valid - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_om, "--om"), return domi::PARAM_INVALID, + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_om, "om"), return domi::PARAM_INVALID, "model file path is invalid: %s.", FLAGS_om.c_str()); // Check whether the JSON path is valid - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckOutputPathValid(FLAGS_json, "--json"), return domi::PARAM_INVALID, + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckOutputPathValid(FLAGS_json, "om"), return domi::PARAM_INVALID, "json file path is invalid: %s.", FLAGS_json.c_str()); return domi::SUCCESS; @@ -444,8 +443,7 @@ class GFlagUtils { if (framework != (int32_t)domi::CAFFE && framework != (int32_t)domi::TENSORFLOW && framework != (int32_t)domi::MINDSPORE && framework != (int32_t)domi::ONNX) { // No framework information was passed in or the entered framework is illegal - ErrorManager::GetInstance().ATCReportErrMessage("E10007", {"parameter", "support"}, - {"framework", "0(Caffe) or 1(MindSpore) or 3(TensorFlow)"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10007", {"parameter"}, {"framework"}); DOMI_LOGE( "Input parameter[--framework] is mandatory and it's value must be: " "0(Caffe) or 1(MindSpore) or 3(TensorFlow)."); @@ -518,29 +516,31 @@ static bool CheckInputFormat() { if (ge::caffe_support_input_format.find(FLAGS_input_format) != ge::caffe_support_input_format.end()) { return true; } + ErrorManager::GetInstance().ATCReportErrMessage("E10031", {"value"}, {FLAGS_input_format}); // only support NCHW ND - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--input_format", FLAGS_input_format, ge::kCaffeFormatSupport}); - GELOGE(ge::FAILED, "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), - ge::kCaffeFormatSupport); + GELOGE(ge::FAILED, + "Input parameter[--input_format]'s value[%s] is wrong, " + "only support NCHW, ND in Caffe model.", + FLAGS_input_format.c_str()); return false; } else if ((FLAGS_framework == static_cast(domi::TENSORFLOW))) { // tf if (ge::tf_support_input_format.find(FLAGS_input_format) != ge::tf_support_input_format.end()) { return true; } + ErrorManager::GetInstance().ATCReportErrMessage("E10032", {"value"}, {FLAGS_input_format}); // only support NCHW NHWC ND NCDHW NDHWC - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--input_format", FLAGS_input_format, ge::kTFFormatSupport}); - GELOGE(ge::FAILED, "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kTFFormatSupport); + GELOGE(ge::FAILED, + "Input parameter[--input_format]'s value[%s] is wrong, " + "only support NCHW, NHWC, ND, NCDHW, NDHWC in tf model", + FLAGS_input_format.c_str()); return false; } else if (FLAGS_framework == static_cast(domi::ONNX)) { if (ge::onnx_support_input_format.find(FLAGS_input_format) != ge::onnx_support_input_format.end()) { return true; } // only support NCHW ND - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--input_format", FLAGS_input_format, ge::kONNXFormatSupport}); - GELOGE(ge::FAILED, "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kONNXFormatSupport); + GELOGE(ge::FAILED, "Input parameter[--input_format]'s value[%s] is error, Only support NCHW, ND in onnx model", + FLAGS_input_format.c_str()); return false; } return true; @@ -622,7 +622,8 @@ void LoadModelParserLib(std::string caffe_parser_path) { return; } -void LoadCustomOpLib(bool need_load_ops_plugin) { +void LoadCustomOpLib() { + OpRegistry::Instance()->registrationDatas.clear(); std::string plugin_path; GetCustomOpPath(plugin_path); @@ -638,11 +639,7 @@ void LoadCustomOpLib(bool need_load_ops_plugin) { } LoadModelParserLib(caffe_parser_path); - if (!need_load_ops_plugin) { - GELOGI("No need to load ops plugin so."); - return; - } - OpRegistry::Instance()->registrationDatas.clear(); + // load other so files except lib_caffe_parser.so in the plugin so path for (auto elem : fileList) { ge::StringUtils::Trim(elem); @@ -657,21 +654,17 @@ void LoadCustomOpLib(bool need_load_ops_plugin) { std::vector registrationDatas = OpRegistry::Instance()->registrationDatas; for (OpRegistrationData reg_data : registrationDatas) { - (void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data); - (void)OpRegistry::Instance()->Register(reg_data); + bool ret = ge::OpRegistrationTbe::Instance()->Finalize(reg_data); + if (ret) { + OpRegistry::Instance()->Register(reg_data); + } } } void SaveCustomCaffeProtoPath() { GELOGI("Enter save custom caffe proto path."); - - std::string path_base = ge::GELib::GetPath(); - GELOGI("path_base is %s", path_base.c_str()); - path_base = path_base.substr(0, path_base.rfind('/')); - path_base = path_base.substr(0, path_base.rfind('/') + 1); - ge::GetParserContext().caffe_proto_path = path_base + "include/proto/"; - string customop_path; + const char *path_env = std::getenv("ASCEND_OPP_PATH"); if (path_env != nullptr) { std::string path = path_env; @@ -680,6 +673,10 @@ void SaveCustomCaffeProtoPath() { ge::GetParserContext().custom_proto_path = customop_path; return; } + std::string path_base = ge::GELib::GetPath(); + GELOGI("path_base is %s", path_base.c_str()); + path_base = path_base.substr(0, path_base.rfind('/')); + path_base = path_base.substr(0, path_base.rfind('/') + 1); customop_path = path_base + "ops/framework/custom/caffe/"; ge::GetParserContext().custom_proto_path = customop_path; return; @@ -723,6 +720,15 @@ Status CreateInputsForInference(const ge::Graph &graph, vector &in return ge::SUCCESS; } +void ChangeStringToBool(std::string &arg_s, bool arg_b) { + if (arg_s == "true") { + arg_b = true; + } else { + arg_b = false; + } + return; +} + domi::Status GenerateInfershapeJson() { if (!CheckInputFormat()) { GELOGE(ge::FAILED, "Check input_format failed"); @@ -731,6 +737,8 @@ domi::Status GenerateInfershapeJson() { Status ret = GFlagUtils::CheckDumpInfershapeJsonFlags(); GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check flags failed!"); + // Load custom operator Library + LoadCustomOpLib(); ge::GeGenerator ge_generator; std::map options; ge::Status geRet = ge_generator.Initialize(options); @@ -772,25 +780,24 @@ static Status ConvertModelToJson(int fwk_type, const string &model_file, const s return ret; } - if ((fwk_type != domi::TENSORFLOW) && (fwk_type != domi::CAFFE) && (fwk_type != domi::ONNX)) { - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--framework", std::to_string(fwk_type), kModelToJsonSupport}); - GELOGE(ge::FAILED, "Invalid value for --framework[%d], %s.", fwk_type, kModelToJsonSupport); + if ((fwk_type != domi::TENSORFLOW) && (fwk_type != domi::CAFFE)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10068", {"param", "value", "supports"}, + {"framework", std::to_string(fwk_type), "only support 0(Caffe) 3(TensorFlow)"}); + GELOGE(ge::FAILED, "Input parameter[--framework] is mandatory and it's value must be: 0(Caffe) 3(TensorFlow)."); return ge::FAILED; } + // Since the Caffe model's conversion to JSON file depends on lib_caffe_parser.so, loadcustomoplib is called here. + LoadCustomOpLib(); + if (FLAGS_dump_mode == "0") { - // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_tensorflow_parser.so. - LoadCustomOpLib(false); ret = ge::ConvertFwkModelToJson((domi::FrameworkType)fwk_type, model_file.c_str(), json_file.c_str()); return ret; } else if (FLAGS_dump_mode == "1") { - // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_tensorflow_parser.so and ops plugin so. - LoadCustomOpLib(true); ret = GenerateInfershapeJson(); return ret; } else { - ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"dump_mode"}); GELOGE(ge::FAILED, "Input parameter[--dump_mode]'s value must be 1 or 0."); return ge::FAILED; } @@ -821,7 +828,7 @@ domi::Status GenerateModel(std::map &options, std::string output ge::Model load_model = ge::Model("loadmodel", "version2"); auto ret1 = load_model.LoadFromFile(FLAGS_model); if (ret1 != ge::GRAPH_SUCCESS) { - ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"parameter"}, {FLAGS_model}); + ErrorManager::GetInstance().ATCReportErrMessage("E10056", {"parameter"}, {FLAGS_model}); DOMI_LOGE( "Load model from %s failed, please check model file or " "input parameter[--framework] is correct", @@ -924,11 +931,10 @@ static void SetEnvForSingleOp(std::map &options) { options.emplace(ge::OPTYPELIST_FOR_IMPLMODE, FLAGS_optypelist_for_implmode); options.emplace(ge::AUTO_TUNE_MODE, FLAGS_auto_tune_mode); options.emplace(ge::GRAPH_MEMORY_MAX_SIZE, kGraphMemoryManagerMallocMaxSize); - options.emplace(ge::OP_DEBUG_LEVEL, to_string(FLAGS_op_debug_level)); } domi::Status GenerateSingleOp(const std::string &json_file_path) { - if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output, "--output")) { + if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output)) { DOMI_LOGE("output path %s is not valid!", FLAGS_output.c_str()); return domi::FAILED; } @@ -994,7 +1000,7 @@ domi::Status GenerateOmModel() { "quotation marks (\") to enclose each argument such as out_nodes, input_shape, dynamic_image_size"); #if !defined(__ANDROID__) && !defined(ANDROID) // Load custom operator Library - LoadCustomOpLib(true); + LoadCustomOpLib(); SaveCustomCaffeProtoPath(); @@ -1032,6 +1038,8 @@ domi::Status GenerateOmModel() { options.insert(std::pair(ge::INPUT_FP16_NODES, FLAGS_input_fp16_nodes)); } + options.insert(std::pair(string(ge::HEAD_STREAM), FLAGS_head_stream)); + options.insert(std::pair(string(ge::AUTO_TUNE_MODE), FLAGS_auto_tune_mode)); options.insert( @@ -1049,7 +1057,7 @@ domi::Status GenerateOmModel() { options.insert(std::pair(string(ge::FUSION_SWITCH_FILE), FLAGS_fusion_switch_file)); - options.insert(std::pair(string(ge::ENABLE_COMPRESS_WEIGHT), (FLAGS_enable_compress_weight == "true") + options.insert(std::pair(string(ge::ENABLE_COMPRESS_WEIGHT), FLAGS_enable_compress_weight ? ge::kEnableCompressWeightTrue : ge::kEnableCompressWeightFalse)); @@ -1064,8 +1072,6 @@ domi::Status GenerateOmModel() { options.insert(std::pair(string(ge::ORIGINAL_MODEL_FILE), FLAGS_output + "_original.om")); } - options.insert(std::pair(string(ge::OP_DEBUG_LEVEL), to_string(FLAGS_op_debug_level))); - // print atc option map ge::PrintOptionMap(options, "atc option"); @@ -1089,8 +1095,8 @@ domi::Status ConvertModelToJson() { return domi::SUCCESS; } -bool CheckRet(domi::Status ret) { - if (ret != domi::SUCCESS) { +bool CheckRet(domi::Status ret, ge::Status geRet) { + if (ret != domi::SUCCESS || geRet != ge::SUCCESS) { if (FLAGS_mode == ONLY_PRE_CHECK) { GELOGW("ATC precheck failed."); } else if (FLAGS_mode == GEN_OM_MODEL) { @@ -1139,7 +1145,7 @@ int init(int argc, char *argv[]) { int ret = -1; const std::set log_level = {"default", "null", "debug", "info", "warning", "error"}; if (log_level.count(FLAGS_log) == 0) { - std::cout << "E10010: invalid value for --log:" << FLAGS_log << ", only support debug, info, warning, error, null" + std::cout << "E10016: invalid value for --log:" << FLAGS_log << ", only support debug, info, warning, error, null" << std::endl; return ret; } @@ -1149,18 +1155,12 @@ int init(int argc, char *argv[]) { return ret; } - std::string path_base = ge::GELib::GetPath(); - ret = ErrorManager::GetInstance().Init(path_base); - if (ret != 0) { - DOMI_LOGE("ErrorManager init fail !"); - return ret; - } - return 0; } int main(int argc, char *argv[]) { Status ret = domi::SUCCESS; + ge::Status geRet = ge::SUCCESS; std::cout << "ATC start working now, please wait for a moment." << std::endl; try { // Initialize @@ -1185,9 +1185,12 @@ int main(int argc, char *argv[]) { GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED; break, "ATC convert pbtxt to json execute failed!!"); } else { - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--mode", std::to_string(FLAGS_mode), kModeSupport}); - GELOGE(ge::PARAM_INVALID, "Invalid value for --mode[%d], %s.", FLAGS_mode, kModeSupport); + ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"value"}, {std::to_string(FLAGS_mode)}); + DOMI_LOGE( + "Invalid value for --mode[%d], only support " + "0(model to framework model), 1(framework model to json), 3(only pre-check), " + "5(pbtxt to json)!", + FLAGS_mode); ret = domi::FAILED; break; } @@ -1202,12 +1205,8 @@ int main(int argc, char *argv[]) { std::cout << "ATC run failed, some exceptions occur !" << std::endl; } - if (!CheckRet(ret)) { + if (!CheckRet(ret, geRet)) { std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl; - int result = ErrorManager::GetInstance().OutputErrMessage(STDOUT_FILENO); - if (result != 0) { - DOMI_LOGE("ErrorManager outputErrMessage fail !"); - } return ret; } else { std::cout << "ATC run success, welcome to the next use." << std::endl; diff --git a/src/ge/offline/single_op_parser.cc b/src/ge/offline/single_op_parser.cc index b8947a65..4d589565 100644 --- a/src/ge/offline/single_op_parser.cc +++ b/src/ge/offline/single_op_parser.cc @@ -200,13 +200,13 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { for (auto &tensor_desc : op_desc.input_desc) { if (tensor_desc.type == DT_UNDEFINED) { ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(false, "Input's dataType is invalid when the index is %d", index); + GELOGE(false, "Input index[%d]'s dataType is invalid", index); return false; } if (tensor_desc.format == FORMAT_RESERVED) { ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Input's format is invalid when the index is %d", index); + GELOGE(PARAM_INVALID, "Input index[%d]'s format is invalid", index); return false; } ++index; @@ -216,13 +216,13 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { for (auto &tensor_desc : op_desc.output_desc) { if (tensor_desc.type == DT_UNDEFINED) { ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"output", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index); + GELOGE(PARAM_INVALID, "Output[%d] dataType is invalid", index); return false; } if (tensor_desc.format == FORMAT_RESERVED) { ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"output", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Output's format is invalid when the index is %d", index); + GELOGE(PARAM_INVALID, "Output[%d] format is invalid", index); return false; } ++index; @@ -316,15 +316,17 @@ Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector &options_const) { GetExternalEnginePath(extern_engine_path); GELOGI("OPTION_EXEC_EXTERN_PLUGIN_PATH=%s.", extern_engine_path.c_str()); - op_tiling_manager_.LoadSo(); - ret = plugin_manager_.LoadSo(extern_engine_path, func_check_list); if (ret == SUCCESS) { initialize_ = options; diff --git a/src/ge/opskernel_manager/ops_kernel_manager.h b/src/ge/opskernel_manager/ops_kernel_manager.h index 1d464201..8d98ad3f 100644 --- a/src/ge/opskernel_manager/ops_kernel_manager.h +++ b/src/ge/opskernel_manager/ops_kernel_manager.h @@ -24,7 +24,6 @@ #include "common/debug/log.h" #include "common/ge/plugin_manager.h" -#include "common/ge/op_tiling_manager.h" #include "common/ge_inner_error_codes.h" #include "common/opskernel/ops_kernel_info_store.h" #include "common/optimizer/graph_optimizer.h" @@ -106,7 +105,6 @@ class OpsKernelManager { Status InitGraphOptimizerPriority(); PluginManager plugin_manager_; - OpTilingManager op_tiling_manager_; // opsKernelInfoStore map ops_kernel_store_{}; // graph_optimizer diff --git a/src/ge/session/inner_session.cc b/src/ge/session/inner_session.cc index b97862e1..74495e82 100644 --- a/src/ge/session/inner_session.cc +++ b/src/ge/session/inner_session.cc @@ -29,34 +29,6 @@ #include "runtime/mem.h" namespace ge { -namespace { -Status CheckReuseMemoryOption(const std::map &options) { - const int kDecimal = 10; - auto dump_op_env = std::getenv("DUMP_OP"); - int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0; - auto iter = options.find(OPTION_EXEC_DISABLE_REUSED_MEMORY); - if (iter != options.end()) { - if (iter->second == "0") { - GELOGD("%s=0, reuse memory is open", OPTION_EXEC_DISABLE_REUSED_MEMORY); - if (dump_op_flag) { - GELOGW("Will dump incorrect op data with ge option %s=0", OPTION_EXEC_DISABLE_REUSED_MEMORY); - } - } else if (iter->second == "1") { - GELOGD("%s=1, reuse memory is close", OPTION_EXEC_DISABLE_REUSED_MEMORY); - } else { - GELOGE(PARAM_INVALID, "option %s=%s is invalid", OPTION_EXEC_DISABLE_REUSED_MEMORY, iter->second.c_str()); - return FAILED; - } - } else { - if (dump_op_flag) { - GELOGW("Will dump incorrect op data with default reuse memory"); - } - } - - return SUCCESS; -} -} // namespace - static std::mutex mutex_; // BuildGraph and RunGraph use InnerSession::InnerSession(uint64_t session_id, const std::map &options) @@ -67,36 +39,13 @@ Status InnerSession::Initialize() { GELOGW("[InnerSession:%lu] session already initialize.", session_id_); return SUCCESS; } - - // If the global options and the session options are duplicated, the session options is preferred. - auto all_options = options_; - all_options.insert(GetMutableGlobalOptions().begin(), GetMutableGlobalOptions().end()); - - Status ret = CheckReuseMemoryOption(all_options); - if (ret != SUCCESS) { - GELOGE(ret, "[InnerSession:%lu] check reuse memory option failed.", session_id_); - return ret; - } - UpdateThreadContext(std::map{}); GE_CHK_RT_RET(rtSetDevice(GetContext().DeviceId())); - PropertiesManager::Instance().GetDumpProperties(session_id_).InitByOptions(); - - ret = graph_manager_.Initialize(options_); + Status ret = graph_manager_.Initialize(options_); if (ret != SUCCESS) { GELOGE(ret, "[InnerSession:%lu] initialize failed.", session_id_); - PropertiesManager::Instance().RemoveDumpProperties(session_id_); - return ret; - } - - ret = VarManager::Instance(session_id_)->SetMemoryMallocSize(all_options); - if (ret != SUCCESS) { - GELOGE(ret, "failed to set malloc size"); - (void)graph_manager_.Finalize(); - PropertiesManager::Instance().RemoveDumpProperties(session_id_); - GE_CHK_RT(rtDeviceReset(static_cast(GetContext().DeviceId()))); return ret; } @@ -106,7 +55,6 @@ Status InnerSession::Initialize() { ret = VarManager::Instance(session_id_)->Init(version, session_id_, DEFAULT_DEVICE_ID, DEFAULT_JOB_ID); if (ret != SUCCESS) { GELOGE(ret, "failed to init session instance"); - PropertiesManager::Instance().RemoveDumpProperties(session_id_); } init_flag_ = true; return SUCCESS; @@ -130,9 +78,6 @@ Status InnerSession::Finalize() { // release var memory GELOGI("VarManager free var memory."); (void)VarManager::Instance(session_id_)->FreeVarMemory(); - - PropertiesManager::Instance().RemoveDumpProperties(session_id_); - GE_CHK_RT(rtDeviceReset(static_cast(GetContext().DeviceId()))); return ret; @@ -278,7 +223,6 @@ void InnerSession::UpdateThreadContext(const std::map GetThreadLocalContext().SetGlobalOption(GetMutableGlobalOptions()); GetThreadLocalContext().SetSessionOption(options_); GetThreadLocalContext().SetGraphOption(options); - GetContext().SetSessionId(session_id_); } void InnerSession::UpdateThreadContext(uint32_t graph_id) { diff --git a/src/ge/session/omg.cc b/src/ge/session/omg.cc index 26103063..71dd631e 100644 --- a/src/ge/session/omg.cc +++ b/src/ge/session/omg.cc @@ -65,9 +65,6 @@ namespace ge { namespace { const std::string kGraphDefaultName = "domi_default"; const std::string kScopeIdAttr = "fusion_scope"; -const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; -const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; -const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; } // namespace // When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored @@ -81,7 +78,7 @@ static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_p if ((s == "true") || (s == "false")) { return true; } else { - ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, {atc_param, s}); + ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"parameter", "value"}, {atc_param, s}); GELOGE(PARAM_INVALID, "Input parameter[--%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str()); return false; } @@ -100,12 +97,12 @@ static Status CheckInputShapeNode(const ComputeGraphPtr &graph) { std::string node_name = it.first; ge::NodePtr node = graph->FindNode(node_name); if (node == nullptr) { - ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"input_shape", node_name}); + ErrorManager::GetInstance().ATCReportErrMessage("E10034", {"parameter", "opname"}, {"input_shape", node_name}); GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not exist in model", node_name.c_str()); return PARAM_INVALID; } if (node->GetType() != DATA) { - ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, {"input_shape", node_name}); + ErrorManager::GetInstance().ATCReportErrMessage("E10035", {"parameter", "opname"}, {"input_shape", node_name}); GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not a input opname", node_name.c_str()); return PARAM_INVALID; } @@ -136,19 +133,18 @@ static Status CheckInputFp16Nodes(const ComputeGraphPtr &graph, const string &in for (uint32_t i = 0; i < input_fp16_nodes_vec.size(); ++i) { ge::NodePtr node = graph->FindNode(input_fp16_nodes_vec[i]); if (node == nullptr) { - ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10034", {"parameter", "opname"}, {"input_fp16_nodes", input_fp16_nodes_vec[i]}); - GELOGE(PARAM_INVALID, "Input parameter[--input_fp16_nodes]'s opname[%s] is not exist in model", + GELOGE(PARAM_INVALID, "Can not find node [%s] in graph, please check input_fp16_nodes param", input_fp16_nodes_vec[i].c_str()); return PARAM_INVALID; } auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); if (op_desc->GetType() != DATA) { - ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10035", {"parameter", "opname"}, {"input_fp16_nodes", input_fp16_nodes_vec[i]}); - GELOGE(PARAM_INVALID, "Input parameter[--input_fp16_nodes]'s opname[%s] is not a input opname", - input_fp16_nodes_vec[i].c_str()); + GELOGE(PARAM_INVALID, "input_fp16_nodes: %s is not a input node name", input_fp16_nodes_vec[i].c_str()); return PARAM_INVALID; } if (ge::AttrUtils::SetBool(op_desc, "input_fp16", true)) { @@ -306,32 +302,14 @@ Status SetOutFormatAndDataTypeAttr(ge::OpDescPtr op_desc, const ge::Format forma return domi::SUCCESS; } -bool CheckDigitStr(std::string &str) { - for (char c : str) { - if (!isdigit(c)) { - GELOGE(domi::FAILED, "value[%s] is not positive integer", str.c_str()); - return false; - } - } - return true; -} - Status StringToInt(std::string &str, int32_t &value) { try { - if (!CheckDigitStr(str)) { - GELOGE(PARAM_INVALID, "Invalid of digit string: %s ", str.c_str()); - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--output_type", str, "is not positive integer"}); - return PARAM_INVALID; - } value = stoi(str); } catch (std::invalid_argument &) { - GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch invalid_argument.", str.c_str()); - ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"output_type", str}); + GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", str.c_str()); return PARAM_INVALID; } catch (std::out_of_range &) { - GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch out_of_range.", str.c_str()); - ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"output_type", str}); + GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", str.c_str()); return PARAM_INVALID; } return SUCCESS; @@ -347,9 +325,8 @@ Status VerifyOutputTypeAndOutNodes(std::vector &out_type_vec) { } for (uint32_t i = 0; i < out_type_vec.size(); ++i) { if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--output_type", out_type_vec[i], kOutputTypeError}); - GELOGE(domi::FAILED, "Invalid value for --output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError); + ErrorManager::GetInstance().ATCReportErrMessage("E10059", {"value"}, {out_type_vec[i]}); + GELOGE(domi::FAILED, "Can not find this node (%s) in out_nodes.", out_type_vec[i].c_str()); return domi::FAILED; } } @@ -362,9 +339,9 @@ Status ParseOutputType(const std::string &output_type, std::map node_index_type_v = StringUtils::Split(node, ':'); if (node_index_type_v.size() != 3) { // The size must be 3. - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--output_type", node, kOutputTypeSample}); - GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", node.c_str(), kOutputTypeSample); + ErrorManager::GetInstance().ATCReportErrMessage("E10058", {"value"}, {node}); + GELOGE(PARAM_INVALID, + "The param of output_type is invalid, the correct format is [opname:index:dtype]," + "while the actual input is %s.", + node.c_str()); return domi::FAILED; } ge::DataType tmp_dt; @@ -384,15 +363,13 @@ Status ParseOutputType(const std::string &output_type, std::mapsecond; @@ -419,22 +396,6 @@ Status ParseOutputType(const std::string &output_type, std::mapGetOutputsSize(); - if (index < 0 || index >= out_size) { - GELOGE(domi::FAILED, - "out_node [%s] output index:%d must be smaller " - "than node output size:%d and can not be negative!", - op_desc->GetName().c_str(), index, out_size); - std::string fail_reason = "output index:" + to_string(index) + - " must be smaller than output size:" + to_string(out_size) + " and can not be negative!"; - ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, - {"out_nodes", op_desc->GetName(), fail_reason}); - return domi::FAILED; - } - return domi::SUCCESS; -} - Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output) { ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); GE_CHECK_NOTNULL(compute_graph); @@ -443,6 +404,7 @@ Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::vector output_formats = domi::GetContext().output_formats; std::vector> output_nodes_info; std::vector output_nodes_name; + std::map> out_type_index_map; std::map> out_type_dt_map; if (!output_type.empty()) { @@ -461,10 +423,6 @@ Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const } auto op_desc = out_node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - if (CheckOutNode(op_desc, user_out_nodes[i].second) != SUCCESS) { - GELOGE(domi::FAILED, "Check out node (%s) fail.", user_out_nodes[i].first.c_str()); - return domi::FAILED; - } if (i < output_formats.size()) { if (output_formats[i] == domi::DOMI_TENSOR_NC1HWC0) { GELOGI("The output node [%s] should be set NC1HWC0", user_out_nodes[i].first.c_str()); @@ -603,9 +561,8 @@ Status ParseOutNodes(const string &out_nodes) { for (const string &node : nodes_v) { vector key_value_v = StringUtils::Split(node, ':'); if (key_value_v.size() != 2) { // The size must be 2. - ErrorManager::GetInstance().ATCReportErrMessage( - "E10001", {"parameter", "value", "reason"}, - {"--out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""}); + ErrorManager::GetInstance().ATCReportErrMessage("E10069", {"param", "value", "supports"}, + {"out_nodes", node, "opname:index"}); GELOGE(PARAM_INVALID, "The input format of --out_nodes is invalid, the correct format is " "\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.", @@ -614,12 +571,6 @@ Status ParseOutNodes(const string &out_nodes) { } auto iter = domi::GetContext().out_nodes_map.find(key_value_v[0]); // stoi: The method may throw an exception: invalid_argument/out_of_range - if (!CheckDigitStr(key_value_v[1])) { - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, - {"--out_nodes", out_nodes, "is not positive integer"}); - GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s", out_nodes.c_str()); - return PARAM_INVALID; - } int32_t index = stoi(StringUtils::Trim(key_value_v[1])); if (iter != domi::GetContext().out_nodes_map.end()) { iter->second.emplace_back(index); @@ -633,11 +584,9 @@ Status ParseOutNodes(const string &out_nodes) { } } catch (std::invalid_argument &) { GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); - ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"out_nodes", out_nodes}); return PARAM_INVALID; } catch (std::out_of_range &) { GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); - ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"out_nodes", out_nodes}); return PARAM_INVALID; } @@ -649,7 +598,7 @@ Status ParseOutNodes(const string &out_nodes) { /// @param [in] graph Input network graph /// @return SUCCESS: Input parameters are correct; PARAM_INVALID: Input parameters are incorrect /// -static Status CheckOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf) { +static Status CheckOpNameMap(const ComputeGraphPtr &graph) { GE_CHECK_NOTNULL(graph); unordered_map graphNodeTypes; for (const NodePtr &node : graph->GetAllNodes()) { @@ -664,9 +613,7 @@ static Status CheckOpNameMap(const ComputeGraphPtr &graph, const std::string &op GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(propertiesMap.empty(), "op_name_map file is empty, please check file!"); for (auto iter = propertiesMap.begin(); iter != propertiesMap.end(); iter++) { GE_IF_BOOL_EXEC(graphNodeTypes.find(iter->second) == graphNodeTypes.end(), - ErrorManager::GetInstance().ATCReportErrMessage( - "E10003", {"parameter", "value", "reason"}, - {"op_name_map", op_conf, "type[" + iter->second + "] is not found in model"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10060", {"parameter"}, {"op_name_map"}); GELOGE(PARAM_INVALID, "Invalid parameter for op_name_map."); return PARAM_INVALID;); } return SUCCESS; @@ -723,8 +670,7 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::mapCreateModelParser(framework); GE_CHK_BOOL_RET_STATUS(model_parser != nullptr, FAILED, "ATC create model parser ret fail, framework:%d.", framework); return model_parser->ToJson(model_file, json_file); } - ErrorManager::GetInstance().ATCReportErrMessage( - "E10001", {"parameter", "value", "reason"}, - {"--framework", std::to_string(framework), "only support 0(Caffe) 3(TensorFlow)"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10045", {"parameter"}, {"model"}); GELOGE(PARAM_INVALID, "Input parameter[--framework] is mandatory and it's value must be: 0(Caffe) 3(TensorFlow)."); return PARAM_INVALID; } diff --git a/src/ge/session/session_manager.cc b/src/ge/session/session_manager.cc index 68a8aa70..c3439b0b 100644 --- a/src/ge/session/session_manager.cc +++ b/src/ge/session/session_manager.cc @@ -51,11 +51,11 @@ Status SessionManager::Finalize() { return SUCCESS; } -Status SessionManager::SetRtContext(SessionId session_id, rtContext_t rt_context) { +Status SessionManager::SetrtContext(rtContext_t rt_context) { GELOGI("set rt_context RT_CTX_NORMAL_MODE, device id:%u.", GetContext().DeviceId()); GE_CHK_RT_RET(rtCtxCreate(&rt_context, RT_CTX_NORMAL_MODE, static_cast(GetContext().DeviceId()))); GE_CHK_RT_RET(rtCtxSetCurrent(rt_context)); - RtContextUtil::GetInstance().AddRtContext(session_id, rt_context); + RtContextUtil::GetInstance().AddrtContext(rt_context); return SUCCESS; } @@ -85,7 +85,7 @@ Status SessionManager::CreateSession(const std::map &o session_id = next_session_id; // create a context - ret = SetRtContext(session_id, rtContext_t()); + ret = SetrtContext(rtContext_t()); return ret; } @@ -106,7 +106,7 @@ Status SessionManager::DestroySession(SessionId session_id) { } // Unified destruct rt_context - RtContextUtil::GetInstance().DestroyRtContexts(session_id); + RtContextUtil::GetInstance().DestroyrtContexts(); SessionPtr innerSession = it->second; Status ret = innerSession->Finalize(); @@ -300,4 +300,4 @@ bool SessionManager::IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id) } return innerSession->IsGraphNeedRebuild(graph_id); } -} // namespace ge +}; // namespace ge diff --git a/src/ge/session/session_manager.h b/src/ge/session/session_manager.h index 5cdb849f..5cce5214 100644 --- a/src/ge/session/session_manager.h +++ b/src/ge/session/session_manager.h @@ -33,6 +33,7 @@ class SessionManager { friend class GELib; public: + Status SetrtContext(rtContext_t rtContext); /// /// @ingroup ge_session /// @brief create session @@ -162,12 +163,10 @@ class SessionManager { Status GetNextSessionId(SessionId &next_session_id); - Status SetRtContext(SessionId session_id, rtContext_t rtContext); - std::map session_manager_map_; std::mutex mutex_; bool init_flag_ = false; }; -} // namespace ge +}; // namespace ge #endif // GE_SESSION_SESSION_MANAGER_H_ diff --git a/src/ge/single_op/single_op.cc b/src/ge/single_op/single_op.cc index e2d756df..9578471a 100644 --- a/src/ge/single_op/single_op.cc +++ b/src/ge/single_op/single_op.cc @@ -50,13 +50,9 @@ Status SingleOp::ValidateArgs(const std::vector &inputs, const std:: for (size_t i = 0; i < num_inputs; ++i) { // preventing from read out of bound size_t aligned_size = GetAlignedSize(inputs[i].length); - GELOGI("Input [%zu], aligned_size:%zu, inputs.length:%u, input_sizes_:%u", i, aligned_size, inputs[i].length, - input_sizes_[i]); if (aligned_size < input_sizes_[i]) { - GELOGE(PARAM_INVALID, - "Input size mismatch. index = %zu, model expect %zu," - " but given %zu(after align)", - i, input_sizes_[i], aligned_size); + GELOGE(PARAM_INVALID, "Input size mismatch. index = %zu, model expect %zu, but given %zu(after align)", i, + input_sizes_[i], aligned_size); return PARAM_INVALID; } } @@ -70,13 +66,9 @@ Status SingleOp::ValidateArgs(const std::vector &inputs, const std:: for (size_t i = 0; i < num_outputs; ++i) { // preventing from write out of bound size_t aligned_size = GetAlignedSize(outputs[i].length); - GELOGI("Output [%zu], aligned_size:%zu, outputs.length:%u, output_sizes_:%u", i, aligned_size, outputs[i].length, - output_sizes_[i]); if (aligned_size < output_sizes_[i]) { - GELOGE(PARAM_INVALID, - "Output size mismatch. index = %zu, model expect %zu," - "but given %zu(after align)", - i, output_sizes_[i], aligned_size); + GELOGE(PARAM_INVALID, "Output size mismatch. index = %zu, model expect %zu, but given %zu(after align)", i, + output_sizes_[i], aligned_size); return PARAM_INVALID; } } @@ -89,11 +81,23 @@ Status SingleOp::GetArgs(const std::vector &inputs, const std::vecto if (use_physical_addr_) { for (auto &input : inputs) { auto *addr = reinterpret_cast(input.data); + size_t aligned_size = GetAlignedSize(input.length); + auto ret = ModelUtils::ConvertVirtualAddressToPhysical(addr, aligned_size, addr); + if (ret != SUCCESS) { + GELOGE(ret, "ConvertVirtualAddressToPhysical failed. Arg index = %zu", arg_index); + return ret; + } args_[arg_index++] = reinterpret_cast(addr); } for (auto &output : outputs) { auto *addr = reinterpret_cast(output.data); + size_t aligned_size = GetAlignedSize(output.length); + auto ret = ModelUtils::ConvertVirtualAddressToPhysical(addr, aligned_size, addr); + if (ret != SUCCESS) { + GELOGE(ret, "ConvertVirtualAddressToPhysical failed. Arg index = %zu", arg_index); + return ret; + } args_[arg_index++] = reinterpret_cast(addr); } } else { @@ -113,7 +117,6 @@ Status SingleOp::UpdateArgs(const std::vector &inputs, const std::ve if (ret != SUCCESS) { return ret; } - // update tbe task args size_t num_args = arg_table_.size(); for (size_t i = 0; i < num_args; ++i) { std::vector &ptr_to_arg_in_tasks = arg_table_[i]; @@ -126,34 +129,18 @@ Status SingleOp::UpdateArgs(const std::vector &inputs, const std::ve *arg_addr = args_[i]; } } - // update aicpu_TF or aicpu_CC args for (auto &task : tasks_) { - size_t io_addr_num = args_.size(); if (task->GetOpTaskType() == OP_TASK_AICPU) { - GELOGD("Update aicpu_TF task args"); + GELOGD("Update aicpu task args"); AiCpuTask *task_aicpu = dynamic_cast(task); GE_CHECK_NOTNULL(task_aicpu); - auto *dst_io_addr = const_cast(reinterpret_cast(task_aicpu->GetIOAddr())); - GE_CHECK_NOTNULL(dst_io_addr); - auto rt_ret = rtMemcpyAsync(dst_io_addr, sizeof(uint64_t) * args_.size(), &args_[0], + auto *dstIOAddr = const_cast(reinterpret_cast(task_aicpu->GetIOAddr())); + auto rt_ret = rtMemcpyAsync(dstIOAddr, sizeof(uint64_t) * args_.size(), &args_[0], sizeof(uint64_t) * args_.size(), RT_MEMCPY_HOST_TO_DEVICE_EX, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "rtMemcpyAsync addresses failed, ret = %d", rt_ret); return RT_FAILED; } - } else if (task->GetOpTaskType() == OP_TASK_AICPUCC) { - GELOGD("Update aicpu_CC task args"); - AiCpuCCTask *task_aicpu_cc = dynamic_cast(task); - GE_CHECK_NOTNULL(task_aicpu_cc); - const uintptr_t *task_io_addr = reinterpret_cast(task_aicpu_cc->GetIOAddr()); - GE_CHECK_NOTNULL(task_io_addr); - auto io_addr = reinterpret_cast(const_cast(task_io_addr)); - for (size_t i = 0; i < io_addr_num; ++i) { - io_addr[i] = reinterpret_cast(args_[i]); - } - } else { - GELOGW("Only TF_kernel aicpu and aicpu_CC are supported, but got %u", task->GetOpTaskType()); - continue; } } return SUCCESS; @@ -177,7 +164,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOp::ExecuteAsync(c return ret; } } - return ret; } diff --git a/src/ge/single_op/single_op_manager.cc b/src/ge/single_op/single_op_manager.cc index 79f3f044..990ca9cc 100644 --- a/src/ge/single_op/single_op_manager.cc +++ b/src/ge/single_op/single_op_manager.cc @@ -41,17 +41,18 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::GetOpFr uintptr_t resource_id; // runtime uses NULL to denote a default stream for each device if (stream == nullptr) { - // use device id as resource key instead - int32_t dev_id = 0; - auto rt_err = rtGetDevice(&dev_id); + // get current context + rtContext_t rt_cur_ctx = nullptr; + auto rt_err = rtCtxGetCurrent(&rt_cur_ctx); if (rt_err != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Get current device id failed. ret = %d", static_cast(rt_err)); + GELOGE(RT_FAILED, "get current context failed, runtime result is %d", static_cast(rt_err)); return RT_FAILED; } - - GELOGI("GetOpFromModel with default stream. device id = %d", dev_id); - resource_id = static_cast(dev_id); + // use current context as resource key instead + GELOGI("use context as resource key instead when default stream"); + resource_id = reinterpret_cast(rt_cur_ctx); } else { + GELOGI("use stream as resource key instead when create stream"); resource_id = reinterpret_cast(stream); } diff --git a/src/ge/single_op/single_op_model.cc b/src/ge/single_op/single_op_model.cc index b72a41fc..9decdf75 100644 --- a/src/ge/single_op/single_op_model.cc +++ b/src/ge/single_op/single_op_model.cc @@ -28,7 +28,6 @@ #include "graph/utils/tensor_utils.h" #include "runtime/rt.h" #include "task/aicpu_task_builder.h" -#include "task/aicpu_kernel_task_builder.h" #include "task/tbe_task_builder.h" using domi::TaskDef; @@ -199,6 +198,11 @@ Status SingleOpModel::SetInputsAndOutputs(SingleOp &single_op) { int arg_index = 0; for (size_t i = 0; i < input_offset_list_.size(); ++i) { auto *addr = model_params_.mem_base + input_offset_list_[i]; + auto ret = ModelUtils::ConvertVirtualAddressToPhysical(addr, input_sizes_[i], addr); + if (ret != SUCCESS) { + GELOGE(ret, "ConvertVirtualAddressToPhysical failed. Input index = %zu", i); + return ret; + } model_params_.addr_mapping_.emplace(reinterpret_cast(addr), arg_index++); single_op.input_sizes_.emplace_back(input_sizes_[i]); single_op.input_addr_list_.emplace_back(addr); @@ -206,6 +210,11 @@ Status SingleOpModel::SetInputsAndOutputs(SingleOp &single_op) { for (size_t i = 0; i < output_offset_list_.size(); ++i) { auto *addr = model_params_.mem_base + output_offset_list_[i]; + auto ret = ModelUtils::ConvertVirtualAddressToPhysical(addr, output_sizes_[i], addr); + if (ret != SUCCESS) { + GELOGE(ret, "ConvertVirtualAddressToPhysical failed. Output index = %zu", i); + return ret; + } model_params_.addr_mapping_.emplace(reinterpret_cast(addr), arg_index++); single_op.output_sizes_.emplace_back(output_sizes_[i]); single_op.output_addr_list_.emplace_back(addr); @@ -225,31 +234,16 @@ Status SingleOpModel::BuildTaskList(SingleOp &single_op) { task_def.DebugString().c_str()); auto task_type = static_cast(task_def.type()); if (task_type == RT_MODEL_TASK_KERNEL) { - const domi::KernelDef &kernel_def = task_def.kernel(); - const auto &context = kernel_def.context(); - auto kernel_type = static_cast(context.kernel_type()); - if (kernel_type == cce::ccKernelType::TE) { - GELOGD("Building TBE task"); - OpTask *task = nullptr; - auto ret = BuildKernelTask(task_def.kernel(), single_op, &task); - if (ret != SUCCESS) { - return ret; - } - single_op.tasks_.emplace_back(task); - } else if (kernel_type == cce::ccKernelType::AI_CPU) { - GELOGD("Building AICPU_CC task"); - OpTask *task = nullptr; - auto ret = BuildCpuKernelTask(task_def.kernel(), &task); - if (ret != SUCCESS) { - return ret; - } - single_op.tasks_.emplace_back(task); - } else { - GELOGE(UNSUPPORTED, "Only TBE kernel and AI_CPU kernek are supported, but got %u", context.kernel_type()); - return UNSUPPORTED; + GELOGD("Building TBE task"); + OpTask *task = nullptr; + auto ret = BuildKernelTask(task_def.kernel(), single_op, &task); + if (ret != SUCCESS) { + return ret; } + + single_op.tasks_.emplace_back(task); } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { - GELOGD("Building AICPU_TF task"); + GELOGD("Building AICPU task"); OpTask *task = nullptr; auto ret = BuildKernelExTask(task_def.kernel_ex(), single_op, &task); if (ret != SUCCESS) { @@ -287,6 +281,12 @@ void SingleOpModel::ParseArgTable(TbeOpTask *task, SingleOp &op) { Status SingleOpModel::BuildKernelTask(const domi::KernelDef &kernel_def, SingleOp &single_op, OpTask **task) { GE_CHECK_NOTNULL(task); const auto &context = kernel_def.context(); + auto kernel_type = static_cast(context.kernel_type()); + if (kernel_type != cce::ccKernelType::TE) { + GELOGE(UNSUPPORTED, "Only TBE kernel is supported, but got %u", context.kernel_type()); + return UNSUPPORTED; + } + auto iter = op_list_.find(context.op_index()); if (iter == op_list_.end()) { GELOGE(INTERNAL_ERROR, "op desc not found. op index = %u", context.op_index()); @@ -323,13 +323,13 @@ Status SingleOpModel::BuildKernelExTask(const domi::KernelExDef &kernel_def, Sin std::unique_ptr aicpu_task(new (std::nothrow) AiCpuTask()); if (aicpu_task == nullptr) { - GELOGE(MEMALLOC_FAILED, "create aicpu_TF op task failed"); + GELOGE(MEMALLOC_FAILED, "create aicpu op task failed"); return MEMALLOC_FAILED; } auto builder = AiCpuTaskBuilder(iter->second, kernel_def); auto ret = builder.BuildTask(*aicpu_task, model_params_); if (ret != SUCCESS) { - GELOGE(ret, "build aicpu_TF op task failed"); + GELOGE(ret, "build aicpu op task failed"); return ret; } @@ -337,24 +337,6 @@ Status SingleOpModel::BuildKernelExTask(const domi::KernelExDef &kernel_def, Sin return SUCCESS; } -Status SingleOpModel::BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task) { - std::unique_ptr aicpucc_task(new (std::nothrow) AiCpuCCTask()); - if (aicpucc_task == nullptr) { - GELOGE(MEMALLOC_FAILED, "create aicpu_CC op task failed"); - return MEMALLOC_FAILED; - } - - auto builder = AiCpuCCTaskBuilder(kernel_def); - auto ret = builder.BuildTask(*aicpucc_task); - if (ret != SUCCESS) { - GELOGE(ret, "build aicpu_CC op task failed"); - return ret; - } - - *task = aicpucc_task.release(); - return SUCCESS; -} - Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { auto ret = InitModelMem(resource); if (ret != SUCCESS) { diff --git a/src/ge/single_op/single_op_model.h b/src/ge/single_op/single_op_model.h index 3b8c2616..4d8aae30 100644 --- a/src/ge/single_op/single_op_model.h +++ b/src/ge/single_op/single_op_model.h @@ -64,7 +64,6 @@ class SingleOpModel { Status BuildTaskList(SingleOp &single_op); Status BuildKernelTask(const domi::KernelDef &kernel_def, SingleOp &single_op, OpTask **task); Status BuildKernelExTask(const domi::KernelExDef &kernel_def, SingleOp &single_op, OpTask **task); - Status BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task); static void ParseOpModelParams(ModelHelper &model_helper, SingleOpModelParam ¶m); void ParseArgTable(TbeOpTask *task, SingleOp &op); diff --git a/src/ge/single_op/task/aicpu_kernel_task_builder.cc b/src/ge/single_op/task/aicpu_kernel_task_builder.cc deleted file mode 100644 index 936c7b67..00000000 --- a/src/ge/single_op/task/aicpu_kernel_task_builder.cc +++ /dev/null @@ -1,56 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "single_op/task/aicpu_kernel_task_builder.h" - -namespace ge { -AiCpuCCTaskBuilder::AiCpuCCTaskBuilder(const domi::KernelDef &kernel_def) : kernel_def_(kernel_def) {} - -Status AiCpuCCTaskBuilder::SetKernelArgs(AiCpuCCTask &task) { - size_t aicpu_arg_size = kernel_def_.args_size(); - if (aicpu_arg_size <= 0) { - GELOGE(RT_FAILED, "aicpu_arg_size is invalid, value = %zu", aicpu_arg_size); - return RT_FAILED; - } - void *aicpu_args = malloc(aicpu_arg_size); - if (aicpu_args == nullptr) { - GELOGE(RT_FAILED, "malloc failed, size = %zu", aicpu_arg_size); - return RT_FAILED; - } - - task.SetKernelArgs(aicpu_args, aicpu_arg_size); - auto err = memcpy_s(aicpu_args, aicpu_arg_size, kernel_def_.args().data(), aicpu_arg_size); - if (err != EOK) { - GELOGE(RT_FAILED, "memcpy_s args failed, size = %zu, err = %d", aicpu_arg_size, err); - return RT_FAILED; - } - - task.SetIoAddr(static_cast(aicpu_args) + sizeof(aicpu::AicpuParamHead)); - return SUCCESS; -} - -Status AiCpuCCTaskBuilder::BuildTask(AiCpuCCTask &task) { - auto ret = SetKernelArgs(task); - if (ret != SUCCESS) { - return ret; - } - const std::string &so_name = kernel_def_.so_name(); - const std::string &kernel_name = kernel_def_.kernel_name(); - task.SetSoName(so_name); - task.SetkernelName(kernel_name); - return SUCCESS; -} -} // namespace ge \ No newline at end of file diff --git a/src/ge/single_op/task/aicpu_kernel_task_builder.h b/src/ge/single_op/task/aicpu_kernel_task_builder.h deleted file mode 100644 index c445132e..00000000 --- a/src/ge/single_op/task/aicpu_kernel_task_builder.h +++ /dev/null @@ -1,40 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_SINGLE_OP_TASK_AICPU_KERNEL_TASK_BUILDER_H_ -#define GE_SINGLE_OP_TASK_AICPU_KERNEL_TASK_BUILDER_H_ - -#include -#include "aicpu/common/aicpu_task_struct.h" -#include "single_op/single_op.h" -#include "single_op/single_op_model.h" -#include "runtime/mem.h" - -namespace ge { -class AiCpuCCTaskBuilder { - public: - explicit AiCpuCCTaskBuilder(const domi::KernelDef &kernel_def); - ~AiCpuCCTaskBuilder() = default; - - Status BuildTask(AiCpuCCTask &task); - - private: - Status SetKernelArgs(AiCpuCCTask &task); - const domi::KernelDef &kernel_def_; -}; -} // namespace ge - -#endif // GE_SINGLE_OP_TASK_AICPUCC_TASK_BUILDER_H_ \ No newline at end of file diff --git a/src/ge/single_op/task/aicpu_task_builder.cc b/src/ge/single_op/task/aicpu_task_builder.cc index bc2c76f6..1a4c37ca 100644 --- a/src/ge/single_op/task/aicpu_task_builder.cc +++ b/src/ge/single_op/task/aicpu_task_builder.cc @@ -129,8 +129,7 @@ Status AiCpuTaskBuilder::BuildTask(ge::AiCpuTask &task, const SingleOpModelParam task.task_info_ = kernel_def_.task_info(); task.workspace_addr_ = ws_addr_vec[0]; - auto debug_info = BuildTaskUtils::GetTaskInfo(op_desc_); - GELOGI("[TASK_INFO] %s %s", task.task_info_.c_str(), debug_info.c_str()); return SUCCESS; } + } // namespace ge diff --git a/src/ge/single_op/task/build_task_utils.cc b/src/ge/single_op/task/build_task_utils.cc index 9e97ee57..883679be 100644 --- a/src/ge/single_op/task/build_task_utils.cc +++ b/src/ge/single_op/task/build_task_utils.cc @@ -19,9 +19,7 @@ #include "runtime/rt.h" #include "graph/load/new_model_manager/model_utils.h" #include "graph/manager/graph_var_manager.h" -#include "graph/utils/type_utils.h" #include "framework/common/debug/ge_log.h" -#include "framework/common/types.h" namespace ge { namespace { @@ -64,42 +62,4 @@ std::vector BuildTaskUtils::GetKernelArgs(const OpDescPtr &op_desc, cons auto addresses = GetAddresses(op_desc, param); return JoinAddresses(addresses); } - -std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc) { - std::stringstream ss; - if (op_desc != nullptr) { - auto op_type = op_desc->GetType(); - if (op_type == ge::NETOUTPUT || op_type == ge::DATA) { - return ss.str(); - } - // Conv2D IN[DT_FLOAT16 NC1HWC0[256, 128, 7, 7, 16],DT_FLOAT16 FRACTAL_Z[128, 32, 16, 16]] - // OUT[DT_FLOAT16 NC1HWC0[256, 32, 7, 7, 16]] - ss << op_type << " IN["; - for (uint32_t idx = 0; idx < op_desc->GetInputsSize(); idx++) { - const GeTensorDescPtr &input = op_desc->MutableInputDesc(idx); - ss << TypeUtils::DataTypeToSerialString(input->GetDataType()) << " "; - ss << TypeUtils::FormatToSerialString(input->GetFormat()); - ss << VectorToString(input->GetShape().GetDims()); - if (idx < op_desc->GetInputsSize() - 1) { - ss << ","; - } - } - ss << "] OUT["; - - for (uint32_t idx = 0; idx < op_desc->GetOutputsSize(); idx++) { - const GeTensorDescPtr &output = op_desc->MutableOutputDesc(idx); - ss << TypeUtils::DataTypeToSerialString(output->GetDataType()) << " "; - Format out_format = output->GetFormat(); - const GeShape &out_shape = output->GetShape(); - const auto &dims = out_shape.GetDims(); - ss << TypeUtils::FormatToSerialString(out_format); - ss << VectorToString(dims); - if (idx < op_desc->GetOutputsSize() - 1) { - ss << ","; - } - } - ss << "]\n"; - } - return ss.str(); -} } // namespace ge diff --git a/src/ge/single_op/task/build_task_utils.h b/src/ge/single_op/task/build_task_utils.h index f5885fd2..a5030e69 100644 --- a/src/ge/single_op/task/build_task_utils.h +++ b/src/ge/single_op/task/build_task_utils.h @@ -18,7 +18,6 @@ #define GE_SINGLE_OP_TASK_BUILD_TASK_UTILS_H_ #include -#include #include "graph/op_desc.h" #include "single_op/single_op.h" @@ -32,21 +31,6 @@ class BuildTaskUtils { static std::vector> GetAddresses(const OpDescPtr &op_desc, const SingleOpModelParam ¶m); static std::vector JoinAddresses(const std::vector> &addresses); static std::vector GetKernelArgs(const OpDescPtr &op_desc, const SingleOpModelParam ¶m); - static std::string GetTaskInfo(const OpDescPtr &op_desc); - template - static std::string VectorToString(const std::vector &values) { - std::stringstream ss; - ss << '['; - auto size = values.size(); - for (size_t i = 0; i < size; ++i) { - ss << values[i]; - if (i != size - 1) { - ss << ", "; - } - } - ss << ']'; - return ss.str(); - } }; } // namespace ge #endif // GE_SINGLE_OP_TASK_BUILD_TASK_UTILS_H_ diff --git a/src/ge/single_op/task/op_task.cc b/src/ge/single_op/task/op_task.cc index 19e8b6a4..e93fad71 100644 --- a/src/ge/single_op/task/op_task.cc +++ b/src/ge/single_op/task/op_task.cc @@ -16,18 +16,10 @@ #include "single_op/task/op_task.h" -#include -#include - #include "runtime/rt.h" #include "framework/common/debug/ge_log.h" namespace ge { -namespace { -constexpr int kLaunchRetryTimes = 1000; -constexpr int kSleepTime = 10; -} // namespace - void TbeOpTask::SetStubFunc(const std::string &name, const void *stub_func) { this->stub_name_ = name; this->stub_func_ = stub_func; @@ -61,20 +53,12 @@ Status TbeOpTask::LaunchKernel(rtStream_t stream) { GELOGD("To invoke rtKernelLaunch. task = %s, block_dim = %u", this->stub_name_.c_str(), block_dim_); auto *sm_desc = reinterpret_cast(sm_desc_); auto ret = rtKernelLaunch(stub_func_, block_dim_, args_, static_cast(arg_size_), sm_desc, stream); - int retry_times = 0; - while (ret != RT_ERROR_NONE && retry_times < kLaunchRetryTimes) { - retry_times++; - GELOGW("Retry after %d ms, retry_times: %d", kSleepTime, retry_times); - std::this_thread::sleep_for(std::chrono::milliseconds(kSleepTime)); - ret = rtKernelLaunch(stub_func_, block_dim_, args_, arg_size_, sm_desc, stream); - } - if (ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Invoke rtKernelLaunch failed. ret = %d, task = %s", ret, this->stub_name_.c_str()); return RT_FAILED; } - GELOGI("[TASK_INFO] %s", this->stub_name_.c_str()); + GELOGD("Invoke rtKernelLaunch succeeded. task = %s", this->stub_name_.c_str()); return SUCCESS; } @@ -104,49 +88,8 @@ Status AiCpuTask::LaunchKernel(rtStream_t stream) { GELOGE(RT_FAILED, "Invoke rtKernelLaunch failed. ret = %d, task = %s", ret, this->op_type_.c_str()); return RT_FAILED; } - GELOGI("[TASK_INFO] %s", this->task_info_.c_str()); - return SUCCESS; -} - -void AiCpuCCTask::SetKernelArgs(void *args, size_t arg_size) { - args_ = args; - arg_size_ = arg_size; - // the blockdim value is defult "1" for rtCpuKernelLaunch - block_dim_ = 1; -} - -void AiCpuCCTask::SetSoName(const std::string &so_name) { so_name_ = so_name; } - -void AiCpuCCTask::SetkernelName(const std::string &kernel_Name) { kernel_name_ = kernel_Name; } - -void AiCpuCCTask::SetIoAddr(void *io_addr) { io_addr_ = io_addr; } - -const void *AiCpuCCTask::GetIOAddr() const { return io_addr_; } - -const void *AiCpuCCTask::GetArgs() const { return args_; } - -size_t AiCpuCCTask::GetArgSize() const { return arg_size_; } - -AiCpuCCTask::~AiCpuCCTask() { - if (args_ != nullptr) { - free(args_); - args_ = nullptr; - } -} -Status AiCpuCCTask::LaunchKernel(rtStream_t stream) { - GELOGI("To invoke rtCpuKernelLaunch. block_dim = %u, so_name is %s, kernel_name is %s", block_dim_, so_name_.data(), - kernel_name_.data()); - // sm_desc is nullptr, because l2 buffer does not support - auto *sm_desc = reinterpret_cast(sm_desc_); - auto ret = - rtCpuKernelLaunch(static_cast(so_name_.data()), static_cast(kernel_name_.data()), - block_dim_, args_, static_cast(arg_size_), sm_desc, stream); - if (ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Invoke rtCpuKernelLaunch failed. ret = %d", ret); - return RT_FAILED; - } - GELOGD("Invoke rtCpuKernelLaunch succeeded"); + GELOGD("Invoke rtKernelLaunch succeeded. task = %s", this->op_type_.c_str()); return SUCCESS; } } // namespace ge diff --git a/src/ge/single_op/task/op_task.h b/src/ge/single_op/task/op_task.h index fd4cc96f..168a71b3 100644 --- a/src/ge/single_op/task/op_task.h +++ b/src/ge/single_op/task/op_task.h @@ -28,7 +28,6 @@ namespace ge { enum OpTaskType { OP_TASK_TBE = 0, OP_TASK_AICPU, - OP_TASK_AICPUCC, OP_TASK_INVALID, }; @@ -80,34 +79,6 @@ class AiCpuTask : public OpTask { std::string op_type_; void *io_addr_ = nullptr; }; - -class AiCpuCCTask : public OpTask { - public: - AiCpuCCTask() = default; - ~AiCpuCCTask() override; - AiCpuCCTask(const AiCpuCCTask &) = delete; - AiCpuCCTask &operator=(const AiCpuCCTask &) = delete; - - Status LaunchKernel(rtStream_t stream) override; - OpTaskType GetOpTaskType() override { return OP_TASK_AICPUCC; } - const void *GetIOAddr() const; - const void *GetArgs() const; - void SetKernelArgs(void *args, size_t arg_size); - void SetSoName(const std::string &so_name); - void SetkernelName(const std::string &kernel_Name); - void SetIoAddr(void *io_addr); - size_t GetArgSize() const; - - private: - friend class AiCpuCCTaskBuilder; - std::string so_name_; - std::string kernel_name_; - void *args_ = nullptr; - size_t arg_size_ = 0; - uint32_t block_dim_ = 1; - void *sm_desc_ = nullptr; - void *io_addr_ = nullptr; -}; } // namespace ge #endif // GE_SINGLE_OP_TASK_OP_TASK_H_ diff --git a/src/ge/single_op/task/tbe_task_builder.cc b/src/ge/single_op/task/tbe_task_builder.cc index a422fb96..c0f6877f 100644 --- a/src/ge/single_op/task/tbe_task_builder.cc +++ b/src/ge/single_op/task/tbe_task_builder.cc @@ -290,8 +290,6 @@ Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶ if (ret != SUCCESS) { return ret; } - auto task_info = BuildTaskUtils::GetTaskInfo(op_desc_); - GELOGI("[TASK_INFO] %s %s", stub_name_.c_str(), task_info.c_str()); void *stub_func = nullptr; auto rtRet = rtGetFunctionByName(stub_name_.c_str(), &stub_func); diff --git a/src/ge/stub/Makefile b/src/ge/stub/Makefile new file mode 100644 index 00000000..a0b35b42 --- /dev/null +++ b/src/ge/stub/Makefile @@ -0,0 +1,6 @@ +inc_path := $(shell pwd)/inc/external/ +out_path := $(shell pwd)/out/atc/lib64/stub/ +stub_path := $(shell pwd)/framework/domi/stub/ + +mkdir_stub := $(shell mkdir -p $(out_path)) +local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) diff --git a/src/ge/stub/README b/src/ge/stub/README new file mode 100644 index 00000000..ca98ce85 --- /dev/null +++ b/src/ge/stub/README @@ -0,0 +1,4 @@ +################################################################################### +the directory (stub) saves the stub file +gen_stubapi.py is using for retrieving API and generating stub functions +################################################################################### diff --git a/src/ge/stub/gen_stubapi.py b/src/ge/stub/gen_stubapi.py new file mode 100644 index 00000000..6185c479 --- /dev/null +++ b/src/ge/stub/gen_stubapi.py @@ -0,0 +1,573 @@ +import os +import re +import sys +import logging + +logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s', + level=logging.INFO) + +""" + this attr is used for symbol table visible +""" +GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY' + +""" + generate stub func body by return type +""" +RETURN_STATEMENTS = { + 'graphStatus': ' return GRAPH_SUCCESS;', + 'Status': ' return SUCCESS;', + 'Graph': ' return Graph();', + 'Graph&': ' return *this;', + 'Format': ' return Format();', + 'Format&': ' return *this;', + 'Shape': ' return Shape();', + 'Shape&': ' return *this;', + 'TensorDesc': ' return TensorDesc();', + 'TensorDesc&': ' return *this;', + 'Tensor': ' return Tensor();', + 'Tensor&': ' return *this;', + 'Operator': ' return Operator();', + 'Operator&': ' return *this;', + 'Ptr': ' return nullptr;', + 'std::string': ' return "";', + 'std::string&': ' return "";', + 'string': ' return "";', + 'int': ' return 0;', + 'DataType': ' return DT_FLOAT;', + 'InferenceContextPtr': ' return nullptr;', + 'SubgraphBuilder': ' return nullptr;', + 'OperatorImplPtr': ' return nullptr;', + 'OutHandler': ' return nullptr;', + 'std::vector': ' return {};', + 'std::vector': ' return {};', + 'std::map': ' return {};', + 'uint32_t': ' return 0;', + 'int64_t': ' return 0;', + 'uint64_t': ' return 0;', + 'size_t': ' return 0;', + 'float': ' return 0.0f;', + 'bool': ' return false;', +} + +""" + max code len per line in hua_wei software programming specifications +""" +max_code_len_per_line = 100 + +""" + white_list_for_debug, include_dir_key_words is to + determines which header files to generate cc files from + when DEBUG on +""" +white_list_for_debug = ["operator.h", "tensor.h", + "graph.h", "operator_factory.h", + "ge_ir_build.h"] +include_dir_key_words = ["ge", "graph"] +DEBUG = True + + +def need_generate_func(func_line): + """ + :param func_line: + :return: + """ + if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \ + or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"): + return False + return True + + +def file_endswith_white_list_suffix(file): + """ + :param file: + :return: + """ + if DEBUG: + for suffix in white_list_for_debug: + if file.endswith(suffix): + return True + return False + else: + return True + + +""" + belows are patterns used for analyse .h file +""" +# pattern function +pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after +([a-zA-Z~_] # void int likely +.* +[)] #we find ) +(?!.*{) # we do not want the case int abc() const { return 1;} +.*) +(;.*) #we want to find ; and after for we will replace these later +\n$ +""", re.VERBOSE | re.MULTILINE | re.DOTALL) + +# pattern comment +pattern_comment = re.compile(r'^\s*//') +pattern_comment_2_start = re.compile(r'^\s*/[*]') +pattern_comment_2_end = re.compile(r'[*]/\s*$') +# pattern define +pattern_define = re.compile(r'^\s*#define') +pattern_define_return = re.compile(r'\\\s*$') +# blank line +pattern_blank_line = re.compile(r'^\s*$') +# virtual,explicit,friend,static +pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)') +# lead space +pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]') +# functions will have patterns such as func ( or func( +# but operator is an exception; the class name is preceded by an operator, and the above mode does not exist +# format like :"operator = ()" +pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]') +# template +pattern_template = re.compile(r'^\s*template') +pattern_template_end = re.compile(r'>\s*$') +# namespace +pattern_namespace = re.compile(r'namespace.*{') +# class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with +pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+ 0 and not friend_match: + line, func_name = self.handle_class_member_func(line, template_string) + # Normal functions + else: + line, func_name = self.handle_normal_func(line, template_string) + + need_generate = need_generate_func(line) + # func body + line += self.implement_function(line) + # comment + line = self.gen_comment(start_i) + line + # write to out file + self.write_func_content(line, func_name, need_generate) + # next loop + self.line_index += 1 + + logging.info('Added %s functions', len(self.func_list_exist)) + logging.info('Successfully converted,please see ' + self.output_file) + + def handle_func1(self, line): + """ + :param line: + :return: + """ + find1 = re.search('[(]', line) + if not find1: + self.line_index += 1 + return "continue", line, None + find2 = re.search('[)]', line) + start_i = self.line_index + space_match = pattern_leading_space.search(line) + # deal with + # int abc(int a, + # int b) + if find1 and (not find2): + self.line_index += 1 + line2 = self.input_content[self.line_index] + if space_match: + line2 = re.sub('^' + space_match.group(1), '', line2) + line += line2 + while self.line_index < len(self.input_content) and (not re.search('[)]', line2)): + self.line_index += 1 + line2 = self.input_content[self.line_index] + line2 = re.sub('^' + space_match.group(1), '', line2) + line += line2 + + match_start = pattern_start.search(self.input_content[self.line_index]) + match_end = pattern_end.search(self.input_content[self.line_index]) + if match_start: # like ) { or ) {} int the last line + if not match_end: + self.stack.append('normal_now') + ii = start_i + while ii <= self.line_index: + ii += 1 + self.line_index += 1 + return "continue", line, start_i + logging.info("line[%s]", line) + # ' int abc();'->'int abc()' + (line, match) = pattern_func.subn(r'\2\n', line) + logging.info("line[%s]", line) + # deal with case: + # 'int \n abc(int a, int b)' + if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]): + line = self.input_content[start_i - 1] + line + line = line.lstrip() + if not match: + self.line_index += 1 + return "continue", line, start_i + return "pass", line, start_i + + def handle_stack(self, match_start): + """ + :param match_start: + :return: + """ + line = self.input_content[self.line_index] + match_end = pattern_end.search(line) + if match_start: + self.stack.append('normal_now') + if match_end: + top_status = self.stack.pop() + if top_status == 'namespace_now': + self.output_fd.write(line + '\n') + elif top_status == 'class_now': + self.stack_class.pop() + self.stack_template.pop() + if match_start or match_end: + self.line_index += 1 + return "continue" + + if len(self.stack) > 0 and self.stack[-1] == 'normal_now': + self.line_index += 1 + return "continue" + return "pass" + + def handle_class(self, template_string, line, match_start, match_class): + """ + :param template_string: + :param line: + :param match_start: + :param match_class: + :return: + """ + if match_class: # we face a class + self.stack_template.append(template_string) + self.stack.append('class_now') + class_name = match_class.group(3) + + # class template specializations: class A > + if '<' in class_name: + k = line.index('<') + fit = 1 + for ii in range(k + 1, len(line)): + if line[ii] == '<': + fit += 1 + if line[ii] == '>': + fit -= 1 + if fit == 0: + break + class_name += line[k + 1:ii + 1] + logging.info('class_name[%s]', class_name) + self.stack_class.append(class_name) + while not match_start: + self.line_index += 1 + line = self.input_content[self.line_index] + match_start = pattern_start.search(line) + self.line_index += 1 + return "continue" + return "pass" + + def handle_template(self): + line = self.input_content[self.line_index] + match_template = pattern_template.search(line) + template_string = '' + if match_template: + match_template_end = pattern_template_end.search(line) + template_string = line + while not match_template_end: + self.line_index += 1 + line = self.input_content[self.line_index] + template_string += line + match_template_end = pattern_template_end.search(line) + self.line_index += 1 + return template_string + + def handle_namespace(self): + line = self.input_content[self.line_index] + match_namespace = pattern_namespace.search(line) + if match_namespace: # we face namespace + self.output_fd.write(line + '\n') + self.stack.append('namespace_now') + self.line_index += 1 + + def handle_normal_func(self, line, template_string): + template_line = '' + self.stack_template.append(template_string) + if self.stack_template[-1] != '': + template_line = re.sub(r'\s*template', 'template', self.stack_template[-1]) + # change '< class T = a, class U = A(3)>' to '' + template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) + template_line = re.sub(r'\s*=.*,', ',', template_line) + template_line = re.sub(r'\s*=.*', '', template_line) + line = re.sub(r'\s*=.*,', ',', line) + line = re.sub(r'\s*=.*\)', ')', line) + line = template_line + line + self.stack_template.pop() + func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() + logging.info("line[%s]", line) + logging.info("func_name[%s]", func_name) + return line, func_name + + def handle_class_member_func(self, line, template_string): + template_line = '' + x = '' + if template_string != '': + template_string = re.sub(r'\s*template', 'template', template_string) + template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string) + template_string = re.sub(r'\s*=.*,', ',', template_string) + template_string = re.sub(r'\s*=.*', '', template_string) + if self.stack_template[-1] != '': + if not (re.search(r'<\s*>', stack_template[-1])): + template_line = re.sub(r'^\s*template', 'template', stack_template[-1]) + if not (re.search(r'<.*>', self.stack_class[-1])): + # for x we get like template -> + x = re.sub(r'template\s*<', '<', template_line) # remove template -> + x = re.sub(r'\n', '', x) + x = re.sub(r'\s*=.*,', ',', x) + x = re.sub(r'\s*=.*\>', '>', x) + x = x.rstrip() # remove \n + x = re.sub(r'(class|typename)\s+|(|\s*class)', '', + x) # remove class,typename -> + x = re.sub(r'<\s+', '<', x) + x = re.sub(r'\s+>', '>', x) + x = re.sub(r'\s+,', ',', x) + x = re.sub(r',\s+', ', ', x) + line = re.sub(r'\s*=\s+0', '', line) + line = re.sub(r'\s*=\s+.*,', ',', line) + line = re.sub(r'\s*=\s+.*\)', ')', line) + logging.info("x[%s]\nline[%s]", x, line) + # if the function is long, void ABC::foo() + # breaks into two lines void ABC::\n foo() + temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1) + if len(temp_line) > max_code_len_per_line: + line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1) + else: + line = temp_line + logging.info("line[%s]", line) + # add template as the above if there is one + template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) + template_line = re.sub(r'\s*=.*,', ',', template_line) + template_line = re.sub(r'\s*=.*', '', template_line) + line = template_line + template_string + line + func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() + logging.info("line[%s]", line) + logging.info("func_name[%s]", func_name) + return line, func_name + + def write_func_content(self, content, func_name, need_generate): + if not (func_name in self.func_list_exist) and need_generate: + self.output_fd.write(content) + self.func_list_exist.append(func_name) + logging.info('add func:[%s]', func_name) + + def gen_comment(self, start_i): + comment_line = '' + # Function comments are on top of function declarations, copy them over + k = start_i - 1 # one line before this func start + if pattern_template.search(self.input_content[k]): + k -= 1 + if pattern_comment_2_end.search(self.input_content[k]): + comment_line = self.input_content[k].lstrip() + while not pattern_comment_2_start.search(self.input_content[k]): + k -= 1 + comment_line = self.input_content[k].lstrip() + comment_line + else: + for j in range(k, 0, -1): + c_line = self.input_content[j] + if pattern_comment.search(c_line): + c_line = re.sub(r'\s*//', '//', c_line) + comment_line = c_line + comment_line + else: + break + return comment_line + + @staticmethod + def implement_function(func): + function_def = '' + function_def += '{\n' + + all_items = func.split() + start = 0 + return_type = all_items[start] + if return_type == "const": + start += 1 + return_type = all_items[start] + if return_type.startswith(('std::map', 'std::set', 'std::vector')): + return_type = "std::map" + if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): + return_type = "Ptr" + if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): + return_type += "&" + if RETURN_STATEMENTS.__contains__(return_type): + function_def += RETURN_STATEMENTS[return_type] + else: + logging.warning("Unhandled return type[%s]", return_type) + + function_def += '\n' + function_def += '}\n' + function_def += '\n' + return function_def + + +def collect_header_files(path): + """ + :param path: + :return: + """ + header_files = [] + shared_includes_content = [] + for root, dirs, files in os.walk(path): + files.sort() + for file in files: + if file.find("git") >= 0: + continue + if not file.endswith('.h'): + continue + file_path = os.path.join(root, file) + file_path = file_path.replace('\\', '/') + header_files.append(file_path) + include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:]) + shared_includes_content.append(include_str) + return header_files, shared_includes_content + + +def generate_stub_file(inc_dir, out_cc_dir): + """ + :param inc_dir: + :param out_cc_dir: + :return: + """ + target_header_files, shared_includes_content = collect_header_files(inc_dir) + for header_file in target_header_files: + if not file_endswith_white_list_suffix(header_file): + continue + cc_file = re.sub('.h*$', '.cc', header_file) + h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content) + h_2_cc.h2cc() + + +def gen_code(inc_dir, out_cc_dir): + """ + :param inc_dir: + :param out_cc_dir: + :return: + """ + if not inc_dir.endswith('/'): + inc_dir += '/' + if not out_cc_dir.endswith('/'): + out_cc_dir += '/' + for include_dir_key_word in include_dir_key_words: + generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir) + + +if __name__ == '__main__': + inc_dir = sys.argv[1] + out_cc_dir = sys.argv[2] + gen_code(inc_dir, out_cc_dir) diff --git a/tests/depends/cce/src/op_kernel_registry.cc b/tests/depends/cce/src/op_kernel_registry.cc index 5ccd1391..9bb32a31 100644 --- a/tests/depends/cce/src/op_kernel_registry.cc +++ b/tests/depends/cce/src/op_kernel_registry.cc @@ -1,3 +1,19 @@ +/** + * Copyright 2019 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + #include "register/op_kernel_registry.h" namespace ge { diff --git a/third_party/fwkacllib/inc/hccl/base.h b/third_party/fwkacllib/inc/hccl/base.h index 1d83d7bf..74163baf 100644 --- a/third_party/fwkacllib/inc/hccl/base.h +++ b/third_party/fwkacllib/inc/hccl/base.h @@ -102,11 +102,6 @@ typedef enum tagHcclDataType { HCCL_DATA_TYPE_RESERVED /**< reserved */ } hcclDataType_t; -constexpr u32 HCCL_UNIQUE_ID_BYTES = 2060; // 2060: unique id length -using hcclUniqueId = struct hcclUniqueIdDef { - char internal[HCCL_UNIQUE_ID_BYTES]; -}; - const u32 HCCL_MAX_SEGMENT_NUM = 8; // The max number of gradient segments. /** @@ -125,12 +120,6 @@ enum GradSplitForceMode { FORCE_RESERVED /**< reserved */ }; -enum OriginalGraphShapeType { - KNOWN_SHAPE, - UNKNOWN_SHAPE, - SHAPE_RESERVED /**< reserved */ -}; - /** * @brief stream handle. */ diff --git a/third_party/fwkacllib/inc/hccl/hcom.h b/third_party/fwkacllib/inc/hccl/hcom.h index 19bf4fb3..a448d411 100644 --- a/third_party/fwkacllib/inc/hccl/hcom.h +++ b/third_party/fwkacllib/inc/hccl/hcom.h @@ -22,6 +22,7 @@ #ifndef HCOM_H_ #define HCOM_H_ +#include #include #ifdef __cplusplus @@ -245,9 +246,8 @@ hcclResult_t hcom_receive(const char *tag, void *outputPtr, u64 count, hcclDataT * @param segmentIdx A list identifying the index of end gradient in each segment. * @return hcclResult_t */ -hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum, - u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force = FORCE_NONE, - OriginalGraphShapeType shapeType = KNOWN_SHAPE); +hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, + u32 maxSegmentNum, u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force = FORCE_NONE); /** * @brief Set the gradient split strategy with in the group, according to gradient index. diff --git a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h index 6ac8f8f6..ce83d143 100644 --- a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h +++ b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h @@ -344,8 +344,6 @@ extern INT32 mmRealPath(const CHAR *path, CHAR *realPath, INT32 realPathLen); extern INT32 mmDup2(INT32 oldFd, INT32 newFd); -extern INT32 mmDup(INT32 fd); - extern INT32 mmUnlink(const CHAR *filename); extern INT32 mmChmod(const CHAR *filename, INT32 mode); diff --git a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h index 68a70c27..ef15f371 100644 --- a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h +++ b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h @@ -378,7 +378,6 @@ _declspec(dllexport) INT32 mmGetRealPath(CHAR *path, CHAR *realPath); _declspec(dllexport) INT32 mmRealPath(const CHAR *path, CHAR *realPath, INT32 realPathLen); _declspec(dllexport) INT32 mmDup2(INT32 oldFd, INT32 newFd); -_declspec(dllexport) INT32 mmDup(INT32 fd); _declspec(dllexport) INT32 mmUnlink(const CHAR *filename); _declspec(dllexport) INT32 mmChmod(const CHAR *filename, INT32 mode); _declspec(dllexport) INT32 mmFileno(FILE *stream); diff --git a/third_party/fwkacllib/inc/ops/hvd_ops.h b/third_party/fwkacllib/inc/ops/hvd_ops.h deleted file mode 100644 index 09748b8e..00000000 --- a/third_party/fwkacllib/inc/ops/hvd_ops.h +++ /dev/null @@ -1,77 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_OP_HVD_OPS_H_ -#define GE_OP_HVD_OPS_H_ - -#include "graph/operator_reg.h" - -namespace ge { -/** - * @brief Outputs a tensor gathering all input tensors. - * @par Inputs: - * x: A tensor. Must be one of the following types: uint8, int8, uint16, int16, int32, - * int64, float16, bool. - * @par Attributes: - * @li rank_size: A required integer identifying the number of ranks - * participating in the op. - * @par Outputs: - * y: A Tensor. Has the same type as "x". - */ -REG_OP(HorovodAllgather) - // GE not support float64 currently - .INPUT(x, TensorType({DT_UINT8, DT_INT8, DT_UINT16, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_BOOL})) - .OUTPUT(y, TensorType({DT_UINT8, DT_INT8, DT_UINT16, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_BOOL})) - // add rank_size attr - .REQUIRED_ATTR(rank_size, Int) - .OP_END_FACTORY_REG(HorovodAllgather) - -/** - * @brief Outputs a tensor containing the reduction across all input tensors - * passed to op. - * @par Inputs: - * x: A tensor. Must be one of the following types: int32, int64, float16, float32 - * @par Attributes: - * @li reduce_op: A required int identifying the reduction operation to - * perform.The supported operation are: "sum", "max", "min", "prod". - * @par Outputs: - * y: A Tensor. Has the same type as "x". - */ -REG_OP(HorovodAllreduce) - .INPUT(x, TensorType({DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT})) - .OUTPUT(y, TensorType({DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT})) - .REQUIRED_ATTR(reduce_op, Int) - .OP_END_FACTORY_REG(HorovodAllreduce) - -/** - * @brief Broadcasts the input tensor in root rank to all ranks. - * @par Inputs: - * x: A list of dynamic input tensor. Must be one of the following types: - * int8, int32, float16, float32. - * @par Attributes: - * @li root_rank: A required integer identifying the root rank in the op - * input of this rank will be broadcast to other ranks. - * @par Outputs: - * y: A list of dynamic output tensor. Has the same type and length as "x". - */ -REG_OP(HorovodBroadcast) - .INPUT(x, TensorType({DT_UINT8, DT_INT8, DT_UINT16, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_BOOL})) - .OUTPUT(y, TensorType({DT_UINT8, DT_INT8, DT_UINT16, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_BOOL})) - .REQUIRED_ATTR(root_rank, Int) - .OP_END_FACTORY_REG(HorovodBroadcast) - -} // namespace ge -#endif // GE_OP_HVD_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/internal_ops.h b/third_party/fwkacllib/inc/ops/internal_ops.h deleted file mode 100644 index e3caa45f..00000000 --- a/third_party/fwkacllib/inc/ops/internal_ops.h +++ /dev/null @@ -1,48 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_OP_INTERNAL_OPS_H_ -#define GE_OP_INTERNAL_OPS_H_ - -#include "graph/operator_reg.h" -#include "graph/operator.h" - -namespace ge { - -/** -*@brief aicpu assit help op for auxiliary matrix generation. - -*@par Inputs: -*The input is dynamic for attribute func_name \n - -*@par Attributes: -*@li func_name:An required param, for example "topkv2". \n - -*@par Outputs: -*The output is dynamic for attribute func_name. -*/ - -REG_OP(AssistHelp) - .DYNAMIC_INPUT(x, TensorType({ DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, - DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE })) - .DYNAMIC_OUTPUT(y, TensorType({ DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, - DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) - . REQUIRED_ATTR (func_name, String) - . OP_END_FACTORY_REG(AssistHelp) - -} // namespace ge - -#endif // GE_OP_INTERNAL_OPS_H_ diff --git a/third_party/fwkacllib/inc/register/op_registry.h b/third_party/fwkacllib/inc/register/op_registry.h index 1dc14b8b..1fcdf9de 100644 --- a/third_party/fwkacllib/inc/register/op_registry.h +++ b/third_party/fwkacllib/inc/register/op_registry.h @@ -35,7 +35,6 @@ enum RemoveInputType { OMG_MOVE_TYPE_SCALAR_VALUE, OMG_REMOVE_TYPE_WITH_COND = 1000, OMG_REMOVE_INPUT_WITH_ORIGINAL_TYPE, - OMG_INPUT_REORDER, }; struct RemoveInputConfigure { @@ -44,7 +43,6 @@ struct RemoveInputConfigure { RemoveInputType moveType; bool attrValue = false; std::string originalType; - std::vector input_order; }; class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistry { @@ -59,11 +57,11 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistry { void GetOpTypeByImplyType(std::vector &vec_op_type, const domi::ImplyType &imply_type); - domi::ParseParamFunc GetParseParamFunc(const std::string &op_type, const std::string &ori_type); + domi::ParseParamFunc GetParseParamFunc(const std::string &op_type); - domi::ParseParamByOpFunc GetParseParamByOperatorFunc(const std::string &ori_type); + domi::ParseParamByOpFunc GetParseParamByOperatorFunc(const std::string &op_type); - domi::FusionParseParamFunc GetFusionParseParamFunc(const std::string &op_type, const std::string &ori_type); + domi::FusionParseParamFunc GetFusionParseParamFunc(const std::string &op_type); domi::ParseSubgraphFunc GetParseSubgraphPostFunc(const std::string &op_type); @@ -74,13 +72,14 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistry { bool GetOmTypeByOriOpType(const std::string &ori_optype, std::string &om_type); private: + std::unordered_map> op_ori_optype_map_; std::unordered_map op_run_mode_map_; - std::unordered_map op_parse_params_fn_map_; + std::unordered_map opParseParamsFnMap_; std::unordered_map parse_params_by_op_func_map_; - std::unordered_map fusion_op_parse_params_fn_map_; + std::unordered_map fusionOpParseParamsFnMap_; std::unordered_map op_types_to_parse_subgraph_post_func_; std::unordered_map> remove_input_configure_map_; - std::unordered_map origin_type_to_om_type_; + std::unordered_map originOpType2OmOpType_; }; } // namespace domi #endif // INC_REGISTER_OP_REGISTRY_H_ diff --git a/third_party/fwkacllib/inc/register/op_tiling.h b/third_party/fwkacllib/inc/register/op_tiling.h deleted file mode 100644 index 92067a20..00000000 --- a/third_party/fwkacllib/inc/register/op_tiling.h +++ /dev/null @@ -1,130 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef INC_OP_TILING_H_ -#define INC_OP_TILING_H_ - -#include "external/register/register_types.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/node.h" - -#include -#include -#include -#include -#include -#include -#include "graph/node.h" - -#define REGISTER_OP_TILING_FUNC(optype, opfunc) \ - REGISTER_OP_TILING_FUNC_UNIQ_HELPER(optype, opfunc, __COUNTER__) - -#define REGISTER_OP_TILING_FUNC_UNIQ_HELPER(optype, opfunc, counter) \ - REGISTER_OP_TILING_FUNC_UNIQ(optype, opfunc, counter) - -#define REGISTER_OP_TILING_FUNC_UNIQ(optype, opfunc, counter) \ - static OpTilingInterf g_##optype##TilingInterf##counter(#optype, opfunc) - -namespace optiling { - -enum TensorArgType { - TA_NONE, - TA_SINGLE, - TA_LIST, -}; - - -using ByteBuffer = std::stringstream; - -struct TeOpTensor { - std::vector shape; - std::vector ori_shape; - std::string format; - std::string ori_format; - std::string dtype; - std::map attrs; -}; - - -struct TeOpTensorArg { - TensorArgType arg_type; - std::vector tensor; -}; - -struct OpRunInfo { - uint32_t block_dim; - std::vector workspaces; - ByteBuffer tiling_data; -}; - - -using TeOpAttrArgs = std::vector; -using TeConstTensorData = std::tuple; - -struct TeOpParas { - std::vector inputs; - std::vector outputs; - std::map const_inputs; - TeOpAttrArgs attrs; -}; - -using OpTilingFunc = std::function; - -using OpTilingFuncPtr = bool(*)(const std::string&, const TeOpParas&, const std::string&, OpRunInfo&); - -class FMK_FUNC_HOST_VISIBILITY OpTilingInterf -{ -public: - OpTilingInterf(std::string op_type, OpTilingFunc func); - ~OpTilingInterf() = default; - static std::map &RegisteredOpInterf(); -}; - - -template -ByteBuffer& ByteBufferPut(ByteBuffer &buf, const T &value) -{ - buf.write(reinterpret_cast(&value), sizeof(value)); - buf.flush(); - return buf; -} - -template -ByteBuffer& ByteBufferGet(ByteBuffer &buf, T &value) -{ - buf.read(reinterpret_cast(&value), sizeof(value)); - return buf; -} - -inline size_t ByteBufferGetAll(ByteBuffer &buf, char *dest, size_t dest_len) -{ - size_t nread = 0; - size_t rn = 0; - do { - rn = buf.readsome(dest + nread, dest_len - nread); - nread += rn; - } while (rn > 0 && dest_len > nread); - - return nread; -} - - -extern "C" ge::graphStatus OpParaCalculate(const ge::Node &node, OpRunInfo &run_info); - -} - -#endif // INC_OP_TILING_H_ diff --git a/third_party/fwkacllib/inc/runtime/base.h b/third_party/fwkacllib/inc/runtime/base.h index 7539a549..49c9de6a 100644 --- a/third_party/fwkacllib/inc/runtime/base.h +++ b/third_party/fwkacllib/inc/runtime/base.h @@ -68,8 +68,6 @@ typedef enum tagRtError { RT_ERROR_NO_STREAM_CB_REG = 0x96, // no callback register info for stream RT_ERROR_DATA_DUMP_LOAD_FAILED = 0x97, // data dump load info fail RT_ERROR_CALLBACK_THREAD_UNSUBSTRIBE = 0x98, // callback thread unsubstribe - RT_ERROR_DEBUG_REGISTER_FAILED = 0x99, // debug register fail - RT_ERROR_DEBUG_UNREGISTER_FAILED = 0x9A, // debug unregister fail RT_ERROR_RESERVED } rtError_t; @@ -186,6 +184,14 @@ RTS_API rtError_t rtGetLastError(); */ RTS_API rtError_t rtPeekAtLastError(); +/** + * @ingroup dvrt_base + * @brief set polling receive mode for task report + * @param [out] NA + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtSetPollingMode(); + /** * @ingroup dvrt_base * @brief register callback for error code diff --git a/third_party/fwkacllib/inc/runtime/config.h b/third_party/fwkacllib/inc/runtime/config.h index 3dad53c5..2e48cc57 100644 --- a/third_party/fwkacllib/inc/runtime/config.h +++ b/third_party/fwkacllib/inc/runtime/config.h @@ -41,7 +41,8 @@ typedef enum tagRtChipType { CHIP_CLOUD, CHIP_MDC, CHIP_LHISI, - CHIP_DC, + CHIP_OTHER_PHN, + CHIP_OTHER_OLD, CHIP_END, } rtChipType_t; diff --git a/third_party/fwkacllib/inc/runtime/context.h b/third_party/fwkacllib/inc/runtime/context.h index b059268e..ed1f13c2 100644 --- a/third_party/fwkacllib/inc/runtime/context.h +++ b/third_party/fwkacllib/inc/runtime/context.h @@ -98,6 +98,14 @@ RTS_API rtError_t rtCtxSynchronize(void); */ RTS_API rtError_t rtCtxGetCurrent(rtContext_t *ctx); +/** + * @ingroup rt_context + * @brief returns the primary context of device. + * @param [out] ctx returned context + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetPriCtxByDeviceId(int32_t device, rtContext_t *ctx); + /** * @ingroup rt_context * @brief returns the device ID for the current context @@ -106,6 +114,16 @@ RTS_API rtError_t rtCtxGetCurrent(rtContext_t *ctx); */ RTS_API rtError_t rtCtxGetDevice(int32_t *device); +/** + * @ingroup rt_context + * @brief set ctx run mode: normal or dryrun + * @param [in] ctx: context + * @param [in] enable: set true means enable dryrun mode + * @param [in] flag: reserved + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtCtxSetDryRun(rtContext_t ctx, rtDryRunFlag_t enable, uint32_t flag); + #ifdef __cplusplus } #endif diff --git a/third_party/fwkacllib/inc/runtime/dev.h b/third_party/fwkacllib/inc/runtime/dev.h index 60928202..928f2822 100644 --- a/third_party/fwkacllib/inc/runtime/dev.h +++ b/third_party/fwkacllib/inc/runtime/dev.h @@ -32,7 +32,6 @@ typedef struct tagRTDeviceInfo { uint32_t ts_cpu_core_num; uint32_t ai_cpu_core_num; uint32_t ai_core_num; - uint32_t ai_core_freq; uint32_t ai_cpu_core_id; uint32_t ai_core_id; uint32_t aicpu_occupy_bitmap; @@ -47,13 +46,6 @@ typedef enum tagRtRunMode { RT_RUN_MODE_RESERVED } rtRunMode; -typedef enum tagRtAicpuDeployType { - AICPU_DEPLOY_CROSS_OS = 0x0, - AICPU_DEPLOY_CROSS_PROCESS = 0x1, - AICPU_DEPLOY_CROSS_THREAD = 0x2, - AICPU_DEPLOY_RESERVED -} rtAicpuDeployType_t; - /** * @ingroup dvrt_dev * @brief get total device number. @@ -70,40 +62,15 @@ RTS_API rtError_t rtGetDeviceCount(int32_t *count); * @return RT_ERROR_DRV_ERR for error */ RTS_API rtError_t rtGetDeviceIDs(uint32_t *devices, uint32_t len); - /** * @ingroup dvrt_dev - * @brief get device infomation. + * @brief get total device infomation. * @param [in] device the device id - * @param [in] moduleType module type - typedef enum { - MODULE_TYPE_SYSTEM = 0, system info - MODULE_TYPE_AICPU, aicpu info - MODULE_TYPE_CCPU, ccpu_info - MODULE_TYPE_DCPU, dcpu info - MODULE_TYPE_AICORE, AI CORE info - MODULE_TYPE_TSCPU, tscpu info - MODULE_TYPE_PCIE, PCIE info - } DEV_MODULE_TYPE; - * @param [in] infoType info type - typedef enum { - INFO_TYPE_ENV = 0, - INFO_TYPE_VERSION, - INFO_TYPE_MASTERID, - INFO_TYPE_CORE_NUM, - INFO_TYPE_OS_SCHED, - INFO_TYPE_IN_USED, - INFO_TYPE_ERROR_MAP, - INFO_TYPE_OCCUPY, - INFO_TYPE_ID, - INFO_TYPE_IP, - INFO_TYPE_ENDIAN, - } DEV_INFO_TYPE; - * @param [out] value the device info + * @param [out] info the device info * @return RT_ERROR_NONE for ok * @return RT_ERROR_NO_DEVICE for can not find any device */ -RTS_API rtError_t rtGetDeviceInfo(uint32_t deviceId, int32_t moduleType, int32_t infoType, int64_t *value); +RTS_API rtError_t rtGetDeviceInfo(int32_t device, rtDeviceInfo_t *info); /** * @ingroup dvrt_dev @@ -163,25 +130,6 @@ RTS_API rtError_t rtEnableP2P(uint32_t devIdDes, uint32_t phyIdSrc); */ RTS_API rtError_t rtDisableP2P(uint32_t devIdDes, uint32_t phyIdSrc); -/** - * @ingroup dvrt_dev - * @brief get status - * @param [in] devIdDes the logical device id - * @param [in] phyIdSrc the physical device id - * @param [in|out] status status value - * @return RT_ERROR_NONE for ok - * @return RT_ERROR_NO_DEVICE for can not find any device - */ -RTS_API rtError_t rtGetP2PStatus(uint32_t devIdDes, uint32_t phyIdSrc, uint32_t *status); - -/** - * @ingroup dvrt_dev - * @brief get value of current thread - * @param [in|out] pid value of pid - * @return RT_ERROR_NONE for ok - */ -RTS_API rtError_t rtDeviceGetBareTgid(uint32_t *pid); - /** * @ingroup dvrt_dev * @brief get target device of current thread @@ -264,15 +212,6 @@ RTS_API rtError_t rtSetTSDevice(uint32_t tsId); */ RTS_API rtError_t rtGetRunMode(rtRunMode *mode); -/** - * @ingroup dvrt_dev - * @brief get aicpu deploy - * @param [out] aicpu deploy - * @return RT_ERROR_NONE for ok - * @return RT_ERROR_DRV_ERR for can not get aicpu deploy - */ -RTS_API rtError_t rtGetAicpuDeploy(rtAicpuDeployType_t *deplyType); - /** * @ingroup dvrt_dev * @brief set chipType @@ -286,17 +225,6 @@ RTS_API rtError_t rtSetSocVersion(const char *version); * @return RT_ERROR_NONE for ok */ rtError_t rtGetSocVersion(char *version, const uint32_t maxLen); - -/** - * @ingroup dvrt_dev - * @brief get status - * @param [in] devId the logical device id - * @param [in] otherDevId the other logical device id - * @param [in] infoType info type - * @param [in|out] value pair info - * @return RT_ERROR_NONE for ok - */ -RTS_API rtError_t rtGetPairDevicesInfo(uint32_t devId, uint32_t otherDevId, int32_t infoType, int64_t *value); #ifdef __cplusplus } #endif diff --git a/third_party/fwkacllib/inc/runtime/rt_model.h b/third_party/fwkacllib/inc/runtime/rt_model.h index 5c85a3d7..790492fc 100644 --- a/third_party/fwkacllib/inc/runtime/rt_model.h +++ b/third_party/fwkacllib/inc/runtime/rt_model.h @@ -65,13 +65,6 @@ typedef enum tagModelQueueFlag { #define EXECUTOR_TS ((uint32_t)0x01) #define EXECUTOR_AICPU ((uint32_t)0x02) -/* - * @ingroup rt_model - * @brief debug flag for kernel exception dump - */ -#define RT_DEBUG_FLAG_AICORE_OVERFLOW (0x1 << 0) -#define RT_DEBUG_FLAG_ATOMIC_ADD_OVERFLOW (0x1 << 1) - /** * @ingroup * @brief the type defination of aicpu model task command @@ -410,26 +403,6 @@ RTS_API rtError_t rtModelBindQueue(rtModel_t model, uint32_t queueId, rtModelQue */ RTS_API rtError_t rtModelGetId(rtModel_t model, uint32_t *modelId); -/* - * @ingroup rt_model - * @brief enable debug for dump overflow exception - * @param [in] addr: ddr address of kernel exception dumpped - * @param [in] model: model handle - * @param [in] flag: debug flag - * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input handle - */ -rtError_t rtDebugRegister(rtModel_t model, uint32_t flag, const void *addr, uint32_t *streamId, uint32_t *taskId); - -/* - * @ingroup rt_model - * @brief disable debug for dump overflow exception - * @param [in] model: model handle - * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input handle - */ -RTS_API rtError_t rtDebugUnRegister(rtModel_t model); - #ifdef __cplusplus } #endif diff --git a/third_party/fwkacllib/inc/toolchain/slog.h b/third_party/fwkacllib/inc/toolchain/slog.h index 261fe866..f77df225 100644 --- a/third_party/fwkacllib/inc/toolchain/slog.h +++ b/third_party/fwkacllib/inc/toolchain/slog.h @@ -91,10 +91,6 @@ extern "C" { * max log length */ #define MSG_LENGTH 1024 -#define DEBUG_LOG_MASK (0x00010000) -#define SECURITY_LOG_MASK (0x00100000) -#define RUN_LOG_MASK (0x01000000) -#define OPERATION_LOG_MASK (0x10000000) typedef struct tagDCODE { const char *cName; @@ -173,11 +169,83 @@ enum { PROCMGR, // Process Manager, Base Platform BBOX, AIVECTOR, - TBE, - FV, INVLID_MOUDLE_ID }; +#ifdef MODULE_ID_NAME + +/** + * @ingroup slog + * + * set module id to map + */ +#define SET_MOUDLE_ID_MAP_NAME(x) \ + { #x, x } + +static DCODE g_moduleIdName[] = {SET_MOUDLE_ID_MAP_NAME(SLOG), + SET_MOUDLE_ID_MAP_NAME(IDEDD), + SET_MOUDLE_ID_MAP_NAME(IDEDH), + SET_MOUDLE_ID_MAP_NAME(HCCL), + SET_MOUDLE_ID_MAP_NAME(FMK), + SET_MOUDLE_ID_MAP_NAME(HIAIENGINE), + SET_MOUDLE_ID_MAP_NAME(DVPP), + SET_MOUDLE_ID_MAP_NAME(RUNTIME), + SET_MOUDLE_ID_MAP_NAME(CCE), +#if (OS_TYPE == LINUX) + SET_MOUDLE_ID_MAP_NAME(HDC), +#else + SET_MOUDLE_ID_MAP_NAME(HDCL), +#endif // OS_TYPE + SET_MOUDLE_ID_MAP_NAME(DRV), + SET_MOUDLE_ID_MAP_NAME(MDCFUSION), + SET_MOUDLE_ID_MAP_NAME(MDCLOCATION), + SET_MOUDLE_ID_MAP_NAME(MDCPERCEPTION), + SET_MOUDLE_ID_MAP_NAME(MDCFSM), + SET_MOUDLE_ID_MAP_NAME(MDCCOMMON), + SET_MOUDLE_ID_MAP_NAME(MDCMONITOR), + SET_MOUDLE_ID_MAP_NAME(MDCBSWP), + SET_MOUDLE_ID_MAP_NAME(MDCDEFAULT), + SET_MOUDLE_ID_MAP_NAME(MDCSC), + SET_MOUDLE_ID_MAP_NAME(MDCPNC), + SET_MOUDLE_ID_MAP_NAME(MLL), + SET_MOUDLE_ID_MAP_NAME(DEVMM), + SET_MOUDLE_ID_MAP_NAME(KERNEL), + SET_MOUDLE_ID_MAP_NAME(LIBMEDIA), + SET_MOUDLE_ID_MAP_NAME(CCECPU), + SET_MOUDLE_ID_MAP_NAME(ASCENDDK), + SET_MOUDLE_ID_MAP_NAME(ROS), + SET_MOUDLE_ID_MAP_NAME(HCCP), + SET_MOUDLE_ID_MAP_NAME(ROCE), + SET_MOUDLE_ID_MAP_NAME(TEFUSION), + SET_MOUDLE_ID_MAP_NAME(PROFILING), + SET_MOUDLE_ID_MAP_NAME(DP), + SET_MOUDLE_ID_MAP_NAME(APP), + SET_MOUDLE_ID_MAP_NAME(TS), + SET_MOUDLE_ID_MAP_NAME(TSDUMP), + SET_MOUDLE_ID_MAP_NAME(AICPU), + SET_MOUDLE_ID_MAP_NAME(LP), + SET_MOUDLE_ID_MAP_NAME(TDT), + SET_MOUDLE_ID_MAP_NAME(FE), + SET_MOUDLE_ID_MAP_NAME(MD), + SET_MOUDLE_ID_MAP_NAME(MB), + SET_MOUDLE_ID_MAP_NAME(ME), + SET_MOUDLE_ID_MAP_NAME(IMU), + SET_MOUDLE_ID_MAP_NAME(IMP), + SET_MOUDLE_ID_MAP_NAME(GE), + SET_MOUDLE_ID_MAP_NAME(MDCFUSA), + SET_MOUDLE_ID_MAP_NAME(CAMERA), + SET_MOUDLE_ID_MAP_NAME(ASCENDCL), + SET_MOUDLE_ID_MAP_NAME(TEEOS), + SET_MOUDLE_ID_MAP_NAME(ISP), + SET_MOUDLE_ID_MAP_NAME(SIS), + SET_MOUDLE_ID_MAP_NAME(HSM), + SET_MOUDLE_ID_MAP_NAME(DSS), + SET_MOUDLE_ID_MAP_NAME(PROCMGR), + SET_MOUDLE_ID_MAP_NAME(BBOX), + SET_MOUDLE_ID_MAP_NAME(AIVECTOR), + { NULL, -1 }}; +#endif // MODULE_ID_NAME + #if (OS_TYPE == LINUX) /** * @ingroup slog @@ -318,11 +386,6 @@ extern int CheckLogLevel(int moduleId, int logLevel); DlogWithKVInner(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ } while (0) -/** - * @ingroup slog - * @brief DlogFlush: flush log buffer to file - */ -void DlogFlush(void); /** * @ingroup slog diff --git a/third_party/prebuild/aarch64/liberror_manager.so b/third_party/prebuild/aarch64/liberror_manager.so deleted file mode 100755 index 759d8e30..00000000 Binary files a/third_party/prebuild/aarch64/liberror_manager.so and /dev/null differ diff --git a/third_party/prebuild/aarch64/libslog.so b/third_party/prebuild/aarch64/libslog.so deleted file mode 100755 index 700fc118..00000000 Binary files a/third_party/prebuild/aarch64/libslog.so and /dev/null differ diff --git a/third_party/prebuild/x86_64/liberror_manager.so b/third_party/prebuild/x86_64/liberror_manager.so deleted file mode 100755 index cd9ad8bc..00000000 Binary files a/third_party/prebuild/x86_64/liberror_manager.so and /dev/null differ diff --git a/third_party/prebuild/x86_64/libslog.so b/third_party/prebuild/x86_64/libslog.so index 01b75e40..b476618d 100755 Binary files a/third_party/prebuild/x86_64/libslog.so and b/third_party/prebuild/x86_64/libslog.so differ