diff --git a/build.sh b/build.sh index 0afaa7fb..1871bbb8 100644 --- a/build.sh +++ b/build.sh @@ -174,9 +174,11 @@ echo "---------------- GraphEngine output generated ----------------" # generate output package in tar form, including ut/st libraries/executables cd ${BASEPATH} mkdir -p output/plugin/nnengine/ge_config/ +mkdir -p output/plugin/opskernel/ find output/ -name graphengine_lib.tar -exec rm {} \; cp src/ge/engine_manager/engine_conf.json output/plugin/nnengine/ge_config/ find output/ -maxdepth 1 -name libengine.so -exec mv -f {} output/plugin/nnengine/ \; +find output/ -maxdepth 1 -name libge_local_engine.so -exec mv -f {} output/plugin/opskernel/ \; tar -cf graphengine_lib.tar output/* mv -f graphengine_lib.tar output echo "---------------- GraphEngine package archive generated ----------------" diff --git a/inc/common/opskernel/ge_task_info.h b/inc/common/opskernel/ge_task_info.h index 360f8a5d..9f3c409d 100644 --- a/inc/common/opskernel/ge_task_info.h +++ b/inc/common/opskernel/ge_task_info.h @@ -52,5 +52,23 @@ struct GETaskInfo { std::vector kernelHcclInfo; }; + +struct HcomOpertion { + std::string hcclType; + void *inputPtr; + void *outputPtr; + uint64_t count; + int32_t dataType; + int32_t opType; + int32_t root; +}; + +struct HcomRemoteAccessAddrInfo { + uint32_t remotetRankID; + uint64_t remoteAddr; // host embedding table address + uint64_t localAddr; // device HBM address + uint64_t length; // memory Length in Bytes +}; + } // namespace ge #endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ diff --git a/inc/common/opskernel/ops_kernel_info_store.h b/inc/common/opskernel/ops_kernel_info_store.h index 46338e45..ce1464d4 100644 --- a/inc/common/opskernel/ops_kernel_info_store.h +++ b/inc/common/opskernel/ops_kernel_info_store.h @@ -43,10 +43,10 @@ class OpsKernelInfoStore { virtual ~OpsKernelInfoStore() {} // initialize opsKernelInfoStore - virtual Status Initialize(const map &options) = 0; + virtual Status Initialize(const map &options) = 0; /*lint -e148*/ // close opsKernelInfoStore - virtual Status Finalize() = 0; + virtual Status Finalize() = 0; /*lint -e148*/ virtual Status CreateSession(const std::map &session_options) { return SUCCESS; } @@ -66,10 +66,11 @@ class OpsKernelInfoStore { virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag){}; // memory allocation requirement - virtual Status CalcOpRunningParam(Node &node) = 0; + virtual Status CalcOpRunningParam(Node &node) = 0; /*lint -e148*/ // generate task for op。 - virtual Status GenerateTask(const Node &node, RunContext &context, std::vector &tasks) = 0; + virtual Status GenerateTask(const Node &node, RunContext &context, + std::vector &tasks) = 0; /*lint -e148*/ // only call fe engine interface to compile single op virtual Status CompileOp(vector &node_vec) { return SUCCESS; } diff --git a/inc/common/opskernel/ops_kernel_info_types.h b/inc/common/opskernel/ops_kernel_info_types.h index d13840bd..684c1abc 100644 --- a/inc/common/opskernel/ops_kernel_info_types.h +++ b/inc/common/opskernel/ops_kernel_info_types.h @@ -26,6 +26,7 @@ using std::string; namespace ge { +/*lint -e148*/ struct RunContext { rtModel_t model; rtStream_t stream; @@ -40,6 +41,8 @@ struct RunContext { std::vector graphLabelList; // all labels of graph, order by ge label id(0,1,...) }; +/*lint +e148*/ + struct Task { uint32_t id; uint16_t type; @@ -48,7 +51,8 @@ struct Task { }; struct OpInfo { - string engine; // which engin + string engine; // which engin + /*lint -e148*/ string opKernelLib; // which opsKernelStore int computeCost; // compute cost bool flagPartial; // whether to support is related to shape diff --git a/inc/common/optimizer/graph_optimizer.h b/inc/common/optimizer/graph_optimizer.h index 5897842f..c330dd63 100644 --- a/inc/common/optimizer/graph_optimizer.h +++ b/inc/common/optimizer/graph_optimizer.h @@ -27,6 +27,7 @@ using std::map; using std::string; +/*lint -e148*/ namespace ge { class GraphOptimizer { public: @@ -60,4 +61,5 @@ class GraphOptimizer { virtual Status OptimizeStreamGraph(ComputeGraph &graph, const RunContext &context) { return SUCCESS; } }; } // namespace ge +/*lint +e148*/ #endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ diff --git a/inc/common/util/compress/compress.h b/inc/common/util/compress/compress.h index 6908fb75..e350f9e5 100644 --- a/inc/common/util/compress/compress.h +++ b/inc/common/util/compress/compress.h @@ -28,6 +28,7 @@ struct CompressConfig { size_t channel; // channels of L2 or DDR. For load balance size_t fractalSize; // size of compressing block bool isTight; // whether compose compressed data tightly + size_t init_offset; }; CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, diff --git a/inc/common/util/compress/compress_weight.h b/inc/common/util/compress/compress_weight.h new file mode 100644 index 00000000..34ea47d1 --- /dev/null +++ b/inc/common/util/compress/compress_weight.h @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef COMPRESS_WEIGHT_H +#define COMPRESS_WEIGHT_H + +#include "compress.h" + +const int SHAPE_SIZE_WEIGHT = 4; + +struct CompressOpConfig { + int64_t wShape[SHAPE_SIZE_WEIGHT]; + size_t compressTilingK; + size_t compressTilingN; + struct CompressConfig compressConfig; +}; + +extern "C" CmpStatus CompressWeightsConv2D(const char *const input, char *const zipBuffer, char *const infoBuffer, + CompressOpConfig *const param); +#endif // COMPRESS_WEIGHT_H diff --git a/inc/common/util/error_manager/error_manager.h b/inc/common/util/error_manager/error_manager.h index 76d5ce33..438e68a7 100644 --- a/inc/common/util/error_manager/error_manager.h +++ b/inc/common/util/error_manager/error_manager.h @@ -31,27 +31,37 @@ class ErrorManager { /// /// @brief init - /// @param [in] path current so path + /// @param [in] path: current so path /// @return int 0(success) -1(fail) /// int Init(std::string path); /// /// @brief Report error message - /// @param [in] errCode error code - /// @param [in] mapArgs parameter map + /// @param [in] error_code: error code + /// @param [in] args_map: parameter map /// @return int 0(success) -1(fail) /// int ReportErrMessage(std::string error_code, const std::map &args_map); + /// /// @brief output error message - /// @param [in] handle print handle + /// @param [in] handle: print handle /// @return int 0(success) -1(fail) /// int OutputErrMessage(int handle); + /// + /// @brief output message + /// @param [in] handle: print handle + /// @return int 0(success) -1(fail) + /// + int OutputMessage(int handle); + + /// /// @brief Report error message - /// @param [in] vector parameter key, vector parameter value + /// @param [in] key: vector parameter key + /// @param [in] value: vector parameter value /// void ATCReportErrMessage(std::string error_code, const std::vector &key = {}, const std::vector &value = {}); @@ -60,7 +70,7 @@ class ErrorManager { struct ErrorInfo { std::string error_id; std::string error_message; - std::vector arglist; + std::vector arg_list; }; ErrorManager() {} @@ -77,7 +87,8 @@ class ErrorManager { bool is_init_ = false; std::map error_map_; - std::vector error_message_evc_; + std::vector error_messages_; + std::vector warning_messages_; }; #endif // ERROR_MANAGER_H_ diff --git a/inc/common/util/platform_info.h b/inc/common/util/platform_info.h index cd143fcc..8d2a0579 100644 --- a/inc/common/util/platform_info.h +++ b/inc/common/util/platform_info.h @@ -27,7 +27,6 @@ using std::string; using std::vector; namespace fe { - class PlatformInfoManager { public: PlatformInfoManager(const PlatformInfoManager &) = delete; @@ -39,6 +38,8 @@ class PlatformInfoManager { uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); + uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); + void SetOptionalCompilationInfo(OptionalInfo &optiCompilationInfo); private: @@ -81,6 +82,8 @@ class PlatformInfoManager { void ParseVectorCoreMemoryRates(map &vectorCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); + void ParseCPUCache(map &CPUCacheMap, PlatformInfo &platformInfoTemp); + void ParseVectorCoreintrinsicDtypeMap(map &vectorCoreintrinsicDtypeMap, PlatformInfo &platformInfoTemp); @@ -94,6 +97,5 @@ class PlatformInfoManager { map platformInfoMap_; OptionalInfo optiCompilationInfo_; }; - } // namespace fe #endif diff --git a/inc/common/util/platform_info_def.h b/inc/common/util/platform_info_def.h index e840a8b9..c660e8f1 100644 --- a/inc/common/util/platform_info_def.h +++ b/inc/common/util/platform_info_def.h @@ -73,6 +73,8 @@ typedef struct tagAiCoreSpec { typedef struct tagAiCoreMemoryRates { double ddrRate; + double ddrReadRate; + double ddrWriteRate; double l2Rate; double l2ReadRate; double l2WriteRate; @@ -86,6 +88,7 @@ typedef struct tagAiCoreMemoryRates { } AiCoreMemoryRates; typedef struct tagVectorCoreSpec { + double vecFreq; uint64_t vecCalcSize; uint64_t smaskBuffer; uint64_t ubSize; @@ -94,10 +97,15 @@ typedef struct tagVectorCoreSpec { uint64_t ubbankNum; uint64_t ubburstInOneBlock; uint64_t ubbankGroupNum; + uint64_t vectorRegSize; + uint64_t predicateRegSize; + uint64_t addressRegSize; } VectorCoreSpec; typedef struct tagVectorCoreMemoryRates { double ddrRate; + double ddrReadRate; + double ddrWriteRate; double l2Rate; double l2ReadRate; double l2WriteRate; @@ -105,6 +113,11 @@ typedef struct tagVectorCoreMemoryRates { double ubToDdrRate; } VectorCoreMemoryRates; +typedef struct tagCPUCache { + uint32_t AICPUSyncBySW; + uint32_t TSCPUSyncBySW; +} CPUCache; + typedef struct tagPlatformInfo { StrInfo strInfo; SoCInfo socInfo; @@ -113,6 +126,7 @@ typedef struct tagPlatformInfo { map> aiCoreIntrinsicDtypeMap; VectorCoreSpec vectorCoreSpec; VectorCoreMemoryRates vectorCoreMemoryRates; + CPUCache cpucache; map> vectorCoreIntrinsicDtypeMap; } PlatformInfo; diff --git a/inc/external/ge/ge_api_error_codes.h b/inc/external/ge/ge_api_error_codes.h index e7f52724..7b045d54 100644 --- a/inc/external/ge/ge_api_error_codes.h +++ b/inc/external/ge/ge_api_error_codes.h @@ -70,7 +70,7 @@ using Status = uint32_t; // General error code GE_ERRORNO(0, 0, 0, 0, 0, SUCCESS, 0, "success"); -GE_ERRORNO(0b11, 0b11, 0b111, 0xFF, 0b11111, FAILED, 0xFFF, "failed"); +GE_ERRORNO(0b11, 0b11, 0b111, 0xFF, 0b11111, FAILED, 0xFFF, "failed"); /*lint !e401*/ } // namespace ge #endif // INC_EXTERNAL_GE_GE_API_ERROR_CODES_H_ diff --git a/inc/external/ge/ge_api_types.h b/inc/external/ge/ge_api_types.h index 1632f11c..619812d7 100644 --- a/inc/external/ge/ge_api_types.h +++ b/inc/external/ge/ge_api_types.h @@ -44,8 +44,11 @@ const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; const char *const OPTION_EXEC_DUMP_MODE = "ge.exec.dumpMode"; +const char *const OPTION_EXEC_ENABLE_DUMP_DEBUG = "ge.exec.enableDumpDebug"; +const char *const OPTION_EXEC_DUMP_DEBUG_MODE = "ge.exec.dumpDebugMode"; const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; +const char *const OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES = "ge.exec.enableScopeFusionPasses"; // profiling flag const char *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode"; const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; @@ -170,6 +173,9 @@ const char *const kDynamicBatchSize = "ge.dynamicBatchSize"; // configure whether to use dynamic image size const char *const kDynamicImageSize = "ge.dynamicImageSize"; +// Configure whether to use dynamic dims +const char *const kDynamicDims = "ge.dynamicDims"; + // Configure auto tune mode, this option only take effect while AUTO_TUNE_FLAG is Y, // example: GA|RL, support configure multiple, split by | const std::string AUTO_TUNE_MODE = "ge.autoTuneMode"; @@ -219,6 +225,10 @@ const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; // Configure input fp16 nodes const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; +// Configure debug level, its value should be 0(default), 1 or 2. +// 0: close debug; 1: open TBE compiler; 2: open ccec compiler +const std::string OP_DEBUG_LEVEL = "ge.opDebugLevel"; + // Graph run mode enum GraphRunMode { PREDICTION = 0, TRAIN }; @@ -261,6 +271,7 @@ static const char *const INPUT_SHAPE = "input_shape"; static const char *const OP_NAME_MAP = "op_name_map"; static const char *const DYNAMIC_BATCH_SIZE = kDynamicBatchSize; static const char *const DYNAMIC_IMAGE_SIZE = kDynamicImageSize; +static const char *const DYNAMIC_DIMS = kDynamicDims; static const char *const INSERT_OP_FILE = ge::INSERT_OP_FILE.c_str(); static const char *const PRECISION_MODE = ge::PRECISION_MODE.c_str(); static const char *const EXEC_DISABLE_REUSED_MEMORY = ge::OPTION_EXEC_DISABLE_REUSED_MEMORY; @@ -283,10 +294,11 @@ static const char *const OPTYPELIST_FOR_IMPLMODE = ge::OPTYPELIST_FOR_IMPLMODE.c // for interface: aclgrphBuildModel const std::set ir_builder_suppported_options = { - INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, DYNAMIC_BATCH_SIZE, - DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, - AUTO_TUNE_MODE, OUTPUT_TYPE, OUT_NODES, INPUT_FP16_NODES, - LOG_LEVEL}; + INPUT_FORMAT, INPUT_SHAPE, OP_NAME_MAP, + DYNAMIC_BATCH_SIZE, DYNAMIC_IMAGE_SIZE, DYNAMIC_DIMS, + INSERT_OP_FILE, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, + AUTO_TUNE_MODE, OUTPUT_TYPE, OUT_NODES, + INPUT_FP16_NODES, LOG_LEVEL}; // for interface: aclgrphBuildInitialize const std::set global_options = {CORE_TYPE, SOC_VERSION, diff --git a/inc/external/graph/attr_value.h b/inc/external/graph/attr_value.h index 32fce04c..af430f9b 100644 --- a/inc/external/graph/attr_value.h +++ b/inc/external/graph/attr_value.h @@ -34,6 +34,7 @@ using std::vector; namespace ge { class AttrValueImpl; +/*lint -e148*/ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue { public: using INT = int64_t; @@ -69,5 +70,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue { VALUE_SET_GET_DEC(AttrValue::FLOAT) #undef VALUE_SET_GET_DEC }; +/*lint +e148*/ } // namespace ge #endif // INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ diff --git a/inc/external/graph/operator.h b/inc/external/graph/operator.h index 4f837b9d..2dcdb773 100644 --- a/inc/external/graph/operator.h +++ b/inc/external/graph/operator.h @@ -61,6 +61,7 @@ using std::function; using std::shared_ptr; using std::string; +/*lint -e148*/ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { public: friend class OperatorImpl; @@ -88,7 +89,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { explicit Operator(const string &type); - Operator(const string &name, const string &type); + Operator(const string &name, const string &type); // lint !e148 virtual ~Operator() = default; @@ -101,7 +102,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { // Only has one output index = 0 Operator &SetInput(const string &dst_name, const Operator &src_oprt); - Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); + Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); // lint !e148 Operator &AddControlInput(const Operator &src_oprt); @@ -123,22 +124,22 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { TensorDesc GetOutputDesc(uint32_t index) const; - graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); + graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); // lint !e148 TensorDesc GetDynamicInputDesc(const string &name, uint32_t index) const; - graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); + graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 TensorDesc GetDynamicOutputDesc(const string &name, uint32_t index) const; - graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); + graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 - graphStatus InferShapeAndType(); + graphStatus InferShapeAndType(); // lint !e148 void SetInferenceContext(const InferenceContextPtr &inference_context); InferenceContextPtr GetInferenceContext() const; - graphStatus VerifyAllAttr(bool disable_common_verifier = false); + graphStatus VerifyAllAttr(bool disable_common_verifier = false); // lint !e148 size_t GetInputsSize() const; @@ -251,19 +252,20 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { void RequiredAttrRegister(const string &name); - graphStatus VerifyAll(); + graphStatus VerifyAll(); // lint !e148 // Only has one output index = 0 Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt); - Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name); + Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, + const string &name); // lint !e148 void SubgraphRegister(const string &ir_name, bool dynamic); void SubgraphCountRegister(const string &ir_name, uint32_t count); void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder); private: - Operator &SetInput(const string &dst_name, const OutHandler &out_handler); + Operator &SetInput(const string &dst_name, const OutHandler &out_handler); // lint !e148 OutHandler GetOutput(const string &name) const; @@ -273,6 +275,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { graphStatus GetInputConstDataOut(const string &dst_name, Tensor &data) const; }; +/*lint +e148*/ } // namespace ge #endif // INC_EXTERNAL_GRAPH_OPERATOR_H_ diff --git a/inc/external/graph/operator_reg.h b/inc/external/graph/operator_reg.h index dfa21acf..d155f4bd 100644 --- a/inc/external/graph/operator_reg.h +++ b/inc/external/graph/operator_reg.h @@ -343,6 +343,7 @@ class OpReg { auto x_type = op.GetInputDesc(in_name).GetDataType(); \ TensorDesc op_output_desc = op.GetOutputDesc(out_name); \ op_output_desc.SetShape(ge::Shape(x_shape)); \ + op_output_desc.SetOriginShape(ge::Shape(x_shape)); \ op_output_desc.SetDataType(x_type); \ return op.UpdateOutputDesc(out_name, op_output_desc); \ } diff --git a/inc/external/graph/tensor.h b/inc/external/graph/tensor.h index 5174c248..800e1037 100644 --- a/inc/external/graph/tensor.h +++ b/inc/external/graph/tensor.h @@ -126,5 +126,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Tensor { friend class TensorAdapter; }; } // namespace ge +/*lint +e148*/ #endif // INC_EXTERNAL_GRAPH_TENSOR_H_ diff --git a/inc/external/graph/types.h b/inc/external/graph/types.h index 4cd9ba91..a1245c9d 100644 --- a/inc/external/graph/types.h +++ b/inc/external/graph/types.h @@ -145,7 +145,8 @@ enum Format { FORMAT_FRACTAL_ZN_LSTM, FORMAT_FRACTAL_Z_G, FORMAT_RESERVED, - FORMAT_ALL + FORMAT_ALL, + FORMAT_NULL }; // for unknown shape op type diff --git a/inc/external/register/register.h b/inc/external/register/register.h index a8421511..e905e8d4 100644 --- a/inc/external/register/register.h +++ b/inc/external/register/register.h @@ -40,6 +40,7 @@ using std::to_string; using std::unique_ptr; using std::vector; +/*lint -e148*/ namespace ge { class Operator; class TensorDesc; @@ -98,6 +99,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { OpRegistrationData &DelInputWithOriginalType(int input_idx, const std::string &ori_type); + OpRegistrationData &InputReorderVector(const vector &input_order); + domi::ImplyType GetImplyType() const; std::string GetOmOptype() const; std::set GetOriginOpTypeSet() const; @@ -130,4 +133,5 @@ namespace ge { using OpRegistrationData = domi::OpRegistrationData; using OpReceiver = domi::OpReceiver; } // namespace ge +/*lint +e148*/ #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ diff --git a/inc/framework/common/debug/ge_log.h b/inc/framework/common/debug/ge_log.h index e2023cb8..6ac00037 100644 --- a/inc/framework/common/debug/ge_log.h +++ b/inc/framework/common/debug/ge_log.h @@ -51,30 +51,6 @@ inline pid_t GetTid() { return tid; } -#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() - -#define GE_TIMESTAMP_END(stage, stage_name) \ - do { \ - uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ - GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ - (endUsec_##stage - startUsec_##stage)); \ - } while (0); - -#define GE_TIMESTAMP_CALLNUM_START(stage) \ - uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \ - uint64_t call_num_of##stage = 0; \ - uint64_t time_of##stage = 0 - -#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap()) - -#define GE_TIMESTAMP_ADD(stage) \ - time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \ - call_num_of##stage++ - -#define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ - GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ - call_num_of##stage) - #define GE_LOG_ERROR(MOD_NAME, ERROR_CODE, fmt, ...) \ dlog_error(MOD_NAME, "%lu %s: ErrorNo: %d(%s) " fmt, GetTid(), __FUNCTION__, ERROR_CODE, \ ((GE_GET_ERRORNO_STR(ERROR_CODE)).c_str()), ##__VA_ARGS__) diff --git a/inc/framework/common/debug/log.h b/inc/framework/common/debug/log.h index 28c6585e..dbf22ead 100644 --- a/inc/framework/common/debug/log.h +++ b/inc/framework/common/debug/log.h @@ -19,15 +19,12 @@ #include -#include "cce/cce_def.hpp" +#include "runtime/rt.h" #include "common/string_util.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "ge/ge_api_error_codes.h" -using cce::CC_STATUS_SUCCESS; -using cce::ccStatus_t; - #if !defined(__ANDROID__) && !defined(ANDROID) #define DOMI_LOGE(...) GE_LOG_ERROR(GE_MODULE_NAME, ge::FAILED, __VA_ARGS__) #else @@ -102,17 +99,13 @@ using cce::ccStatus_t; } while (0); // If expr is not true, print the log and return the specified status -#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ - do { \ - bool b = (expr); \ - if (!b) { \ - std::string msg; \ - (void)msg.append(ge::StringUtils::FormatString(__VA_ARGS__)); \ - (void)msg.append( \ - ge::StringUtils::FormatString(" Error Code:0x%X(%s)", _status, GET_ERRORNO_STR(_status).c_str())); \ - DOMI_LOGE("%s", msg.c_str()); \ - return _status; \ - } \ +#define GE_CHK_BOOL_RET_STATUS(expr, _status, ...) \ + do { \ + bool b = (expr); \ + if (!b) { \ + GELOGE(_status, __VA_ARGS__); \ + return _status; \ + } \ } while (0); // If expr is not true, print the log and return the specified status @@ -132,7 +125,7 @@ using cce::ccStatus_t; DOMI_LOGE(__VA_ARGS__); \ exec_expr; \ } \ - }; + } // If expr is not true, print the log and execute a custom statement #define GE_CHK_BOOL_EXEC_WARN(expr, exec_expr, ...) \ @@ -142,7 +135,7 @@ using cce::ccStatus_t; GELOGW(__VA_ARGS__); \ exec_expr; \ } \ - }; + } // If expr is not true, print the log and execute a custom statement #define GE_CHK_BOOL_EXEC_INFO(expr, exec_expr, ...) \ { \ @@ -151,7 +144,7 @@ using cce::ccStatus_t; GELOGI(__VA_ARGS__); \ exec_expr; \ } \ - }; + } // If expr is not true, print the log and execute a custom statement #define GE_CHK_BOOL_TRUE_EXEC_INFO(expr, exec_expr, ...) \ @@ -161,7 +154,7 @@ using cce::ccStatus_t; GELOGI(__VA_ARGS__); \ exec_expr; \ } \ - }; + } // If expr is true, print logs and execute custom statements #define GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(expr, exec_expr, ...) \ @@ -171,7 +164,7 @@ using cce::ccStatus_t; DOMI_LOGE(__VA_ARGS__); \ exec_expr; \ } \ - }; + } // If expr is true, print the Information log and execute a custom statement #define GE_CHK_TRUE_EXEC_INFO(expr, exec_expr, ...) \ { \ @@ -180,7 +173,7 @@ using cce::ccStatus_t; GELOGI(__VA_ARGS__); \ exec_expr; \ } \ - }; + } // If expr is not SUCCESS, print the log and execute the expression + return #define GE_CHK_BOOL_TRUE_RET_VOID(expr, exec_expr, ...) \ @@ -191,7 +184,7 @@ using cce::ccStatus_t; exec_expr; \ return; \ } \ - }; + } // If expr is not SUCCESS, print the log and execute the expression + return _status #define GE_CHK_BOOL_TRUE_EXEC_RET_STATUS(expr, _status, exec_expr, ...) \ @@ -202,7 +195,7 @@ using cce::ccStatus_t; exec_expr; \ return _status; \ } \ - }; + } // If expr is not true, execute a custom statement #define GE_CHK_BOOL_EXEC_NOLOG(expr, exec_expr) \ @@ -211,7 +204,7 @@ using cce::ccStatus_t; if (!b) { \ exec_expr; \ } \ - }; + } // -----------------runtime related macro definitions------------------------------- // If expr is not RT_ERROR_NONE, print the log @@ -231,7 +224,7 @@ using cce::ccStatus_t; DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ exec_expr; \ } \ - }; + } // If expr is not RT_ERROR_NONE, print the log and return #define GE_CHK_RT_RET(expr) \ @@ -239,27 +232,17 @@ using cce::ccStatus_t; rtError_t _rt_ret = (expr); \ if (_rt_ret != RT_ERROR_NONE) { \ DOMI_LOGE("Call rt api failed, ret: 0x%X", _rt_ret); \ - return ge::RT_FAILED; \ + return RT_ERROR_TO_GE_STATUS(_rt_ret); \ } \ } while (0); -// ------------------------cce related macro definitions---------------------------- -// If expr is not CC_STATUS_SUCCESS, print the log -#define GE_CHK_CCE(expr) \ - do { \ - ccStatus_t _cc_ret = (expr); \ - if (_cc_ret != CC_STATUS_SUCCESS) { \ - DOMI_LOGE("Call cce api failed, ret: 0x%X", _cc_ret); \ - } \ - } while (0); - // If expr is true, execute exec_expr without printing logs #define GE_IF_BOOL_EXEC(expr, exec_expr) \ { \ if (expr) { \ exec_expr; \ } \ - }; + } // If make_shared is abnormal, print the log and execute the statement #define GE_MAKE_SHARED(exec_expr0, exec_expr1) \ diff --git a/inc/framework/common/ge_inner_error_codes.h b/inc/framework/common/ge_inner_error_codes.h index c4a36597..ca727589 100644 --- a/inc/framework/common/ge_inner_error_codes.h +++ b/inc/framework/common/ge_inner_error_codes.h @@ -14,6 +14,7 @@ * limitations under the License. */ +/*lint -e* */ #ifndef INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ #define INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ @@ -280,8 +281,24 @@ GE_ERRORNO_RUNTIME(GE_RTI_CALL_HCCL_REDUCE_SCATTER_FAILED, 47, "call hccl hcom r // Executor module error code definition GE_ERRORNO_EXECUTOR(GE_EXEC_NOT_INIT, 1, "GE Executor is not yet initialized."); -GE_ERRORNO_EXECUTOR(GE_AIPP_NOT_EXIST, 2, "GE AIPP is not exist."); -GE_ERRORNO_EXECUTOR(GE_DYNAMIC_AIPP_NOT_SUPPORT_QUERY, 3, "GE Dynamic AIPP is not support to query temporarily."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_PATH_INVALID, 2, "Model file path is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_KEY_PATH_INVALID, 3, "Key file path of model is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_ID_INVALID, 4, "Model id is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_DATA_SIZE_INVALID, 5, "Data size of model is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_PARTITION_NUM_INVALID, 6, "Partition number of model is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_QUEUE_ID_INVALID, 7, "Queue id of model is invalid."); +GE_ERRORNO_EXECUTOR(GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION, 8, "Model does not support encryption."); +GE_ERRORNO_EXECUTOR(GE_EXEC_READ_MODEL_FILE_FAILED, 9, "Failed to read model file."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_MODEL_REPEATED, 10, "The model is loaded repeatedly."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_MODEL_PARTITION_FAILED, 11, "Failed to load model partition."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED, 12, "Failed to load weight partition."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_TASK_PARTITION_FAILED, 13, "Failed to load task partition."); +GE_ERRORNO_EXECUTOR(GE_EXEC_LOAD_KERNEL_PARTITION_FAILED, 14, "Failed to load kernel partition."); +GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, 15, "Failed to allocate feature map memory."); +GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_WEIGHT_MEM_FAILED, 16, "Failed to allocate weight memory."); +GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_VAR_MEM_FAILED, 17, "Failed to allocate variable memory."); +GE_ERRORNO_EXECUTOR(GE_AIPP_NOT_EXIST, 18, "GE AIPP is not exist."); +GE_ERRORNO_EXECUTOR(GE_DYNAMIC_AIPP_NOT_SUPPORT_QUERY, 19, "GE Dynamic AIPP is not support to query temporarily."); // Generator module error code definition GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, 1, "Graph manager initialize failed."); @@ -289,6 +306,8 @@ GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, 2, "Graph mana GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, 3, "Graph manager build graph failed."); GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, 4, "Graph manager finalize failed."); GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_SAVE_MODEL_FAILED, 5, "Graph manager save model failed."); + +#define RT_ERROR_TO_GE_STATUS(RT_ERROR) static_cast(RT_ERROR) } // namespace ge #endif // INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ diff --git a/inc/framework/common/ge_types.h b/inc/framework/common/ge_types.h index 27ae28ee..00bfa301 100644 --- a/inc/framework/common/ge_types.h +++ b/inc/framework/common/ge_types.h @@ -54,9 +54,9 @@ const char *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM"; struct DataBuffer { public: void *data; // Data address - uint32_t length; // Data length + uint64_t length; // Data length bool isDataSupportMemShare = false; - DataBuffer(void *dataIn, uint32_t len, bool isSupportMemShare) + DataBuffer(void *dataIn, uint64_t len, bool isSupportMemShare) : data(dataIn), length(len), isDataSupportMemShare(isSupportMemShare) {} DataBuffer() : data(nullptr), length(0), isDataSupportMemShare(false) {} @@ -106,7 +106,7 @@ struct ShapeDescription { // Definition of input and output description information struct InputOutputDescInfo { std::string name; - uint32_t size; + uint64_t size; uint32_t data_type; ShapeDescription shape_info; }; @@ -231,6 +231,7 @@ struct Options { // Profiling info of task struct TaskDescInfo { + std::string model_name; std::string op_name; uint32_t block_dim; uint32_t task_id; @@ -239,6 +240,7 @@ struct TaskDescInfo { // Profiling info of graph struct ComputeGraphDescInfo { + std::string model_name; std::string op_name; std::string op_type; std::vector input_format; diff --git a/inc/framework/common/helper/model_helper.h b/inc/framework/common/helper/model_helper.h index 3c9de891..3671f970 100644 --- a/inc/framework/common/helper/model_helper.h +++ b/inc/framework/common/helper/model_helper.h @@ -44,8 +44,6 @@ class ModelHelper { void SetSaveMode(bool val) { is_offline_ = val; } bool GetSaveMode(void) const { return is_offline_; } - static Status TransModelToGeModel(const ModelPtr& model, GeModelPtr& ge_model); - static Status TransGeModelToModel(const GeModelPtr& geModelPtr, ModelPtr& modelPtr); Status GetBaseNameFromFileName(const std::string& file_name, std::string& base_name); Status GetModelNameFromMergedGraphName(const std::string& graph_name, std::string& model_name); diff --git a/inc/framework/common/string_util.h b/inc/framework/common/string_util.h index b74eddcf..918a3950 100644 --- a/inc/framework/common/string_util.h +++ b/inc/framework/common/string_util.h @@ -36,8 +36,8 @@ class StringUtils { #endif return s; } - - static std::string &Rtrim(std::string &s) { + // lint -esym(551,*) + static std::string &Rtrim(std::string &s) { /*lint !e618*/ #if __cplusplus >= 201103L (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); #else @@ -45,7 +45,7 @@ class StringUtils { #endif return s; } - + // lint -esym(551,*) /// /// @ingroup domi_common /// @brief delete spaces at the beginning and end of a string diff --git a/inc/framework/common/types.h b/inc/framework/common/types.h index e3844a61..db692c36 100644 --- a/inc/framework/common/types.h +++ b/inc/framework/common/types.h @@ -48,6 +48,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_S FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_LAYER; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_FILE_PATH; FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string DUMP_MODE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_AICORE; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ATOMIC; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_DEBUG_ALL; // Supported public properties name FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string PROP_OME_START_TIME; // Start time @@ -335,6 +338,8 @@ REGISTER_OPTYPE_DECLARE(BASICLSTMCELL, "BasicLSTMCell"); REGISTER_OPTYPE_DECLARE(GETNEXT, "GetNext"); REGISTER_OPTYPE_DECLARE(INITDATA, "InitData"); REGISTER_OPTYPE_DECLARE(TRANSSHAPE, "TransShape") +REGISTER_OPTYPE_DECLARE(REFIDENTITY, "RefIdentity"); +REGISTER_OPTYPE_DECLARE(BITCAST, "Bitcast"); // ANN dedicated operator REGISTER_OPTYPE_DECLARE(ANN_MEAN, "AnnMean"); @@ -428,6 +433,8 @@ REGISTER_OPTYPE_DECLARE(HCOMALLREDUCE, "HcomAllReduce"); REGISTER_OPTYPE_DECLARE(HCOMREDUCESCATTER, "HcomReduceScatter"); REGISTER_OPTYPE_DECLARE(HCOMSEND, "HcomSend"); REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive"); +REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); +REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); REGISTER_OPTYPE_DECLARE(VARISINITIALIZEDOP, "VarIsInitializedOp"); @@ -554,6 +561,16 @@ enum ModelCheckType { UNCHECK // no verification }; +/// +/// @brief dynamic input type +/// +enum DynamicInputType { + FIXED = 0, // default mode + DYNAMIC_BATCH = 1, + DYNAMIC_IMAGE = 2, + DYNAMIC_DIMS = 3 +}; + /// /// @brief magic number of the model file /// @@ -631,6 +648,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_N FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_END_GRAPH; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string NODE_NAME_OP_DEBUG; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_OP_DEBUG; + // convolution node type FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY extern const std::string OP_TYPE_CONVOLUTION; // adds a convolutional node name for the hard AIPP diff --git a/inc/framework/executor/ge_executor.h b/inc/framework/executor/ge_executor.h index 91b50311..129b8613 100644 --- a/inc/framework/executor/ge_executor.h +++ b/inc/framework/executor/ge_executor.h @@ -21,28 +21,31 @@ #include #include +#include "common/dynamic_aipp.h" #include "common/ge_inner_error_codes.h" #include "common/ge_types.h" #include "common/types.h" #include "graph/tensor.h" +#include "graph/ge_tensor.h" #include "runtime/base.h" -#include "common/dynamic_aipp.h" namespace ge { class ModelListenerAdapter; class SingleOp; +class DynamicSingleOp; struct RunModelData { uint32_t index; // Data index uint32_t modelId; - std::vector blobs; // All input/output data buffer - uint32_t timestamp; // Data creation time - uint32_t timeout; // Processing timeout - uint64_t request_id = 0; // Request ID - uint64_t dynamic_batch_size = 0; // Dynamic batch size scene, set dynamic size, not supported by default:0 - uint64_t dynamic_image_height = 0; // Dynamic image size scene, set image height, not supported by default:0 - uint64_t dynamic_image_width = 0; // Dynamic image size scene, set image width, not supported by default:0 + std::vector blobs; // All input/output data buffer + uint32_t timestamp; // Data creation time + uint32_t timeout; // Processing timeout + uint64_t request_id = 0; // Request ID + uint64_t dynamic_batch_size = 0; // Dynamic batch size scene, set dynamic size, not supported by default:0 + uint64_t dynamic_image_height = 0; // Dynamic image size scene, set image height, not supported by default:0 + uint64_t dynamic_image_width = 0; // Dynamic image size scene, set image width, not supported by default:0 + std::vector dynamic_dims; // Dynamic dims scene, set dynamic dims, not supported by default:empty }; class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { @@ -87,16 +90,52 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { /// ge::Status SetDynamicImageSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t image_height, uint64_t image_width); + + /// + /// @ingroup ge + /// @brief Set dynamic dims info + /// @param [in] model_id: model id allocate from manager + /// @param [in] dynamic_input_addr: dynamic input addr created by user + /// @param [in] length: length of dynamic input addr + /// @param [in] dynamic_dim_num: number of dynamic dimension + /// @param [in] dynamic_dims: array of dynamic dimensions + /// @return execute result + /// + ge::Status SetDynamicDims(uint32_t model_id, void *dynamic_input_addr, uint64_t length, + const std::vector &dynamic_dims); + + /// + /// @ingroup ge + /// @brief Get current dynamic dims info by combined dims + /// @param [in] model_id: model id allocate from manager + /// @param [in] combined_dims: array of combined dimensions + /// @param [out] cur_dynamic_dims: current dynamic dims + /// @return execute result + /// + ge::Status GetCurDynamicDims(uint32_t model_id, const std::vector &combined_dims, + std::vector &cur_dynamic_dims); + /// /// @ingroup ge /// @brief Get dynamic batch_info /// @param [in] model_id /// @param [out] batch_info + /// @param [out] dynamic_type /// @return execute result /// - ge::Status GetDynamicBatchInfo(uint32_t model_id, std::vector> &batch_info); + ge::Status GetDynamicBatchInfo(uint32_t model_id, std::vector> &batch_info, + int32_t &dynamic_type); - ge::Status GetCurShape(const uint32_t model_id, std::vector &batch_info); + /// + /// @ingroup ge + /// @brief Get combined dynamic dims info + /// @param [in] model_id + /// @param [out] batch_info + /// @return execute result + /// + ge::Status GetCombinedDynamicDims(uint32_t model_id, std::vector> &batch_info); + + ge::Status GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type); /// /// @ingroup ge @@ -209,6 +248,13 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { static ge::Status ExecuteAsync(SingleOp *executor, const std::vector &inputs, std::vector &outputs); + static ge::Status LoadDynamicSingleOp(const std::string &model_name, const ge::ModelData &modelData, void *stream, + DynamicSingleOp **single_op); + + static ge::Status ExecuteAsync(DynamicSingleOp *executor, const std::vector &input_desc, + const std::vector &inputs, std::vector &output_desc, + std::vector &outputs); + static ge::Status ReleaseSingleOpResource(void *stream); ge::Status GetBatchInfoSize(uint32_t model_id, size_t &shape_count); diff --git a/inc/framework/ge_runtime/model_runner.h b/inc/framework/ge_runtime/model_runner.h index 6e7abcb9..e495dfdf 100644 --- a/inc/framework/ge_runtime/model_runner.h +++ b/inc/framework/ge_runtime/model_runner.h @@ -28,7 +28,7 @@ namespace ge { namespace model_runner { class RuntimeModel; - +using RuntimeInfo = std::tuple; class ModelRunner { public: static ModelRunner &Instance(); @@ -36,8 +36,18 @@ class ModelRunner { bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, std::shared_ptr davinci_model, std::shared_ptr listener); + bool DistributeTask(uint32_t model_id); + + bool LoadModelComplete(uint32_t model_id); + const std::vector &GetTaskIdList(uint32_t model_id) const; + const std::vector &GetStreamIdList(uint32_t model_id) const; + + const std::map> &GetRuntimeInfoMap(uint32_t model_id) const; + + void *GetModelHandle(uint32_t model_id) const; + bool UnloadModel(uint32_t model_id); bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); diff --git a/inc/framework/ge_runtime/task_info.h b/inc/framework/ge_runtime/task_info.h index a48ed68b..68d71870 100644 --- a/inc/framework/ge_runtime/task_info.h +++ b/inc/framework/ge_runtime/task_info.h @@ -21,6 +21,7 @@ #include #include #include +#include #include #include "cce/taskdown_api.h" @@ -52,21 +53,27 @@ class TaskInfo { virtual ~TaskInfo() {} uint32_t stream_id() const { return stream_id_; } TaskInfoType type() const { return type_; } + std::string op_name() const { return op_name_; } + bool dump_flag() const { return dump_flag_; } protected: - TaskInfo(uint32_t stream_id, TaskInfoType type) : stream_id_(stream_id), type_(type) {} + TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag) + : op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {} private: + std::string op_name_; uint32_t stream_id_; TaskInfoType type_; + bool dump_flag_; }; class CceTaskInfo : public TaskInfo { public: - CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim, - const std::vector &args, uint32_t args_size, const std::vector &sm_desc, - const std::vector &flow_table, const std::vector &args_offset, bool is_flowtable) - : TaskInfo(stream_id, TaskInfoType::CCE), + CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, + uint32_t block_dim, const std::vector &args, uint32_t args_size, + const std::vector &sm_desc, const std::vector &flow_table, + const std::vector &args_offset, bool is_flowtable) + : TaskInfo(op_name, stream_id, TaskInfoType::CCE, false), ctx_(ctx), stub_func_(stub_func), block_dim_(block_dim), @@ -102,11 +109,11 @@ class CceTaskInfo : public TaskInfo { class TbeTaskInfo : public TaskInfo { public: - TbeTaskInfo(uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, const std::vector &args, - uint32_t args_size, const std::vector &sm_desc, void *binary, uint32_t binary_size, - const std::vector &meta_data, const std::vector &input_data_addrs, - const std::vector &output_data_addrs, const std::vector &workspace_addrs) - : TaskInfo(stream_id, TaskInfoType::TBE), + TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, + const std::vector &args, uint32_t args_size, const std::vector &sm_desc, void *binary, + uint32_t binary_size, const std::vector &meta_data, const std::vector &input_data_addrs, + const std::vector &output_data_addrs, const std::vector &workspace_addrs, bool dump_flag) + : TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag), stub_func_(stub_func), block_dim_(block_dim), args_(args), @@ -153,9 +160,10 @@ class TbeTaskInfo : public TaskInfo { class AicpuTaskInfo : public TaskInfo { public: - AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def, - const std::vector &input_data_addrs, const std::vector &output_data_addrs) - : TaskInfo(stream_id, TaskInfoType::AICPU), + AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name, + const std::string &node_def, const std::vector &input_data_addrs, + const std::vector &output_data_addrs, bool dump_flag) + : TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag), so_name_(so_name), kernel_name_(kernel_name), node_def_(node_def), @@ -177,37 +185,45 @@ class AicpuTaskInfo : public TaskInfo { std::vector output_data_addrs_; }; -class LabelTaskInfo : public TaskInfo { +class LabelSetTaskInfo : public TaskInfo { public: + LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) + : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {} + ~LabelSetTaskInfo() override {} uint32_t label_id() const { return label_id_; } - protected: - LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id) - : TaskInfo(stream_id, type), label_id_(label_id) {} - virtual ~LabelTaskInfo() override {} - + private: uint32_t label_id_; }; -class LabelSetTaskInfo : public LabelTaskInfo { +class LabelGotoTaskInfo : public TaskInfo { public: - LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) - : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {} - ~LabelSetTaskInfo() override {} + LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id) + : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {} + ~LabelGotoTaskInfo() override {} + uint32_t label_id() const { return label_id_; } + + private: + uint32_t label_id_; }; -class LabelSwitchTaskInfo : public LabelTaskInfo { +class LabelSwitchTaskInfo : public TaskInfo { public: - LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id) - : LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {} + LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size, + const std::vector &label_list, void *cond) + : TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false), + label_size_(label_size), + label_list_(label_list), + cond_(cond) {} ~LabelSwitchTaskInfo() override {} -}; + uint32_t label_size() { return label_size_; }; + const std::vector &label_list() { return label_list_; }; + void *cond() { return cond_; }; -class LabelGotoTaskInfo : public LabelTaskInfo { - public: - LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) - : LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {} - ~LabelGotoTaskInfo() override {} + private: + uint32_t label_size_; + std::vector label_list_; + void *cond_; }; class EventTaskInfo : public TaskInfo { @@ -215,8 +231,8 @@ class EventTaskInfo : public TaskInfo { uint32_t event_id() const { return event_id_; } protected: - EventTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t event_id) - : TaskInfo(stream_id, type), event_id_(event_id) {} + EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id) + : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {} virtual ~EventTaskInfo() override {} uint32_t event_id_; @@ -224,39 +240,41 @@ class EventTaskInfo : public TaskInfo { class EventRecordTaskInfo : public EventTaskInfo { public: - EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id) - : EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {} + EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) + : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {} ~EventRecordTaskInfo() override {} }; class EventWaitTaskInfo : public EventTaskInfo { public: - EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id) - : EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {} + EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id) + : EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {} ~EventWaitTaskInfo() override {} }; class FusionStartTaskInfo : public TaskInfo { public: - explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {} + explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id) + : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {} ~FusionStartTaskInfo() override {} }; class FusionEndTaskInfo : public TaskInfo { public: - explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {} + explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id) + : TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {} ~FusionEndTaskInfo() override {} }; class HcclTaskInfo : public TaskInfo { public: - HcclTaskInfo(uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr, - void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, + HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, + void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, const std::vector &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, - int64_t op_type, int64_t data_type, std::function hcom_bind_model, - std::function hcom_unbind_model, - std::function, void *)> hcom_distribute_task) - : TaskInfo(stream_id, TaskInfoType::HCCL), + int64_t op_type, int64_t data_type, const std::string &group, + std::function hcom_bind_model, std::function hcom_unbind_model, + std::function, void *)> hcom_distribute_task, bool dump_flag) + : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), hccl_type_(hccl_type), input_data_addr_(input_data_addr), output_data_addr_(output_data_addr), @@ -269,6 +287,7 @@ class HcclTaskInfo : public TaskInfo { root_id_(root_id), op_type_(op_type), data_type_(data_type), + group_(group), hcom_bind_model_(hcom_bind_model), hcom_unbind_model_(hcom_unbind_model), hcom_distribute_task_(hcom_distribute_task) {} @@ -286,6 +305,7 @@ class HcclTaskInfo : public TaskInfo { int64_t root_id() const { return root_id_; } int64_t op_type() const { return op_type_; } int64_t data_type() const { return data_type_; } + const std::string &group() const { return group_; } std::function hcom_bind_model() const { return hcom_bind_model_; } std::function hcom_unbind_model() const { return hcom_unbind_model_; } std::function, void *)> hcom_distribute_task() const { @@ -305,6 +325,7 @@ class HcclTaskInfo : public TaskInfo { int64_t root_id_; int64_t op_type_; int64_t data_type_; + std::string group_; std::function hcom_bind_model_; std::function hcom_unbind_model_; std::function, void *)> hcom_distribute_task_; @@ -312,8 +333,11 @@ class HcclTaskInfo : public TaskInfo { class ProfilerTraceTaskInfo : public TaskInfo { public: - ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) - : TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {} + ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) + : TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false), + log_id_(log_id), + notify_(notify), + flat_(flat) {} ~ProfilerTraceTaskInfo() override {} uint64_t log_id() const { return log_id_; } @@ -328,8 +352,9 @@ class ProfilerTraceTaskInfo : public TaskInfo { class MemcpyAsyncTaskInfo : public TaskInfo { public: - MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind) - : TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC), + MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src, + uint64_t count, uint32_t kind, bool dump_flag) + : TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag), dst_(dst), dst_max_(dst_max), src_(src), @@ -353,9 +378,9 @@ class MemcpyAsyncTaskInfo : public TaskInfo { class StreamSwitchTaskInfo : public TaskInfo { public: - StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond, - int64_t data_type) - : TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH), + StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr, + void *value_addr, int64_t cond, int64_t data_type) + : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false), true_stream_id_(true_stream_id), input_addr_(input_addr), value_addr_(value_addr), @@ -379,8 +404,8 @@ class StreamSwitchTaskInfo : public TaskInfo { class StreamActiveTaskInfo : public TaskInfo { public: - StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id) - : TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {} + StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id) + : TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {} ~StreamActiveTaskInfo() override {} uint32_t active_stream_id() const { return active_stream_id_; } diff --git a/inc/framework/generator/ge_generator.h b/inc/framework/generator/ge_generator.h index f0707c67..d3f472e9 100644 --- a/inc/framework/generator/ge_generator.h +++ b/inc/framework/generator/ge_generator.h @@ -27,6 +27,7 @@ #include "graph/ge_tensor.h" #include "graph/graph.h" #include "graph/op_desc.h" +#include "graph/detail/attributes_holder.h" namespace ge { class GeGenerator { diff --git a/inc/framework/memory/memory_api.h b/inc/framework/memory/memory_api.h new file mode 100644 index 00000000..656e4710 --- /dev/null +++ b/inc/framework/memory/memory_api.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_FRAMEWORK_MEMORY_MEMORY_API_H_ +#define INC_FRAMEWORK_MEMORY_MEMORY_API_H_ + +#include +#include + +#include "ge/ge_api_error_codes.h" +#include "runtime/mem.h" + +namespace ge { +enum MemStorageType { + HBM = 0, + RDMA_HBM, +}; + +struct HostVarInfo { + uint64_t base_addr; + uint64_t var_size; +}; + +/// +/// \param size [in] rdma pool memory size to be allocated. +/// \param mem_type [in] memory type for rdma pool. +/// \return Status result of function +Status InitRdmaPool(size_t size, rtMemType_t mem_type = RT_MEMORY_HBM); + +/// +/// \param var_info [in] host variable addr infos. +/// \param mem_type [in] memory type for rdma pool. +/// \return Status result of function +Status RdmaRemoteRegister(const std::vector &var_info, rtMemType_t mem_type = RT_MEMORY_HBM); + +/// +/// \param var_name [in] var_name name of host variable. +/// \param base_addr [out] base_addr vase addr of host variable. +/// \param var_size [out] var_size memory_size of host variable. +/// \return Status result of function +Status GetVarBaseAddrAndSize(const std::string &var_name, uint64_t &base_addr, uint64_t &var_size); +} // namespace ge +#endif // INC_FRAMEWORK_MEMORY_MEMORY_API_H_ diff --git a/inc/framework/omg/omg.h b/inc/framework/omg/omg.h index c7dbdd5b..6a120439 100644 --- a/inc/framework/omg/omg.h +++ b/inc/framework/omg/omg.h @@ -96,17 +96,12 @@ Status CheckCustomAiCpuOpLib(); Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); -Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); - -Status GetOutputLeaf(ge::NodePtr node, std::vector> &output_nodes_info); - void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, std::vector &output_nodes_name); void UpdateOmgCtxWithParserCtx(); void UpdateParserCtxWithOmgCtx(); - } // namespace ge namespace domi { diff --git a/inc/framework/omg/omg_inner_types.h b/inc/framework/omg/omg_inner_types.h index 70d59c2f..80361232 100644 --- a/inc/framework/omg/omg_inner_types.h +++ b/inc/framework/omg/omg_inner_types.h @@ -120,6 +120,7 @@ struct OmgContext { bool is_dynamic_input = false; std::string dynamic_batch_size; std::string dynamic_image_size; + std::string dynamic_dims; }; } // namespace ge diff --git a/inc/graph/buffer.h b/inc/graph/buffer.h index e6be3daa..ca4355a7 100644 --- a/inc/graph/buffer.h +++ b/inc/graph/buffer.h @@ -57,11 +57,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer { // For compatibility inline const std::uint8_t *data() const { return GetData(); } - inline std::uint8_t *data() { return GetData(); } + inline std::uint8_t *data() { return GetData(); } // lint !e659 inline std::size_t size() const { return GetSize(); } inline void clear() { return ClearBuffer(); } - uint8_t operator[](size_t index) const { - if (buffer_ != nullptr && index < buffer_->size()) { + uint8_t operator[](size_t index) const { // lint !e1022 !e1042 + if (buffer_ != nullptr && index < buffer_->size()) { // lint !e574 return (uint8_t)(*buffer_)[index]; } return 0xff; diff --git a/inc/graph/compute_graph.h b/inc/graph/compute_graph.h index 4f865f12..8d3db43c 100644 --- a/inc/graph/compute_graph.h +++ b/inc/graph/compute_graph.h @@ -74,6 +74,9 @@ class ComputeGraph : public std::enable_shared_from_this, public A size_t GetAllNodesSize() const; Vistor GetAllNodes() const; + // is_unknown_shape: false, same with GetAllNodes func + // is_unknown_shape: true, same with GetDirectNodes func + Vistor GetNodes(bool is_unknown_shape) const; size_t GetDirectNodesSize() const; Vistor GetDirectNode() const; Vistor GetInputNodes() const; @@ -81,14 +84,18 @@ class ComputeGraph : public std::enable_shared_from_this, public A NodePtr FindNode(const std::string &name) const; NodePtr FindFirstNodeMatchType(const std::string &name) const; + /*lint -e504*/ // AddNode with NodePtr NodePtr AddNode(NodePtr node); NodePtr AddNode(OpDescPtr op); - NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize. + NodePtr AddNode(OpDescPtr op, int64_t id); // for unserialize NodePtr AddNodeFront(NodePtr node); NodePtr AddNodeFront(const OpDescPtr &op); NodePtr AddInputNode(NodePtr node); NodePtr AddOutputNode(NodePtr node); + // insert node with specific pre_node + NodePtr AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node); + NodePtr AddNodeAfter(NodePtr node, const NodePtr &pre_node); graphStatus RemoveNode(const NodePtr &node); graphStatus RemoveInputNode(const NodePtr &node); @@ -133,6 +140,8 @@ class ComputeGraph : public std::enable_shared_from_this, public A bool IsValid() const; void Dump() const; + void Swap(ComputeGraph &graph); + graphStatus IsolateNode(const NodePtr &node); graphStatus Verify(); graphStatus InferShape(); @@ -141,6 +150,7 @@ class ComputeGraph : public std::enable_shared_from_this, public A graphStatus InsertEventNodes(); bool operator==(const ComputeGraph &r_compute_graph) const; + /*lint +e504*/ const std::map, std::vector> &GetShareParamLayer() const { return params_share_map_; } @@ -174,6 +184,10 @@ class ComputeGraph : public std::enable_shared_from_this, public A void SetInputSize(uint32_t size) { input_size_ = size; } uint32_t GetInputSize() const { return input_size_; } + // false: known shape true: unknow shape + bool GetGraphUnknownFlag() const { return is_unknown_shape_graph_; } + void SetGraphUnknownFlag(bool flag) { is_unknown_shape_graph_ = flag; } + /// /// Set is need train iteration. /// If set true, it means this graph need to be run iteration some @@ -249,6 +263,8 @@ class ComputeGraph : public std::enable_shared_from_this, public A bool VectorInputNodePtrIsEqual(const std::vector &r_node_ptr_vector, const std::vector &l_node_ptr_vector) const; + void SetNodesOwner(); + friend class ModelSerializeImp; friend class GraphDebugImp; friend class OnnxUtils; @@ -282,7 +298,8 @@ class ComputeGraph : public std::enable_shared_from_this, public A std::map op_name_map_; uint64_t session_id_ = 0; ge::Format data_format_ = ge::FORMAT_ND; + // unknown graph indicator, default is false, mean known shape + bool is_unknown_shape_graph_ = false; }; } // namespace ge - #endif // INC_GRAPH_COMPUTE_GRAPH_H_ diff --git a/inc/graph/debug/ge_attr_define.h b/inc/graph/debug/ge_attr_define.h index 5db047c0..57e389e8 100644 --- a/inc/graph/debug/ge_attr_define.h +++ b/inc/graph/debug/ge_attr_define.h @@ -14,6 +14,7 @@ * limitations under the License. */ +/*lint -e618*/ #ifndef INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ #define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ @@ -185,6 +186,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_ORIGIN_SIZE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_INPUT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NODE_CONNECT_OUTPUT; + // to be deleted GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_TO_BE_DELETED; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string PERMUTE_RESHAPE_FUSION; @@ -778,6 +782,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_CORE_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ATC_VERSION; + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OPP_VERSION; + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_MODE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_VALUE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string QUANTIZE_SCALE_OFFSET; @@ -930,12 +938,14 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PRED_VALUE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_NUM; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_BATCH_LABEL; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_BATCH; // Control flow GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG; @@ -979,6 +989,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NEE // For mutil-batch GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERT_BY_MBATCH; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_TYPE; // For inserted op GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INSERTED_BY_GE; @@ -996,7 +1007,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE; -// used for l1 fusion and other fusion in future +// used for lX fusion GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_ID; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_GROUP_KEY; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; @@ -1010,9 +1021,21 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DATA_DUMP_REF; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L2_FUSION_GROUP_ID; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_ADDR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE; + +// for unregistered op +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_OPPATH; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_ATTRLIST; + +// op overflow dump +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_FLAG; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OP_DEBUG_MODE; // functional ops attr GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH; @@ -1058,6 +1081,31 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOR // for gradient group GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP; GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG; + +// dynamic shape attrs +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX; + +// atc user def dtype&format +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_DATATYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ATC_USER_DEFINE_FORMAT; + +// for fusion op plugin +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE; + +// graph partition for aicpu +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME; + +// input and output memory type +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_VARIABLE_PLACEMENT; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INPUT_MEMORY_TYPE; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OUTPUT_MEMORY_TYPE; + +// input_output_offset +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_BASIC_OFFSET; +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET; } // namespace ge #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ +/*lint +e618*/ diff --git a/inc/graph/detail/any_map.h b/inc/graph/detail/any_map.h index c417c6a9..70533ea1 100644 --- a/inc/graph/detail/any_map.h +++ b/inc/graph/detail/any_map.h @@ -38,7 +38,7 @@ class TypeID { bool operator==(const TypeID &__arg) const { return type_ == __arg.type_; } private: - explicit TypeID(string type) : type_(std::move(type)) {} + explicit TypeID(string type) : type_(std::move(type)) {} // lint !e30 !e32 string type_; }; @@ -53,6 +53,8 @@ class AnyMap { bool Has(const string &name) const { return anyValues_.find(name) != anyValues_.end(); } + void Swap(AnyMap &other) { anyValues_.swap(other.anyValues_); } + private: class Placeholder { public: diff --git a/inc/graph/detail/attributes_holder.h b/inc/graph/detail/attributes_holder.h index bb26dec5..49741143 100644 --- a/inc/graph/detail/attributes_holder.h +++ b/inc/graph/detail/attributes_holder.h @@ -50,7 +50,7 @@ class OpDef; class GraphDef; } // namespace proto -using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>; +using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>; // lint !e1073 using ProtoMsgOwner = std::shared_ptr<::google::protobuf::Message>; template @@ -95,6 +95,14 @@ class GeIrProtoHelper { } } + void Swap(GeIrProtoHelper &other) { + protoOwner_.swap(other.protoOwner_); + + ProtoType *temp = protoMsg_; + protoMsg_ = other.protoMsg_; + other.protoMsg_ = temp; + } + // protoMsg_ is part of protoOwner_, they have the same runtime ProtoMsgOwner protoOwner_ = nullptr; ProtoType *protoMsg_ = nullptr; @@ -120,6 +128,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { void CopyAttrsFrom(const AttrHolder &holder); + void Swap(AttrHolder &holder) { + requiredAttrs_.swap(holder.requiredAttrs_); + extAttrs_.Swap(holder.extAttrs_); + } + template bool SetExtAttr(const string &name, const T &value) { return extAttrs_.Set(name, value); @@ -134,7 +147,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { protected: graphStatus AddRequiredAttr(const std::string &name); const std::unordered_set GetAllAttrNames() const; - const std::map GetAllAttrs() const; + const std::map GetAllAttrs() const; // lint !e1073 virtual ProtoAttrMapHelper MutableAttrMap() = 0; virtual ConstProtoAttrMapHelper GetAttrMap() const = 0; @@ -149,5 +162,4 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { AnyMap extAttrs_; }; } // namespace ge - #endif // INC_GRAPH_DETAIL_ATTRIBUTES_HOLDER_H_ diff --git a/inc/graph/detail/model_serialize_imp.h b/inc/graph/detail/model_serialize_imp.h index b8b3916a..ff27335a 100644 --- a/inc/graph/detail/model_serialize_imp.h +++ b/inc/graph/detail/model_serialize_imp.h @@ -67,6 +67,9 @@ class ModelSerializeImp { bool HandleNodeNameRef(); bool UnserializeOpDesc(OpDescPtr &opDesc, proto::OpDef &opDefProto); + void AttrDefToOpDesc(OpDescPtr &op_desc, std::vector &key_in, std::vector &key_out, + std::vector &value_in, std::vector &value_out, std::vector &opt); + void OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto); bool UnserializeNode(ComputeGraphPtr &graph, proto::OpDef &opDefProto); diff --git a/inc/graph/ge_attr_value.h b/inc/graph/ge_attr_value.h index b665beba..0c265c20 100644 --- a/inc/graph/ge_attr_value.h +++ b/inc/graph/ge_attr_value.h @@ -310,7 +310,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { VALUE_SET_GET_DEC(GeAttrValue::GRAPH) VALUE_SET_GET_DEC(BYTES) VALUE_SET_GET_DEC(NamedAttrs) - VALUE_SET_GET_DEC(ge::DataType) + VALUE_SET_GET_DEC(ge::DataType) // lint !e665 VALUE_SET_GET_DEC(vector) VALUE_SET_GET_DEC(vector) VALUE_SET_GET_DEC(vector) @@ -320,8 +320,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { VALUE_SET_GET_DEC(vector) VALUE_SET_GET_DEC(vector) VALUE_SET_GET_DEC(vector) - VALUE_SET_GET_DEC(vector>) - VALUE_SET_GET_DEC(vector) + VALUE_SET_GET_DEC(vector>) // lint !e665 + VALUE_SET_GET_DEC(vector) // lint !e665 #undef VALUE_SET_GET_DEC GeIrProtoHelper value_; diff --git a/inc/graph/ge_context.h b/inc/graph/ge_context.h index b1ccd5b9..af6b35bc 100644 --- a/inc/graph/ge_context.h +++ b/inc/graph/ge_context.h @@ -28,6 +28,7 @@ class GEContext { uint32_t DeviceId(); uint64_t TraceId(); void Init(); + void SetSessionId(uint64_t session_id); void SetCtxDeviceId(uint32_t device_id); private: diff --git a/inc/graph/ge_tensor.h b/inc/graph/ge_tensor.h index 29a315d6..834dca0b 100644 --- a/inc/graph/ge_tensor.h +++ b/inc/graph/ge_tensor.h @@ -25,6 +25,7 @@ #include "graph/buffer.h" #include "graph/ge_error_codes.h" #include "graph/types.h" + namespace ge { class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeShape { public: @@ -108,8 +109,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc : public AttrH DataType GetDataType() const; void SetDataType(DataType dt); - void SetOriginDataType(DataType originDataType); DataType GetOriginDataType() const; + void SetOriginDataType(DataType originDataType); + + std::vector GetRefPortIndex() const; + void SetRefPortByIndex(const std::vector &index); GeTensorDesc Clone() const; GeTensorDesc &operator=(const GeTensorDesc &desc); @@ -186,5 +190,4 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensor { GeTensorDesc &DescReference() const; }; } // namespace ge - #endif // INC_GRAPH_GE_TENSOR_H_ diff --git a/inc/graph/model_serialize.h b/inc/graph/model_serialize.h index 3f7d65a9..16529512 100644 --- a/inc/graph/model_serialize.h +++ b/inc/graph/model_serialize.h @@ -49,5 +49,4 @@ class ModelSerialize { friend class GraphDebugImp; }; } // namespace ge - #endif // INC_GRAPH_MODEL_SERIALIZE_H_ diff --git a/inc/graph/node.h b/inc/graph/node.h index 74aaf72f..2629f525 100644 --- a/inc/graph/node.h +++ b/inc/graph/node.h @@ -190,7 +190,7 @@ class Node : public std::enable_shared_from_this { vector out_data_anchors_; InControlAnchorPtr in_control_anchor_; OutControlAnchorPtr out_control_anchor_; - map attrs_; + map attrs_; // lint !e1073 bool has_init_{false}; bool anchor_status_updated_{false}; std::vector send_event_id_list_; diff --git a/inc/graph/op_desc.h b/inc/graph/op_desc.h index faca2d99..1457aa15 100644 --- a/inc/graph/op_desc.h +++ b/inc/graph/op_desc.h @@ -105,6 +105,8 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { GeTensorDescPtr MutableInputDesc(uint32_t index) const; + GeTensorDescPtr MutableInputDesc(const string &name) const; + Vistor GetAllInputsDesc() const; Vistor GetAllInputsDescPtr() const; @@ -127,6 +129,8 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { GeTensorDescPtr MutableOutputDesc(uint32_t index) const; + GeTensorDescPtr MutableOutputDesc(const string &name) const; + uint32_t GetAllOutputsDescSize() const; Vistor GetAllOutputsDesc() const; @@ -149,16 +153,15 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { graphStatus AddDynamicOutputDesc(const string &name, const unsigned int num, bool isPushBack = true); + void RemoveInputDesc(uint32_t index); + void RemoveOutputDesc(uint32_t index); + bool IsOptionalInput(const string &name) const; bool IsOptionalInput(uint32_t index) const; std::map GetAllInputName() const; - void SetAllInputName(const std::map &input_name_idx); - - std::vector GetAllOptionalInputName() const; - std::map GetAllOutputName(); bool UpdateInputName(std::map inputNameIdx); @@ -296,6 +299,8 @@ class OpDesc : public std::enable_shared_from_this, public AttrHolder { std::map subgraph_ir_names_to_type_; vector inputs_desc_{}; + map input_name_idx_{}; + std::unordered_set optional_input_names_{}; vector outputs_desc_{}; map output_name_idx_{}; std::function infer_func_ = nullptr; diff --git a/inc/graph/shape_refiner.h b/inc/graph/shape_refiner.h index 65664615..4f8783a3 100644 --- a/inc/graph/shape_refiner.h +++ b/inc/graph/shape_refiner.h @@ -31,6 +31,7 @@ class ShapeRefiner { static graphStatus InferShapeAndType(const NodePtr &node, bool before_subgraph); static graphStatus InferShapeAndType(const NodePtr &node); static graphStatus InferShapeAndType(const ConstNodePtr &node, Operator &op); + static void ClearContextMap(); private: static void PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase); diff --git a/inc/graph/utils/graph_utils.h b/inc/graph/utils/graph_utils.h index 6c344435..5f627ea4 100644 --- a/inc/graph/utils/graph_utils.h +++ b/inc/graph/utils/graph_utils.h @@ -23,6 +23,8 @@ #include #include #include +#include + #include "graph/anchor.h" #include "graph/node.h" #include "graph/compute_graph.h" @@ -130,7 +132,7 @@ struct NodeIndexIO { IOType io_type_ = kOut; std::string value_; - std::string ToString() const { return value_; } + const std::string &ToString() const { return value_; } }; class GraphUtils { @@ -188,8 +190,8 @@ class GraphUtils { /// @param [in] output_index /// @return graphStatus /// - static graphStatus InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector &dsts, - const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); + static graphStatus InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector &dsts, + const NodePtr &insert_node, uint32_t input_index = 0, uint32_t output_index = 0); static graphStatus RemoveJustNode(ComputeGraphPtr compute_graph, const NodePtr &node); @@ -303,8 +305,33 @@ class GraphUtils { /// static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); + /// + /// Copy all in-data edges from `src_node` to `dst_node` + /// @param src_node + /// @param dst_node + /// @return + /// + static graphStatus CopyInDataEdges(const NodePtr &src_node, NodePtr &dst_node); + static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); + /// + /// Make a copy of ComputeGraph. + /// @param graph: original graph. + /// @param prefix: node name prefix of new graph. + /// @return ComputeGraphPtr + /// + static ComputeGraphPtr CloneGraph(const ComputeGraphPtr &graph, const string &prefix, + std::vector &input_nodes, std::vector &output_nodes); + + /// + /// Copy tensor attribute to new node. + /// @param [in] dst_desc: cloned node. + /// @param [in] src_node: original node. + /// @return success: GRAPH_SUCESS + /// + static graphStatus CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node); + static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector &node_vec); /// @@ -392,6 +419,16 @@ class GraphUtils { std::map> &symbol_to_anchors, std::map &anchor_to_symbol); + /// + /// Relink all edges for cloned ComputeGraph. + /// @param [in] node: original node. + /// @param [in] prefix: node name prefix of new node. + /// @param [in] all_nodes: all nodes in new graph. + /// @return success: GRAPH_SUCESS + /// + static graphStatus RelinkGraphEdges(const NodePtr &node, const string &prefix, + const std::unordered_map &all_nodes); + /// /// Union ref-mapping /// @param [in] exist_node_info1 @@ -728,5 +765,4 @@ class PartialGraphBuilder : public ComputeGraphBuilder { std::vector exist_nodes_; }; } // namespace ge - #endif // INC_GRAPH_UTILS_GRAPH_UTILS_H_ diff --git a/inc/graph/utils/node_utils.h b/inc/graph/utils/node_utils.h index 6e0e655d..019bb3a7 100644 --- a/inc/graph/utils/node_utils.h +++ b/inc/graph/utils/node_utils.h @@ -63,6 +63,9 @@ class NodeUtils { static void UnlinkAll(const Node &node); static graphStatus UpdatePeerNodeInputDesc(const NodePtr &node_ptr); + static graphStatus AppendInputAnchor(const NodePtr &node, uint32_t index); + static graphStatus RemoveInputAnchor(const NodePtr &node, uint32_t index); + static bool IsInNodesEmpty(const Node &node); static GeTensorDesc GetOutputDesc(const Node &node, uint32_t index); static GeTensorDesc GetInputDesc(const Node &node, uint32_t index); @@ -99,6 +102,13 @@ class NodeUtils { /// static NodePtr GetParentInput(const NodePtr &node); + /// + /// @brief Check is varying_input for while node + /// @param [in] node: Data node for subgraph + /// @return bool + /// + static bool IsWhileVaryingInput(const ge::NodePtr &node); + /// /// @brief Get subgraph input is constant. /// @param [in] node @@ -114,6 +124,24 @@ class NodeUtils { /// static graphStatus RemoveSubgraphsOnNode(const NodePtr &node); + /// + /// @brief Get subgraph input data node by index. + /// @param [in] node + /// @return Node + /// + static vector GetSubgraphDataNodesByIndex(const Node &node, int index); + + /// + /// @brief Get subgraph input data node by index. + /// @param [in] node + /// @return Node + /// + static vector GetSubgraphOutputNodes(const Node &node); + + static NodePtr GetInDataNodeByIndex(const Node &node, int index); + + static vector GetOutDataNodesByIndex(const Node &node, int index); + private: static std::map> map_send_info_; static std::map> map_recv_info_; diff --git a/inc/graph/utils/tensor_adapter.h b/inc/graph/utils/tensor_adapter.h index f9993606..a7355553 100644 --- a/inc/graph/utils/tensor_adapter.h +++ b/inc/graph/utils/tensor_adapter.h @@ -20,6 +20,7 @@ #include #include "graph/ge_tensor.h" #include "graph/tensor.h" + namespace ge { using GeTensorPtr = std::shared_ptr; using ConstGeTensorPtr = std::shared_ptr; diff --git a/inc/graph/utils/tensor_utils.h b/inc/graph/utils/tensor_utils.h index 2fa398db..caa80dcf 100644 --- a/inc/graph/utils/tensor_utils.h +++ b/inc/graph/utils/tensor_utils.h @@ -21,6 +21,7 @@ #include "graph/def_types.h" #include "graph/ge_error_codes.h" #include "graph/ge_tensor.h" + namespace ge { class TensorUtils { public: diff --git a/src/common/graph/CMakeLists.txt b/src/common/graph/CMakeLists.txt index 43f5b597..f041e4b6 100755 --- a/src/common/graph/CMakeLists.txt +++ b/src/common/graph/CMakeLists.txt @@ -71,5 +71,6 @@ target_link_libraries(graph PRIVATE ${PROTOBUF_LIBRARY} ${c_sec} ${slog} + ${error_manager} rt dl) diff --git a/src/common/graph/compute_graph.cc b/src/common/graph/compute_graph.cc index b73cf939..52953fb2 100644 --- a/src/common/graph/compute_graph.cc +++ b/src/common/graph/compute_graph.cc @@ -62,18 +62,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string ComputeGraph::GetName() co GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetName(const string &name) { name_ = name; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesSize() const { - size_t s = nodes_.size(); - for (const auto &sub_graph : sub_graph_) { - s += sub_graph->GetAllNodesSize(); - } - return s; + return GetAllNodes().size(); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetAllNodes() const { - if (sub_graph_.empty()) { - return Vistor(shared_from_this(), nodes_); - } - std::vector> subgraphs; return AllGraphNodes(subgraphs); } @@ -106,6 +98,15 @@ ComputeGraph::Vistor ComputeGraph::AllGraphNodes(std::vector(shared_from_this(), all_nodes); } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetNodes( + bool is_unknown_shape) const { + if (is_unknown_shape) { + return GetDirectNode(); + } else { + return GetAllNodes(); + } +} + size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor ComputeGraph::GetDirectNode() const { @@ -268,7 +269,7 @@ NodePtr ComputeGraph::AddNodeFront(NodePtr node) { NodePtr ComputeGraph::AddNodeFront(const OpDescPtr &op) { if (op == nullptr) { - GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null."); + GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); return nullptr; } op->SetId(nodes_.size()); @@ -278,9 +279,38 @@ NodePtr ComputeGraph::AddNodeFront(const OpDescPtr &op) { return AddNodeFront(node_ptr); } +NodePtr ComputeGraph::AddNodeAfter(NodePtr node, const NodePtr &pre_node) { + if (node == nullptr || node->GetOpDesc() == nullptr || pre_node == nullptr) { + GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null."); + return nullptr; + } + node->GetOpDesc()->SetId(nodes_.size()); + auto node_iter = std::find(nodes_.begin(), nodes_.end(), pre_node); + if (node_iter != nodes_.end()) { + nodes_.insert(node_iter + 1, node); + } else { + GELOGE(GRAPH_FAILED, "Cannot find pre_node in nodes_."); + return nullptr; + } + + return node; +} + +NodePtr ComputeGraph::AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node) { + if (op == nullptr) { + GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); + return nullptr; + } + op->SetId(nodes_.size()); + NodePtr node_ptr = shared_ptr(new (std::nothrow) Node(op, shared_from_this())); + GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); + GE_IF_BOOL_EXEC(node_ptr->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init failed."); return nullptr); + return AddNodeAfter(node_ptr, pre_node); +} + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(NodePtr node) { if (node == nullptr || node->GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should be not null."); + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); return nullptr; } node->GetOpDesc()->SetId((int64_t)GetDirectNodesSize()); @@ -290,7 +320,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(Nod GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpDescPtr op) { if (op == nullptr) { - GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null."); + GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); return nullptr; } op->SetId(GetDirectNodesSize()); @@ -302,7 +332,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(OpD NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize. if (op == nullptr) { - GELOGE(GRAPH_FAILED, "The OpDesc ptr should be not null."); + GELOGE(GRAPH_FAILED, "The OpDesc ptr should not be null."); return nullptr; } op->SetId(id); @@ -315,7 +345,7 @@ NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize. NodePtr ComputeGraph::AddInputNode(NodePtr node) { if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should be not null."); + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); return nullptr; } input_nodes_.push_back(node); @@ -327,7 +357,7 @@ NodePtr ComputeGraph::AddInputNode(NodePtr node) { NodePtr ComputeGraph::AddOutputNode(NodePtr node) { if (node == nullptr || node->GetOpDesc() == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr or opdesc should be not null."); + GELOGE(GRAPH_FAILED, "The node ptr or opdesc should not be null."); return nullptr; } @@ -363,7 +393,7 @@ graphStatus ComputeGraph::RemoveConstInput(const NodePtr &node) { if (out_anchor->GetOwnerNode()->GetType() == CONSTANT || out_anchor->GetOwnerNode()->GetType() == CONSTANTOP) { GE_CHK_BOOL_RET_STATUS(GraphUtils::RemoveEdge(out_anchor, in_anchor) == GRAPH_SUCCESS, GRAPH_FAILED, "Remove edge from const op failed."); - if (out_anchor->GetOwnerNode()->GetOutDataNodes().size() == 0) { + if (out_anchor->GetOwnerNode()->GetOutNodes().size() == 0) { GELOGI("Remove const op %s.", out_anchor->GetOwnerNode()->GetName().c_str()); auto iter = find(nodes_.begin(), nodes_.end(), out_anchor->GetOwnerNode()); if (iter != nodes_.end()) { @@ -377,7 +407,7 @@ graphStatus ComputeGraph::RemoveConstInput(const NodePtr &node) { GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveNode(const NodePtr &node) { if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should be not null."); + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); return GRAPH_FAILED; } @@ -406,7 +436,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::RemoveN // Used in sub_graph scenes graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) { if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should be not null."); + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); return GRAPH_FAILED; } @@ -421,7 +451,7 @@ graphStatus ComputeGraph::RemoveInputNode(const NodePtr &node) { // Used in sub_graph scenes graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) { if (node == nullptr) { - GELOGE(GRAPH_FAILED, "The node ptr should be not null."); + GELOGE(GRAPH_FAILED, "The node ptr should not be null."); return GRAPH_FAILED; } @@ -442,7 +472,7 @@ graphStatus ComputeGraph::RemoveOutputNode(const NodePtr &node) { std::shared_ptr ComputeGraph::AddSubGraph(std::shared_ptr sub_graph) { if (sub_graph == nullptr) { - GELOGE(GRAPH_FAILED, "The graph ptr should be not null."); + GELOGE(GRAPH_FAILED, "The graph ptr should not be null."); return nullptr; } sub_graph_.push_back(sub_graph); @@ -452,7 +482,7 @@ std::shared_ptr ComputeGraph::AddSubGraph(std::shared_ptr &sub_graph) { if (sub_graph == nullptr) { - GELOGE(GRAPH_FAILED, "The graph ptr should be not null."); + GELOGE(GRAPH_FAILED, "The graph ptr should not be null."); return GRAPH_FAILED; } @@ -491,12 +521,15 @@ ComputeGraph::AddSubgraph(const std::string &name, const std::shared_ptrparent_graph_.expired()) { - GE_LOGE("The subgraphs can only be added to the root graph"); - return GRAPH_PARAM_INVALID; + GELOGW("The subgraphs should only be added to the root graph"); } if (name != subgraph->GetName()) { GELOGW("The subgraph name %s is different with input %s", subgraph->GetName().c_str(), name.c_str()); } + if (names_to_subgraph_.find(name) != names_to_subgraph_.end()) { + GE_LOGE("The subgraph %s existed", name.c_str()); + return GRAPH_PARAM_INVALID; + } sub_graph_.push_back(subgraph); names_to_subgraph_[name] = subgraph; return GRAPH_SUCCESS; @@ -640,7 +673,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE GELOGW("node or OpDescPtr is nullptr."); continue; } - GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should be not null."); return GRAPH_FAILED); + GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "The node should not be null."); return GRAPH_FAILED); if (node->GetOpDesc()->GetType() == RECV) { auto iter = find(node_vec.begin(), node_vec.end(), node); if (iter == node_vec.end()) { @@ -786,7 +819,8 @@ graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::mapAttrHolder::Swap(graph); + + origGraph_.swap(graph.origGraph_); + + name_.swap(graph.name_); + std::swap(graph_id_, graph.graph_id_); + attrs_.Swap(graph.attrs_); + nodes_.swap(graph.nodes_); + all_nodes_infos_.swap(graph.all_nodes_infos_); + target_nodes_info_.swap(graph.target_nodes_info_); + + input_nodes_.swap(graph.input_nodes_); + inputs_order_.swap(graph.inputs_order_); + std::swap(input_size_, graph.input_size_); + out_nodes_map_.swap(graph.out_nodes_map_); + std::swap(output_size_, graph.output_size_); + output_nodes_info_.swap(graph.output_nodes_info_); + + sub_graph_.swap(graph.sub_graph_); + names_to_subgraph_.swap(graph.names_to_subgraph_); + parent_graph_.swap(graph.parent_graph_); + parent_node_.swap(graph.parent_node_); + + // the members followed should not in the ComputeGraph class + std::swap(is_valid_flag_, graph.is_valid_flag_); + std::swap(is_summary_graph_, graph.is_summary_graph_); + std::swap(need_iteration_, graph.need_iteration_); + params_share_map_.swap(graph.params_share_map_); + op_name_map_.swap(graph.op_name_map_); + std::swap(session_id_, graph.session_id_); + std::swap(data_format_, graph.data_format_); + std::swap(is_unknown_shape_graph_, graph.is_unknown_shape_graph_); + + // Update Node owner. + SetNodesOwner(); + graph.SetNodesOwner(); +} + +void ComputeGraph::SetNodesOwner() { + for (const auto &node : nodes_) { + if (node == nullptr) { + continue; + } + node->SetOwnerComputeGraph(shared_from_this()); + } +} + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::IsolateNode(const NodePtr &node) { GE_CHECK_NOTNULL(node); auto next_nodes = node->GetOutAllNodes(); @@ -1104,9 +1186,11 @@ graphStatus ComputeGraph::RemoveExtraOutEdge(const NodePtr &node) { } graphStatus ComputeGraph::Verify() { + bool is_unknown_graph = GetGraphUnknownFlag(); for (const auto &node_ptr : GetAllNodes()) { GE_CHECK_NOTNULL(node_ptr); GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); + GE_IF_BOOL_EXEC(is_unknown_graph, continue); GE_CHK_BOOL_EXEC(node_ptr->GetOpDesc()->CommonVerify() == GRAPH_SUCCESS, return GRAPH_FAILED, "Verifying %s failed.", node_ptr->GetName().c_str()); } diff --git a/src/common/graph/debug/ge_op_types.h b/src/common/graph/debug/ge_op_types.h index da36f72c..f11ef31e 100644 --- a/src/common/graph/debug/ge_op_types.h +++ b/src/common/graph/debug/ge_op_types.h @@ -34,12 +34,16 @@ GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); GE_REGISTER_OPTYPE(SWITCH, "Switch"); GE_REGISTER_OPTYPE(MERGE, "Merge"); GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); +GE_REGISTER_OPTYPE(ENTER, "Enter"); +GE_REGISTER_OPTYPE(REFENTER, "RefEnter"); GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); GE_REGISTER_OPTYPE(CONSTANT, "Const"); +GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); GE_REGISTER_OPTYPE(INITDATA, "InitData"); +GE_REGISTER_OPTYPE(REFIDENTITY, "RefIdentity"); GE_REGISTER_OPTYPE(ANN_DATA, "AnnData"); GE_REGISTER_OPTYPE(CONSTANTOP, "Constant"); diff --git a/src/common/graph/format_refiner.cc b/src/common/graph/format_refiner.cc index 11a610ce..9cb76539 100644 --- a/src/common/graph/format_refiner.cc +++ b/src/common/graph/format_refiner.cc @@ -41,11 +41,9 @@ using namespace ge; using namespace std; namespace ge { namespace { -static const std::unordered_set kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; -static bool net_format_is_nd = true; -static Format g_user_set_format = FORMAT_ND; -static bool is_first_infer = true; -static RefRelations reflection_builder; +const std::unordered_set kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; +const string kIsGraphInferred = "_is_graph_inferred"; +RefRelations reflection_builder; } // namespace graphStatus ReflectionProcess(const std::unordered_set &reflection, @@ -72,9 +70,49 @@ graphStatus ReflectionProcess(const std::unordered_set &re return GRAPH_SUCCESS; } -graphStatus FormatRefiner::RefreshConstantOutProcess(const OpDescPtr &op_desc) { +graphStatus BiasAddFormatFixProcess(ge::NodePtr &node_ptr) { + // 5 meas dim num + if (node_ptr->GetType() != "BiasAdd") { + return GRAPH_SUCCESS; + } + std::unordered_map kTfFormatFix = {{"NHWC", FORMAT_NDHWC}, {"NCHW", FORMAT_NCDHW}}; + for (size_t i = 0; i < node_ptr->GetOpDesc()->GetInputsSize(); i++) { + auto in_desc = node_ptr->GetOpDesc()->MutableInputDesc(i); + GE_CHECK_NOTNULL(in_desc); + if (in_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num + continue; + } + auto format = in_desc->GetOriginFormat(); + auto key = TypeUtils::FormatToSerialString(format); + auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; + in_desc->SetOriginFormat(fixed_format); + in_desc->SetFormat(fixed_format); + GELOGD("fix the %zu'th input of node[%s]. Origin format is %s , after fixed it is %s", i, + node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::FormatToSerialString(fixed_format).c_str()); + } + for (size_t i = 0; i < node_ptr->GetOpDesc()->GetOutputsSize(); i++) { + auto out_desc = node_ptr->GetOpDesc()->MutableOutputDesc(i); + GE_CHECK_NOTNULL(out_desc); + if (out_desc->MutableShape().GetDimNum() != 5) { // 5 means dim num + continue; + } + auto format = out_desc->GetOriginFormat(); + auto key = TypeUtils::FormatToSerialString(format); + auto fixed_format = (kTfFormatFix.count(key) == 0) ? format : kTfFormatFix[key]; + out_desc->SetOriginFormat(fixed_format); + out_desc->SetFormat(fixed_format); + GELOGD("fix the %zu'th output of node[%s]. Origin format is %s , after fixed it is %s", i, + node_ptr->GetName().c_str(), TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::FormatToSerialString(fixed_format).c_str()); + } + return GRAPH_SUCCESS; +} + +graphStatus FormatRefiner::RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc) { + GE_CHECK_NOTNULL(graph); GE_CHECK_NOTNULL(op_desc); - if (op_desc->GetType() == CONSTANTOP && is_first_infer == true) { + if (op_desc->GetType() == CONSTANTOP && !IsGraphInferred(graph)) { ConstGeTensorPtr tensor_value; if (!AttrUtils::GetTensor(op_desc, "value", tensor_value)) { GELOGE(GRAPH_FAILED, "Get value failed, node name:%s.", op_desc->GetName().c_str()); @@ -95,7 +133,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std } anchor_points.clear(); // Get all anchor point nodes and switch nodes - for (const auto &node_ptr : graph->GetAllNodes()) { + for (auto &node_ptr : graph->GetAllNodes()) { if (node_ptr == nullptr) { return GRAPH_FAILED; } @@ -103,7 +141,7 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std if (op_desc == nullptr) { return GRAPH_FAILED; } - graphStatus status = RefreshConstantOutProcess(op_desc); + graphStatus status = RefreshConstantOutProcess(graph, op_desc); if (status != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "refresh constant out process failed!"); return GRAPH_FAILED; @@ -135,6 +173,16 @@ graphStatus FormatRefiner::GetAnchorPoints(const ge::ComputeGraphPtr &graph, std if (!node_is_all_nd) { continue; } + // special process for biasAdd op + // In tensorflow, biasAdd's format is alwayse NHWC even though set the arg + // "data_format" to NDHWC or NCDHW.It will destroy our format-infer mechanism + // so here do special process + status = BiasAddFormatFixProcess(node_ptr); + if (status != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "fix biasAdd process failed!"); + return GRAPH_FAILED; + } + GELOGD("Node[%s] is anchor point!", node_ptr->GetName().c_str()); anchor_points.push_back(node_ptr); } @@ -344,14 +392,11 @@ void FormatRefiner::RefreshOriginFormatOfAnchor(std::vector &anchor } } -void FormatRefiner::SetInferOrigineFormatFlag(bool is_first) { is_first_infer = is_first; } - -graphStatus FormatRefiner::DataNodeFormatProcess(std::vector &data_nodes, ge::Format data_format, +graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector &data_nodes, + ge::Format data_format, std::unordered_map &node_status) { - bool is_internal_format = TypeUtils::IsInternalFormat(data_format); - bool need_process = (!is_first_infer) && (!is_internal_format) && (data_format != FORMAT_ND); - if (!need_process) { - GELOGI("no necessary to do DataNodeFormatProcess.is_first_infer:%d, data_format:%s", is_first_infer, + if (!(IsGraphInferred(graph) && (!TypeUtils::IsInternalFormat(data_format)) && (data_format != FORMAT_ND))) { + GELOGI("no necessary to do DataNodeFormatProcess. is_graph_inferred:%d, data_format:%s", IsGraphInferred(graph), TypeUtils::FormatToSerialString(data_format).c_str()); return GRAPH_SUCCESS; } @@ -410,8 +455,6 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) std::vector anchor_points; std::vector data_nodes; // global net format - net_format_is_nd = true; - g_user_set_format = FORMAT_ND; if (graph == nullptr) { GELOGE(GRAPH_FAILED, "input graph is null"); @@ -448,10 +491,15 @@ graphStatus FormatRefiner::InferOrigineFormat(const ge::ComputeGraphPtr &graph) /// format for these data nodes. /// Notice: ignore 5D formats auto data_format = graph->GetDataFormat(); - status = DataNodeFormatProcess(data_nodes, data_format, node_status); - // Set infer flag to false - SetInferOrigineFormatFlag(false); + status = DataNodeFormatProcess(graph, data_nodes, data_format, node_status); + + (void)AttrUtils::SetBool(graph, kIsGraphInferred, true); return status; } + +bool FormatRefiner::IsGraphInferred(const ComputeGraphPtr &graph) { + bool is_graph_inferred = false; + return (AttrUtils::GetBool(graph, kIsGraphInferred, is_graph_inferred) && is_graph_inferred); +} } // namespace ge diff --git a/src/common/graph/format_refiner.h b/src/common/graph/format_refiner.h index fa40a034..eca93bae 100644 --- a/src/common/graph/format_refiner.h +++ b/src/common/graph/format_refiner.h @@ -30,10 +30,9 @@ namespace ge { class FormatRefiner { public: static graphStatus InferOrigineFormat(const ge::ComputeGraphPtr &graph); - static void SetInferOrigineFormatFlag(bool is_first = true); private: - static graphStatus RefreshConstantOutProcess(const OpDescPtr &op_desc); + static graphStatus RefreshConstantOutProcess(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); static graphStatus GetAnchorPoints(const ge::ComputeGraphPtr &graph, std::vector &anchor_points, std::vector &data_nodes, std::unordered_map &node_status); @@ -43,8 +42,9 @@ class FormatRefiner { std::unordered_map &node_status); static graphStatus ForwardInferProcess(std::deque &nodes, ge::NodePtr &node, std::unordered_map &node_status); - static graphStatus DataNodeFormatProcess(std::vector &data_nodes, ge::Format data_format, - std::unordered_map &node_status); + static graphStatus DataNodeFormatProcess(const ComputeGraphPtr &graph, std::vector &data_nodes, + ge::Format data_format, std::unordered_map &node_status); + static bool IsGraphInferred(const ComputeGraphPtr &graph); }; } // namespace ge #endif // COMMON_GRAPH_FORMAT_REFINER_H_ diff --git a/src/common/graph/ge_attr_define.cc b/src/common/graph/ge_attr_define.cc index 96638249..f78ca7aa 100644 --- a/src/common/graph/ge_attr_define.cc +++ b/src/common/graph/ge_attr_define.cc @@ -158,6 +158,10 @@ const std::string ATTR_NAME_AUTOMIC_ADD_MEM_SIZE = "automic_add_mem_size"; const std::string ATTR_NAME_DYNAMIC_OUTPUT_DIMS = "_dynamic_output_dims"; const std::string ATTR_NAME_INPUT_ORIGIN_SIZE = "input_origin_size"; +// Identify node connecting to input and output +const std::string ATTR_NAME_NODE_CONNECT_INPUT = "_is_connected_to_data"; +const std::string ATTR_NAME_NODE_CONNECT_OUTPUT = "_is_connected_to_netoutput"; + // To be deleted const std::string ATTR_TO_BE_DELETED = "to_be_deleted"; const std::string PERMUTE_RESHAPE_FUSION = "permute_reshape_fusion"; @@ -725,6 +729,10 @@ const std::string ATTR_MODEL_TASK_INDEX_OP_NAME = "task_index_op_name"; const std::string ATTR_MODEL_CORE_TYPE = "core_type"; +const std::string ATTR_MODEL_ATC_VERSION = "atc_version"; + +const std::string ATTR_MODEL_OPP_VERSION = "opp_version"; + // Public attribute const std::string ATTR_NAME_IMPLY_TYPE = "imply_type"; @@ -901,6 +909,7 @@ const std::string ATTR_NAME_IS_END_OF_INPUTMEM_LIFECYCLE = "is_end_of_inputmem_l const std::string ATTR_NAME_PRED_VALUE = "_pred_value"; const std::string ATTR_NAME_BATCH_NUM = "_batch_num"; const std::string ATTR_NAME_BATCH_LABEL = "_batch_label"; +const std::string ATTR_NAME_COMBINED_BATCH = "_combined_batch"; // Control flow const std::string ATTR_NAME_STREAM_SWITCH_COND = "switch_condition"; @@ -910,6 +919,7 @@ const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE = "subgraph_first_active"; +const std::string ATTR_NAME_COMBINED_DYNAMIC_DIMS = "combined_dynamic_dims"; const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; @@ -934,7 +944,7 @@ const std::string ATTR_NAME_MEMORY_TYPE_WORKSPACE = "memory_type_workspace"; const std::string MODEL_ATTR_SESSION_ID = "session_id"; -// l1 fusion and other fusion in future +// lx fusion const std::string ATTR_NAME_L1_FUSION_GROUP_ID = "_l1_fusion_group_id"; const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; @@ -948,9 +958,17 @@ const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1 const std::string ATTR_NAME_SWITCH_FOR_L1_FUSION = "_enable_l1_fusion"; const std::string ATTR_N_BATCH_SPILT = "_is_n_batch_split"; const std::string ATTR_NO_TASK_AND_DUMP_NEEDED = "_no_task_and_dump_needed"; +const std::string ATTR_DATA_DUMP_REF = "_datadump_ref"; const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION = "_output_offset_for_buffer_fusion"; const std::string ATTR_NAME_L2_FUSION_GROUP_ID = "_l2_fusion_group_id"; const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; +const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag"; +const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr"; +const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size"; + +// Op debug attrs +const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; +const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode"; // Atomic addr clean attrs const std::string ATOMIC_ATTR_INPUT_INDEX = "atomic_input_index"; @@ -971,6 +989,8 @@ const std::string ATTR_INSERT_BY_MBATCH = "mbatch-inserted-node"; const std::string ATTR_MBATCH_ORIGIN_INPUT_DIMS = "_mbatch_origin_input_dims"; +const std::string ATTR_DYNAMIC_TYPE = "mbatch_dynamic_type"; + // For inserted op const std::string ATTR_INSERTED_BY_GE = "_inserted_by_ge"; @@ -1009,10 +1029,38 @@ const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_ const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_output_offset_list_list"; +// for unregistered op +const std::string ATTR_NAME_UNREGST_OPPATH = "_unregst_oppath"; +const std::string ATTR_NAME_UNREGST_ATTRLIST = "_unregst_attrlist"; + // used for Horovod const std::string ATTR_INTER_EVENT_IDENTIFY = "event_id"; const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op"; // used for allreduce tailing optimization const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group"; const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node"; + +// dynamic shape attr +const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR = "_alloc_fixed_addr"; +const std::string ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX = "_alloc_fixed_addr_index"; + +// atc user def dtype&format +const std::string ATTR_ATC_USER_DEFINE_DATATYPE = "_user_defined_data_type"; +const std::string ATTR_ATC_USER_DEFINE_FORMAT = "_user_defined_format"; + +// for fusion op plugin +const std::string ATTR_NAME_FUSIONOP_ORIGINAL_TYPE = "_fusionop_original_type"; + +// graph partition for aicpu +const std::string ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME = "pld_front_node_engine_name"; +const std::string ATTR_NAME_END_REAR_NODE_ENGINE_NAME = "end_rear_node_engine_name"; + +// input and output memory type +const std::string ATTR_VARIABLE_PLACEMENT = "_variable_placement"; +const std::string ATTR_INPUT_MEMORY_TYPE = "_input_memory_type"; +const std::string ATTR_OUTPUT_MEMORY_TYPE = "_output_memory_type"; + +// input_output_offset +const std::string ATTR_ZERO_COPY_BASIC_OFFSET = "_zero_copy_basic_offset"; +const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET = "_zero_copy_relative_offset"; } // namespace ge diff --git a/src/common/graph/ge_attr_value.cc b/src/common/graph/ge_attr_value.cc index 3a1dec6d..a8490470 100644 --- a/src/common/graph/ge_attr_value.cc +++ b/src/common/graph/ge_attr_value.cc @@ -33,7 +33,8 @@ using std::vector; namespace ge { NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } -NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) : named_attrs_(owner, proto_msg) {} +NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) + : named_attrs_(owner, proto_msg) {} // lint !e1744 void NamedAttrs::SetName(const std::string &name) { auto proto_msg = named_attrs_.GetProtoMsg(); @@ -238,7 +239,7 @@ ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR) ATTR_VALUE_SET_GET_IMP(vector) ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT) ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) +ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524 ATTR_VALUE_SET_GET_IMP(vector) ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL) ATTR_VALUE_SET_GET_IMP(vector) @@ -252,9 +253,11 @@ ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES) ATTR_VALUE_SET_GET_IMP(vector) ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS) ATTR_VALUE_SET_GET_IMP(vector) +/*lint -e665*/ ATTR_VALUE_SET_GET_IMP(vector>) -ATTR_VALUE_SET_GET_IMP(vector) -ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) +/*lint +e665*/ +ATTR_VALUE_SET_GET_IMP(vector) // lint !e665 +ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665 #undef ATTR_VALUE_SET_GET_IMP @@ -782,14 +785,14 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM if (graph_def == nullptr) { GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); graph_def = nullptr; - return false; + return false; // lint !e665 } else { ModelSerializeImp imp; imp.SetProtobufOwner(graph_def); if (!imp.UnserializeGraph(graph, *graph_def)) { GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); return false; - } + } // lint !e514 value = graph; } return true; @@ -809,7 +812,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM if (graph_def == nullptr) { GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); graph_def = nullptr; - return false; + return false; // lint !e665 } else { ComputeGraphPtr graph = nullptr; ModelSerializeImp imp; @@ -817,7 +820,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM if (!imp.UnserializeGraph(graph, *graph_def)) { GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); return false; - } + } // lint !e514 value.push_back(graph); } } @@ -969,7 +972,9 @@ ATTR_UTILS_SET_IMP(Tensor, GeTensor) ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS) ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) +/*lint -e665*/ ATTR_UTILS_SET_GET_IMP(ListListInt, vector>) +/*lint +e665*/ ATTR_UTILS_SET_GET_IMP(ListInt, vector) ATTR_UTILS_SET_IMP(ListInt, vector) @@ -984,8 +989,8 @@ ATTR_UTILS_SET_IMP(ListTensor, vector) ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector) ATTR_UTILS_SET_GET_IMP(ListBytes, vector) ATTR_UTILS_SET_GET_IMP(ListGraph, vector) -ATTR_UTILS_SET_GET_IMP(ListDataType, vector) -ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) +ATTR_UTILS_SET_GET_IMP(ListDataType, vector) // lint !e665 +ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) // lint !e665 bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name, std::initializer_list &&value) { @@ -1154,7 +1159,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListOpDesc(Con } for (const auto &item : bytes_vals) { ModelSerialize serialize; - auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); + auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); // lint !e732 value.push_back(op_desc); } return true; @@ -1206,7 +1211,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc( op_def = ComGraphMakeShared(); if (op_def == nullptr) { GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); - return nullptr; + return nullptr; // lint !e665 } ModelSerializeImp imp; (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); @@ -1216,27 +1221,16 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc( GE_CHK_BOOL_EXEC(imp.UnserializeOpDesc(op_desc, *op_def), return op_desc, "op_desc unserialize failed"); op_desc->extAttrs_ = org_op_desc->extAttrs_; - if (op_desc->HasAttr("_input_name_idx_key")) { - if (op_desc->DelAttr("_input_name_idx_key") != SUCCESS) { - GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_key failed."); - } - } - - if (op_desc->HasAttr("_input_name_idx_value")) { - if (op_desc->DelAttr("_input_name_idx_value") != SUCCESS) { - GELOGE(GRAPH_FAILED, "DelAttr _input_name_idx_value failed."); - } + // This function may be called by some passes of fusion engine, in this condition, do not need these attribute + if (!op_desc->input_name_idx_.empty()) { + op_desc->input_name_idx_.clear(); } - - if (op_desc->HasAttr("_opt_input")) { - if (op_desc->DelAttr("_opt_input") != SUCCESS) { - GELOGE(GRAPH_FAILED, "DelAttr _opt_input failed."); - } - } - if (!op_desc->output_name_idx_.empty()) { op_desc->output_name_idx_.clear(); } + if (!op_desc->optional_input_names_.empty()) { + op_desc->optional_input_names_.clear(); + } return op_desc; } @@ -1260,6 +1254,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CopyOpDesc(c op_desc->extAttrs_ = org_op_desc->extAttrs_; + op_desc->input_name_idx_.insert(org_op_desc->input_name_idx_.begin(), org_op_desc->input_name_idx_.end()); + op_desc->optional_input_names_.insert(org_op_desc->optional_input_names_.begin(), + org_op_desc->optional_input_names_.end()); op_desc->output_name_idx_.insert(org_op_desc->output_name_idx_.begin(), org_op_desc->output_name_idx_.end()); op_desc->infer_func_ = org_op_desc->infer_func_; diff --git a/src/common/graph/ge_tensor.cc b/src/common/graph/ge_tensor.cc index 8ffbba91..196b8569 100644 --- a/src/common/graph/ge_tensor.cc +++ b/src/common/graph/ge_tensor.cc @@ -220,6 +220,7 @@ const string TENSOR_UTILS_ORIGIN_SHAPE = "origin_shape"; const string TENSOR_UTILS_ORIGIN_FORMAT = "origin_format"; const string TENSOR_UTILS_ORIGIN_DATA_TYPE = "origin_data_type"; const string TENSOR_UTILS_SHAPE_RANGE = "shape_range"; +const string TENSOR_UTILS_REF_PORT_INDEX = "ref_port_index"; GeShape::GeShape(const ProtoMsgOwner &proto_owner, proto::ShapeDef *proto_msg) : shape_def_(proto_owner, proto_msg) {} @@ -567,6 +568,16 @@ DataType GeTensorDesc::GetOriginDataType() const { return TypeUtils::SerialStringToDataType(origin_data_type_str); } +std::vector GeTensorDesc::GetRefPortIndex() const { + vector ref_port_index; + (void)AttrUtils::GetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, ref_port_index); + return ref_port_index; +} + +void GeTensorDesc::SetRefPortByIndex(const std::vector &index) { + (void)AttrUtils::SetListInt(this, TENSOR_UTILS_REF_PORT_INDEX, index); +} + graphStatus GeTensorDesc::IsValid() const { auto dtype = this->GetDataType(); auto format = this->GetFormat(); diff --git a/src/common/graph/graph.cc b/src/common/graph/graph.cc index 09d4fd56..fc30e9d6 100644 --- a/src/common/graph/graph.cc +++ b/src/common/graph/graph.cc @@ -210,7 +210,7 @@ class GraphImpl { graphStatus FindOpByName(const string &name, ge::Operator &op) const { auto it = op_list_.find(name); - GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "Error: there is no op: %s.", name.c_str()); + GE_CHK_BOOL_EXEC(it != op_list_.end(), return GRAPH_FAILED, "there is no op: %s.", name.c_str()); op = it->second; return GRAPH_SUCCESS; } diff --git a/src/common/graph/graph.mk b/src/common/graph/graph.mk index 5eaf7d86..b007dac8 100644 --- a/src/common/graph/graph.mk +++ b/src/common/graph/graph.mk @@ -77,6 +77,7 @@ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libprotobuf \ libslog \ + liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl @@ -94,10 +95,36 @@ LOCAL_CPPFLAGS += -fexceptions LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) LOCAL_SRC_FILES := \ - ../../out/atc/lib64/stub/graph.cc \ - ../../out/atc/lib64/stub/operator.cc \ - ../../out/atc/lib64/stub/tensor.cc \ - ../../out/atc/lib64/stub/operator_factory.cc \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_HOST_SHARED_LIBRARY) + +#compiler for host +include $(CLEAR_VARS) +LOCAL_MODULE := fwk_stub/libgraph + +LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -O2 +LOCAL_CPPFLAGS += -fexceptions + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := \ + ../../out/graph/lib64/stub/attr_value.cc \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + ../../out/graph/lib64/stub/inference_context.cc \ LOCAL_SHARED_LIBRARIES := @@ -122,6 +149,7 @@ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libprotobuf \ libslog \ + liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl @@ -142,10 +170,39 @@ LOCAL_CFLAGS += -O2 LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) LOCAL_SRC_FILES := \ - ../../out/atc/lib64/stub/graph.cc \ - ../../out/atc/lib64/stub/operator.cc \ - ../../out/atc/lib64/stub/tensor.cc \ - ../../out/atc/lib64/stub/operator_factory.cc \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +ifeq ($(device_os),android) +LOCAL_LDFLAGS := -ldl +endif + +LOCAL_MULTILIB := 64 +LOCAL_PROPRIETARY_MODULE := true + +include $(BUILD_SHARED_LIBRARY) + +#compiler for device +include $(CLEAR_VARS) +LOCAL_MODULE := fwk_stub/libgraph + +LOCAL_CFLAGS += -O2 + +LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) +LOCAL_SRC_FILES := \ + ../../out/graph/lib64/stub/attr_value.cc \ + ../../out/graph/lib64/stub/graph.cc \ + ../../out/graph/lib64/stub/operator.cc \ + ../../out/graph/lib64/stub/operator_factory.cc \ + ../../out/graph/lib64/stub/tensor.cc \ + ../../out/graph/lib64/stub/inference_context.cc \ LOCAL_SHARED_LIBRARIES := @@ -174,6 +231,7 @@ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libprotobuf \ libslog \ + liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl @@ -199,6 +257,7 @@ LOCAL_STATIC_LIBRARIES := \ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libslog \ + liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl @@ -222,6 +281,7 @@ LOCAL_STATIC_LIBRARIES := \ LOCAL_SHARED_LIBRARIES := \ libc_sec \ libslog \ + liberror_manager \ LOCAL_LDFLAGS := -lrt -ldl diff --git a/src/common/graph/model_serialize.cc b/src/common/graph/model_serialize.cc index 19cb4538..673bb31b 100644 --- a/src/common/graph/model_serialize.cc +++ b/src/common/graph/model_serialize.cc @@ -88,10 +88,8 @@ bool ModelSerializeImp::SerializeEdge(const NodePtr &node, proto::OpDef *op_def_ } bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto, bool is_dump) { - if (op_desc == nullptr || op_def_proto == nullptr) { - GELOGE(GRAPH_FAILED, "Input Para Invalid"); - return false; - } + GE_CHK_BOOL_EXEC(op_desc != nullptr, return false, "op_desc is null."); + GE_CHK_BOOL_EXEC(op_def_proto != nullptr, return false, "op_def_proto is null."); if (op_desc->op_def_.GetProtoMsg() != nullptr) { *op_def_proto = *op_desc->op_def_.GetProtoMsg(); // Delete unnecessary attr @@ -130,18 +128,40 @@ bool ModelSerializeImp::SerializeOpDesc(const ConstOpDescPtr &op_desc, proto::Op for (const std::string &name : op_desc->GetSubgraphInstanceNames()) { op_def_proto->add_subgraph_name(name); } + OpDescToAttrDef(op_desc, op_def_proto); + } + return true; +} - proto::AttrDef key; - proto::AttrDef value; +void ModelSerializeImp::OpDescToAttrDef(const ConstOpDescPtr &op_desc, proto::OpDef *op_def_proto) { + proto::AttrDef key_in; + proto::AttrDef value_in; + auto op_desc_attr = op_def_proto->mutable_attr(); + if (!op_desc->input_name_idx_.empty()) { + for (auto &item : op_desc->input_name_idx_) { + key_in.mutable_list()->add_s(item.first); + value_in.mutable_list()->add_i(item.second); + } + op_desc_attr->insert({"_input_name_key", key_in}); + op_desc_attr->insert({"_input_name_value", value_in}); + } + proto::AttrDef key_out; + proto::AttrDef value_out; + if (!op_desc->output_name_idx_.empty()) { for (auto &item : op_desc->output_name_idx_) { - key.mutable_list()->add_s(item.first); - value.mutable_list()->add_i(item.second); + key_out.mutable_list()->add_s(item.first); + value_out.mutable_list()->add_i(item.second); } - auto op_desc_attr = op_def_proto->mutable_attr(); - op_desc_attr->insert({"_output_name_key", key}); - op_desc_attr->insert({"_output_name_value", value}); + op_desc_attr->insert({"_output_name_key", key_out}); + op_desc_attr->insert({"_output_name_value", value_out}); + } + proto::AttrDef opt_input; + if (!op_desc->optional_input_names_.empty()) { + for (auto &item : op_desc->optional_input_names_) { + opt_input.mutable_list()->add_s(item); + } + op_desc_attr->insert({"_opt_input", opt_input}); } - return true; } bool ModelSerializeImp::SerializeNode(const NodePtr &node, proto::OpDef *op_def_proto, bool is_dump) { @@ -237,13 +257,70 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::Unseriali } } +void ModelSerializeImp::AttrDefToOpDesc(OpDescPtr &op_desc, std::vector &key_in, std::vector &key_out, + std::vector &value_in, std::vector &value_out, + std::vector &opt_input) { + if (!key_in.empty()) { + if (key_in.size() != value_in.size()) { + GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(), + value_in.size()); + } else { + for (uint32_t i = 0; i < key_in.size(); ++i) { + op_desc->input_name_idx_.insert(std::pair(key_in.at(i), value_in.at(i))); + } + } + } + if (!key_out.empty()) { + if (key_out.size() != value_out.size()) { + GELOGW("Key and value vector size is different. key_size: %zu, value_size: %zu.", key_out.size(), + value_out.size()); + } else { + for (uint32_t i = 0; i < key_out.size(); ++i) { + op_desc->output_name_idx_.insert(std::pair(key_out.at(i), value_out.at(i))); + } + } + } + if (!opt_input.empty()) { + for (const auto &i : opt_input) { + op_desc->optional_input_names_.insert(i); + } + } +} + bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_def_proto) { - std::vector key; - std::vector value; + std::vector opt_input; + std::vector key_in; + std::vector value_in; + if (op_def_proto.attr().count("_opt_input") > 0) { + auto &name_list = op_def_proto.attr().at("_opt_input").list(); + for (const auto &item_s : name_list.s()) { + opt_input.push_back(item_s); + } + auto op_desc_attr = op_def_proto.mutable_attr(); + op_desc_attr->erase("_opt_input"); + } + if (op_def_proto.attr().count("_input_name_key") > 0) { + auto &output_name_key_list = op_def_proto.attr().at("_input_name_key").list(); + for (const auto &item_s : output_name_key_list.s()) { + key_in.push_back(item_s); + } + auto op_desc_attr = op_def_proto.mutable_attr(); + op_desc_attr->erase("_input_name_key"); + } + if (op_def_proto.attr().count("_input_name_value") > 0) { + auto &input_name_value_list = op_def_proto.attr().at("_input_name_value").list(); + for (const auto &item_i : input_name_value_list.i()) { + value_in.push_back(static_cast(item_i)); + } + auto op_desc_attr = op_def_proto.mutable_attr(); + op_desc_attr->erase("_input_name_value"); + } + std::vector key_out; + std::vector value_out; if (op_def_proto.attr().count("_output_name_key") > 0) { auto &output_name_key_list = op_def_proto.attr().at("_output_name_key").list(); for (const auto &item_s : output_name_key_list.s()) { - key.push_back(item_s); + key_out.push_back(item_s); } auto op_desc_attr = op_def_proto.mutable_attr(); op_desc_attr->erase("_output_name_key"); @@ -251,7 +328,7 @@ bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_d if (op_def_proto.attr().count("_output_name_value") > 0) { auto &output_name_value_list = op_def_proto.attr().at("_output_name_value").list(); for (const auto &item_i : output_name_value_list.i()) { - value.push_back(static_cast(item_i)); + value_out.push_back(static_cast(item_i)); } auto op_desc_attr = op_def_proto.mutable_attr(); op_desc_attr->erase("_output_name_value"); @@ -282,15 +359,8 @@ bool ModelSerializeImp::UnserializeOpDesc(OpDescPtr &op_desc, proto::OpDef &op_d op_desc->SetSubgraphInstanceName(graph_index++, name); } - if (key.size() != 0) { - if (key.size() != value.size()) { - GELOGE(GRAPH_FAILED, "twe vector size is different. key_size: %zu, value_size: %zu.", key.size(), value.size()); - } else { - for (uint32_t i = 0; i < key.size(); ++i) { - op_desc->output_name_idx_.insert(std::pair(key.at(i), value.at(i))); - } - } - } + // insert name index by key and value + AttrDefToOpDesc(op_desc, key_in, key_out, value_in, value_out, opt_input); return true; } @@ -338,13 +408,13 @@ bool ModelSerializeImp::HandleNodeNameRef() { item.dst_node_name.c_str(), item.dst_in_index); return false; } - GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); + GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 } else { // Control edge auto src_anchor = src_node_it->second->GetOutControlAnchor(); auto dst_anchor = item.dst_node->GetInControlAnchor(); if (src_anchor != nullptr && dst_anchor != nullptr) { - GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); + GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 } } } diff --git a/src/common/graph/node.cc b/src/common/graph/node.cc index e0939e7e..b210957d 100644 --- a/src/common/graph/node.cc +++ b/src/common/graph/node.cc @@ -26,6 +26,7 @@ #include "utils/ge_ir_utils.h" #include "utils/node_utils.h" #include "utils/op_desc_utils.h" +#include "common/util/error_manager/error_manager.h" using std::string; using std::vector; @@ -154,7 +155,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool Node::NodeAnchorIsEqual(cons const auto &peer_node = left_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); const auto &r_peer_node = right_anchor->GetPeerAnchors().at(j)->GetOwnerNode(); if (peer_node == nullptr || r_peer_node == nullptr) { - GELOGE(GRAPH_FAILED, "Error: anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", + GELOGE(GRAPH_FAILED, "anchor's peer node is null, node name: %s index[%zu] peer node index[%zu]. ", this->GetName().c_str(), i, j); return false; } @@ -434,8 +435,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Node::Vistor Node::Get GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAnchor(int idx) const { if (idx < 0 || idx >= static_cast(in_data_anchors_.size())) { - GELOGE(GRAPH_FAILED, "the node doesn't have %d th in_data_anchor, node %s:%s", idx, GetType().c_str(), - GetName().c_str()); + ErrorManager::GetInstance().ATCReportErrMessage( + "E19019", {"opname", "index", "anchorname", "optype"}, + {GetName().c_str(), std::to_string(idx), "in_data_anchor", GetType().c_str()}); + GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s in_data_anchor which optype is %s.", GetName().c_str(), idx, + GetType().c_str()); return nullptr; } else { return in_data_anchors_[idx]; @@ -445,7 +449,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY InDataAnchorPtr Node::GetInDataAn GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int idx) const { // Idx can't be less than -1 or >= in_data_anchors_.size(), -1 means index of control anchor_ if (idx < -1 || idx >= static_cast(in_data_anchors_.size())) { - GELOGW("the node doesn't have %d th in_anchor, node %s:%s", idx, GetType().c_str(), GetName().c_str()); + GELOGW("Op[%s] doesn't have index[%d]'s in_anchor which optype is %s.", GetName().c_str(), idx, GetType().c_str()); return nullptr; } else { // Return control anchor @@ -461,8 +465,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetInAnchor(int i GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int idx) const { // Idx can't be less than -1 or >= out_data_anchors_.size(), -1 means index of control anchor_ if (idx < -1 || idx >= static_cast(out_data_anchors_.size())) { - GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_anchor, node %s:%s", idx, GetType().c_str(), - GetName().c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E19019", {"opname", "index", "anchorname", "optype"}, + { + GetName().c_str(), + std::to_string(idx), + "out_anchor", + GetType().c_str(), + }); + GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_anchor which optype is %s.", GetName().c_str(), idx, + GetType().c_str()); return nullptr; } else { // Return control anchor @@ -477,8 +488,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AnchorPtr Node::GetOutAnchor(int GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OutDataAnchorPtr Node::GetOutDataAnchor(int idx) const { if (idx < 0 || idx >= static_cast(out_data_anchors_.size())) { - GELOGE(GRAPH_FAILED, "the node doesn't have %d th out_data_anchor, node %s:%s", idx, GetType().c_str(), - GetName().c_str()); + ErrorManager::GetInstance().ATCReportErrMessage( + "E19019", {"opname", "index", "anchorname", "optype"}, + {GetName().c_str(), std::to_string(idx), "out_data_anchor", GetType().c_str()}); + GELOGE(GRAPH_FAILED, "Op[%s] doesn't have index[%d]'s out_data_anchor which optype is %s.", GetName().c_str(), idx, + GetType().c_str()); return nullptr; } else { return out_data_anchors_[idx]; @@ -726,22 +740,27 @@ graphStatus Node::Verify() const { const string aipp_data_type = "AippData"; const string const_type = "Const"; const string variable_type = "Variable"; + bool is_unknown_graph = GetOwnerComputeGraph()->GetGraphUnknownFlag(); GE_CHK_BOOL_EXEC(op_ != nullptr, return GRAPH_FAILED, "original OpDesc is nullptr"); - for (const auto &in_anchor_ptr : GetAllInDataAnchors()) { - if (in_anchor_ptr == nullptr) { - GELOGW("in anchor ptr is null"); - continue; + if (!is_unknown_graph) { + for (const auto &in_anchor_ptr : GetAllInDataAnchors()) { + GE_IF_BOOL_EXEC(in_anchor_ptr == nullptr, GELOGW("in anchor ptr is null"); continue); + bool valid_anchor = op_->GetType() == data_type || op_->GetType() == aipp_data_type || + op_->GetType() == const_type || op_->GetType() == variable_type || + op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || in_anchor_ptr->GetPeerAnchors().size() > 0; + if (!valid_anchor) { + ErrorManager::GetInstance().ATCReportErrMessage("E11019", {"opname", "index"}, + {GetName(), std::to_string(in_anchor_ptr->GetIdx())}); + GELOGE(GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); + return GRAPH_FAILED; + } } - GE_CHK_BOOL_RET_STATUS( - op_->GetType() == data_type || op_->GetType() == aipp_data_type || op_->GetType() == const_type || - op_->GetType() == variable_type || op_->IsOptionalInput(in_anchor_ptr->GetIdx()) || - in_anchor_ptr->GetPeerAnchors().size() > 0, - GRAPH_FAILED, "operator %s's input %d is not linked.", GetName().c_str(), in_anchor_ptr->GetIdx()); } string frameworkop_type = "FrameworkOp"; - if (op_->GetType() != frameworkop_type) { + bool need_update_name = op_->GetType() != frameworkop_type && !is_unknown_graph; + if (need_update_name) { auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_->GetType()); if (node_op.IsEmpty()) { GELOGW("get op from OperatorFactory fail. opType: %s", op_->GetType().c_str()); @@ -761,7 +780,7 @@ graphStatus Node::Verify() const { } node_op.BreakConnect(); } - + GE_IF_BOOL_EXEC(is_unknown_graph, return GRAPH_SUCCESS;); if (op_->CommonVerify() == GRAPH_SUCCESS) { Operator op_proxy = ge::OpDescUtils::CreateOperatorFromNode(shared_from_this()); auto verify_func = op_->GetVerifyFunc(); diff --git a/src/common/graph/op_desc.cc b/src/common/graph/op_desc.cc index adb52162..a7451641 100644 --- a/src/common/graph/op_desc.cc +++ b/src/common/graph/op_desc.cc @@ -19,6 +19,7 @@ #include "debug/ge_util.h" #include "external/graph/operator.h" #include "framework/common/debug/ge_log.h" +#include "common/util/error_manager/error_manager.h" #include "graph/ge_attr_value.h" #include "graph/ge_tensor.h" #include "graph/operator_factory_impl.h" @@ -32,6 +33,7 @@ using std::shared_ptr; using std::string; using std::vector; +/*lint -save -e521 -e681 -e732 -e737*/ namespace ge { const std::string ATTR_NAME_ID = "id"; @@ -63,12 +65,6 @@ const std::string ATTR_NAME_IS_INPUT_CONST = "is_input_const"; const std::string ATTR_NAME_OP_INFER_DEPENDS = "_op_infer_depends"; -const std::string ATTR_NAME_OPT_INPUT = "_opt_input"; - -const std::string ATTR_NAME_INPUT_NAME_IDX_KEY = "_input_name_idx_key"; - -const std::string ATTR_NAME_INPUT_NAME_IDX_VALUE = "_input_name_idx_value"; - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::OpDesc() { op_def_.InitDefault(); if (op_def_.GetProtoMsg() != nullptr) { @@ -210,8 +206,7 @@ graphStatus OpDesc::AddInputDesc(uint32_t index, const ge::GeTensorDesc &input_d } graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { - auto input_name_idx = GetAllInputName(); - if (input_name_idx.find(name) != input_name_idx.end()) { + if (input_name_idx_.find(name) != input_name_idx_.end()) { GELOGI("input %s is exist, update it", name.c_str()); graphStatus ret = UpdateInputDesc(name, input_desc); return ret; @@ -223,17 +218,15 @@ graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &inp return GRAPH_FAILED; } inputs_desc_.push_back(in_desc); - (void)input_name_idx.insert(make_pair(name, index)); - SetAllInputName(input_name_idx); + (void)input_name_idx_.insert(make_pair(name, index)); return GRAPH_SUCCESS; } } graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int num, size_t index) { - auto input_name_idx = GetAllInputName(); for (unsigned int i = 0; i < num; i++) { string input_name = name + std::to_string(i); - GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED, + GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, "Add input tensor_desc is existed. name[%s]", input_name.c_str()); std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); @@ -250,24 +243,22 @@ graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int nu (void)inputs_desc_.insert(inputs_desc_.begin() + index + i, in_desc); // Update index in input_name_idx - for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) { + for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { if (it->second >= (index + i)) { it->second += 1; } } - (void)input_name_idx.insert(make_pair(input_name, i + index)); + (void)input_name_idx_.insert(make_pair(input_name, i + index)); } - SetAllInputName(input_name_idx); return GRAPH_SUCCESS; } graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { - auto input_name_idx = GetAllInputName(); for (unsigned int i = 0; i < num; i++) { string input_name = name + std::to_string(i); - GE_CHK_BOOL_RET_STATUS((input_name_idx.find(input_name) == input_name_idx.end()), GRAPH_FAILED, + GE_CHK_BOOL_RET_STATUS((input_name_idx_.find(input_name) == input_name_idx_.end()), GRAPH_FAILED, "Add input tensor_desc is existed. name[%s]", input_name.c_str()); std::shared_ptr in_desc = ComGraphMakeShared(GeTensorDesc()); @@ -278,13 +269,12 @@ graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int n (void)inputs_desc_.insert(inputs_desc_.begin(), in_desc); // Update index in input_name_idx - for (auto it = input_name_idx.begin(); it != input_name_idx.end(); ++it) { + for (auto it = input_name_idx_.begin(); it != input_name_idx_.end(); ++it) { it->second += 1; } - (void)input_name_idx.insert(make_pair(input_name, 0)); + (void)input_name_idx_.insert(make_pair(input_name, 0)); } - SetAllInputName(input_name_idx); return GRAPH_SUCCESS; } @@ -315,19 +305,10 @@ graphStatus OpDesc::AddOutputDescForward(const string &name, const unsigned int graphStatus OpDesc::AddOptionalInputDesc(const string &name, const ge::GeTensorDesc &input_desc) { if (OpDesc::AddInputDesc(name, input_desc) == GRAPH_FAILED) return GRAPH_FAILED; - vector optional_input_names; - (void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); - optional_input_names.push_back(name); - (void)AttrUtils::SetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); + (void)optional_input_names_.insert(name); return GRAPH_SUCCESS; } -std::vector OpDesc::GetAllOptionalInputName() const { - vector optional_input_names; - (void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); - return optional_input_names; -} - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { GE_CHK_BOOL_RET_STATUS((index < inputs_desc_.size()), GRAPH_FAILED, "The index is invalid. index[%u]", index); @@ -342,12 +323,11 @@ OpDesc::UpdateInputDesc(uint32_t index, const ge::GeTensorDesc &tensor_Desc) { } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescMembersAreEqual(const OpDesc &r_op_desc) const { - return ( - IsEqual(this->GetAllInputName(), r_op_desc.GetAllInputName(), "OpDesc.GetAllInputName()") && - IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && - IsEqual(this->GetAllOptionalInputName(), r_op_desc.GetAllOptionalInputName(), "OpDesc.GetAllOptionalInputName()") && - IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && - IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); + return (IsEqual(this->input_name_idx_, r_op_desc.input_name_idx_, "OpDesc.input_name_idx_") && + IsEqual(this->output_name_idx_, r_op_desc.output_name_idx_, "OpDesc.output_name_idx_") && + IsEqual(this->optional_input_names_, r_op_desc.optional_input_names_, "OpDesc.optional_input_names_") && + IsEqual(this->engine_name_, r_op_desc.engine_name_, "OpDesc.engine_name_") && + IsEqual(this->op_kernel_lib_name_, r_op_desc.op_kernel_lib_name_, "OpDesc.op_kernel_lib_name_")); } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::OpDescAttrsAreEqual(const OpDesc &r_op_desc) const { @@ -421,9 +401,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool OpDesc::operator==(const OpD } graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc &tensor_Desc) { - auto input_name_idx = GetAllInputName(); - auto it = input_name_idx.find(name); - if (it == input_name_idx.end()) { + auto it = input_name_idx_.find(name); + if (it == input_name_idx_.end()) { GELOGW("Cann't find the input desc. name[%s]", name.c_str()); return GRAPH_FAILED; } @@ -443,9 +422,8 @@ graphStatus OpDesc::UpdateInputDesc(const string &name, const ge::GeTensorDesc & } bool OpDesc::InputIsSet(const string &name) const { - auto input_name_idx = GetAllInputName(); - auto it = input_name_idx.find(name); - if (it != input_name_idx.end()) { + auto it = input_name_idx_.find(name); + if (it != input_name_idx_.end()) { GE_IF_BOOL_EXEC(it->second >= inputs_desc_.size(), GELOGE(GRAPH_FAILED, "it->second is invalid."); return false); auto tensor_desc = inputs_desc_[it->second]; GE_IF_BOOL_EXEC(tensor_desc == nullptr, GELOGE(GRAPH_FAILED, "tensor_desc is null."); return false); @@ -463,20 +441,40 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDesc OpDesc::GetInputDesc } GeTensorDesc OpDesc::GetInputDesc(const string &name) const { - auto input_name_idx = GetAllInputName(); - auto it = input_name_idx.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), GeTensorDesc()); + auto it = input_name_idx_.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), GeTensorDesc()); GE_CHK_BOOL_RET_STATUS_NOLOG(it->second < inputs_desc_.size(), GeTensorDesc()); return *(inputs_desc_[it->second].get()); } -GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputNames() const { +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { + GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); + if (inputs_desc_[index] == nullptr) { + return nullptr; + } + if (inputs_desc_[index]->IsValid() != GRAPH_SUCCESS) { + GELOGW("input desc is invalid"); + return nullptr; + } + return inputs_desc_[index]; +} + +GeTensorDescPtr OpDesc::MutableInputDesc(const string &name) const { auto input_name_idx = GetAllInputName(); + auto it = input_name_idx.find(name); + if (it == input_name_idx.end()) { + GELOGW("Failed to get [%s] input desc", name.c_str()); + return nullptr; + } + return MutableInputDesc(it->second); +} + +GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputNames() const { vector names; - if (input_name_idx.empty()) { + if (input_name_idx_.empty()) { return OpDesc::Vistor(shared_from_this(), names); } - for (std::pair input : input_name_idx) { + for (std::pair input : input_name_idx_) { names.push_back(input.first); } return OpDesc::Vistor(shared_from_this(), names); @@ -496,15 +494,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void OpDesc::SetOpEngineName(cons GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY std::string OpDesc::GetOpEngineName() const { return engine_name_; } -GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableInputDesc(uint32_t index) const { - GE_CHK_BOOL_RET_STATUS(index < inputs_desc_.size(), nullptr, "Can't find the input desc %u", index); - if (inputs_desc_[index] == nullptr) { - return nullptr; - } - GE_CHK_BOOL_RET_STATUS(inputs_desc_[index]->IsValid() == GRAPH_SUCCESS, nullptr, "input desc is invalid"); - return inputs_desc_[index]; -} - GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDesc::Vistor OpDesc::GetAllInputsDesc() const { vector temp{}; for (const auto &it : inputs_desc_) { @@ -609,6 +598,15 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOu return outputs_desc_[index]; } +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeTensorDescPtr OpDesc::MutableOutputDesc(const string &name) const { + auto it = output_name_idx_.find(name); + if (it == output_name_idx_.end()) { + GELOGW("Failed to get [%s] output desc", name.c_str()); + return nullptr; + } + return MutableOutputDesc(it->second); +} + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY uint32_t OpDesc::GetAllOutputsDescSize() const { return static_cast(outputs_desc_.size()); } @@ -652,9 +650,8 @@ OpDesc::GetInputDescPtrDfault(uint32_t index) const { } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetInputDescPtr(const string &name) const { - auto input_name_idx = GetAllInputName(); - auto it = input_name_idx.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), shared_ptr()); + auto it = input_name_idx_.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), shared_ptr()); return inputs_desc_[it->second]; } @@ -687,47 +684,26 @@ graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int return GRAPH_SUCCESS; } -bool OpDesc::IsOptionalInput(const string &name) const { - vector optional_input_names; - (void)AttrUtils::GetListStr(this, ATTR_NAME_OPT_INPUT, optional_input_names); - for (auto &item : optional_input_names) { - if (item == name) { - return true; - } +void OpDesc::RemoveInputDesc(uint32_t index) { + while (inputs_desc_.size() > index) { + inputs_desc_.pop_back(); } - return false; } -bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } - -std::map OpDesc::GetAllInputName() const { - std::map input_name_idx; - std::vector key; - std::vector value; - (void)AttrUtils::GetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key); - (void)AttrUtils::GetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value); - - if (key.size() != value.size()) { - GE_LOGE("twe vector size is different. key_size: %zu, value_size: %zu.", key.size(), value.size()); - } else { - for (uint32_t i = 0; i < key.size(); ++i) { - input_name_idx.insert(std::pair(key.at(i), value.at(i))); - } +void OpDesc::RemoveOutputDesc(uint32_t index) { + while (outputs_desc_.size() > index) { + outputs_desc_.pop_back(); } - return input_name_idx; } -void OpDesc::SetAllInputName(const std::map &input_name_idx) { - std::vector key; - std::vector value; - for (auto &item : input_name_idx) { - key.emplace_back(item.first); - value.emplace_back(item.second); - } - (void)AttrUtils::SetListStr(this, ATTR_NAME_INPUT_NAME_IDX_KEY, key); - (void)AttrUtils::SetListInt(this, ATTR_NAME_INPUT_NAME_IDX_VALUE, value); +bool OpDesc::IsOptionalInput(const string &name) const { + return optional_input_names_.find(name) != optional_input_names_.end(); } +bool OpDesc::IsOptionalInput(uint32_t index) const { return IsOptionalInput(GetInputNameByIndex(index)); } + +std::map OpDesc::GetAllInputName() const { return input_name_idx_; } + std::map OpDesc::GetAllOutputName() { return output_name_idx_; } bool OpDesc::UpdateInputName(std::map input_name_idx) { @@ -737,7 +713,6 @@ bool OpDesc::UpdateInputName(std::map input_name_idx) { auto factory_map_size = input_name_idx.size(); // It indicates that some inputs have no optionalname. // The redundant optionalname of factory needs to be deleted and then assigned - auto all_input_name_idx = GetAllInputName(); if (input_map_size < factory_map_size) { GELOGI("UpdateInputName org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); @@ -750,18 +725,17 @@ bool OpDesc::UpdateInputName(std::map input_name_idx) { } if (input_name_idx.size() == input_map_size) { GELOGI("UpdateInputName"); - all_input_name_idx = input_name_idx; + input_name_idx_ = input_name_idx; } else { ret = false; GELOGW("after UpdateInputName factoryName map size : %zu", input_name_idx.size()); } } else if (input_map_size == factory_map_size) { - all_input_name_idx = input_name_idx; + input_name_idx_ = input_name_idx; } else { ret = false; GELOGW("org inputname map size: %zu, factory inputname map size: %zu", input_map_size, factory_map_size); } - SetAllInputName(all_input_name_idx); return ret; } @@ -882,36 +856,41 @@ graphStatus OpDesc::CommonVerify() const { // Checking shape of all inputs vector ishape = GetInputDescPtr(iname)->GetShape().GetDims(); for (int64_t dim : ishape) { - GE_CHK_BOOL_RET_STATUS(dim >= -2, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", - iname.c_str()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + dim < -2, ErrorManager::GetInstance().ATCReportErrMessage( + "E19014", {"opname", "value", "reason"}, + {GetName(), "input " + iname + " shape", "contains negative or zero dimension"}); + return GRAPH_FAILED, "Op[%s]'s input %s shape contains negative or zero dimension.", GetName().c_str(), + iname.c_str()); } } // Check all attributes defined const auto &all_attributes = GetAllAttrs(); for (const auto &name : GetAllAttrNames()) { - GE_CHK_BOOL_RET_STATUS(all_attributes.find(name) != all_attributes.end(), GRAPH_FAILED, - "operator attribute %s is empty.", name.c_str()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + all_attributes.find(name) == all_attributes.end(), + ErrorManager::GetInstance().ATCReportErrMessage("E19014", {"opname", "value", "reason"}, + {GetName(), "attribute " + name, "is empty"}); + return GRAPH_FAILED, "operator attribute %s is empty.", name.c_str()); } return GRAPH_SUCCESS; } GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY string OpDesc::GetInputNameByIndex(uint32_t index) const { - auto input_name_idx = GetAllInputName(); - auto it = input_name_idx.begin(); - for (; it != input_name_idx.end(); ++it) { + auto it = input_name_idx_.begin(); + for (; it != input_name_idx_.end(); ++it) { if (it->second == index) { break; } } - GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx.end(), ""); + GE_CHK_BOOL_RET_STATUS_NOLOG(it != input_name_idx_.end(), ""); return it->first; } int OpDesc::GetInputIndexByName(const string &name) const { - auto input_name_idx = GetAllInputName(); - auto it_find = input_name_idx.find(name); - GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx.end(), -1); + auto it_find = input_name_idx_.find(name); + GE_CHK_BOOL_RET_STATUS_NOLOG(it_find != input_name_idx_.end(), -1); return static_cast(it_find->second); } @@ -1204,12 +1183,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDesc::GetIsInputCo GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDesc::RestoreInputNameIdx(const string &name, const int &index) { - auto input_name_idx = GetAllInputName(); - if (input_name_idx.find(name) != input_name_idx.end()) { + if (input_name_idx_.find(name) != input_name_idx_.end()) { GELOGI("Restore input name index is existed. name[%s]", name.c_str()); } - (void)input_name_idx.insert(make_pair(name, index)); - SetAllInputName(input_name_idx); + (void)input_name_idx_.insert(make_pair(name, index)); return GRAPH_SUCCESS; } diff --git a/src/common/graph/operator.cc b/src/common/graph/operator.cc index 1ac8d41d..03d4221e 100644 --- a/src/common/graph/operator.cc +++ b/src/common/graph/operator.cc @@ -21,7 +21,7 @@ #include #include #include -#include "array_ops.h" +#include "./array_ops.h" #include "debug/ge_log.h" #include "debug/ge_op_types.h" #include "debug/ge_util.h" @@ -36,6 +36,8 @@ #include "graph/op_desc.h" #include "graph/runtime_inference_context.h" #include "graph/usr_types.h" +#include "graph/utils/node_utils.h" +#include "graph/debug/ge_attr_define.h" #include "utils/graph_utils.h" #include "utils/op_desc_utils.h" #include "utils/tensor_adapter.h" @@ -54,11 +56,13 @@ using std::string; using std::to_string; using std::vector; +/*lint -save -e529 -e728*/ +/*lint -e446 -e732*/ +/*lint -e665*/ namespace ge { class OpIO { public: - explicit OpIO(const string &name, int index, const OperatorImplPtr &owner) - : name_(name), index_(index), owner_(owner) {} + OpIO(const string &name, int index, const OperatorImplPtr &owner) : name_(name), index_(index), owner_(owner) {} ~OpIO() = default; @@ -546,56 +550,46 @@ Operator &Operator::AddControlInput(const Operator &src_oprt) { } graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) const { - if (operator_impl_ == nullptr) { - GELOGE(GRAPH_FAILED, "operator impl is nullptr."); - return GRAPH_FAILED; - } - ge::ConstNodePtr node_ptr = operator_impl_->GetNode(); - if (node_ptr) { + GE_CHECK_NOTNULL(operator_impl_); + auto node_ptr = operator_impl_->GetNode(); + if (node_ptr != nullptr) { // For inner compute graph auto op_desc = node_ptr->GetOpDesc(); - if (op_desc == nullptr) { - GELOGE(GRAPH_FAILED, "op_desc is nullptr."); - return GRAPH_FAILED; - } + GE_CHECK_NOTNULL(op_desc); auto index = op_desc->GetInputIndexByName(dst_name); auto in_data_anchor = node_ptr->GetInDataAnchor(index); - if (in_data_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "in_data_anchor is nullptr."); - return GRAPH_FAILED; - } + GE_CHECK_NOTNULL(in_data_anchor); auto out_data_anchor = in_data_anchor->GetPeerOutAnchor(); - if (out_data_anchor == nullptr) { - GELOGE(GRAPH_FAILED, "out_data_anchor is nullptr."); - return GRAPH_FAILED; - } - std::shared_ptr peer_node_ptr = out_data_anchor->GetOwnerNode(); - if (peer_node_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "peer_node_ptr is nullptr."); - return GRAPH_FAILED; - } - ge::OperatorImplPtr operator_impl_ptr = nullptr; - operator_impl_ptr = ComGraphMakeShared(peer_node_ptr); - if (operator_impl_ptr == nullptr) { - GELOGE(GRAPH_FAILED, "OperatorImpl make shared failed"); - return GRAPH_FAILED; - } - Operator const_op(std::move(operator_impl_ptr)); - if (peer_node_ptr->GetOpDesc() != nullptr) { - const auto &op_descType = peer_node_ptr->GetOpDesc()->GetType(); - if (op_descType == CONSTANTOP) { - return const_op.GetAttr(op::Constant::name_attr_value(), data); - } else if (op_descType == CONSTANT) { - return const_op.GetAttr(op::Const::name_attr_value(), data); + GE_CHECK_NOTNULL(out_data_anchor); + auto peer_node = out_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_node); + auto peer_op_desc = peer_node->GetOpDesc(); + GE_CHECK_NOTNULL(peer_op_desc); + auto peer_op_type = peer_op_desc->GetType(); + if (peer_op_type == CONSTANTOP || peer_op_type == CONSTANT) { + auto const_op_impl = ComGraphMakeShared(peer_node); + GE_CHECK_NOTNULL(const_op_impl); + Operator const_op(std::move(const_op_impl)); + return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); + } else if (peer_op_type == DATA) { + auto parent_node = NodeUtils::GetParentInput(peer_node); + while ((parent_node != nullptr) && (parent_node->GetType() == DATA)) { + parent_node = NodeUtils::GetParentInput(parent_node); + } + if ((parent_node != nullptr) && + ((parent_node->GetType() == CONSTANT) || (parent_node->GetType() == CONSTANTOP))) { + auto const_op_impl = ComGraphMakeShared(parent_node); + GE_CHECK_NOTNULL(const_op_impl); + Operator const_op(std::move(const_op_impl)); + return const_op.GetAttr(ATTR_NAME_WEIGHTS, data); } } - // Try get from runtime inference context auto session_id = std::to_string(GetContext().SessionId()); RuntimeInferenceContext *runtime_infer_ctx = nullptr; if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); - auto ret = runtime_infer_ctx->GetTensor(peer_node_ptr->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); + auto ret = runtime_infer_ctx->GetTensor(peer_node->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); if (ret == GRAPH_SUCCESS) { return GRAPH_SUCCESS; } @@ -604,6 +598,8 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co // For outer graph return GetInputConstDataOut(dst_name, data); } + auto op_name = operator_impl_->GetName(); + GELOGW("node[%s]'s input[%s]'s peer node is not const", op_name.c_str(), dst_name.c_str()); return GRAPH_FAILED; } graphStatus Operator::GetInputConstDataOut(const string &dst_name, Tensor &data) const { @@ -914,7 +910,7 @@ OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; } GELOGW("set attr name %s failed.", name.c_str()); \ } \ return *this; \ - } + } // lint !e665 #define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \ graphStatus Operator::GetAttr(const string &name, ArgType attr_value) const { \ @@ -927,7 +923,7 @@ OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; } return GRAPH_FAILED; \ } \ return GRAPH_SUCCESS; \ - } + } // lint !e665 void Operator::BreakConnect() const { if (operator_impl_ == nullptr) { @@ -948,7 +944,7 @@ void Operator::BreakConnect() const { if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ GELOGW("reg attr name %s failed.", name.c_str()); \ } \ - } + } // lint !e665 OP_ATTR_SET_IMP(int64_t, Int) OP_ATTR_SET_IMP(int32_t, Int) @@ -969,22 +965,22 @@ OP_ATTR_SET_IMP(const vector> &, ListListInt) OP_ATTR_SET_IMP(float, Float) OP_ATTR_GET_IMP(float &, Float) OP_ATTR_SET_IMP(const vector &, ListFloat) -OP_ATTR_GET_IMP(vector &, ListFloat) +OP_ATTR_GET_IMP(vector &, ListFloat) // lint !e665 OP_ATTR_SET_IMP(bool, Bool) OP_ATTR_GET_IMP(bool &, Bool) OP_ATTR_SET_IMP(const vector &, ListBool) -OP_ATTR_GET_IMP(vector &, ListBool) +OP_ATTR_GET_IMP(vector &, ListBool) // lint !e665 OP_ATTR_SET_IMP(const string &, Str) OP_ATTR_GET_IMP(string &, Str) OP_ATTR_SET_IMP(const vector &, ListStr) -OP_ATTR_GET_IMP(vector &, ListStr) +OP_ATTR_GET_IMP(vector &, ListStr) // lint !e665 OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs) OP_ATTR_SET_IMP(const vector &, ListNamedAttrs) -OP_ATTR_GET_IMP(vector &, ListNamedAttrs) +OP_ATTR_GET_IMP(vector &, ListNamedAttrs) // lint !e665 OP_ATTR_REG_IMP(int64_t, Int) OP_ATTR_REG_IMP(const vector &, ListInt) @@ -1547,3 +1543,5 @@ void GraphUtils::BreakConnect(const std::map &all_node } } } // namespace ge +/*lint +e446 +e732*/ +/*lint +e665*/ diff --git a/src/common/graph/opsproto/opsproto_manager.cc b/src/common/graph/opsproto/opsproto_manager.cc index c2afc191..4c8c1be5 100644 --- a/src/common/graph/opsproto/opsproto_manager.cc +++ b/src/common/graph/opsproto/opsproto_manager.cc @@ -31,7 +31,9 @@ OpsProtoManager *OpsProtoManager::Instance() { } bool OpsProtoManager::Initialize(const std::map &options) { + /*lint -e1561*/ auto proto_iter = options.find("ge.opsProtoLibPath"); + /*lint +e1561*/ if (proto_iter == options.end()) { GELOGW("ge.opsProtoLibPath option not set, return."); return false; diff --git a/src/common/graph/option/ge_context.cc b/src/common/graph/option/ge_context.cc index f5ebdeee..f5f5e4c9 100644 --- a/src/common/graph/option/ge_context.cc +++ b/src/common/graph/option/ge_context.cc @@ -85,6 +85,8 @@ uint32_t GEContext::DeviceId() { return device_id_; } uint64_t GEContext::TraceId() { return trace_id_; } +void GEContext::SetSessionId(uint64_t session_id) { session_id_ = session_id; } + void GEContext::SetCtxDeviceId(uint32_t device_id) { device_id_ = device_id; } } // namespace ge diff --git a/src/common/graph/ref_relation.cc b/src/common/graph/ref_relation.cc index b3cf37af..7785bc43 100644 --- a/src/common/graph/ref_relation.cc +++ b/src/common/graph/ref_relation.cc @@ -37,7 +37,7 @@ const string kWhile = "While"; const string kIf = "If"; const string kCase = "Case"; -const int kMaxElementNum = 100; +const uint16_t kMaxElementNum = 100; std::unordered_set function_op = {kWhile, kIf, kCase}; } // namespace @@ -170,6 +170,7 @@ graphStatus RefRelations::Impl::BuildRefRelationsForWhile( // data_nodes has been sorted // for while, input num must be same as output num auto input_num = root_node->GetAllInDataAnchorsSize(); + NodePtr netoutput = nullptr; size_t ref_i = 0; while (ref_i < input_num) { @@ -212,10 +213,44 @@ graphStatus RefRelations::Impl::BuildRefRelationsForWhile( cell_netoutput_in.in_out = NODE_IN; cell_netoutput_in.in_out_idx = ele.second; ref_i_all_refs.emplace_back(cell_netoutput_in); + netoutput = ele.first; } node_refs.emplace_back(ref_i_all_refs); ref_i++; } + /* There exist scene like the follows, it means data0 data1 netoutput 0'th + * and 1'th tensor should be the same addr. + * Data0 Data1 + * \/ + * /\ + * netoutput + */ + if (netoutput == nullptr) { + return GRAPH_SUCCESS; + } + for (const auto &in_anchor : netoutput->GetAllInDataAnchors()) { + auto peer_out_data_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_out_data_anchor == nullptr) { + continue; + } + auto peer_out_data_node = peer_out_data_anchor->GetOwnerNode(); + if (peer_out_data_node == nullptr || peer_out_data_node->GetOpDesc() == nullptr) { + GELOGW("Node[%s]\'s peer_out_data_node or peer_out_data_node desc is null", (netoutput->GetName()).c_str()); + continue; + } + if (peer_out_data_node->GetType() != DATA) { + continue; + } + auto in_data_anchor_idx = in_anchor->GetIdx(); + auto net_in_desc = netoutput->GetOpDesc()->MutableInputDesc(static_cast(in_data_anchor_idx)); + int ref_d; + int ref_n; + (void)AttrUtils::GetInt(peer_out_data_node->GetOpDesc(), kRefIndex, ref_d); + (void)AttrUtils::GetInt(net_in_desc, kRefIndex, ref_n); + + node_refs[ref_d].insert(node_refs[ref_d].end(), node_refs[ref_n].begin(), node_refs[ref_n].end()); + node_refs[ref_n].insert(node_refs[ref_n].end(), node_refs[ref_d].begin(), node_refs[ref_d].end()); + } return GRAPH_SUCCESS; } @@ -242,6 +277,10 @@ void RefRelations::Impl::GetDataAndNetoutputOfSubGraph(const ge::ComputeGraph &r int sub_graph_idx = 0; for (const auto &name : sub_graph_names) { auto sub_graph = root_graph.GetSubgraph(name); + if (sub_graph == nullptr) { + GELOGW("Can not find the sub graph %s for root graph %s.", name.c_str(), root_graph.GetName().c_str()); + continue; + } for (const auto &sub_graph_node : sub_graph->GetDirectNode()) { auto sub_graph_node_type = sub_graph_node->GetType(); @@ -296,6 +335,9 @@ graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector &data_n data_nodes.pop_back(); int ref_idx = 0; (void)AttrUtils::GetInt(data->GetOpDesc(), kRefIndex, ref_idx); + if (ref_idx >= static_cast(classed_data_nodes.size())) { + return GRAPH_FAILED; + } classed_data_nodes[ref_idx].emplace_back(data); } return GRAPH_SUCCESS; @@ -317,7 +359,7 @@ graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( } int ref_o; if (AttrUtils::GetInt(in_desc, kRefIndex, ref_o)) { - if (ref_o >= kMaxElementNum) { + if (ref_o >= static_cast(classed_netoutput_nodes.size())) { return GRAPH_FAILED; } classed_netoutput_nodes[ref_o].emplace_back( @@ -349,8 +391,9 @@ graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { vector netoutput_nodes; // Get data and netoutput of sub_graph GetDataAndNetoutputOfSubGraph(root_graph, data_nodes, netoutput_nodes, sub_graph_names, node_type); - vector> classed_data_nodes(kMaxElementNum); // according to ref_idx - vector>> classed_netoutput_nodes(kMaxElementNum); // according to ref_idx + size_t max_elem_num = (data_nodes.size() > kMaxElementNum) ? data_nodes.size() : kMaxElementNum; + vector> classed_data_nodes(max_elem_num); // according to ref_idx + vector>> classed_netoutput_nodes(max_elem_num); // according to ref_idx status = ProcessSubgraphDataNodes(data_nodes, classed_data_nodes); if (status != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "classfy data nodes failed!"); diff --git a/src/common/graph/runtime_inference_context.cc b/src/common/graph/runtime_inference_context.cc index 916da564..95068481 100644 --- a/src/common/graph/runtime_inference_context.cc +++ b/src/common/graph/runtime_inference_context.cc @@ -30,6 +30,7 @@ graphStatus RuntimeInferenceContext::CreateContext(const std::string &context_id return GRAPH_FAILED; } + std::lock_guard lk(ctx_mu_); auto emplace_ret = contexts_.emplace(context_id, std::move(ctx)); if (!emplace_ret.second) { GELOGE(GRAPH_FAILED, "Old context not destroyed"); diff --git a/src/common/graph/shape_refiner.cc b/src/common/graph/shape_refiner.cc index edf426a5..479ec1cb 100644 --- a/src/common/graph/shape_refiner.cc +++ b/src/common/graph/shape_refiner.cc @@ -37,6 +37,162 @@ namespace ge { namespace { +const uint32_t kWhileBodySubGraphIdx = 1; + +graphStatus ReverseBrushWhileBodySubGraph(const ConstNodePtr &node) { + GELOGD("Enter reverse brush while body subgraph process!"); + + auto sub_graph_body = NodeUtils::GetSubgraph(*node, kWhileBodySubGraphIdx); + if (sub_graph_body == nullptr) { + GELOGE(GRAPH_FAILED, "Get while body graph failed!"); + return GRAPH_FAILED; + } + + for (const auto &node_sub : sub_graph_body->GetAllNodes()) { + for (size_t i = 0; i < node_sub->GetAllInDataAnchorsSize(); i++) { + auto input_desc = node_sub->GetOpDesc()->MutableInputDesc(i); + (void)input_desc->SetUnknownDimNumShape(); + } + for (size_t i = 0; i < node_sub->GetAllOutDataAnchorsSize(); i++) { + auto output_desc = node_sub->GetOpDesc()->MutableOutputDesc(i); + (void)output_desc->SetUnknownDimNumShape(); + } + } + + return GRAPH_SUCCESS; +} + +graphStatus UpdataOutputForMultiBatcch(const ConstNodePtr &node, + std::vector> &ref_out_tensors) { + // check sub_graph shape. Get max for update. + for (size_t i = 0; i < ref_out_tensors.size(); ++i) { + if (ref_out_tensors[i].empty()) { + continue; + } + + int64_t max_size = 0; + size_t max_shape_index = 0; + auto &ref_out_tensor = ref_out_tensors[i].at(0); + const auto &ref_out_tensor_shape = ref_out_tensor.MutableShape(); + for (size_t j = 0; j < ref_out_tensors[i].size(); ++j) { + auto &tensor = ref_out_tensors[i].at(j); + if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { + GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); + return GRAPH_FAILED; + } + + auto shape = tensor.MutableShape(); + if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { + GELOGE(GRAPH_FAILED, "node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", + node->GetName().c_str(), i, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); + return GRAPH_FAILED; + } + + int64_t size = 1; + for (auto dim : shape.GetDims()) { + if (INT64_MAX / dim < size) { + GELOGE(PARAM_INVALID, "The shape size overflow"); + return PARAM_INVALID; + } + size *= dim; + } + + if (size > max_size) { + max_size = size; + max_shape_index = j; + } + } + + (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensors[i].at(max_shape_index)); + } + + return GRAPH_SUCCESS; +} + +graphStatus UpdateParentNodeForBranch(const ConstNodePtr &node, + std::vector> &ref_out_tensors) { + GELOGD("Enter update parent node shape for class branch op process"); + if (node->GetOpDesc()->HasAttr(ATTR_NAME_BATCH_NUM)) { + return UpdataOutputForMultiBatcch(node, ref_out_tensors); + } + + // check sub_graph shape.If not same ,do unknown shape process + for (size_t i = 0; i < ref_out_tensors.size(); i++) { + if (ref_out_tensors[i].empty()) { + continue; + } + auto ref_out_tensor = ref_out_tensors[i].at(0); + ge::GeShape &ref_out_tensor_shape = ref_out_tensor.MutableShape(); + for (auto &tensor : ref_out_tensors[i]) { + if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { + GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype output", node->GetName().c_str()); + return GRAPH_FAILED; + } + auto shape = tensor.MutableShape(); + if (shape.GetDims().size() != ref_out_tensor_shape.GetDims().size()) { + GELOGD("node is %s, i : %d, shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, + shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); + ref_out_tensor_shape = GeShape(UNKNOWN_RANK); + break; + } + for (size_t j = 0; j < ref_out_tensor_shape.GetDims().size(); j++) { + if (ref_out_tensor_shape.GetDim(j) == shape.GetDim(j)) { + continue; + } + GELOGD("node is %s, i : %d, j: %d ,shape size: %lu, ref_out_tensor_shape size: %lu", node->GetName().c_str(), i, + j, shape.GetShapeSize(), ref_out_tensor_shape.GetShapeSize()); + (void)ref_out_tensor_shape.SetDim(j, UNKNOWN_DIM); + } + } + (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); + } + return GRAPH_SUCCESS; +} + +graphStatus UpdateParentNodeForWhile(const ConstNodePtr &node, std::vector> &ref_data_tensors, + std::vector> &ref_out_tensors) { + GELOGD("Enter update parent node shape for class while op process"); + if (ref_data_tensors.size() != ref_out_tensors.size()) { + GELOGE(GRAPH_FAILED, "while op [%s] input number[%zu] and output number[%zu] is not same!", node->GetName().c_str(), + ref_data_tensors.size(), ref_out_tensors.size()); + return GRAPH_FAILED; + } + for (size_t i = 0; i < ref_data_tensors.size(); i++) { + if (ref_out_tensors[i].size() != 1) { + GELOGE(GRAPH_FAILED, "while op, every output should only find one output tensor in all graph!"); + return GRAPH_FAILED; + } + } + bool is_need_reverse_brush = false; + // check input and output + for (size_t i = 0; i < ref_out_tensors.size(); i++) { + if (ref_out_tensors[i].empty()) { + continue; + } + auto ref_out_tensor = ref_out_tensors[i].at(0); + auto tmp_shape = ref_out_tensor.MutableShape(); + // ref_i's data and output tensor shape should be same + for (auto &tensor : ref_data_tensors[i]) { + if (ref_out_tensor.GetDataType() != tensor.GetDataType()) { + GELOGE(GRAPH_FAILED, "node[%s] does not support diff dtype or format output.", node->GetName().c_str()); + return GRAPH_FAILED; + } + auto shape = tensor.MutableShape(); + if (shape.GetDims() != tmp_shape.GetDims()) { + ref_out_tensor.SetUnknownDimNumShape(); + is_need_reverse_brush = true; + break; + } + } + (void)node->GetOpDesc()->UpdateOutputDesc(i, ref_out_tensor); + } + // reverse refresh while body shape + if (is_need_reverse_brush) { + return ReverseBrushWhileBodySubGraph(node); + } + return GRAPH_SUCCESS; +} + graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { auto op_desc = node->GetOpDesc(); auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); @@ -66,11 +222,14 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { node->GetName().c_str()); return GRAPH_FAILED; } - if (!AttrUtils::GetInt(node_sub->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { + if (!AttrUtils::GetInt(data_opdesc, ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { GE_LOGE("Invalid data node on the sub graph %s parent node %s, no ref-index attribute", name.c_str(), node->GetName().c_str()); return GRAPH_FAILED; } + if (data_opdesc->HasAttr(ATTR_MBATCH_ORIGIN_INPUT_DIMS)) { + continue; + } auto input_desc = op_desc->MutableInputDesc(ref_i); if (input_desc == nullptr) { GE_LOGE( @@ -98,6 +257,37 @@ graphStatus UpdateSubGraphDataNodes(const ConstNodePtr &node) { } return GRAPH_SUCCESS; } + +graphStatus FindSubgraphDataAndNetoutput(std::shared_ptr &sub_graph, NodePtr &netoutput, + const ConstNodePtr &node, + std::vector> &ref_data_tensors) { + auto sub_nodes = sub_graph->GetDirectNode(); + for (size_t i = sub_nodes.size(); i > 0; --i) { + auto sub_node = sub_nodes.at(i - 1); + if (sub_node->GetType() == NETOUTPUT) { + netoutput = sub_node; + } + if (sub_node->GetType() == DATA) { + if (sub_node->GetOpDesc() == nullptr) { + return GRAPH_FAILED; + } + + int ref_i; + if (!AttrUtils::GetInt(sub_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, ref_i)) { + GELOGE(GRAPH_FAILED, "subgraph data node[%s] has no parent node!", sub_node->GetName().c_str()); + return GRAPH_FAILED; + } + if (ref_i < 0 || static_cast(ref_i) >= node->GetAllInDataAnchorsSize()) { + GELOGE(GRAPH_FAILED, "data node[%s]'s ref index[%d] is not in range [0, %zu)!", sub_node->GetName().c_str(), + ref_i, node->GetAllInDataAnchorsSize()); + return GRAPH_FAILED; + } + ref_data_tensors[ref_i].emplace_back(sub_node->GetOpDesc()->GetOutputDesc(0)); + } + } + return GRAPH_SUCCESS; +} + graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { auto op_desc = node->GetOpDesc(); auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); @@ -105,7 +295,10 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { return GRAPH_SUCCESS; } + std::vector> ref_data_tensors(node->GetAllInDataAnchorsSize()); + std::vector> ref_out_tensors(node->GetAllOutDataAnchorsSize()); auto root_graph = GraphUtils::FindRootGraph(node->GetOwnerComputeGraph()); + for (const auto &name : sub_graph_names) { if (name.empty()) { GELOGW("The node %s contains empty subgraph instance name", node->GetName().c_str()); @@ -117,13 +310,9 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { return GRAPH_FAILED; } NodePtr netoutput = nullptr; - auto sub_nodes = sub_graph->GetDirectNode(); - for (size_t i = sub_nodes.size(); i > 0; --i) { - auto sub_node = sub_nodes.at(i - 1); - if (sub_node->GetType() == NETOUTPUT) { - netoutput = sub_node; - break; - } + auto ret = FindSubgraphDataAndNetoutput(sub_graph, netoutput, node, ref_data_tensors); + if (ret != GRAPH_SUCCESS) { + return ret; } if (netoutput == nullptr) { GE_LOGE("No NetOutput node on sub graph %s, parent node %s", name.c_str(), node->GetName().c_str()); @@ -150,22 +339,23 @@ graphStatus UpdateParentNodeOutTensor(const ConstNodePtr &node) { continue; } GELOGI("Parent node index of edge desc is %d", ref_i); - auto output_desc = op_desc->MutableOutputDesc(static_cast(ref_i)); - if (output_desc == nullptr) { - GE_LOGE( - "The ref index(%d) on the input %d of netoutput %s on the sub graph %s " - "parent node %s are incompatible, outputs num %u", - ref_i, edge_anchor->GetIdx(), netoutput->GetName().c_str(), name.c_str(), node->GetName().c_str(), - node->GetAllOutDataAnchorsSize()); + if (ref_i < 0 || static_cast(ref_i) >= node->GetAllOutDataAnchorsSize()) { return GRAPH_FAILED; } - op_desc->UpdateOutputDesc(edge_anchor->GetIdx(), *edge_desc); + ref_out_tensors[ref_i].emplace_back(*edge_desc); } } - return GRAPH_SUCCESS; + + if (node->GetType() == WHILE) { + return UpdateParentNodeForWhile(node, ref_data_tensors, ref_out_tensors); + } + return UpdateParentNodeForBranch(node, ref_out_tensors); } } // namespace void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::string &phase) { + if (!IsLogEnable(GE, DLOG_DEBUG)) { + return; + } if (node == nullptr) { GELOGE(GRAPH_FAILED, "node is null"); return; @@ -185,6 +375,18 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str TypeUtils::FormatToSerialString(input_desc->GetFormat()) + " "; } str += input_desc_str; + + input_desc_str = "input origin shape: "; + for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { + input_desc_str += "["; + for (int64_t dim : input_desc->GetOriginShape().GetDims()) { + input_desc_str += std::to_string(dim) + " "; + } + input_desc_str += "]"; + input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) + ":" + + TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) + " "; + } + str += input_desc_str; } if (op_desc->GetAllOutputsDescSize() != 0) { @@ -202,6 +404,21 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str TypeUtils::FormatToSerialString(output_desc->GetFormat()) + " "; } str += output_desc_str; + + output_desc_str = "output origin shape: "; + for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { + if (output_desc == nullptr) { + continue; + } + output_desc_str += "["; + for (int64_t dim : output_desc->GetOriginShape().GetDims()) { + output_desc_str += std::to_string(dim) + " "; + } + output_desc_str += "]"; + output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) + ":" + + TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) + " "; + } + str += output_desc_str; } GELOGD("Shape dump [%s], Node name: [%s]. %s", phase.c_str(), node->GetName().c_str(), str.c_str()); } @@ -222,7 +439,6 @@ graphStatus ShapeRefiner::InferShapeAndType(const ConstNodePtr &node, Operator & return ret; } } - // Get infer func and execute ret = op_desc->CallInferFunc(op); if (ret == GRAPH_PARAM_INVALID) { @@ -329,6 +545,9 @@ InferenceContextPtr CreateInferenceContext(const std::unordered_map context_map; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ShapeRefiner::ClearContextMap() { context_map.clear(); } + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferShapeAndType(const NodePtr &node) { return InferShapeAndType(node, true); } @@ -339,19 +558,20 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh GELOGE(GRAPH_FAILED, "Verifying %s failed.", node->GetName().c_str()); return GRAPH_FAILED; } + PrintInOutTensorShape(node, "before_infershape"); + Operator op = OpDescUtils::CreateOperatorFromNode(node); - auto inference_context = CreateInferenceContext(context_map, node); - if (inference_context == nullptr) { - GELOGE(GRAPH_FAILED, "inference context is null"); - return GRAPH_FAILED; + bool is_unknown_graph = node->GetOwnerComputeGraph()->GetGraphUnknownFlag(); + if (!is_unknown_graph) { + auto inference_context = CreateInferenceContext(context_map, node); + if (inference_context == nullptr) { + GELOGE(GRAPH_FAILED, "inference context is null"); + return GRAPH_FAILED; + } + GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); + op.SetInferenceContext(inference_context); } - GELOGD("create context for node:%s, marks %zu", node->GetName().c_str(), inference_context->GetMarks().size()); - - PrintInOutTensorShape(node, "before_infershape"); - - Operator op = OpDescUtils::CreateOperatorFromNode(node); - op.SetInferenceContext(inference_context); graphStatus status = InferShapeAndType(node, op, before_subgraph); if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { (void)ge::NodeUtils::UpdatePeerNodeInputDesc(node); @@ -359,16 +579,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh GELOGE(GRAPH_FAILED, "%s call infer function failed.", node->GetName().c_str()); return GRAPH_FAILED; } - - auto ctx_after_infer = op.GetInferenceContext(); - if (ctx_after_infer != nullptr) { - GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); - if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { - GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); - (void)context_map.emplace(node, ctx_after_infer); + if (!is_unknown_graph) { + auto ctx_after_infer = op.GetInferenceContext(); + if (ctx_after_infer != nullptr) { + GELOGD("[%s] after infershape. mark:%zu", node->GetName().c_str(), ctx_after_infer->GetMarks().size()); + if (!ctx_after_infer->GetOutputHandleShapesAndTypes().empty() || !ctx_after_infer->GetMarks().empty()) { + GELOGD("[%s] set inference context after. mark:%zu", node->GetName().c_str(), + ctx_after_infer->GetMarks().size()); + (void)context_map.emplace(node, ctx_after_infer); + } } } - PrintInOutTensorShape(node, "after_infershape"); return GRAPH_SUCCESS; diff --git a/src/common/graph/stub/Makefile b/src/common/graph/stub/Makefile deleted file mode 100644 index 832adcd5..00000000 --- a/src/common/graph/stub/Makefile +++ /dev/null @@ -1,6 +0,0 @@ -inc_path := $(shell pwd)/inc/external/ -out_path := $(shell pwd)/out/atc/lib64/stub/ -stub_path := $(shell pwd)/common/graph/stub/ - -mkdir_stub := $(shell mkdir -p $(out_path)) -graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) diff --git a/src/common/graph/stub/gen_stubapi.py b/src/common/graph/stub/gen_stubapi.py deleted file mode 100644 index 6185c479..00000000 --- a/src/common/graph/stub/gen_stubapi.py +++ /dev/null @@ -1,573 +0,0 @@ -import os -import re -import sys -import logging - -logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s', - level=logging.INFO) - -""" - this attr is used for symbol table visible -""" -GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY' - -""" - generate stub func body by return type -""" -RETURN_STATEMENTS = { - 'graphStatus': ' return GRAPH_SUCCESS;', - 'Status': ' return SUCCESS;', - 'Graph': ' return Graph();', - 'Graph&': ' return *this;', - 'Format': ' return Format();', - 'Format&': ' return *this;', - 'Shape': ' return Shape();', - 'Shape&': ' return *this;', - 'TensorDesc': ' return TensorDesc();', - 'TensorDesc&': ' return *this;', - 'Tensor': ' return Tensor();', - 'Tensor&': ' return *this;', - 'Operator': ' return Operator();', - 'Operator&': ' return *this;', - 'Ptr': ' return nullptr;', - 'std::string': ' return "";', - 'std::string&': ' return "";', - 'string': ' return "";', - 'int': ' return 0;', - 'DataType': ' return DT_FLOAT;', - 'InferenceContextPtr': ' return nullptr;', - 'SubgraphBuilder': ' return nullptr;', - 'OperatorImplPtr': ' return nullptr;', - 'OutHandler': ' return nullptr;', - 'std::vector': ' return {};', - 'std::vector': ' return {};', - 'std::map': ' return {};', - 'uint32_t': ' return 0;', - 'int64_t': ' return 0;', - 'uint64_t': ' return 0;', - 'size_t': ' return 0;', - 'float': ' return 0.0f;', - 'bool': ' return false;', -} - -""" - max code len per line in hua_wei software programming specifications -""" -max_code_len_per_line = 100 - -""" - white_list_for_debug, include_dir_key_words is to - determines which header files to generate cc files from - when DEBUG on -""" -white_list_for_debug = ["operator.h", "tensor.h", - "graph.h", "operator_factory.h", - "ge_ir_build.h"] -include_dir_key_words = ["ge", "graph"] -DEBUG = True - - -def need_generate_func(func_line): - """ - :param func_line: - :return: - """ - if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \ - or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"): - return False - return True - - -def file_endswith_white_list_suffix(file): - """ - :param file: - :return: - """ - if DEBUG: - for suffix in white_list_for_debug: - if file.endswith(suffix): - return True - return False - else: - return True - - -""" - belows are patterns used for analyse .h file -""" -# pattern function -pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after -([a-zA-Z~_] # void int likely -.* -[)] #we find ) -(?!.*{) # we do not want the case int abc() const { return 1;} -.*) -(;.*) #we want to find ; and after for we will replace these later -\n$ -""", re.VERBOSE | re.MULTILINE | re.DOTALL) - -# pattern comment -pattern_comment = re.compile(r'^\s*//') -pattern_comment_2_start = re.compile(r'^\s*/[*]') -pattern_comment_2_end = re.compile(r'[*]/\s*$') -# pattern define -pattern_define = re.compile(r'^\s*#define') -pattern_define_return = re.compile(r'\\\s*$') -# blank line -pattern_blank_line = re.compile(r'^\s*$') -# virtual,explicit,friend,static -pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)') -# lead space -pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]') -# functions will have patterns such as func ( or func( -# but operator is an exception; the class name is preceded by an operator, and the above mode does not exist -# format like :"operator = ()" -pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]') -# template -pattern_template = re.compile(r'^\s*template') -pattern_template_end = re.compile(r'>\s*$') -# namespace -pattern_namespace = re.compile(r'namespace.*{') -# class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with -pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+ 0 and not friend_match: - line, func_name = self.handle_class_member_func(line, template_string) - # Normal functions - else: - line, func_name = self.handle_normal_func(line, template_string) - - need_generate = need_generate_func(line) - # func body - line += self.implement_function(line) - # comment - line = self.gen_comment(start_i) + line - # write to out file - self.write_func_content(line, func_name, need_generate) - # next loop - self.line_index += 1 - - logging.info('Added %s functions', len(self.func_list_exist)) - logging.info('Successfully converted,please see ' + self.output_file) - - def handle_func1(self, line): - """ - :param line: - :return: - """ - find1 = re.search('[(]', line) - if not find1: - self.line_index += 1 - return "continue", line, None - find2 = re.search('[)]', line) - start_i = self.line_index - space_match = pattern_leading_space.search(line) - # deal with - # int abc(int a, - # int b) - if find1 and (not find2): - self.line_index += 1 - line2 = self.input_content[self.line_index] - if space_match: - line2 = re.sub('^' + space_match.group(1), '', line2) - line += line2 - while self.line_index < len(self.input_content) and (not re.search('[)]', line2)): - self.line_index += 1 - line2 = self.input_content[self.line_index] - line2 = re.sub('^' + space_match.group(1), '', line2) - line += line2 - - match_start = pattern_start.search(self.input_content[self.line_index]) - match_end = pattern_end.search(self.input_content[self.line_index]) - if match_start: # like ) { or ) {} int the last line - if not match_end: - self.stack.append('normal_now') - ii = start_i - while ii <= self.line_index: - ii += 1 - self.line_index += 1 - return "continue", line, start_i - logging.info("line[%s]", line) - # ' int abc();'->'int abc()' - (line, match) = pattern_func.subn(r'\2\n', line) - logging.info("line[%s]", line) - # deal with case: - # 'int \n abc(int a, int b)' - if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]): - line = self.input_content[start_i - 1] + line - line = line.lstrip() - if not match: - self.line_index += 1 - return "continue", line, start_i - return "pass", line, start_i - - def handle_stack(self, match_start): - """ - :param match_start: - :return: - """ - line = self.input_content[self.line_index] - match_end = pattern_end.search(line) - if match_start: - self.stack.append('normal_now') - if match_end: - top_status = self.stack.pop() - if top_status == 'namespace_now': - self.output_fd.write(line + '\n') - elif top_status == 'class_now': - self.stack_class.pop() - self.stack_template.pop() - if match_start or match_end: - self.line_index += 1 - return "continue" - - if len(self.stack) > 0 and self.stack[-1] == 'normal_now': - self.line_index += 1 - return "continue" - return "pass" - - def handle_class(self, template_string, line, match_start, match_class): - """ - :param template_string: - :param line: - :param match_start: - :param match_class: - :return: - """ - if match_class: # we face a class - self.stack_template.append(template_string) - self.stack.append('class_now') - class_name = match_class.group(3) - - # class template specializations: class A > - if '<' in class_name: - k = line.index('<') - fit = 1 - for ii in range(k + 1, len(line)): - if line[ii] == '<': - fit += 1 - if line[ii] == '>': - fit -= 1 - if fit == 0: - break - class_name += line[k + 1:ii + 1] - logging.info('class_name[%s]', class_name) - self.stack_class.append(class_name) - while not match_start: - self.line_index += 1 - line = self.input_content[self.line_index] - match_start = pattern_start.search(line) - self.line_index += 1 - return "continue" - return "pass" - - def handle_template(self): - line = self.input_content[self.line_index] - match_template = pattern_template.search(line) - template_string = '' - if match_template: - match_template_end = pattern_template_end.search(line) - template_string = line - while not match_template_end: - self.line_index += 1 - line = self.input_content[self.line_index] - template_string += line - match_template_end = pattern_template_end.search(line) - self.line_index += 1 - return template_string - - def handle_namespace(self): - line = self.input_content[self.line_index] - match_namespace = pattern_namespace.search(line) - if match_namespace: # we face namespace - self.output_fd.write(line + '\n') - self.stack.append('namespace_now') - self.line_index += 1 - - def handle_normal_func(self, line, template_string): - template_line = '' - self.stack_template.append(template_string) - if self.stack_template[-1] != '': - template_line = re.sub(r'\s*template', 'template', self.stack_template[-1]) - # change '< class T = a, class U = A(3)>' to '' - template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) - template_line = re.sub(r'\s*=.*,', ',', template_line) - template_line = re.sub(r'\s*=.*', '', template_line) - line = re.sub(r'\s*=.*,', ',', line) - line = re.sub(r'\s*=.*\)', ')', line) - line = template_line + line - self.stack_template.pop() - func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() - logging.info("line[%s]", line) - logging.info("func_name[%s]", func_name) - return line, func_name - - def handle_class_member_func(self, line, template_string): - template_line = '' - x = '' - if template_string != '': - template_string = re.sub(r'\s*template', 'template', template_string) - template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string) - template_string = re.sub(r'\s*=.*,', ',', template_string) - template_string = re.sub(r'\s*=.*', '', template_string) - if self.stack_template[-1] != '': - if not (re.search(r'<\s*>', stack_template[-1])): - template_line = re.sub(r'^\s*template', 'template', stack_template[-1]) - if not (re.search(r'<.*>', self.stack_class[-1])): - # for x we get like template -> - x = re.sub(r'template\s*<', '<', template_line) # remove template -> - x = re.sub(r'\n', '', x) - x = re.sub(r'\s*=.*,', ',', x) - x = re.sub(r'\s*=.*\>', '>', x) - x = x.rstrip() # remove \n - x = re.sub(r'(class|typename)\s+|(|\s*class)', '', - x) # remove class,typename -> - x = re.sub(r'<\s+', '<', x) - x = re.sub(r'\s+>', '>', x) - x = re.sub(r'\s+,', ',', x) - x = re.sub(r',\s+', ', ', x) - line = re.sub(r'\s*=\s+0', '', line) - line = re.sub(r'\s*=\s+.*,', ',', line) - line = re.sub(r'\s*=\s+.*\)', ')', line) - logging.info("x[%s]\nline[%s]", x, line) - # if the function is long, void ABC::foo() - # breaks into two lines void ABC::\n foo() - temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1) - if len(temp_line) > max_code_len_per_line: - line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1) - else: - line = temp_line - logging.info("line[%s]", line) - # add template as the above if there is one - template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) - template_line = re.sub(r'\s*=.*,', ',', template_line) - template_line = re.sub(r'\s*=.*', '', template_line) - line = template_line + template_string + line - func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() - logging.info("line[%s]", line) - logging.info("func_name[%s]", func_name) - return line, func_name - - def write_func_content(self, content, func_name, need_generate): - if not (func_name in self.func_list_exist) and need_generate: - self.output_fd.write(content) - self.func_list_exist.append(func_name) - logging.info('add func:[%s]', func_name) - - def gen_comment(self, start_i): - comment_line = '' - # Function comments are on top of function declarations, copy them over - k = start_i - 1 # one line before this func start - if pattern_template.search(self.input_content[k]): - k -= 1 - if pattern_comment_2_end.search(self.input_content[k]): - comment_line = self.input_content[k].lstrip() - while not pattern_comment_2_start.search(self.input_content[k]): - k -= 1 - comment_line = self.input_content[k].lstrip() + comment_line - else: - for j in range(k, 0, -1): - c_line = self.input_content[j] - if pattern_comment.search(c_line): - c_line = re.sub(r'\s*//', '//', c_line) - comment_line = c_line + comment_line - else: - break - return comment_line - - @staticmethod - def implement_function(func): - function_def = '' - function_def += '{\n' - - all_items = func.split() - start = 0 - return_type = all_items[start] - if return_type == "const": - start += 1 - return_type = all_items[start] - if return_type.startswith(('std::map', 'std::set', 'std::vector')): - return_type = "std::map" - if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): - return_type = "Ptr" - if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): - return_type += "&" - if RETURN_STATEMENTS.__contains__(return_type): - function_def += RETURN_STATEMENTS[return_type] - else: - logging.warning("Unhandled return type[%s]", return_type) - - function_def += '\n' - function_def += '}\n' - function_def += '\n' - return function_def - - -def collect_header_files(path): - """ - :param path: - :return: - """ - header_files = [] - shared_includes_content = [] - for root, dirs, files in os.walk(path): - files.sort() - for file in files: - if file.find("git") >= 0: - continue - if not file.endswith('.h'): - continue - file_path = os.path.join(root, file) - file_path = file_path.replace('\\', '/') - header_files.append(file_path) - include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:]) - shared_includes_content.append(include_str) - return header_files, shared_includes_content - - -def generate_stub_file(inc_dir, out_cc_dir): - """ - :param inc_dir: - :param out_cc_dir: - :return: - """ - target_header_files, shared_includes_content = collect_header_files(inc_dir) - for header_file in target_header_files: - if not file_endswith_white_list_suffix(header_file): - continue - cc_file = re.sub('.h*$', '.cc', header_file) - h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content) - h_2_cc.h2cc() - - -def gen_code(inc_dir, out_cc_dir): - """ - :param inc_dir: - :param out_cc_dir: - :return: - """ - if not inc_dir.endswith('/'): - inc_dir += '/' - if not out_cc_dir.endswith('/'): - out_cc_dir += '/' - for include_dir_key_word in include_dir_key_words: - generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir) - - -if __name__ == '__main__': - inc_dir = sys.argv[1] - out_cc_dir = sys.argv[2] - gen_code(inc_dir, out_cc_dir) diff --git a/src/common/graph/tensor.cc b/src/common/graph/tensor.cc index 0d511645..1f30c876 100644 --- a/src/common/graph/tensor.cc +++ b/src/common/graph/tensor.cc @@ -178,16 +178,18 @@ int64_t Shape::GetShapeSize() const { return 0; } -TensorDesc::TensorDesc() { impl = ComGraphMakeShared(); } +TensorDesc::TensorDesc() { + impl = ComGraphMakeShared(); // lint !e665 +} TensorDesc::TensorDesc(Shape shape, Format format, DataType dt) { - impl = ComGraphMakeShared(shape, format, dt); + impl = ComGraphMakeShared(shape, format, dt); // lint !e665 SetRealDimCnt(shape.GetDimNum()); } TensorDesc::TensorDesc(const TensorDesc &desc) { // Copy - impl = ComGraphMakeShared(); + impl = ComGraphMakeShared(); // lint !e665 if (desc.impl != nullptr && impl != nullptr) { *impl = *desc.impl; } @@ -358,7 +360,9 @@ void TensorDesc::SetName(const std::string &name) { Tensor::Tensor() { impl = ComGraphMakeShared(); } -Tensor::Tensor(const TensorDesc &tensor_desc) { impl = ComGraphMakeShared(tensor_desc); } +Tensor::Tensor(const TensorDesc &tensor_desc) { + impl = ComGraphMakeShared(tensor_desc); // lint !e665 +} Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector &data) { uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); @@ -380,7 +384,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector &data) } } } - impl = ComGraphMakeShared(tensor_desc, data); + impl = ComGraphMakeShared(tensor_desc, data); // lint !e665 } Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) { @@ -402,7 +406,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) } } - impl = ComGraphMakeShared(tensor_desc, data, size); + impl = ComGraphMakeShared(tensor_desc, data, size); // lint !e665 } Tensor::Tensor(TensorDesc &&tensor_desc, std::vector &&data) { @@ -425,7 +429,7 @@ Tensor::Tensor(TensorDesc &&tensor_desc, std::vector &&data) { } } } - impl = ComGraphMakeShared(std::move(tensor_desc), std::move(data)); + impl = ComGraphMakeShared(std::move(tensor_desc), std::move(data)); // lint !e665 } TensorDesc Tensor::GetTensorDesc() const { @@ -639,7 +643,7 @@ TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_ GeTensorPtr TensorAdapter::Tensor2GeTensor(const Tensor &tensor) { GeTensorPtr ge_tensor; if (tensor.impl != nullptr) { - ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor.Clone()); + ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor.Clone()); // lint !e665 } return ge_tensor; } @@ -655,7 +659,7 @@ Tensor TensorAdapter::GeTensor2Tensor(const ConstGeTensorPtr &ge_tensor) { ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) { GeTensorPtr ge_tensor; if (tensor.impl != nullptr) { - ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor); + ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor); // lint !e665 } return ge_tensor; } @@ -663,7 +667,7 @@ ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) { GeTensorPtr TensorAdapter::AsGeTensorPtr(Tensor &tensor) { GeTensorPtr ge_tensor; if (tensor.impl != nullptr) { - ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor); + ge_tensor = ComGraphMakeShared(tensor.impl->ge_tensor); // lint !e665 } return ge_tensor; } diff --git a/src/common/graph/utils/graph_utils.cc b/src/common/graph/utils/graph_utils.cc index ca2ebcdc..19c28c63 100644 --- a/src/common/graph/utils/graph_utils.cc +++ b/src/common/graph/utils/graph_utils.cc @@ -38,6 +38,7 @@ #include "utils/ge_ir_utils.h" #include "utils/node_utils.h" #include "debug/ge_op_types.h" +#include "external/ge/ge_api_types.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" @@ -410,8 +411,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertTra /// @return graphStatus /// GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus -GraphUtils::InsertNodeBefore(const OutDataAnchorPtr &src, const std::vector &dsts, - const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { +GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector &dsts, + const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { GE_CHECK_NOTNULL(src); GE_CHECK_NOTNULL(insert_node); @@ -570,7 +571,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons static int max_dumpfile_num = 0; if (max_dumpfile_num == 0) { string opt = "0"; - (void)GetContext().GetOption("ge.maxDumpFileNum", opt); + (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); } if (max_dumpfile_num != 0 && file_idx > max_dumpfile_num) { @@ -670,7 +671,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToText if (maxDumpFileSize == 0) { string opt = "0"; // Can not check return value - (void)GetContext().GetOption("ge.maxDumpFileSize", opt); + (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt); maxDumpFileSize = atol(opt.c_str()); } if (maxDumpFileSize != 0 && fileSize != -1 && fileSize > maxDumpFileSize) { @@ -740,7 +741,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn static int max_dumpfile_num = 0; if (max_dumpfile_num == 0) { string opt = "0"; - (void)GetContext().GetOption("ge.maxDumpFileNum", opt); + (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); } if (max_dumpfile_num != 0 && file_index > max_dumpfile_num) { @@ -920,7 +921,7 @@ graphStatus RelinkDataIO(const NodePtr &node, const std::vector &io_map, In InNodesToOut GetFullConnectIONodes(const NodePtr &node) { InNodesToOut in_nodes_to_out; if (node == nullptr) { - GELOGE(GRAPH_FAILED, "Node is nullptr,node is %s", node->GetName().c_str()); + GELOGE(GRAPH_FAILED, "Node is nullptr"); return in_nodes_to_out; } auto in_nodes_list = node->GetInNodes(); @@ -1308,6 +1309,36 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::MoveOutCt return GRAPH_SUCCESS; } +/// +/// Copy all in-data edges from `src_node` to `dst_node`. +/// @param src_node +/// @param dst_node +/// @return +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::CopyInDataEdges(const NodePtr &src_node, + NodePtr &dst_node) { + if ((src_node == nullptr) || (dst_node == nullptr)) { + GELOGE(GRAPH_FAILED, "Parameter is nullptr"); + return GRAPH_PARAM_INVALID; + } + auto src_data_in_nodes = src_node->GetInDataNodes(); + if (src_data_in_nodes.empty()) { + return GRAPH_SUCCESS; + } + for (const auto &in_data_anchor : src_node->GetAllInDataAnchors()) { + auto input_desc = src_node->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); + auto ret = + GraphUtils::AddEdge(in_data_anchor->GetPeerOutAnchor(), dst_node->GetInDataAnchor(in_data_anchor->GetIdx())); + if (ret != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Failed to add data edge from %s to %s when copy in data edge from %s to %s", + in_data_anchor->GetPeerOutAnchor()->GetOwnerNode()->GetName().c_str(), dst_node->GetName().c_str(), + src_node->GetName().c_str(), dst_node->GetName().c_str()); + return ret; + } + } + return GRAPH_SUCCESS; +} + GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::AppendInputNode(const ComputeGraphPtr &graph, const NodePtr &node) { if (graph->AddInputNode(node) == nullptr) { @@ -1328,6 +1359,153 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::FindR return result; } +/// +/// Make a copy of ComputeGraph. +/// @param graph: original graph. +/// @param prefix: node name prefix of new graph. +/// @return ComputeGraphPtr +/// +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr +GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std::string &prefix, std::vector &input_nodes, + std::vector &output_nodes) { + GE_CHK_BOOL_EXEC(graph != nullptr, return nullptr, "Original graph is null"); + ComputeGraphPtr new_graph = ComGraphMakeShared(graph->GetName()); + GE_CHK_BOOL_EXEC(new_graph != nullptr, return nullptr, "Create new graph failed"); + + std::unordered_map all_new_nodes; + for (const auto &n : graph->GetDirectNode()) { + OpDescPtr op_desc = AttrUtils::CopyOpDesc(n->GetOpDesc()); + GE_CHK_BOOL_EXEC(op_desc != nullptr, return nullptr, "Create new node failed"); + + if (CopyTensorAttrs(op_desc, n) != GRAPH_SUCCESS) { + return nullptr; + } + + op_desc->SetName(prefix + n->GetName()); + NodePtr node = new_graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(node != nullptr, return nullptr, "Add node[%s] to graph failed", op_desc->GetName().c_str()); + all_new_nodes[node->GetName()] = node; + + if (node->GetType() == DATA) { + input_nodes.emplace_back(node); + } else if (node->GetType() == NETOUTPUT) { + output_nodes.emplace_back(node); + } + } + + for (const auto &n : graph->GetDirectNode()) { + if (RelinkGraphEdges(n, prefix, all_new_nodes) != GRAPH_SUCCESS) { + return nullptr; + } + } + + return new_graph; +} + +/// +/// Copy tensor attribute to new node. +/// @param [in] dst_node: cloned node. +/// @param [in] src_node: original node. +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::CopyTensorAttrs(const OpDescPtr &dst_desc, const NodePtr &src_node) { + if (dst_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Input param dst node not valid"); + return GRAPH_FAILED; + } + if (src_node == nullptr || src_node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "Input param src node not valid"); + return GRAPH_FAILED; + } + + const auto &src_desc = src_node->GetOpDesc(); + dst_desc->CopyAttrsFrom(*src_desc); + + for (uint32_t i = 0; i < src_node->GetAllInDataAnchorsSize(); ++i) { + auto input_desc = dst_desc->MutableInputDesc(i); + if (input_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Param dst node not valid"); + return GRAPH_FAILED; + } + input_desc->CopyAttrsFrom(src_desc->GetInputDesc(i)); + } + + for (uint32_t i = 0; i < src_node->GetAllOutDataAnchorsSize(); ++i) { + auto output_desc = dst_desc->MutableOutputDesc(i); + if (output_desc == nullptr) { + GELOGE(GRAPH_FAILED, "Param dst node not valid"); + return GRAPH_FAILED; + } + output_desc->CopyAttrsFrom(src_desc->GetOutputDesc(i)); + } + + return GRAPH_SUCCESS; +} + +/// +/// Relink all edges for cloned ComputeGraph. +/// @param [in] node: original node. +/// @param [in] prefix: node name prefix of new node. +/// @param [in] all_nodes: all nodes in new graph. +/// @return success: GRAPH_SUCESS +/// +graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &prefix, + const std::unordered_map &all_nodes) { + if (node == nullptr || node->GetOpDesc() == nullptr) { + GELOGE(GRAPH_FAILED, "Input node not valid"); + return GRAPH_FAILED; + } + + auto it = all_nodes.find(prefix + node->GetName()); + if (it == all_nodes.end()) { + GELOGE(GRAPH_FAILED, "node[%s] not found", node->GetName().c_str()); + return GRAPH_FAILED; + } + const auto &new_node = it->second; + + for (const auto &in_anchor : node->GetAllInDataAnchors()) { + GE_CHK_BOOL_EXEC(in_anchor != nullptr, return GRAPH_FAILED, "In data anchor is null"); + const auto &out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + GELOGW("Peer out anchor is null: %s", node->GetName().c_str()); + continue; + } + GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); + + it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName()); + if (it == all_nodes.end()) { + GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); + return GRAPH_FAILED; + } + const auto &new_out_node = it->second; + + auto rslt = + GraphUtils::AddEdge(new_out_node->GetOutAnchor(out_anchor->GetIdx()), new_node->GetInAnchor(in_anchor->GetIdx())); + GE_CHK_BOOL_EXEC(rslt == GRAPH_SUCCESS, return GRAPH_FAILED, "link failed[%s to %s]", + new_out_node->GetName().c_str(), new_node->GetName().c_str()); + } + + if (node->GetInControlAnchor() != nullptr) { + for (const auto &out_anchor : node->GetInControlAnchor()->GetPeerAnchors()) { + GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "Peer out anchor is null: %s", node->GetName().c_str()); + GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); + + it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName()); + if (it == all_nodes.end()) { + GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); + return GRAPH_FAILED; + } + const auto &new_out_node = it->second; + + auto rslt = GraphUtils::AddEdge(new_out_node->GetOutAnchor(out_anchor->GetIdx()), new_node->GetInControlAnchor()); + GE_CHK_BOOL_EXEC(rslt == GRAPH_SUCCESS, return GRAPH_FAILED, "link failed[%s to %s]", + new_out_node->GetName().c_str(), new_node->GetName().c_str()); + } + } + + return GRAPH_SUCCESS; +} + /// /// Get reference-mapping of all data_anchors in graph /// @param [in] graph @@ -1339,7 +1517,7 @@ graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, std::map> &symbol_to_anchors, std::map &anchor_to_symbol) { GE_CHECK_NOTNULL(graph); - for (auto &node : graph->GetAllNodes()) { + for (const auto &node : graph->GetAllNodes()) { // in_data_anchor if (HandleInAnchorMapping(node, symbol_to_anchors, anchor_to_symbol) != GRAPH_SUCCESS) { GE_LOGE("Find ref_mapping for in_data_anchors of node %s failed.", node->GetName().c_str()); @@ -1396,16 +1574,16 @@ graphStatus GraphUtils::HandleInAnchorMapping(const NodePtr &node, return HandleSubgraphInput(node, symbol_to_anchors, anchor_to_symbol); } - std::string type = node->GetType(); + const std::string &type = node->GetType(); if ((type == MERGE) || (type == STREAMMERGE)) { return HandleMergeInput(node, symbol_to_anchors, anchor_to_symbol); } - for (auto &in_data_anchor : node->GetAllInDataAnchors()) { + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { NodeIndexIO cur_node_info(node, in_data_anchor->GetIdx(), kIn); OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { - std::string symbol = cur_node_info.ToString(); + const std::string &symbol = cur_node_info.ToString(); GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); symbol_to_anchors[symbol] = {cur_node_info}; anchor_to_symbol[symbol] = symbol; @@ -1432,7 +1610,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, std::map> &symbol_to_anchors, std::map &anchor_to_symbol) { GE_CHECK_NOTNULL(node); - for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { NodeIndexIO cur_node_info(node, out_data_anchor->GetIdx(), kOut); if (anchor_to_symbol.find(cur_node_info.ToString()) != anchor_to_symbol.end()) { continue; @@ -1446,7 +1624,7 @@ graphStatus GraphUtils::HandleOutAnchorMapping(const NodePtr &node, return GRAPH_FAILED; } } else { - std::string symbol = cur_node_info.ToString(); + const std::string &symbol = cur_node_info.ToString(); GELOGD("Add anchor %s, symbol %s.", cur_node_info.ToString().c_str(), symbol.c_str()); symbol_to_anchors.emplace(std::make_pair(symbol, std::list{cur_node_info})); anchor_to_symbol.emplace(std::make_pair(symbol, symbol)); @@ -1506,7 +1684,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, GE_CHECK_NOTNULL(node); std::vector exist_node_infos; std::vector cur_node_infos; - for (auto &in_data_anchor : node->GetAllInDataAnchors()) { + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { std::string next_name; @@ -1529,10 +1707,10 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, size_t anchor_nums = 0; NodeIndexIO max_node_index_io(nullptr, 0, kOut); - for (auto &temp_node_info : exist_node_infos) { + for (const auto &temp_node_info : exist_node_infos) { auto iter1 = anchor_to_symbol.find(temp_node_info.ToString()); if (iter1 != anchor_to_symbol.end()) { - std::string temp_symbol = iter1->second; + const std::string &temp_symbol = iter1->second; auto iter2 = symbol_to_anchors.find(temp_symbol); if (iter2 != symbol_to_anchors.end()) { if (iter2->second.size() > anchor_nums) { @@ -1544,7 +1722,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, } std::string symbol; - for (auto &temp_node_info : exist_node_infos) { + for (const auto &temp_node_info : exist_node_infos) { if ((UnionSymbolMapping(max_node_index_io, temp_node_info, symbol_to_anchors, anchor_to_symbol, symbol) != GRAPH_SUCCESS) || symbol.empty()) { @@ -1556,7 +1734,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, auto iter = symbol_to_anchors.find(symbol); if (iter != symbol_to_anchors.end()) { - for (auto &temp_node_info : cur_node_infos) { + for (const auto &temp_node_info : cur_node_infos) { GELOGD("Add anchor %s, symbol %s.", temp_node_info.ToString().c_str(), symbol.c_str()); iter->second.emplace_back(temp_node_info); anchor_to_symbol.emplace(std::make_pair(temp_node_info.ToString(), symbol)); @@ -1584,7 +1762,7 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - for (auto &in_data_anchor : node->GetAllInDataAnchors()) { + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(peer_out_anchor); @@ -1627,8 +1805,8 @@ graphStatus GraphUtils::HandleSubgraphOutput(const NodePtr &node, graphStatus GraphUtils::UnionSymbolMapping(const NodeIndexIO &exist_node_info1, const NodeIndexIO &exist_node_info2, std::map> &symbol_to_anchors, std::map &anchor_to_symbol, std::string &symbol) { - std::string symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; - std::string symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; + const std::string &symbol1 = anchor_to_symbol[exist_node_info1.ToString()]; + const std::string &symbol2 = anchor_to_symbol[exist_node_info2.ToString()]; if (symbol1 == symbol2) { symbol = symbol1; GELOGI("no need to union."); @@ -1684,7 +1862,7 @@ graphStatus GraphUtils::UpdateRefMapping(const NodeIndexIO &cur_node_info, const return GRAPH_FAILED; } - std::string symbol = iter1->second; + const std::string &symbol = iter1->second; auto iter2 = symbol_to_anchors.find(symbol); if (iter2 == symbol_to_anchors.end()) { GE_LOGE("symbol %s not found.", symbol.c_str()); @@ -1712,7 +1890,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t // pass-through op NodePtr node = out_data_anchor->GetOwnerNode(); - std::string type = node->GetType(); + const std::string &type = node->GetType(); const std::set pass_through_set = {NETOUTPUT, WHILE, _WHILE, STATELESSWHILE}; if ((pass_through_set.count(type) > 0) || (NodeUtils::IsSubgraphInput(node))) { reuse_in_index = output_index; @@ -1755,7 +1933,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t uint32_t reuse_input_index = 0; if (TensorUtils::GetReuseInputIndex(*output_op_desc, reuse_input_index) == GRAPH_SUCCESS) { reuse_in_index = static_cast(reuse_input_index); - GELOGI("ReuseInput name[%s] output[%u] reuse input[%d].", op_desc->GetName().c_str(), output_index, + GELOGI("ReuseInput name[%s] output[%d] reuse input[%d].", op_desc->GetName().c_str(), output_index, reuse_in_index); return true; } @@ -2297,7 +2475,7 @@ void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string & return; } - std::string name = node->GetName() + "_RetVal"; + std::string name = node->GetName() + "_RetVal_" + std::to_string(index); OpDescPtr ret_val_desc = shared_ptr(new (std::nothrow) OpDesc(name, FRAMEWORKOP)); if (ret_val_desc == nullptr) { error_code = GRAPH_FAILED; diff --git a/src/common/graph/utils/node_utils.cc b/src/common/graph/utils/node_utils.cc index e4fb8b82..35a842e5 100644 --- a/src/common/graph/utils/node_utils.cc +++ b/src/common/graph/utils/node_utils.cc @@ -295,16 +295,21 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer if (op_desc == nullptr) { return GRAPH_FAILED; } + bool is_unknown_graph = node_ptr->GetOwnerComputeGraph()->GetGraphUnknownFlag(); + if (is_unknown_graph) { + return GRAPH_SUCCESS; + } for (const auto &out_anchor : node_ptr->GetAllOutDataAnchors()) { - GeTensorDesc output_tensor = op_desc->GetOutputDesc(out_anchor->GetIdx()); - ge::TensorUtils::SetRealDimCnt(output_tensor, static_cast(output_tensor.GetShape().GetDims().size())); - output_tensor.SetOriginShape(output_tensor.GetShape()); - output_tensor.SetOriginDataType(output_tensor.GetDataType()); + auto output_tensor = op_desc->MutableOutputDesc(out_anchor->GetIdx()); + ge::TensorUtils::SetRealDimCnt(*output_tensor, static_cast(output_tensor->GetShape().GetDims().size())); + output_tensor->SetOriginShape(output_tensor->GetShape()); + output_tensor->SetOriginDataType(output_tensor->GetDataType()); + GELOGD("node name is %s, origin shape is %ld, origin format is %s, origin data type is %s", - node_ptr->GetName().c_str(), output_tensor.GetOriginShape().GetShapeSize(), - TypeUtils::FormatToSerialString(output_tensor.GetOriginFormat()).c_str(), - TypeUtils::DataTypeToSerialString(output_tensor.GetOriginDataType()).c_str()); - (void)op_desc->UpdateOutputDesc(out_anchor->GetIdx(), output_tensor); + node_ptr->GetName().c_str(), output_tensor->GetOriginShape().GetShapeSize(), + TypeUtils::FormatToSerialString(output_tensor->GetOriginFormat()).c_str(), + TypeUtils::DataTypeToSerialString(output_tensor->GetOriginDataType()).c_str()); + for (const auto &peer_anchor : out_anchor->GetPeerInDataAnchors()) { if (peer_anchor->GetOwnerNode()->GetOpDesc() == nullptr) { GELOGE(GRAPH_FAILED, "peer_anchor opdesc is null"); @@ -316,17 +321,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer continue; } GELOGI("Peer input opdesc name is %s, need to flush: shape size is %zu, datatype is %d, original datatype is %d", - peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor.GetShape().GetDimNum(), - output_tensor.GetDataType(), output_tensor.GetOriginDataType()); - peer_input_desc->SetShape(output_tensor.GetShape()); - peer_input_desc->SetOriginShape(output_tensor.GetOriginShape()); - peer_input_desc->SetDataType(output_tensor.GetDataType()); - peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType()); + peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), output_tensor->GetShape().GetDimNum(), + output_tensor->GetDataType(), output_tensor->GetOriginDataType()); + peer_input_desc->SetOriginShape(output_tensor->GetOriginShape()); + peer_input_desc->SetShape(output_tensor->GetShape()); + peer_input_desc->SetDataType(output_tensor->GetDataType()); + peer_input_desc->SetOriginDataType(output_tensor->GetOriginDataType()); std::vector> shape_range; - (void)output_tensor.GetShapeRange(shape_range); + (void)output_tensor->GetShapeRange(shape_range); peer_input_desc->SetShapeRange(shape_range); ge::TensorUtils::SetRealDimCnt(*peer_input_desc, - static_cast(output_tensor.GetShape().GetDims().size())); + static_cast(output_tensor->GetShape().GetDims().size())); GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", peer_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_input_desc->GetShape().GetDimNum(), peer_input_desc->GetDataType(), peer_input_desc->GetOriginDataType()); @@ -334,6 +339,50 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer } return GRAPH_SUCCESS; } + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInputAnchor(const NodePtr &node, + uint32_t index) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Nodeptr is nullptr"); + return GRAPH_FAILED; + } + + GeTensorDesc data_desc(GeShape(), FORMAT_ND, DT_FLOAT); + OpDescPtr op_desc = node->op_; + for (size_t i = op_desc->GetInputsSize(); i < index; ++i) { + if (op_desc->AddInputDesc(data_desc) != GRAPH_SUCCESS) { + GELOGE(GRAPH_FAILED, "Add input desc failed"); + return GRAPH_FAILED; + } + + auto anchor = ComGraphMakeShared(node, i); + if (anchor == nullptr) { + GELOGE(GRAPH_FAILED, "Current in_data_anchor is null, malloc shared_ptr failed."); + return GRAPH_FAILED; + } + node->in_data_anchors_.push_back(anchor); + } + + return GRAPH_SUCCESS; +} + +GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::RemoveInputAnchor(const NodePtr &node, + uint32_t index) { + if (node == nullptr) { + GELOGE(GRAPH_FAILED, "Nodeptr is nullptr"); + return GRAPH_FAILED; + } + + OpDescPtr op_desc = node->op_; + op_desc->RemoveInputDesc(index); + + while (node->in_data_anchors_.size() > index) { + node->in_data_anchors_.pop_back(); + } + + return GRAPH_SUCCESS; +} + bool NodeUtils::IsInNodesEmpty(const Node &node) { for (const auto &in_anchor : node.in_data_anchors_) { if (in_anchor != nullptr) { @@ -401,10 +450,13 @@ graphStatus NodeUtils::UpdateInputShape(const Node &node, uint32_t index, const graphStatus NodeUtils::GetNodeUnknownShapeStatus(const Node &node, bool &is_unknow) { auto desc = node.GetOpDesc(); GE_CHECK_NOTNULL(desc); - + // check self + is_unknow = OpShapeIsUnknown(desc); + if (is_unknow) { + return GRAPH_SUCCESS; + } auto sub_graph_names = desc->GetSubgraphInstanceNames(); if (sub_graph_names.empty()) { - is_unknow = OpShapeIsUnknown(desc); return GRAPH_SUCCESS; } else { auto owner_graph = node.GetOwnerComputeGraph(); @@ -440,6 +492,7 @@ std::string NodeUtils::GetNodeType(const Node &node) { (void)AttrUtils::GetStr(node.GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type); return type; } + ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { auto op_desc = node.GetOpDesc(); if (op_desc == nullptr) { @@ -492,6 +545,14 @@ bool NodeUtils::IsSubgraphInput(const NodePtr &node) { return false; } if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { + bool is_unknown_shape = false; + (void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape); + if (is_unknown_shape) return false; + } + + if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE) && + kCaseOpTypes.count(parent_op_desc->GetType()) == 0 && kWhileOpTypes.count(parent_op_desc->GetType()) == 0 && + kForOpTypes.count(parent_op_desc->GetType()) == 0 && kIfOpTypes.count(parent_op_desc->GetType()) == 0) { return false; } @@ -513,7 +574,16 @@ bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { if (parent_op_desc == nullptr) { return false; } + if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { + bool is_unknown_shape = false; + (void)AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape); + if (is_unknown_shape) return false; + } + + if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE) && + kCaseOpTypes.count(parent_op_desc->GetType()) == 0 && kWhileOpTypes.count(parent_op_desc->GetType()) == 0 && + kForOpTypes.count(parent_op_desc->GetType()) == 0 && kIfOpTypes.count(parent_op_desc->GetType()) == 0) { return false; } @@ -555,6 +625,53 @@ NodePtr NodeUtils::GetParentInput(const NodePtr &node) { return peer_out_anchor->GetOwnerNode(); } +/// +/// @brief Check is varying_input for while node +/// @param [in] node: Data node for subgraph +/// @return bool +/// +bool NodeUtils::IsWhileVaryingInput(const ge::NodePtr &node) { + if (node == nullptr) { + return false; + } + if (node->GetType() != DATA) { + return false; // not input_node for subgraph + } + + const NodePtr &parent_node = node->GetOwnerComputeGraph()->GetParentNode(); + if (parent_node == nullptr) { + return false; // root graph + } + + if (kWhileOpTypes.count(parent_node->GetType()) == 0) { + return false; // not input_node for while subgraph + } + + uint32_t index_i = 0; + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, index_i)) { + GELOGW("Node %s has no attr PARENT_NODE_INDEX.", node->GetName().c_str()); + return false; + } + bool varying_flag = true; + for (const auto &item : node->GetOutDataNodesAndAnchors()) { + if (item.first->GetType() != NETOUTPUT) { + continue; + } + OpDescPtr op_desc = item.first->GetOpDesc(); + uint32_t index_o = 0; + if ((op_desc == nullptr) || + !AttrUtils::GetInt(op_desc->GetInputDesc(item.second->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index_o)) { + continue; // input for while-cond subgraph + } + if (index_i != index_o) { + continue; // varying input for while-body subgraph + } + varying_flag = false; + break; + } + return varying_flag; +} + /// /// @brief Get subgraph input is constant. /// @param [in] node @@ -637,4 +754,86 @@ Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { return GRAPH_SUCCESS; } +/// +/// @brief Get subgraph input data node by index. +/// @param [in] node +/// @return Node +/// +vector NodeUtils::GetSubgraphDataNodesByIndex(const Node &node, int index) { + vector in_data_node_vec; + auto op_desc = node.GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op_desc, return in_data_node_vec); + auto subgraph_names = op_desc->GetSubgraphInstanceNames(); + if (subgraph_names.empty()) { + GELOGW("Node %s is single node without sub graph.", node.GetName().c_str()); + return in_data_node_vec; + } + auto compute_graph = node.GetOwnerComputeGraph(); + for (const std::string &instance_name : subgraph_names) { + auto subgraph = compute_graph->GetSubgraph(instance_name); + for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { + int parent_index = -1; + if (NodeUtils::IsSubgraphInput(node_in_subgraph)) { + (void)AttrUtils::GetInt(node_in_subgraph->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index); + if (parent_index == index) { + in_data_node_vec.emplace_back(node_in_subgraph); + } + } + } + } + return in_data_node_vec; +} +/// +/// @brief Get subgraph input data node by index. +/// @param [in] node +/// @return Node +/// +vector NodeUtils::GetSubgraphOutputNodes(const Node &node) { + vector out_data_node_vec; + auto op_desc = node.GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(op_desc, return out_data_node_vec); + auto subgraph_names = op_desc->GetSubgraphInstanceNames(); + if (subgraph_names.empty()) { + GELOGI("Node %s is single node without sub graph.", node.GetName().c_str()); + return out_data_node_vec; + } + auto compute_graph = node.GetOwnerComputeGraph(); + for (const std::string &instance_name : subgraph_names) { + auto subgraph = compute_graph->GetSubgraph(instance_name); + for (const auto &node_in_subgraph : subgraph->GetDirectNode()) { + if (NodeUtils::IsSubgraphOutput(node_in_subgraph)) { + out_data_node_vec.emplace_back(node_in_subgraph); + } + } + } + return out_data_node_vec; +} + +NodePtr NodeUtils::GetInDataNodeByIndex(const Node &node, int index) { + if (node.GetInDataAnchor(index) == nullptr) { + return nullptr; + } + if (node.GetInDataAnchor(index)->GetPeerOutAnchor() == nullptr) { + return nullptr; + } + return node.GetInDataAnchor(index)->GetPeerOutAnchor()->GetOwnerNode(); +} + +vector NodeUtils::GetOutDataNodesByIndex(const Node &node, int index) { + vector out_data_nodes; + auto out_data_anchor = node.GetOutDataAnchor(index); + if (out_data_anchor == nullptr) { + return out_data_nodes; + } + for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + if (peer_in_anchor == nullptr) { + continue; + } + if (peer_in_anchor->GetOwnerNode() == nullptr) { + continue; + } + out_data_nodes.emplace_back(peer_in_anchor->GetOwnerNode()); + } + return out_data_nodes; +} } // namespace ge diff --git a/src/common/graph/utils/op_desc_utils.cc b/src/common/graph/utils/op_desc_utils.cc index 6264ddb9..7a52a7f8 100644 --- a/src/common/graph/utils/op_desc_utils.cc +++ b/src/common/graph/utils/op_desc_utils.cc @@ -28,6 +28,7 @@ using std::vector; +/*lint -e512 -e737 -e752*/ namespace ge { const char OP_DESC_QUANT_PARAMS[] = "quantize_factor"; static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1; @@ -132,11 +133,11 @@ graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, Quantize GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) { GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); - return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom(quant)); + return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom(quant)); // lint !e732 } graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) { - return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom(quant)); + return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom(quant)); // lint !e732 } GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) { @@ -197,24 +198,33 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils:: continue; } auto in_node = out_anchor->GetOwnerNode(); - if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { - ret.push_back(in_node); - } else if (in_node->GetType() == DATA) { - const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); - GE_CHK_BOOL_EXEC(graph != nullptr, continue, "Owner graph is null"); - - const NodePtr &parent_node = graph->GetParentNode(); - if (parent_node == nullptr) { - continue; // Root graph. - } - - if (kWhileOpTypes.count(parent_node->GetType()) > 0) { - continue; // Subgraph of While cond or body. + while (true) { + if (in_node == nullptr) { + break; } - - NodePtr input_node = NodeUtils::GetParentInput(in_node); - if ((input_node != nullptr) && ((input_node->GetType() == CONSTANT) || (input_node->GetType() == CONSTANTOP))) { - ret.push_back(input_node); + if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { + ret.push_back(in_node); + break; + } else if (in_node->GetType() == DATA) { + if (NodeUtils::IsWhileVaryingInput(in_node)) { + break; + } + in_node = NodeUtils::GetParentInput(in_node); + } else if ((in_node->GetType() == ENTER) || (in_node->GetType() == REFENTER)) { + bool is_constant = false; + (void)AttrUtils::GetBool(in_node->GetOpDesc(), ENTER_ATTR_CONSTANT_FLAG, is_constant); + if (!is_constant) { + break; + } + // Enter node has and only has one input + if (in_node->GetInDataNodes().size() != 1) { + GELOGW("Check number of input_nodes for Enter node %s failed, size=%zu.", node.GetName().c_str(), + in_node->GetInDataNodes().size()); + break; + } + in_node = in_node->GetInDataNodes().at(0); + } else { + break; } } } @@ -245,7 +255,7 @@ size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) { continue; } } - return input_num; + return input_num; // lint !e712 } else { GE_IF_BOOL_EXEC( node.GetInDataNodes().size() < GetConstInputs(node).size(), @@ -350,7 +360,7 @@ bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) { bool ret = false; if (index < node.GetAllInDataAnchors().size()) { if (NodeUtils::IsAnchorStatusSet(node)) { - ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast(index))) == ANCHOR_DATA); + ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast(index))) == ANCHOR_DATA); // lint !e712 } else { for (const auto &anchor : node.GetAllInDataAnchors()) { if (anchor->GetIdx() != static_cast(index)) { @@ -435,10 +445,27 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils:: GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector OpDescUtils::MutableWeights(const ge::Node &node) { vector ret; - GE_CHK_BOOL_EXEC(node.GetOpDesc() != nullptr, return ret, "node.GetOpDesc is nullptr!"); + auto op_desc = node.GetOpDesc(); + GE_CHK_BOOL_EXEC(op_desc != nullptr, return ret, "op_desc is nullptr!"); + // Place holder operator, try to get the weight from parent node + // when parent node is const operator + if (node.GetType() == PLACEHOLDER) { + std::string parent_op; + (void)AttrUtils::GetStr(op_desc, "parentOpType", parent_op); + // This if judgment is necessary because the current subgraph optimization is multithreaded + // and the parent node of the PLD operation should be a stable type, such as const + if (parent_op == CONSTANT || parent_op == CONSTANTOP) { + NodePtr parent_node = nullptr; + parent_node = op_desc->TryGetExtAttr("parentNode", parent_node); + if (parent_node != nullptr) { + op_desc = parent_node->GetOpDesc(); + GELOGD("pld[%s] get weight from const[%s]", node.GetName().c_str(), op_desc->GetName().c_str()); + } + } + } // Const operator, take the weight directly - if (node.GetOpDesc()->GetType() == CONSTANT || (node.GetOpDesc()->GetType() == CONSTANTOP)) { - auto weight = MutableWeights(node.GetOpDesc()); + if (op_desc->GetType() == CONSTANT || (op_desc->GetType() == CONSTANTOP)) { + auto weight = MutableWeights(op_desc); if (weight == nullptr) { GELOGI("const op has no weight, op name:%s", node.GetName().c_str()); return ret; @@ -733,3 +760,4 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgr return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name); } } // namespace ge +/*lint +e512 +e737 +e752*/ diff --git a/src/common/graph/utils/tensor_utils.cc b/src/common/graph/utils/tensor_utils.cc index 674cab55..26ac8cc8 100644 --- a/src/common/graph/utils/tensor_utils.cc +++ b/src/common/graph/utils/tensor_utils.cc @@ -19,6 +19,7 @@ #include "debug/ge_log.h" #include "framework/common/debug/ge_log.h" +#include "common/util/error_manager/error_manager.h" #include "graph/ge_tensor.h" #include "graph/types.h" #include "graph/utils/type_utils.h" @@ -105,7 +106,10 @@ static graphStatus CalcElementCntByDims(const std::vector &dims, int64_ element_cnt = 1; for (int64_t dim : dims) { if (CheckMultiplyOverflowInt64(element_cnt, dim)) { - GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, as when multiplying %ld and %ld.", element_cnt, dim); + ErrorManager::GetInstance().ATCReportErrMessage( + "E19013", {"function", "var1", "var2"}, + {"CheckMultiplyOverflowInt64", std::to_string(element_cnt), std::to_string(dim)}); + GELOGE(GRAPH_FAILED, "CalcElementCntByDims failed, when multiplying %ld and %ld.", element_cnt, dim); return GRAPH_FAILED; } element_cnt *= dim; @@ -273,7 +277,6 @@ static graphStatus CalcTensorElementCnt(const std::vector &dims, Format case FORMAT_FRACTAL_Z: graph_status = CalcElementCntOfFractalZ(dims, data_type, element_cnt); break; - case FORMAT_NC1HWC0_C04: case FORMAT_FRACTAL_NZ: case FORMAT_FRACTAL_ZZ: case FORMAT_NDHWC: @@ -285,6 +288,7 @@ static graphStatus CalcTensorElementCnt(const std::vector &dims, Format case FORMAT_NDC1HWC0: case FORMAT_FRACTAL_Z_C04: case FORMAT_FRACTAL_ZN_LSTM: + case FORMAT_NC1HWC0_C04: graph_status = CalcElementCntByDims(dims, element_cnt); break; default: diff --git a/src/common/graph/utils/type_utils.cc b/src/common/graph/utils/type_utils.cc index e4986931..5215b141 100644 --- a/src/common/graph/utils/type_utils.cc +++ b/src/common/graph/utils/type_utils.cc @@ -147,7 +147,8 @@ static const std::map kStringToFormatMap = { {"FRACTAL_ZN_LSTM", FORMAT_FRACTAL_ZN_LSTM}, {"FRACTAL_Z_G", FORMAT_FRACTAL_Z_G}, {"FORMAT_RESERVED", FORMAT_RESERVED}, - {"ALL", FORMAT_ALL}}; + {"ALL", FORMAT_ALL}, + {"NULL", FORMAT_NULL}}; static const std::map kDataTypeToStringMap = { {DT_UNDEFINED, "DT_UNDEFINED"}, // Used to indicate a DataType field has not been set. diff --git a/src/ge/CMakeLists.txt b/src/ge/CMakeLists.txt index 894eaf1e..a527bc1f 100755 --- a/src/ge/CMakeLists.txt +++ b/src/ge/CMakeLists.txt @@ -60,6 +60,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "common/formats/formats.cc" "common/formats/utils/formats_trans_utils.cc" "common/fp16_t.cc" + "common/ge/op_tiling_manager.cc" "common/ge/plugin_manager.cc" "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" @@ -94,14 +95,25 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" "graph/load/new_model_manager/task_info/task_info.cc" - "graph/load/output/output.cc" - "graph/manager/*.cc" + "graph/manager/graph_context.cc" + "graph/manager/graph_manager.cc" + "graph/manager/graph_manager_utils.cc" + "graph/manager/graph_mem_allocator.cc" + "graph/manager/graph_caching_allocator.cc" + "graph/manager/graph_var_manager.cc" + "graph/manager/model_manager/event_manager.cc" + "graph/manager/trans_var_data_utils.cc" + "graph/manager/util/debug.cc" + "graph/manager/util/hcom_util.cc" + "graph/manager/util/rt_context_util.cc" + "graph/manager/util/variable_accelerate_ctrl.cc" "graph/manager/model_manager/event_manager.cc" "graph/manager/util/debug.cc" "graph/manager/util/hcom_util.cc" "graph/manager/util/rt_context_util.cc" "graph/manager/util/variable_accelerate_ctrl.cc" "graph/optimize/graph_optimize.cc" + "graph/optimize/mem_rw_conflict_optimize.cc" "graph/optimize/optimizer/allreduce_fusion_pass.cc" "graph/optimize/summary_optimize.cc" "graph/partition/dynamic_shape_partition.cc" @@ -159,8 +171,11 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "hybrid/node_executor/aicpu/aicpu_ext_info.cc" "hybrid/node_executor/aicpu/aicpu_node_executor.cc" "hybrid/node_executor/compiledsubgraph/known_node_executor.cc" + "hybrid/node_executor/controlop/control_op_executor.cc" + "hybrid/node_executor/hccl/hccl_node_executor.cc" "hybrid/node_executor/hostcpu/ge_local_node_executor.cc" "hybrid/node_executor/node_executor.cc" + "hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc" "hybrid/node_executor/task_context.cc" "init/gelib.cc" "model/ge_model.cc" @@ -204,6 +219,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "common/formats/formats.cc" "common/formats/utils/formats_trans_utils.cc" "common/fp16_t.cc" + "common/ge/op_tiling_manager.cc" "common/ge/plugin_manager.cc" "common/helper/model_cache_helper.cc" "common/profiling/profiling_manager.cc" @@ -236,13 +252,19 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc" "graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" "graph/load/new_model_manager/task_info/task_info.cc" - "graph/load/output/output.cc" - "graph/manager/*.cc" + "graph/manager/graph_caching_allocator.cc" + "graph/manager/graph_context.cc" + "graph/manager/graph_manager.cc" + "graph/manager/graph_manager_utils.cc" + "graph/manager/graph_mem_allocator.cc" + "graph/manager/trans_var_data_utils.cc" + "graph/manager/graph_var_manager.cc" "graph/manager/model_manager/event_manager.cc" "graph/manager/util/debug.cc" "graph/manager/util/rt_context_util.cc" "graph/manager/util/variable_accelerate_ctrl.cc" "graph/optimize/graph_optimize.cc" + "graph/optimize/mem_rw_conflict_optimize.cc" "graph/optimize/summary_optimize.cc" "graph/partition/dynamic_shape_partition.cc" "graph/partition/engine_place.cc" diff --git a/src/ge/client/ge_api.cc b/src/ge/client/ge_api.cc index ae6a9892..120c144a 100644 --- a/src/ge/client/ge_api.cc +++ b/src/ge/client/ge_api.cc @@ -28,6 +28,7 @@ #include "graph/opsproto_manager.h" #include "graph/utils/type_utils.h" #include "graph/manager/util/rt_context_util.h" +#include "graph/common/ge_call_wrapper.h" #include "register/op_registry.h" #include "common/ge/tbe_plugin_manager.h" @@ -41,8 +42,8 @@ namespace { const int32_t kMaxStrLen = 128; } -static bool kGeInitialized = false; -static std::mutex kGeReleaseMutex; // GEFinalize and ~Session use +static bool g_ge_initialized = false; +static std::mutex g_ge_release_mutex; // GEFinalize and ~Session use namespace ge { void GetOpsProtoPath(std::string &opsproto_path) { @@ -61,31 +62,6 @@ void GetOpsProtoPath(std::string &opsproto_path) { opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); } -Status CheckDumpAndReuseMemory(const std::map &options) { - const int kDecimal = 10; - auto dump_op_env = std::getenv("DUMP_OP"); - int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0; - auto disableReuseMemoryIter = options.find("ge.exec.disableReuseMemory"); - if (disableReuseMemoryIter != options.end()) { - if (disableReuseMemoryIter->second == "0") { - GELOGD("ge.exec.disableReuseMemory=0, reuse memory is open"); - if (dump_op_flag) { - GELOGW("Will dump incorrect op data with GE Option ge.exec.disableReuseMemory=0"); - } - } else if (disableReuseMemoryIter->second == "1") { - GELOGD("ge.exec.disableReuseMemory=1, reuse memory is close"); - } else { - GELOGE(PARAM_INVALID, "CheckDumpAndReuseMemory ge.exec.disableReuseMemory is valid"); - return FAILED; - } - } else { - if (dump_op_flag) { - GELOGW("Will dump incorrect op data with default reuse memory"); - } - } - return SUCCESS; -} - Status CheckOptionsValid(const std::map &options) { // check job_id is valid auto job_id_iter = options.find(OPTION_EXEC_JOB_ID); @@ -96,11 +72,6 @@ Status CheckOptionsValid(const std::map &options) { } } - // Check ge.exec.disableReuseMemory and env DUMP_OP - if (CheckDumpAndReuseMemory(options) != SUCCESS) { - return FAILED; - } - return SUCCESS; } @@ -108,7 +79,7 @@ Status CheckOptionsValid(const std::map &options) { Status GEInitialize(const std::map &options) { GELOGT(TRACE_INIT, "GEInitialize start"); // 0.check init status - if (kGeInitialized) { + if (g_ge_initialized) { GELOGW("GEInitialize is called more than once"); return SUCCESS; } @@ -147,9 +118,9 @@ Status GEInitialize(const std::map &options) { } // 7.check return status, return - if (!kGeInitialized) { + if (!g_ge_initialized) { // Initialize success, first time calling initialize - kGeInitialized = true; + g_ge_initialized = true; } GELOGT(TRACE_STOP, "GEInitialize finished"); @@ -160,12 +131,12 @@ Status GEInitialize(const std::map &options) { Status GEFinalize() { GELOGT(TRACE_INIT, "GEFinalize start"); // check init status - if (!kGeInitialized) { + if (!g_ge_initialized) { GELOGW("GEFinalize is called before GEInitialize"); return SUCCESS; } - std::lock_guard lock(kGeReleaseMutex); + std::lock_guard lock(g_ge_release_mutex); // call Finalize Status ret = SUCCESS; Status middle_ret; @@ -187,10 +158,10 @@ Status GEFinalize() { ret = middle_ret; } - if (kGeInitialized && ret == SUCCESS) { + if (g_ge_initialized && ret == SUCCESS) { // Unified destruct rt_context - RtContextUtil::GetInstance().DestroyrtContexts(); - kGeInitialized = false; + RtContextUtil::GetInstance().DestroyAllRtContexts(); + g_ge_initialized = false; } GELOGT(TRACE_STOP, "GEFinalize finished"); @@ -202,7 +173,7 @@ Session::Session(const std::map &options) { GELOGT(TRACE_INIT, "Session Constructor start"); // check init status sessionId_ = 0; - if (!kGeInitialized) { + if (!g_ge_initialized) { GELOGE(GE_CLI_GE_NOT_INITIALIZED); return; } @@ -232,13 +203,13 @@ Session::Session(const std::map &options) { Session::~Session() { GELOGT(TRACE_INIT, "Session Destructor start"); // 0.check init status - if (!kGeInitialized) { + if (!g_ge_initialized) { GELOGW("GE is not yet initialized or is finalized."); return; } Status ret = FAILED; - std::lock_guard lock(kGeReleaseMutex); + std::lock_guard lock(g_ge_release_mutex); try { uint64_t session_id = sessionId_; // call DestroySession diff --git a/src/ge/common/convert/pb2json.cc b/src/ge/common/convert/pb2json.cc index 832a8278..0a5d24ee 100644 --- a/src/ge/common/convert/pb2json.cc +++ b/src/ge/common/convert/pb2json.cc @@ -72,9 +72,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(cons void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, const ProtobufReflection *reflection, const set &black_fields, Json &json, bool enum2str) { - if (field == nullptr || reflection == nullptr) { - return; - } switch (field->type()) { case ProtobufFieldDescriptor::TYPE_MESSAGE: { const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); @@ -118,8 +115,12 @@ void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescr case ProtobufFieldDescriptor::TYPE_FLOAT: char str[kSignificantDigits]; - sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)); - json[field->name()] = str; + if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1) { + json[field->name()] = str; + } else { + json[field->name()] = reflection->GetFloat(message, field); + } + break; case ProtobufFieldDescriptor::TYPE_STRING: diff --git a/src/ge/common/formats/format_transfers/datatype_transfer.cc b/src/ge/common/formats/format_transfers/datatype_transfer.cc index 0bd4b8e5..08c6889f 100644 --- a/src/ge/common/formats/format_transfers/datatype_transfer.cc +++ b/src/ge/common/formats/format_transfers/datatype_transfer.cc @@ -29,7 +29,6 @@ namespace ge { namespace formats { - namespace { enum DataTypeTransMode { kTransferWithDatatypeFloatToFloat16, diff --git a/src/ge/common/formats/format_transfers/datatype_transfer.h b/src/ge/common/formats/format_transfers/datatype_transfer.h index 0702592f..4d93fd6c 100644 --- a/src/ge/common/formats/format_transfers/datatype_transfer.h +++ b/src/ge/common/formats/format_transfers/datatype_transfer.h @@ -27,7 +27,6 @@ namespace ge { namespace formats { - struct CastArgs { const uint8_t *data; size_t src_data_size; diff --git a/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc b/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc index dc8e1033..76d8696a 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc @@ -179,6 +179,5 @@ Status FormatTransferDhwcnFractalZ3D::TransShape(Format src_format, const std::v } REGISTER_FORMAT_TRANSFER(FormatTransferDhwcnFractalZ3D, FORMAT_DHWCN, FORMAT_FRACTAL_Z_3D) - } // namespace formats } // namespace ge diff --git a/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc b/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc index 11e3d270..9de2e3a0 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc @@ -180,6 +180,5 @@ Status FormatTransferDhwncFractalZ3DTranspose::TransShape(Format src_format, con } REGISTER_FORMAT_TRANSFER(FormatTransferDhwncFractalZ3DTranspose, FORMAT_DHWNC, FORMAT_FRACTAL_Z_3D_TRANSPOSE) - } // namespace formats } // namespace ge diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc b/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc index ff7b84a4..65798f29 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fractal_nz.cc @@ -56,7 +56,7 @@ Status TransShapeToFracNz(const ShapeVector &src_shape, DataType data_type, Shap dst_shape.clear(); hw_shape.clear(); auto w0 = GetCubeSizeByDataType(data_type); - auto h0 = GetCubeSizeByDataType(data_type); + int64_t h0 = kCubeSize; switch (src_shape.size()) { case 1: dst_shape.push_back(Ceil(src_shape[0], w0)); diff --git a/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc b/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc index f3d06496..f2ec29da 100644 --- a/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc +++ b/src/ge/common/formats/format_transfers/format_transfer_fractal_z.cc @@ -19,6 +19,7 @@ #include #include +#include "common/debug/log.h" #include "common/formats/utils/formats_definitions.h" #include "common/formats/utils/formats_trans_utils.h" #include "framework/common/debug/ge_log.h" @@ -107,8 +108,8 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { int64_t hw = h * w; int64_t chw = c * hw; - int64_t hwc0 = hw * c0; int64_t nchw = n * chw; + int64_t hwc0 = hw * c0; // horizontal fractal matrix count (N) int64_t hf_cnt = Ceil(n, static_cast(kNiSize)); @@ -119,18 +120,15 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; int size = GetSizeByDataType(args.src_data_type); int64_t dst_size = total_ele_cnt * size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } + GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast(dst_size); return SUCCESS;); std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); - if (dst == nullptr) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + dst == nullptr, GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); - return OUT_OF_MEMORY; - } + return OUT_OF_MEMORY;); for (int64_t vfi = 0; vfi < vf_cnt; vfi++) { // vertical fractal matrix base index @@ -156,12 +154,20 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { auto protected_size = dst_size - offset < static_cast(SECUREC_MEM_MAX_LEN) ? dst_size - offset : static_cast(SECUREC_MEM_MAX_LEN); - errno_t ret; + errno_t ret = EOK; if (need_pad_zero) { ret = memset_s(dst.get() + offset, static_cast(protected_size), 0, static_cast(size)); } else { - ret = memcpy_s(dst.get() + offset, static_cast(protected_size), args.data + src_offset * size, - static_cast(size)); + if (protected_size < size) { + GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", + protected_size, size); + return INTERNAL_ERROR; + } + char *dst_data = reinterpret_cast(dst.get() + offset); + const char *src_data = reinterpret_cast(args.data + src_offset * size); + for (int64_t index = 0; index < size; index++) { + *dst_data++ = *src_data++; + } } if (ret != EOK) { GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d pad mode %d", offset, @@ -199,18 +205,15 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { dst_size *= dim; } dst_size *= data_size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } + GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast(dst_size); return SUCCESS;); std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); - if (dst == nullptr) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + dst == nullptr, GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); - return OUT_OF_MEMORY; - } + return OUT_OF_MEMORY;); for (int64_t c1i = 0; c1i < c1; c1i++) { for (int64_t hi = 0; hi < h; hi++) { @@ -223,14 +226,22 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); - errno_t ret; + errno_t ret = EOK; if (pad_zero) { ret = memset_s(dst.get() + dst_offset, static_cast(protected_size), 0, static_cast(data_size)); } else { + if (protected_size < data_size) { + GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", + protected_size, data_size); + return INTERNAL_ERROR; + } int64_t src_idx = hi * wcn + wi * cn + (c1i * c0 + c0i) * n + n1n0i; - ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), - args.data + src_idx * data_size, static_cast(data_size)); + char *dst_data = reinterpret_cast(dst.get() + dst_offset); + const char *src_data = reinterpret_cast(args.data + src_idx * data_size); + for (int64_t index = 0; index < data_size; index++) { + *dst_data++ = *src_data++; + } } if (ret != EOK) { GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", @@ -269,18 +280,15 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { dst_size *= dim; } dst_size *= data_size; - if (dst_size == 0) { - result.length = static_cast(dst_size); - return SUCCESS; - } + GE_CHK_BOOL_EXEC_NOLOG(dst_size != 0, result.length = static_cast(dst_size); return SUCCESS;); std::shared_ptr dst(new (std::nothrow) uint8_t[dst_size], std::default_delete()); - if (dst == nullptr) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + dst == nullptr, GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", TypeUtils::FormatToSerialString(args.src_format).c_str(), TypeUtils::FormatToSerialString(args.dst_format).c_str(), dst_size); - return OUT_OF_MEMORY; - } + return OUT_OF_MEMORY;); for (int64_t c1i = 0; c1i < c1; c1i++) { for (int64_t hi = 0; hi < h; hi++) { @@ -293,14 +301,22 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { ? dst_size - dst_offset : static_cast(SECUREC_MEM_MAX_LEN); auto pad_zero = ((c1i * c0 + c0i) >= c) || (n1n0i >= n); - errno_t ret; + errno_t ret = EOK; if (pad_zero) { ret = memset_s(dst.get() + dst_offset, static_cast(protected_size), 0, static_cast(data_size)); } else { + if (protected_size < data_size) { + GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory, protected_size is %ld and size is %ld", + protected_size, data_size); + return INTERNAL_ERROR; + } int64_t src_idx = n1n0i * hwc + hi * wc + wi * c + (c1i * c0 + c0i); - ret = memcpy_s(dst.get() + dst_offset, static_cast(protected_size), - args.data + src_idx * data_size, static_cast(data_size)); + char *dst_data = reinterpret_cast(dst.get() + dst_offset); + const char *src_data = reinterpret_cast(args.data + src_idx * data_size); + for (int64_t index = 0; index < data_size; index++) { + *dst_data++ = *src_data++; + } } if (ret != EOK) { GELOGE(INTERNAL_ERROR, "Failed to operate the dst memory at offset %ld, error-code %d, pad mode %d", @@ -337,16 +353,16 @@ Status FormatTransferFractalZ::TransFormat(const TransArgs &args, TransResult &r return PARAM_INVALID; } - if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { - return TransFormatFromNchwToFz(args, result); + if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { + return TransFormatNhwcToFz(args, result); } if (args.src_format == FORMAT_HWCN && args.dst_format == FORMAT_FRACTAL_Z) { return TransFormatHwcnToFz(args, result); } - if (args.src_format == FORMAT_NHWC && args.dst_format == FORMAT_FRACTAL_Z) { - return TransFormatNhwcToFz(args, result); + if (args.src_format == FORMAT_NCHW && args.dst_format == FORMAT_FRACTAL_Z) { + return TransFormatFromNchwToFz(args, result); } return UNSUPPORTED; @@ -358,14 +374,14 @@ Status FormatTransferFractalZ::TransShape(Format src_format, const std::vector 0 ? SUCCESS : UNSUPPORTED; } @@ -109,7 +108,7 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { return NOT_CHANGED; } - /* prepare for padding in chw*/ + // prepare for padding in chw int64_t tmp = h * w * c; int64_t n_o = Ceil(n, static_cast(c0)); int64_t c_o = c0; @@ -309,6 +308,5 @@ Status FormatTransferNchwToFZC04::TransShape(Format src_format, const std::vecto } REGISTER_FORMAT_TRANSFER(FormatTransferNchwToFZC04, FORMAT_NCHW, FORMAT_FRACTAL_Z_C04) - } // namespace formats } // namespace ge diff --git a/src/ge/common/formats/utils/formats_definitions.h b/src/ge/common/formats/utils/formats_definitions.h index d889c33c..2faa60e1 100644 --- a/src/ge/common/formats/utils/formats_definitions.h +++ b/src/ge/common/formats/utils/formats_definitions.h @@ -19,7 +19,6 @@ namespace ge { namespace formats { - static const int kCubeSize = 16; static const int kNiSize = 16; static const int64_t kShapeItemNumMAX = 1024UL * 1024UL * 1024UL * 1024UL; @@ -47,7 +46,6 @@ enum FracZDimIndex { kFracZHWC1, kFracZN0, kFracZNi, kFracZC0, kFracZDimsNum }; enum DhwcnDimIndex { kDhwcnD, kDhwcnH, kDhwcnW, kDhwcnC, kDhwcnN, kDhwcnDimsNum }; enum DhwncDimIndex { kDhwncD, kDhwncH, kDhwncW, kDhwncN, kDhwncC, kDhwncDimsNum }; - } // namespace formats } // namespace ge #endif // GE_COMMON_FORMATS_UTILS_FORMATS_DEFINITIONS_H_ diff --git a/src/ge/common/formats/utils/formats_trans_utils.h b/src/ge/common/formats/utils/formats_trans_utils.h index a8fbd09b..8b6f0604 100644 --- a/src/ge/common/formats/utils/formats_trans_utils.h +++ b/src/ge/common/formats/utils/formats_trans_utils.h @@ -21,7 +21,6 @@ #include #include #include - #include "external/graph/types.h" #include "graph/ge_tensor.h" @@ -69,7 +68,6 @@ T Ceil(T n1, T n2) { } return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; } - } // namespace formats } // namespace ge #endif // GE_COMMON_FORMATS_UTILS_FORMATS_TRANS_UTILS_H_ diff --git a/src/ge/common/fp16_t.h b/src/ge/common/fp16_t.h index 34908b95..0fda2cd2 100644 --- a/src/ge/common/fp16_t.h +++ b/src/ge/common/fp16_t.h @@ -600,5 +600,5 @@ int16_t GetManBitLength(T man) { } return len; } -}; // namespace ge +} // namespace ge #endif // GE_COMMON_FP16_T_H_ diff --git a/src/ge/common/ge/op_tiling_manager.cc b/src/ge/common/ge/op_tiling_manager.cc new file mode 100644 index 00000000..7fb7a8fc --- /dev/null +++ b/src/ge/common/ge/op_tiling_manager.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "common/ge/op_tiling_manager.h" +#include "framework/common/debug/log.h" +#include + +namespace { +const char *const kEnvName = "ASCEND_OPP_PATH"; +const std::string kDefaultPath = "/usr/local/Ascend/opp"; +const std::string kDefaultBuiltInTilingPath = "/op_impl/built-in/liboptiling.so"; +const std::string kDefaultCustomTilingPath = "/op_impl/custom/liboptiling.so"; +const uint8_t kPrefixIndex = 9; +} // namespace + +namespace ge { +void OpTilingManager::ClearHandles() noexcept { + for (const auto &handle : handles_) { + if (dlclose(handle.second) != 0) { + GELOGE(FAILED, "Failed to close handle of %s: %s", handle.first.c_str(), dlerror()); + } + } + handles_.clear(); +} + +OpTilingManager::~OpTilingManager() { ClearHandles(); } + +std::string OpTilingManager::GetPath() { + const char *opp_path_env = std::getenv(kEnvName); + std::string opp_path = kDefaultPath; + if (opp_path_env != nullptr) { + char resolved_path[PATH_MAX]; + if (realpath(opp_path_env, resolved_path) == NULL) { + GELOGE(PARAM_INVALID, "Failed load tiling lib as env 'ASCEND_OPP_PATH'(%s) is invalid path.", opp_path_env); + return std::string(); + } + opp_path = resolved_path; + } + return opp_path; +} + +void OpTilingManager::LoadSo() { + std::string opp_path = GetPath(); + if (opp_path.empty()) { + GELOGW("Skip load tiling lib."); + return; + } + std::string built_in_tiling_lib = opp_path + kDefaultBuiltInTilingPath; + std::string custom_tiling_lib = opp_path + kDefaultCustomTilingPath; + std::string built_in_name = kDefaultBuiltInTilingPath.substr(kPrefixIndex); + std::string custom_name = kDefaultCustomTilingPath.substr(kPrefixIndex); + + void *handle_bi = dlopen(built_in_tiling_lib.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle_bi == nullptr) { + GELOGW("Failed to dlopen %s!", dlerror()); + } else { + handles_[built_in_name] = handle_bi; + } + + void *handle_ct = dlopen(custom_tiling_lib.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle_ct == nullptr) { + GELOGW("Failed to dlopen %s!", dlerror()); + } else { + handles_[custom_name] = handle_ct; + } +} + +} // namespace ge diff --git a/src/ge/common/ge/op_tiling_manager.h b/src/ge/common/ge/op_tiling_manager.h new file mode 100644 index 00000000..320e1411 --- /dev/null +++ b/src/ge/common/ge/op_tiling_manager.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_COMMON_GE_OP_TILING_MANAGER_H_ +#define GE_COMMON_GE_OP_TILING_MANAGER_H_ + +#include + +namespace ge { +using SoToHandleMap = std::map; + +class OpTilingManager { + public: + OpTilingManager() = default; + ~OpTilingManager(); + void LoadSo(); + + private: + static std::string GetPath(); + void ClearHandles() noexcept; + SoToHandleMap handles_; +}; +} // namespace ge + +#endif // GE_COMMON_GE_OP_TILING_MANAGER_H_ diff --git a/src/ge/common/ge/tbe_plugin_manager.cc b/src/ge/common/ge/tbe_plugin_manager.cc index e02b9422..d651ced1 100644 --- a/src/ge/common/ge/tbe_plugin_manager.cc +++ b/src/ge/common/ge/tbe_plugin_manager.cc @@ -182,7 +182,7 @@ void TBEPluginManager::GetCustomOpPath(std::string &customop_path) { } void TBEPluginManager::LoadCustomOpLib() { - LoadPluginSo(); + LoadPluginSo(options_); std::vector registration_datas = domi::OpRegistry::Instance()->registrationDatas; GELOGI("The size of registration_datas is: %zu", registration_datas.size()); @@ -193,10 +193,13 @@ void TBEPluginManager::LoadCustomOpLib() { } } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::LoadPluginSo() { +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::LoadPluginSo( + const std::map &options) { vector file_list; string caffe_parser_path; std::string plugin_path; + + options_ = options; GetCustomOpPath(plugin_path); // Whether there are files in the plugin so path diff --git a/src/ge/common/ge/tbe_plugin_manager.h b/src/ge/common/ge/tbe_plugin_manager.h index 82264ae8..2a55e450 100644 --- a/src/ge/common/ge/tbe_plugin_manager.h +++ b/src/ge/common/ge/tbe_plugin_manager.h @@ -48,7 +48,7 @@ class TBEPluginManager { static void InitPreparation(const std::map &options); - void LoadPluginSo(); + void LoadPluginSo(const std::map &options); private: TBEPluginManager() = default; diff --git a/src/ge/common/ge_common.mk b/src/ge/common/ge_common.mk index e913c8f5..e99ff654 100644 --- a/src/ge/common/ge_common.mk +++ b/src/ge/common/ge_common.mk @@ -36,6 +36,7 @@ GE_COMMON_LOCAL_SRC_FILES := \ properties_manager.cc \ types.cc\ model_parser/base.cc \ + model_parser/graph_parser_util.cc \ tbe_kernel_store.cc \ op/attr_value_util.cc \ op/ge_op_utils.cc \ diff --git a/src/ge/common/helper/model_helper.cc b/src/ge/common/helper/model_helper.cc index 2f95cbb1..19614566 100644 --- a/src/ge/common/helper/model_helper.cc +++ b/src/ge/common/helper/model_helper.cc @@ -17,6 +17,7 @@ #include "framework/common/helper/model_helper.h" #include "common/ge/ge_util.h" +#include "common/util/error_manager/error_manager.h" #include "framework/common/debug/log.h" #include "framework/common/util.h" #include "framework/common/debug/ge_log.h" @@ -89,10 +90,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod } } auto ge_model_weight = ge_model->GetWeight(); - GELOGI("WEIGHTS_DATA size is %zu , %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); - if (SaveModelPartition(om_file_save_helper, ModelPartitionType::WEIGHTS_DATA, ge_model_weight.GetData(), - ge_model_weight.GetSize()) != SUCCESS) { - GELOGW("Add weight partition failed"); // weight is not necessary + GELOGI("WEIGHTS_DATA size is %zu, %p", ge_model_weight.GetSize(), ge_model_weight.GetData()); + // weight is not necessary + if (ge_model_weight.GetSize() > 0) { + GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, ModelPartitionType::WEIGHTS_DATA, + ge_model_weight.GetData(), ge_model_weight.GetSize()), + "Add weight partition failed"); } TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore(); @@ -238,44 +241,48 @@ ModelHelper::SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::strin FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(const ge::ModelData &model_data) { if (model_data.model_data == nullptr || model_data.model_len == 0) { - GELOGE(FAILED, "Model_data is nullptr, or model_data_size is 0"); - return FAILED; + GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "Model_data is nullptr, or model_data_size is 0"); + return GE_EXEC_MODEL_DATA_SIZE_INVALID; } if (is_assign_model_) { - GELOGE(FAILED, "Model helper has already loaded!"); - return FAILED; + GELOGE(GE_EXEC_LOAD_MODEL_REPEATED, "Model helper has already loaded!"); + return GE_EXEC_LOAD_MODEL_REPEATED; } if (ReleaseLocalModelData() != SUCCESS) { - GELOGE(FAILED, "ReleaseLocalModelData failed."); - return FAILED; + GELOGE(INTERNAL_ERROR, "ReleaseLocalModelData failed."); + return INTERNAL_ERROR; } + Status status = ge::DavinciModelParser::ParseModelContent(model_data, model_addr_tmp_, model_len_tmp_); if (ge::DavinciModelParser::ParseModelContent(model_data, model_addr_tmp_, model_len_tmp_) != SUCCESS) { - GELOGE(FAILED, "Parse model content failed!"); - return FAILED; + GELOGE(status, "Parse model content failed!"); + return status; } file_header_ = reinterpret_cast(model_data.model_data); OmFileLoadHelper om_load_helper; - if (om_load_helper.Init(model_addr_tmp_, model_len_tmp_) != SUCCESS) { - GELOGE(FAILED, "Om_load_helper init failed"); + status = om_load_helper.Init(model_addr_tmp_, model_len_tmp_); + if (status != SUCCESS) { + GELOGE(status, "Om_load_helper init failed"); model_addr_tmp_ = nullptr; - return FAILED; + return status; } auto partition_table = reinterpret_cast(model_addr_tmp_); if (partition_table->num == kOriginalOmPartitionNum) { - GELOGE(FAILED, "om model is error,please use executable om model"); - return FAILED; + model_addr_tmp_ = nullptr; + GELOGE(GE_EXEC_MODEL_PARTITION_NUM_INVALID, "om model is error,please use executable om model"); + return GE_EXEC_MODEL_PARTITION_NUM_INVALID; } // Encrypt model need to del temp model/no encrypt model don't need to del model model_addr_tmp_ = nullptr; - if (GenerateGeModel(om_load_helper) != SUCCESS) { - GELOGE(FAILED, "GenerateGeModel failed"); - return FAILED; + status = GenerateGeModel(om_load_helper); + if (status != SUCCESS) { + GELOGE(status, "GenerateGeModel failed"); + return status; } is_assign_model_ = true; @@ -287,19 +294,19 @@ Status ModelHelper::GenerateGeModel(OmFileLoadHelper &om_load_helper) { GE_CHECK_NOTNULL(model_); Status ret = LoadModelData(om_load_helper); if (ret != SUCCESS) { - return ret; + return GE_EXEC_LOAD_MODEL_PARTITION_FAILED; } ret = LoadWeights(om_load_helper); if (ret != SUCCESS) { - return ret; + return GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED; } ret = LoadTask(om_load_helper); if (ret != SUCCESS) { - return ret; + return GE_EXEC_LOAD_TASK_PARTITION_FAILED; } ret = LoadTBEKernelStore(om_load_helper); if (ret != SUCCESS) { - return ret; + return GE_EXEC_LOAD_KERNEL_PARTITION_FAILED; } return SUCCESS; } @@ -390,107 +397,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeMo return out_model; } -// Transit func for model to ge_model. It will be removed when load and build support ge_model in future -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::TransModelToGeModel(const ModelPtr &model, - GeModelPtr &ge_model) { - if (model == nullptr) { - GELOGE(FAILED, "Model is null"); - return FAILED; - } - ge_model = ge::MakeShared(); - GE_CHECK_NOTNULL(ge_model); - ge_model->SetGraph(model->GetGraph()); - ge_model->SetName(model->GetName()); - ge_model->SetVersion(model->GetVersion()); - ge_model->SetPlatformVersion(model->GetPlatformVersion()); - ge_model->SetAttr(model->MutableAttrMap()); - - // Copy weight info - auto compute_graph = ge::GraphUtils::GetComputeGraph(model->GetGraph()); - // ge::Buffer weight; - ge::Buffer weight; - (void)ge::AttrUtils::GetZeroCopyBytes(compute_graph, ge::ATTR_NAME_WEIGHTS_DATA, weight); - ge_model->SetWeight(weight); - // Copy task info - if (model->HasAttr(MODEL_ATTR_TASKS)) { - ge::Buffer task_buffer; - GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetZeroCopyBytes(model, MODEL_ATTR_TASKS, task_buffer), FAILED, - "Get bytes failed."); - - std::shared_ptr task = ge::MakeShared(); - GE_CHECK_NOTNULL(task); - GE_IF_BOOL_EXEC(task_buffer.GetData() == nullptr, GELOGE(FAILED, "Get data fail"); return FAILED); - GE_IF_BOOL_EXEC(task_buffer.GetSize() == 0, GELOGE(FAILED, "Get size fail"); return FAILED); - - GE_CHK_BOOL_EXEC(ReadProtoFromArray(task_buffer.GetData(), static_cast(task_buffer.GetSize()), task.get()), - return INTERNAL_ERROR, "ReadProtoFromArray failed."); - - ge_model->SetModelTaskDef(task); - } - // Copy tbe kernel info - // TBEKernelStore kernel_store; - TBEKernelStore kernel_store; - if (compute_graph != nullptr && compute_graph->GetDirectNodesSize() != 0) { - for (const ge::NodePtr &n : compute_graph->GetDirectNode()) { - auto node_op_desc = n->GetOpDesc(); - GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); - TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); - GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); - kernel_store.AddTBEKernel(tbe_kernel); - GELOGI("Add tbe kernel bin %s", tbe_kernel->GetName().c_str()); - } - } - if (!kernel_store.Build()) { - GELOGE(FAILED, "TBE Kernels store build failed!"); - return FAILED; - } - ge_model->SetTBEKernelStore(kernel_store); - - return SUCCESS; -} - -// trasit func for ge_model to Model. will be removed when load and build support ge_model in future -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::TransGeModelToModel(const GeModelPtr &ge_model, - ModelPtr &model) { - if (ge_model == nullptr) { - GELOGE(FAILED, "Ge_model is null"); - return FAILED; - } - model = ge::MakeShared(); - GE_CHECK_NOTNULL(model); - model->SetGraph(ge_model->GetGraph()); - model->SetName(ge_model->GetName()); - model->SetVersion(ge_model->GetVersion()); - model->SetPlatformVersion(ge_model->GetPlatformVersion()); - model->SetAttr(ge_model->MutableAttrMap()); - // Copy weight info - auto compute_graph = ge::GraphUtils::GetComputeGraph(model->GetGraph()); - bool ret = ge::AttrUtils::SetZeroCopyBytes(compute_graph, ge::ATTR_NAME_WEIGHTS_DATA, ge_model->GetWeight()); - if (!ret) { - GELOGE(FAILED, "Copy weight buffer failed!"); - return FAILED; - } - // Copy task info - std::shared_ptr model_task = ge_model->GetModelTaskDefPtr(); - - if (model_task != nullptr) { - int size = model_task->ByteSize(); - ge::Buffer buffer(static_cast(size)); - if (buffer.GetSize() == 0) { - GELOGE(MEMALLOC_FAILED, "alloc model attr task buffer failed!"); - return MEMALLOC_FAILED; - } - // no need to check value - (void)model_task->SerializePartialToArray(buffer.GetData(), size); - ret = ge::AttrUtils::SetZeroCopyBytes(model, MODEL_ATTR_TASKS, std::move(buffer)); - if (!ret) { - GELOGE(FAILED, "Copy task buffer failed!"); - return FAILED; - } - } - return SUCCESS; -} - Status ModelHelper::ReleaseLocalModelData() noexcept { Status result = SUCCESS; if (model_addr_tmp_ != nullptr) { diff --git a/src/ge/common/helper/om_file_helper.cc b/src/ge/common/helper/om_file_helper.cc index 0d58fe71..f25e2af3 100644 --- a/src/ge/common/helper/om_file_helper.cc +++ b/src/ge/common/helper/om_file_helper.cc @@ -41,8 +41,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(c FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(uint8_t *model_data, const uint32_t model_data_size) { - if (LoadModelPartitionTable(model_data, model_data_size) != SUCCESS) { - return FAILED; + Status status = LoadModelPartitionTable(model_data, model_data_size); + if (status != SUCCESS) { + return status; } is_inited_ = true; return SUCCESS; @@ -66,7 +67,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetMod } if (!found) { - if (type != ModelPartitionType::TBE_KERNELS) { + if (type != ModelPartitionType::TBE_KERNELS && type != ModelPartitionType::WEIGHTS_DATA) { GELOGE(FAILED, "GetModelPartition:type:%d is not in partition_datas!", static_cast(type)); return FAILED; } @@ -83,7 +84,9 @@ Status OmFileLoadHelper::CheckModelValid(const ge::ModelData &model) const { // Model length too small if (model.model_len < (sizeof(ModelFileHeader) + sizeof(ModelPartitionTable))) { - GELOGE(PARAM_INVALID, "Invalid model. length < sizeof(ModelFileHeader) + sizeof(ModelPartitionTable)."); + GELOGE(PARAM_INVALID, + "Invalid model. length[%u] < sizeof(ModelFileHeader)[%zu] + sizeof(ModelPartitionTable)[%zu].", + model.model_len, sizeof(ModelFileHeader), sizeof(ModelPartitionTable)); return PARAM_INVALID; } @@ -93,9 +96,9 @@ Status OmFileLoadHelper::CheckModelValid(const ge::ModelData &model) const { if ((model_header->length != model.model_len - sizeof(ModelFileHeader)) || (MODEL_FILE_MAGIC_NUM != model_header->magic)) { GELOGE(PARAM_INVALID, - "Invalid model. file_header->length(%u) + sizeof(ModelFileHeader)(%zu) != model->model_len(%u) || " - "MODEL_FILE_MAGIC_NUM != file_header->magic", - model_header->length, sizeof(ModelFileHeader), model.model_len); + "Invalid model. file_header->length[%u] + sizeof(ModelFileHeader)[%zu] != model->model_len[%u] || " + "MODEL_FILE_MAGIC_NUM[%u] != file_header->magic[%u]", + model_header->length, sizeof(ModelFileHeader), model.model_len, MODEL_FILE_MAGIC_NUM, model_header->magic); return PARAM_INVALID; } return SUCCESS; @@ -112,16 +115,16 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint // Original model partition include graph-info if ((partition_table->num != PARTITION_SIZE) && (partition_table->num != (PARTITION_SIZE - 1)) && (partition_table->num != 1)) { - GELOGE(PARAM_INVALID, "Invalid partition_table->num:%u", partition_table->num); - return PARAM_INVALID; + GELOGE(GE_EXEC_MODEL_PARTITION_NUM_INVALID, "Invalid partition_table->num:%u", partition_table->num); + return GE_EXEC_MODEL_PARTITION_NUM_INVALID; } size_t mem_offset = SIZE_OF_MODEL_PARTITION_TABLE(*partition_table); GELOGI("ModelPartitionTable num :%u, ModelFileHeader length :%zu, ModelPartitionTable length :%zu", partition_table->num, sizeof(ModelFileHeader), mem_offset); if (model_data_size <= mem_offset) { - GELOGE(PARAM_INVALID, "invalid model data, partition_table->num:%u, model data size %u", partition_table->num, - model_data_size); - return PARAM_INVALID; + GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "invalid model data, partition_table->num:%u, model data size %u", + partition_table->num, model_data_size); + return GE_EXEC_MODEL_DATA_SIZE_INVALID; } for (uint32_t i = 0; i < partition_table->num; i++) { ModelPartition partition; @@ -131,9 +134,9 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint context_.partition_datas_.push_back(partition); if (partition.size > model_data_size || mem_offset > model_data_size - partition.size) { - GELOGE(PARAM_INVALID, "The partition size %zu is greater than the model data size %u.", + GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "The partition size %zu is greater than the model data size %u.", partition.size + mem_offset, model_data_size); - return PARAM_INVALID; + return GE_EXEC_MODEL_DATA_SIZE_INVALID; } mem_offset += partition.size; GELOGI("Partition, type:%d, size:%u", static_cast(partition.type), partition.size); diff --git a/src/ge/common/math/fp16_math.h b/src/ge/common/math/fp16_math.h index 5bc9ac6d..c3a4eb28 100644 --- a/src/ge/common/math/fp16_math.h +++ b/src/ge/common/math/fp16_math.h @@ -92,5 +92,5 @@ fp16_t max(fp16_t fp1, fp16_t fp2); /// @brief Calculate the minimum fp16_t of fp1 and fp2 /// @return Returns minimum fp16_t of fp1 and fp2 fp16_t min(fp16_t fp1, fp16_t fp2); -}; // namespace ge +} // namespace ge #endif // GE_COMMON_MATH_FP16_MATH_H_ \ No newline at end of file diff --git a/src/ge/common/math_util.h b/src/ge/common/math_util.h index 5e783e81..a12be9e0 100644 --- a/src/ge/common/math_util.h +++ b/src/ge/common/math_util.h @@ -27,7 +27,6 @@ #include "mmpa/mmpa_api.h" namespace ge { - /** * @ingroup domi_calibration * @brief Initializes an input array to a specified value @@ -67,7 +66,6 @@ Status NnSet(const int32_t n, const Dtype alpha, Dtype *output) { } return SUCCESS; } - } // end namespace ge #endif // GE_COMMON_MATH_UTIL_H_ diff --git a/src/ge/common/model_parser/base.cc b/src/ge/common/model_parser/base.cc index fb6a647f..3b6b9407 100644 --- a/src/ge/common/model_parser/base.cc +++ b/src/ge/common/model_parser/base.cc @@ -35,15 +35,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::LoadFro ge::ModelData &model_data) { std::string real_path = RealPath(model_path); if (real_path.empty()) { - GELOGE(PARAM_INVALID, "Model file path '%s' is invalid", model_path); - return PARAM_INVALID; + GELOGE(GE_EXEC_MODEL_PATH_INVALID, "Model file path '%s' is invalid", model_path); + return GE_EXEC_MODEL_PATH_INVALID; } - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(model_path) == -1, return FAILED, "File size not valid."); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(model_path) == -1, return GE_EXEC_READ_MODEL_FILE_FAILED, + "File size not valid."); std::ifstream fs(real_path.c_str(), std::ifstream::binary); - GE_CHK_BOOL_RET_STATUS(fs.is_open(), FAILED, "Open file failed! path:%s", model_path); + GE_CHK_BOOL_RET_STATUS(fs.is_open(), GE_EXEC_READ_MODEL_FILE_FAILED, "Open file failed! path:%s", model_path); // get length of file: (void)fs.seekg(0, std::ifstream::end); @@ -55,7 +56,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::LoadFro char *data = new (std::nothrow) char[len]; if (data == nullptr) { - GELOGE(MEMALLOC_FAILED, "Load model From file failed, bad memory allocation occur. (need:%ld)", len); + GELOGE(MEMALLOC_FAILED, "Load model From file failed, bad memory allocation occur. (need:%u)", len); return MEMALLOC_FAILED; } @@ -79,31 +80,33 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::ParseMo GE_CHECK_NOTNULL(model.model_data); // Model length too small - GE_CHK_BOOL_RET_STATUS(model.model_len >= sizeof(ModelFileHeader), PARAM_INVALID, - "Invalid model. length < sizeof(ModelFileHeader)."); + GE_CHK_BOOL_RET_STATUS(model.model_len >= sizeof(ModelFileHeader), GE_EXEC_MODEL_DATA_SIZE_INVALID, + "Invalid model. Model data size %u must be greater than or equal to %zu.", model.model_len, + sizeof(ModelFileHeader)); // Get file header auto file_header = reinterpret_cast(model.model_data); // Determine whether the file length and magic number match GE_CHK_BOOL_RET_STATUS( file_header->length == model.model_len - sizeof(ModelFileHeader) && file_header->magic == MODEL_FILE_MAGIC_NUM, - PARAM_INVALID, - "Invalid model. file_header->length + sizeof(ModelFileHeader) != model->model_len || MODEL_FILE_MAGIC_NUM != " - "file_header->magic"); + GE_EXEC_MODEL_DATA_SIZE_INVALID, + "Invalid model. file_header->length[%u] + sizeof(ModelFileHeader)[%zu] != model->model_len[%u] || " + "MODEL_FILE_MAGIC_NUM[%u] != file_header->magic[%u]", + file_header->length, sizeof(ModelFileHeader), model.model_len, MODEL_FILE_MAGIC_NUM, file_header->magic); Status res = SUCCESS; // Get data address uint8_t *data = reinterpret_cast(model.model_data) + sizeof(ModelFileHeader); if (file_header->is_encrypt == ModelEncryptType::UNENCRYPTED) { // Unencrypted model - GE_CHK_BOOL_RET_STATUS(model.key.empty(), PARAM_INVALID, + GE_CHK_BOOL_RET_STATUS(model.key.empty(), GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION, "Invalid param. model is unencrypted, but key is not empty."); model_data = data; model_len = file_header->length; GELOGI("Model_len is %u, model_file_head_len is %zu.", model_len, sizeof(ModelFileHeader)); } else { - GELOGE(PARAM_INVALID, "Invalid model. ModelEncryptType not supported."); - res = PARAM_INVALID; + GELOGE(GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION, "Invalid model. ModelEncryptType not supported."); + res = GE_EXEC_MODEL_NOT_SUPPORT_ENCRYPTION; } return res; diff --git a/src/ge/common/model_parser/graph_parser_util.cc b/src/ge/common/model_parser/graph_parser_util.cc new file mode 100644 index 00000000..19f505c1 --- /dev/null +++ b/src/ge/common/model_parser/graph_parser_util.cc @@ -0,0 +1,501 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph_parser_util.h" +#include +#include "common/auth/file_saver.h" +#include "common/convert/pb2json.h" +#include "common/debug/log.h" +#include "common/debug/memory_dumper.h" +#include "common/model_parser/base.h" +#include "common/model_saver.h" +#include "common/properties_manager.h" +#include "common/string_util.h" +#include "common/types.h" +#include "common/util.h" +#include "common/util/error_manager/error_manager.h" +#include "external/register/register_types.h" +#include "framework/common/debug/ge_log.h" +#include "framework/omg/parser/parser_inner_ctx.h" +#include "graph/compute_graph.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/optimize/common/params.h" +#include "graph/utils/type_utils.h" +#include "omg/omg_inner_types.h" +#include "omg/parser/model_parser.h" +#include "omg/parser/parser_factory.h" +#include "omg/parser/weights_parser.h" +#include "parser/common/pre_checker.h" +#include "proto/ge_ir.pb.h" +#include "register/op_registry.h" + +namespace ge { +namespace { +// The function is incomplete. Currently, only l2_optimize, off_optimize is supported. +const char *const kInputShapeSample1 = "\"input_name1:n1,c1,h1,w1\""; +const char *const kInputShapeSample2 = "\"input_name1:1,3,224,224\""; +const char *const kSplitError1 = "size not equal to 2 split by \":\""; +const char *const kEmptyError = "can not be empty"; +const char *const kFloatNumError = "exist float number"; +const char *const kDigitError = "is not digit"; +const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; +const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; +const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; + +vector SplitInputShape(const std::string &input_shape) { + vector shape_pair_vec; + size_t pos = input_shape.rfind(":"); + if (pos != std::string::npos) { + shape_pair_vec.emplace_back(input_shape.substr(0, pos)); + shape_pair_vec.emplace_back(input_shape.substr(pos + 1, input_shape.size() - pos)); + } + return shape_pair_vec; +} + +static std::map output_type_str_to_datatype = { + {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; + +static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_param) { + if ((s == "true") || (s == "false")) { + return true; + } else { + ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"parameter", "value"}, {atc_param, s}); + GELOGE(PARAM_INVALID, "Input parameter[--%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str()); + return false; + } +} + +bool CheckDigitStr(std::string &str) { + for (char c : str) { + if (!isdigit(c)) { + GELOGE(domi::FAILED, "value[%s] is not positive integer", str.c_str()); + return false; + } + } + return true; +} + +Status StringToInt(std::string &str, int32_t &value) { + try { + if (!CheckDigitStr(str)) { + GELOGE(PARAM_INVALID, "Invalid of digit string: %s ", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", str, "is not positive integer"}); + return PARAM_INVALID; + } + value = stoi(str); + } catch (std::invalid_argument &) { + GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch invalid_argument.", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"output_type", str}); + return PARAM_INVALID; + } catch (std::out_of_range &) { + GELOGE(PARAM_INVALID, "Invalid of digit string: %s, catch out_of_range.", str.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"output_type", str}); + return PARAM_INVALID; + } + return SUCCESS; +} + +Status VerifyOutputTypeAndOutNodes(std::vector &out_type_vec) { + std::vector> user_out_nodes = domi::GetContext().user_out_nodes; + std::set out_nodes_info; + for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { + // out_nodes set should include output_type and output_format + std::string tmp = user_out_nodes[i].first + ":" + to_string(user_out_nodes[i].second); + out_nodes_info.emplace(tmp); + } + for (uint32_t i = 0; i < out_type_vec.size(); ++i) { + if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", out_type_vec[i], kOutputTypeError}); + GELOGE(domi::FAILED, "Invalid value for --output_type[%s], %s.", out_type_vec[i].c_str(), kOutputTypeError); + return domi::FAILED; + } + } + return domi::SUCCESS; +} + +Status ParseOutputType(const std::string &output_type, std::map> &out_type_index_map, + std::map> &out_type_dt_map) { + if (output_type.find(':') == std::string::npos) { + GELOGI("output_type is not multiple nodes, means all out nodes"); + auto it = output_type_str_to_datatype.find(output_type); + if (it == output_type_str_to_datatype.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", output_type, kOutputTypeSupport}); + GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport); + return domi::FAILED; + } + return domi::SUCCESS; + } + std::vector out_type_vec; + vector nodes_v = StringUtils::Split(output_type, ';'); + for (const string &node : nodes_v) { + vector node_index_type_v = StringUtils::Split(node, ':'); + if (node_index_type_v.size() != 3) { // The size must be 3. + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", node, kOutputTypeSample}); + GELOGE(PARAM_INVALID, "Invalid value for --output_type[%s], %s.", node.c_str(), kOutputTypeSample); + return domi::FAILED; + } + ge::DataType tmp_dt; + std::string node_name = StringUtils::Trim(node_index_type_v[0]); + std::string index_str = StringUtils::Trim(node_index_type_v[1]); + int32_t index; + if (StringToInt(index_str, index) != SUCCESS) { + GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s.", index_str.c_str()); + return domi::FAILED; + } + std::string dt_value = StringUtils::Trim(node_index_type_v[2]); + auto it = output_type_str_to_datatype.find(dt_value); + if (it == output_type_str_to_datatype.end()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", dt_value, kOutputTypeSupport}); + GELOGE(ge::PARAM_INVALID, "Invalid value for --output_type[%s], %s.", dt_value.c_str(), kOutputTypeSupport); + return domi::FAILED; + } else { + tmp_dt = it->second; + } + out_type_vec.push_back(node_name + ":" + index_str); + auto it_index = out_type_index_map.find(node_name); + if (it_index == out_type_index_map.end()) { + vector tmp_vec; + tmp_vec.push_back(index); + out_type_index_map.emplace(node_name, tmp_vec); + } else { + it_index->second.push_back(index); + } + + auto it_dt = out_type_dt_map.find(node_name); + if (it_dt == out_type_dt_map.end()) { + vector tmp_vec; + tmp_vec.push_back(tmp_dt); + out_type_dt_map.emplace(node_name, tmp_vec); + } else { + it_dt->second.push_back(tmp_dt); + } + } + return VerifyOutputTypeAndOutNodes(out_type_vec); +} + +Status CheckOutNode(ge::OpDescPtr op_desc, int32_t index) { + int32_t out_size = op_desc->GetOutputsSize(); + if (index < 0 || index >= out_size) { + GELOGE(domi::FAILED, + "out_node [%s] output index:%d must be smaller " + "than node output size:%d and can not be negative!", + op_desc->GetName().c_str(), index, out_size); + std::string fail_reason = "output index:" + to_string(index) + + " must be smaller than output size:" + to_string(out_size) + " and can not be negative!"; + ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"parameter", "value", "reason"}, + {"out_nodes", op_desc->GetName(), fail_reason}); + return domi::FAILED; + } + return domi::SUCCESS; +} + +Status GetOutputLeaf(NodePtr node, std::vector> &output_nodes_info) { + ge::OpDescPtr tmpDescPtr = node->GetOpDesc(); + if (tmpDescPtr == nullptr) { + GELOGE(domi::FAILED, "Get outnode op desc fail."); + return domi::FAILED; + } + size_t size = tmpDescPtr->GetOutputsSize(); + if (node->GetType() != NETOUTPUT) { + for (size_t index = 0; index < size; ++index) { + output_nodes_info.push_back(std::make_pair(node, index)); + } + } else { + const auto in_anchors = node->GetAllInDataAnchors(); + for (auto in_anchor : in_anchors) { + auto out_anchor = in_anchor->GetPeerOutAnchor(); + if (out_anchor == nullptr) { + GELOGE(domi::FAILED, "Get leaf node op desc fail."); + return domi::FAILED; + } + auto out_node = out_anchor->GetOwnerNode(); + output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx())); + } + } + return SUCCESS; +} + +void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, + std::vector &output_nodes_name) { + output_nodes_name.clear(); + if (domi::GetContext().out_top_names.empty()) { + // tf process, no top name. + for (const auto output_node_info : output_nodes_info) { + std::string node_name = output_node_info.first->GetName(); + int32_t index = output_node_info.second; + output_nodes_name.push_back(node_name + ":" + std::to_string(index)); + } + return; + } + // caffe process, need add top name after node_name:index + for (size_t i = 0; i < output_nodes_info.size(); ++i) { + std::string node_name = output_nodes_info[i].first->GetName(); + int32_t index = output_nodes_info[i].second; + if (i < domi::GetContext().out_top_names.size()) { + output_nodes_name.push_back(node_name + ":" + std::to_string(index) + ":" + domi::GetContext().out_top_names[i]); + } else { + GELOGW("Get top name of node [%s] fail.", node_name.c_str()); + output_nodes_name.push_back(node_name + ":" + std::to_string(index)); + } + } +} +} // namespace + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOutputFp16NodesFormat(const string &is_output_fp16) { + if (is_output_fp16.empty()) { + return SUCCESS; + } + + vector &output_formats = domi::GetContext().output_formats; + output_formats.clear(); + vector node_format_vec = StringUtils::Split(is_output_fp16, ','); + for (auto &is_fp16 : node_format_vec) { + StringUtils::Trim(is_fp16); + if (!CheckInputTrueOrFalse(is_fp16, "is_output_adjust_hw_layout")) { + GELOGE(PARAM_INVALID, "Invalid Param, is_output_adjust_hw_layout only support true/false: but is [%s]", + is_output_fp16.c_str()); + return PARAM_INVALID; + } + if (is_fp16 == "false") { + output_formats.push_back(DOMI_TENSOR_ND); + } else if (is_fp16 == "true") { + output_formats.push_back(domi::DOMI_TENSOR_NC1HWC0); + } + } + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SetOutputNodeInfo(ge::Graph &graph, + const std::string &output_type, + const std::string &output) { + ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); + GE_CHECK_NOTNULL(compute_graph); + + std::vector> user_out_nodes = domi::GetContext().user_out_nodes; + std::vector output_formats = domi::GetContext().output_formats; + std::vector> output_nodes_info; + std::vector output_nodes_name; + std::map> out_type_index_map; + std::map> out_type_dt_map; + if (!output_type.empty()) { + if (ParseOutputType(output_type, out_type_index_map, out_type_dt_map) != SUCCESS) { + GELOGE(domi::FAILED, "Parse output_type failed."); + return domi::FAILED; + } + } + + // User declared outputs + for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { + ge::NodePtr out_node = compute_graph->FindNode(user_out_nodes[i].first); + if (out_node == nullptr) { + GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", user_out_nodes[i].first.c_str()); + return domi::FAILED; + } + auto op_desc = out_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (CheckOutNode(op_desc, user_out_nodes[i].second) != SUCCESS) { + GELOGE(domi::FAILED, "Check out node (%s) fail.", user_out_nodes[i].first.c_str()); + return domi::FAILED; + } + if (i < output_formats.size()) { + if (output_formats[i] == domi::DOMI_TENSOR_NC1HWC0) { + GELOGI("The output node [%s] should be set NC1HWC0", user_out_nodes[i].first.c_str()); + if (!ge::AttrUtils::SetBool(op_desc, "output_set_fp16_nc1hwc0", true)) { + GELOGW("The output node [%s] set NC1HWC0 failed", user_out_nodes[i].first.c_str()); + } + } + } + auto it_index = out_type_index_map.find(user_out_nodes[i].first); + auto it_dt = out_type_dt_map.find(user_out_nodes[i].first); + if ((it_index != out_type_index_map.end()) && (it_dt != out_type_dt_map.end())) { + GELOGI("The output node [%s] need to be set output_type", user_out_nodes[i].first.c_str()); + (void)ge::AttrUtils::SetListDataType(op_desc, "_output_dt_list", it_dt->second); + (void)ge::AttrUtils::SetListInt(op_desc, "_output_dt_index", it_index->second); + } + output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second)); + } + // default output node (leaf) + if (user_out_nodes.empty()) { + for (ge::NodePtr node : compute_graph->GetDirectNode()) { + if (!node->GetInDataNodes().empty() && node->GetOutDataNodes().empty()) { + Status ret = GetOutputLeaf(node, output_nodes_info); + GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "find leaf fail."); + } + } + } + GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); + compute_graph->SetGraphOutNodesInfo(output_nodes_info); + domi::GetContext().net_out_nodes = output_nodes_name; + return domi::SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ParseInputShape( + const string &input_shape, unordered_map> &shape_map, + vector>> &user_shape_map, bool is_dynamic_input) { + vector shape_vec = StringUtils::Split(input_shape, ';'); + const int DEFAULT_SHAPE_PAIR_SIZE = 2; + for (const auto &shape : shape_vec) { + vector shape_pair_vec = SplitInputShape(shape); + if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kSplitError1, kInputShapeSample1}); + GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", + shape.c_str(), kSplitError1, kInputShapeSample1); + return false; + } + if (shape_pair_vec[1].empty()) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kEmptyError, kInputShapeSample1}); + GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", + shape.c_str(), kEmptyError, kInputShapeSample1); + return false; + } + + vector shape_value_strs = StringUtils::Split(shape_pair_vec[1], ','); + vector shape_values; + for (auto &shape_value_str : shape_value_strs) { + // stoul: The method may throw an exception: invalid_argument/out_of_range + if (std::string::npos != shape_value_str.find('.')) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kFloatNumError, kInputShapeSample2}); + GELOGW("Parse input parameter [--input_shape]'s shape[%s] failed, reason: %s, correct sample is %s.", + shape.c_str(), kFloatNumError, kInputShapeSample2); + return false; + } + + long left_result = 0; + try { + left_result = stol(StringUtils::Trim(shape_value_str)); + if (!shape_value_str.empty() && (shape_value_str.front() == '-')) { + // The value maybe dynamic shape [-1], need substr it and verify isdigit. + shape_value_str = shape_value_str.substr(1); + } + for (char c : shape_value_str) { + if (!isdigit(c)) { + ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"shape", "reason", "sample"}, + {shape, kDigitError, kInputShapeSample2}); + GELOGE(PARAM_INVALID, "--input_shape's shape value[%s] is not digit", shape_value_str.c_str()); + return false; + } + } + } catch (const std::out_of_range &) { + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, + {"input_shape", shape_value_str}); + GELOGW("Input parameter[--input_shape]’s value[%s] cause out of range execption!", shape_value_str.c_str()); + return false; + } catch (const std::invalid_argument &) { + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, + {"input_shape", shape_value_str}); + GELOGW("Input parameter[--input_shape]’s value[%s] cause invalid argument!", shape_value_str.c_str()); + return false; + } catch (...) { + ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "value"}, + {"input_shape", shape_value_str}); + GELOGW("Input parameter[--input_shape]’s value[%s] cause unkown execption!", shape_value_str.c_str()); + return false; + } + int64_t result = left_result; + // - 1 is not currently supported + if (!is_dynamic_input && result <= 0) { + ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape", "result"}, {shape, std::to_string(result)}); + GELOGW( + "Input parameter[--input_shape]’s shape value[%s] is invalid, " + "expect positive integer, but value is %ld.", + shape.c_str(), result); + return false; + } + shape_values.push_back(result); + } + + shape_map.emplace(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); + user_shape_map.push_back(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); + } + + return true; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOutputNodes(const string &out_nodes) { + try { + // parse output node + if (!out_nodes.empty()) { + domi::GetContext().out_nodes_map.clear(); + domi::GetContext().user_out_nodes.clear(); + + vector nodes_v = StringUtils::Split(out_nodes, ';'); + for (const string &node : nodes_v) { + vector key_value_v = StringUtils::Split(node, ':'); + if (key_value_v.size() != 2) { // The size must be 2. + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--out_nodes", node, "the correct format is \"node_name1:0;node_name1:1;node_name2:0\""}); + GELOGE(PARAM_INVALID, + "The input format of --out_nodes is invalid, the correct format is " + "\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.", + node.c_str()); + return PARAM_INVALID; + } + auto iter = domi::GetContext().out_nodes_map.find(key_value_v[0]); + // stoi: The method may throw an exception: invalid_argument/out_of_range + if (!CheckDigitStr(key_value_v[1])) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--out_nodes", out_nodes, "is not positive integer"}); + GELOGE(PARAM_INVALID, "This str must be digit string, while the actual input is %s", out_nodes.c_str()); + return PARAM_INVALID; + } + int32_t index = stoi(StringUtils::Trim(key_value_v[1])); + if (iter != domi::GetContext().out_nodes_map.end()) { + iter->second.emplace_back(index); + } else { + std::vector index_v; + index_v.emplace_back(index); + domi::GetContext().out_nodes_map.emplace(key_value_v[0], index_v); + } + domi::GetContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index)); + } + } + } catch (std::invalid_argument &) { + GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "value"}, {"out_nodes", out_nodes}); + return PARAM_INVALID; + } catch (std::out_of_range &) { + GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "value"}, {"out_nodes", out_nodes}); + return PARAM_INVALID; + } + return SUCCESS; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ParseOpConf(const char *op_conf) { + if (op_conf != nullptr && *op_conf != '\0') { + // divided by ":" + PropertiesManager::Instance().SetPropertyDelimiter(OP_CONF_DELIMITER); + // Parsing the op_conf configuration item file + if (!PropertiesManager::Instance().Init(op_conf)) { + GELOGE(FAILED, "op_name_map init failed!"); + return FAILED; + } + // Return map and put it into ATC global variable + domi::GetContext().op_conf_map = PropertiesManager::Instance().GetPropertyMap(); + } + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/common/model_parser/graph_parser_util.h b/src/ge/common/model_parser/graph_parser_util.h new file mode 100644 index 00000000..b38642c2 --- /dev/null +++ b/src/ge/common/model_parser/graph_parser_util.h @@ -0,0 +1,62 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_COMMON_GRAPH_PARSER_UTIL_H_ +#define GE_COMMON_GRAPH_PARSER_UTIL_H_ + +#include +#include +#include +#include +#include "framework/common/types.h" +#include "framework/omg/omg_inner_types.h" +#include "proto/ge_ir.pb.h" +#include "proto/om.pb.h" + +#include "graph/compute_graph.h" +#include "graph/graph.h" +#include "graph/model.h" +#include "runtime/kernel.h" + +using domi::Status; +using std::pair; +using std::string; +using std::unordered_map; +using std::vector; + +namespace ge { +Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); + +Status ParseOutputFp16NodesFormat(const string &is_output_fp16); + +Status ParseOutputNodes(const string &out_nodes); + +bool ParseInputShape(const string &input_shape, unordered_map> &shape_map, + vector>> &user_shape_map, bool is_dynamic_input); + +Status ParseOpConf(const char *op_conf); +} // namespace ge + +namespace domi { +/** + * @ingroup domi_omg + * @brief get omg context + * @return reference of OmgContext + */ +ge::OmgContext &GetContext(); +} // namespace domi + +#endif // GE_COMMON_GRAPH_PARSER_UTIL_H_ diff --git a/src/ge/common/model_saver.cc b/src/ge/common/model_saver.cc index 11d9e804..821fde60 100644 --- a/src/ge/common/model_saver.cc +++ b/src/ge/common/model_saver.cc @@ -60,8 +60,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi mode_t mode = S_IRUSR | S_IWUSR; int32_t fd = mmOpen2(real_path, O_RDWR | O_CREAT | O_TRUNC, mode); if (fd == EN_ERROR || fd == EN_INVALID_PARAM) { - ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"filepath", "errMsg"}, {file_path, strerror(errno)}); - GELOGE(FAILED, "Open file failed. file path : %s, %s", file_path, strerror(errno)); + ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file_path, strerror(errno)}); + GELOGE(FAILED, "Open file[%s] failed. %s", file_path, strerror(errno)); return FAILED; } const char *model_char = model_str.c_str(); @@ -69,8 +69,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelSaver::SaveJsonToFi // Write data to file mmSsize_t mmpa_ret = mmWrite(fd, const_cast((const void *)model_char), len); if (mmpa_ret == EN_ERROR || mmpa_ret == EN_INVALID_PARAM) { - ErrorManager::GetInstance().ATCReportErrMessage("E19003", {"mmpa_ret", "errMsg"}, - {std::to_string(mmpa_ret), strerror(errno)}); + ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"file", "errmsg"}, {file_path, strerror(errno)}); // Need to both print the error info of mmWrite and mmClose, so return ret after mmClose GELOGE(FAILED, "Write to file failed. errno = %d, %s", mmpa_ret, strerror(errno)); ret = FAILED; diff --git a/src/ge/common/profiling/profiling_manager.cc b/src/ge/common/profiling/profiling_manager.cc index ecbbf5f2..364f8298 100644 --- a/src/ge/common/profiling/profiling_manager.cc +++ b/src/ge/common/profiling/profiling_manager.cc @@ -363,16 +363,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin std::string data; for (const auto &task : task_desc_info) { + std::string model_name = task.model_name; std::string op_name = task.op_name; uint32_t block_dim = task.block_dim; uint32_t task_id = task.task_id; uint32_t stream_id = task.stream_id; - data = op_name.append(" ").append(std::to_string(block_dim) - .append(" ") - .append(std::to_string(task_id)) - .append(" ") - .append(std::to_string(stream_id)) - .append("\n")); + data = model_name.append(" ").append(op_name).append(" ").append(std::to_string(block_dim) + .append(" ") + .append(std::to_string(task_id)) + .append(" ") + .append(std::to_string(stream_id)) + .append("\n")); Msprof::Engine::ReporterData reporter_data{}; reporter_data.deviceId = device_id; @@ -403,7 +404,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::Profilin std::string data; for (const auto &graph : compute_graph_desc_info) { - data.append("op_name:").append(graph.op_name).append(" op_type:").append(graph.op_type); + data.append("model_name:") + .append(graph.model_name) + .append(" op_name:") + .append(graph.op_name) + .append(" op_type:") + .append(graph.op_type); for (size_t i = 0; i < graph.input_format.size(); ++i) { data.append(" input_id:") .append(std::to_string(i)) diff --git a/src/ge/common/properties_manager.cc b/src/ge/common/properties_manager.cc index cf1ada05..0c2b1db6 100644 --- a/src/ge/common/properties_manager.cc +++ b/src/ge/common/properties_manager.cc @@ -20,15 +20,204 @@ #include #include +#include "common/ge/ge_util.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/debug/log.h" #include "framework/common/ge_types.h" #include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" +#include "graph/ge_context.h" #include "graph/utils/attr_utils.h" namespace ge { +namespace { +const string kEnableFlag = "1"; + +const uint32_t kAicoreOverflow = (0x1 << 0); +const uint32_t kAtomicOverflow = (0x1 << 1); +const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow); +} // namespace + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties::DumpProperties(const DumpProperties &other) { + CopyFrom(other); +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties &DumpProperties::operator=( + const DumpProperties &other) { + CopyFrom(other); + return *this; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::InitByOptions() { + enable_dump_.clear(); + enable_dump_debug_.clear(); + dump_path_.clear(); + dump_step_.clear(); + dump_mode_.clear(); + is_op_debug_ = false; + op_debug_mode_ = 0; + + string enable_dump; + (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP, enable_dump); + enable_dump_ = enable_dump; + + string enable_dump_debug; + (void)GetContext().GetOption(OPTION_EXEC_ENABLE_DUMP_DEBUG, enable_dump_debug); + enable_dump_debug_ = enable_dump_debug; + + if ((enable_dump_ == kEnableFlag) || (enable_dump_debug_ == kEnableFlag)) { + string dump_path; + if (GetContext().GetOption(OPTION_EXEC_DUMP_PATH, dump_path) == GRAPH_SUCCESS) { + if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { + dump_path = dump_path + "/"; + } + dump_path = dump_path + CurrentTimeInStr() + "/"; + GELOGI("Get dump path %s successfully", dump_path.c_str()); + SetDumpPath(dump_path); + } else { + GELOGW("DUMP_PATH is not set"); + } + } + + if (enable_dump_ == kEnableFlag) { + string dump_step; + if (GetContext().GetOption(OPTION_EXEC_DUMP_STEP, dump_step) == GRAPH_SUCCESS) { + GELOGD("Get dump step %s successfully", dump_step.c_str()); + SetDumpStep(dump_step); + } + string dump_mode; + if (GetContext().GetOption(OPTION_EXEC_DUMP_MODE, dump_mode) == GRAPH_SUCCESS) { + GELOGD("Get dump mode %s successfully", dump_mode.c_str()); + SetDumpMode(dump_mode); + } + AddPropertyValue(DUMP_ALL_MODEL, {}); + } + + SetDumpDebugOptions(); +} + +// The following is the new dump scenario of the fusion operator +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::AddPropertyValue( + const std::string &model, const std::set &layers) { + for (const std::string &layer : layers) { + GELOGI("This model %s config to dump layer %s", model.c_str(), layer.c_str()); + } + + model_dump_properties_map_[model] = layers; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::DeletePropertyValue(const std::string &model) { + auto iter = model_dump_properties_map_.find(model); + if (iter != model_dump_properties_map_.end()) { + model_dump_properties_map_.erase(iter); + } +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpProperties::GetAllDumpModel() const { + std::set model_list; + for (auto &iter : model_dump_properties_map_) { + model_list.insert(iter.first); + } + + return model_list; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set DumpProperties::GetPropertyValue( + const std::string &model) const { + auto iter = model_dump_properties_map_.find(model); + if (iter != model_dump_properties_map_.end()) { + return iter->second; + } + return {}; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool DumpProperties::IsLayerNeedDump( + const std::string &model, const std::string &om_name, const std::string &op_name) const { + // if dump all + if (model_dump_properties_map_.find(DUMP_ALL_MODEL) != model_dump_properties_map_.end()) { + return true; + } + + // if this model need dump + auto om_name_iter = model_dump_properties_map_.find(om_name); + auto model_name_iter = model_dump_properties_map_.find(model); + if (om_name_iter != model_dump_properties_map_.end() || model_name_iter != model_dump_properties_map_.end()) { + // if no dump layer info, dump all layer in this model + auto model_iter = om_name_iter != model_dump_properties_map_.end() ? om_name_iter : model_name_iter; + if (model_iter->second.empty()) { + return true; + } + + return model_iter->second.find(op_name) != model_iter->second.end(); + } + + GELOGD("Model %s is not seated to be dump.", model.c_str()); + return false; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpPath(const std::string &path) { + dump_path_ = path; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpPath() const { return dump_path_; } + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpStep(const std::string &step) { + dump_step_ = step; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpStep() const { return dump_step_; } + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpMode(const std::string &mode) { + dump_mode_ = mode; +} + +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string DumpProperties::GetDumpMode() const { return dump_mode_; } + +void DumpProperties::CopyFrom(const DumpProperties &other) { + if (&other != this) { + enable_dump_ = other.enable_dump_; + enable_dump_debug_ = other.enable_dump_debug_; + dump_path_ = other.dump_path_; + dump_step_ = other.dump_step_; + dump_mode_ = other.dump_mode_; + + model_dump_properties_map_ = other.model_dump_properties_map_; + is_op_debug_ = other.is_op_debug_; + op_debug_mode_ = other.op_debug_mode_; + } +} + +void DumpProperties::SetDumpDebugOptions() { + if (enable_dump_debug_ == kEnableFlag) { + string dump_debug_mode; + if (GetContext().GetOption(OPTION_EXEC_DUMP_DEBUG_MODE, dump_debug_mode) == GRAPH_SUCCESS) { + GELOGD("Get dump debug mode %s successfully", dump_debug_mode.c_str()); + } else { + GELOGW("Dump debug mode is not set."); + return; + } + + if (dump_debug_mode == OP_DEBUG_AICORE) { + GELOGD("ge.exec.dumpDebugMode=aicore_overflow, op debug is open."); + is_op_debug_ = true; + op_debug_mode_ = kAicoreOverflow; + } else if (dump_debug_mode == OP_DEBUG_ATOMIC) { + GELOGD("ge.exec.dumpDebugMode=atomic_overflow, op debug is open."); + is_op_debug_ = true; + op_debug_mode_ = kAtomicOverflow; + } else if (dump_debug_mode == OP_DEBUG_ALL) { + GELOGD("ge.exec.dumpDebugMode=all, op debug is open."); + is_op_debug_ = true; + op_debug_mode_ = kAllOverflow; + } else { + GELOGW("ge.exec.dumpDebugMode is invalid."); + } + } else { + GELOGI("ge.exec.enableDumpDebug is false or is not set."); + } +} + PropertiesManager::PropertiesManager() : is_inited_(false), delimiter("=") {} PropertiesManager::~PropertiesManager() {} @@ -159,131 +348,22 @@ PropertiesManager::GetPropertyMap() { // Set separator FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetPropertyDelimiter(const std::string &de) { + std::lock_guard lock(mutex_); delimiter = de; } -// The following is the new dump scenario of the fusion operator -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::AddDumpPropertyValue( - const std::string &model, const std::set &layers) { - for (const std::string &layer : layers) { - GELOGI("This model %s config to dump layer %s", model.c_str(), layer.c_str()); - } - - std::lock_guard lock(dump_mutex_); - model_dump_properties_map_[model] = layers; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::DeleteDumpPropertyValue( - const std::string &model) { - std::lock_guard lock(dump_mutex_); - auto iter = model_dump_properties_map_.find(model); - if (iter != model_dump_properties_map_.end()) { - model_dump_properties_map_.erase(iter); - } -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::ClearDumpPropertyValue() { - std::lock_guard lock(dump_mutex_); - model_dump_properties_map_.clear(); -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set PropertiesManager::GetAllDumpModel() { - std::set model_list; - std::lock_guard lock(dump_mutex_); - for (auto &iter : model_dump_properties_map_) { - model_list.insert(iter.first); - } - - return model_list; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::set PropertiesManager::GetDumpPropertyValue( - const std::string &model) { - std::lock_guard lock(dump_mutex_); - auto iter = model_dump_properties_map_.find(model); - if (iter != model_dump_properties_map_.end()) { - return iter->second; - } - return {}; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool PropertiesManager::IsLayerNeedDump(const std::string &model, - const std::string &om_name, - const std::string &op_name) { - std::lock_guard lock(dump_mutex_); - // if dump all - if (model_dump_properties_map_.find(ge::DUMP_ALL_MODEL) != model_dump_properties_map_.end()) { - return true; - } - - // if this model need dump - auto om_name_iter = model_dump_properties_map_.find(om_name); - auto model_name_iter = model_dump_properties_map_.find(model); - if (om_name_iter != model_dump_properties_map_.end() || model_name_iter != model_dump_properties_map_.end()) { - // if no dump layer info, dump all layer in this model - auto model_iter = om_name_iter != model_dump_properties_map_.end() ? om_name_iter : model_name_iter; - if (model_iter->second.empty()) { - return true; - } - - return model_iter->second.find(op_name) != model_iter->second.end(); - } - - GELOGD("Model %s is not seated to be dump.", model.c_str()); - return false; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY DumpProperties &PropertiesManager::GetDumpProperties( + uint64_t session_id) { + std::lock_guard lock(mutex_); + // If session_id is not found in dump_properties_map_, operator[] will insert one. + return dump_properties_map_[session_id]; } -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool PropertiesManager::QueryModelDumpStatus( - const std::string &model) { - std::lock_guard lock(dump_mutex_); - auto iter = model_dump_properties_map_.find(model); - if (iter != model_dump_properties_map_.end()) { - return true; - } else if (model_dump_properties_map_.find(ge::DUMP_ALL_MODEL) != model_dump_properties_map_.end()) { - return true; +FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::RemoveDumpProperties(uint64_t session_id) { + std::lock_guard lock(mutex_); + auto iter = dump_properties_map_.find(session_id); + if (iter != dump_properties_map_.end()) { + dump_properties_map_.erase(iter); } - return false; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpOutputModel( - const std::string &output_mode) { - std::lock_guard lock(dump_mutex_); - this->output_mode_ = output_mode; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpOutputModel() { - std::lock_guard lock(dump_mutex_); - return this->output_mode_; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpOutputPath( - const std::string &output_path) { - std::lock_guard lock(dump_mutex_); - this->output_path_ = output_path; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpOutputPath() { - std::lock_guard lock(dump_mutex_); - return this->output_path_; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpStep(const std::string &dump_step) { - std::lock_guard lock(dump_mutex_); - this->dump_step_ = dump_step; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpStep() { - std::lock_guard lock(dump_mutex_); - return this->dump_step_; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void PropertiesManager::SetDumpMode(const std::string &dump_mode) { - std::lock_guard lock(dump_mutex_); - this->dump_mode_ = dump_mode; -} - -FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string PropertiesManager::GetDumpMode() { - std::lock_guard lock(dump_mutex_); - return this->dump_mode_; } } // namespace ge diff --git a/src/ge/common/properties_manager.h b/src/ge/common/properties_manager.h index 7cbb5949..3b1547f5 100644 --- a/src/ge/common/properties_manager.h +++ b/src/ge/common/properties_manager.h @@ -32,6 +32,50 @@ static const char *USE_FUSION __attribute__((unused)) = "FMK_USE_FUSION"; static const char *TIMESTAT_ENABLE __attribute__((unused)) = "DAVINCI_TIMESTAT_ENABLE"; static const char *ANNDROID_DEBUG __attribute__((unused)) = "ANNDROID_DEBUG"; +class DumpProperties { + public: + DumpProperties() = default; + ~DumpProperties() = default; + DumpProperties(const DumpProperties &dump); + DumpProperties &operator=(const DumpProperties &dump); + + void InitByOptions(); + + void AddPropertyValue(const std::string &model, const std::set &layers); + void DeletePropertyValue(const std::string &model); + + std::set GetAllDumpModel() const; + std::set GetPropertyValue(const std::string &model) const; + bool IsLayerNeedDump(const std::string &model, const std::string &om_name, const std::string &op_name) const; + + void SetDumpPath(const std::string &path); + std::string GetDumpPath() const; + + void SetDumpStep(const std::string &step); + std::string GetDumpStep() const; + + void SetDumpMode(const std::string &mode); + std::string GetDumpMode() const; + + bool IsOpDebugOpen() const { return is_op_debug_; } + uint32_t GetOpDebugMode() const { return op_debug_mode_; } + + private: + void CopyFrom(const DumpProperties &other); + void SetDumpDebugOptions(); + + string enable_dump_; + string enable_dump_debug_; + + std::string dump_path_; + std::string dump_step_; + std::string dump_mode_; + std::map> model_dump_properties_map_; + + bool is_op_debug_ = false; + uint32_t op_debug_mode_ = 0; +}; + class PropertiesManager { public: // Singleton @@ -81,21 +125,8 @@ class PropertiesManager { */ void SetPropertyDelimiter(const std::string &de); - void AddDumpPropertyValue(const std::string &model, const std::set &layers); - std::set GetAllDumpModel(); - std::set GetDumpPropertyValue(const std::string &model); - bool IsLayerNeedDump(const std::string &model, const std::string &om_name, const std::string &op_name); - void DeleteDumpPropertyValue(const std::string &model); - void ClearDumpPropertyValue(); - bool QueryModelDumpStatus(const std::string &model); - void SetDumpOutputModel(const std::string &output_model); - std::string GetDumpOutputModel(); - void SetDumpOutputPath(const std::string &output_path); - std::string GetDumpOutputPath(); - void SetDumpStep(const std::string &dump_step); - std::string GetDumpStep(); - void SetDumpMode(const std::string &dump_mode); - std::string GetDumpMode(); + DumpProperties &GetDumpProperties(uint64_t session_id); + void RemoveDumpProperties(uint64_t session_id); private: // Private construct, destructor @@ -119,12 +150,7 @@ class PropertiesManager { std::map properties_map_; std::mutex mutex_; - std::string output_mode_; - std::string output_path_; - std::string dump_step_; - std::string dump_mode_; - std::map> model_dump_properties_map_; // model_dump_layers_map_ - std::mutex dump_mutex_; + std::map dump_properties_map_; }; } // namespace ge diff --git a/src/ge/common/tbe_kernel_store.h b/src/ge/common/tbe_kernel_store.h index da231358..51d69af2 100644 --- a/src/ge/common/tbe_kernel_store.h +++ b/src/ge/common/tbe_kernel_store.h @@ -28,7 +28,6 @@ #include "graph/op_kernel_bin.h" namespace ge { - using TBEKernel = ge::OpKernelBin; using TBEKernelPtr = std::shared_ptr; diff --git a/src/ge/common/types.cc b/src/ge/common/types.cc index 97761dea..2de75ff6 100644 --- a/src/ge/common/types.cc +++ b/src/ge/common/types.cc @@ -26,6 +26,11 @@ const std::string DUMP_LAYER = "layer"; const std::string DUMP_FILE_PATH = "path"; const std::string DUMP_MODE = "dump_mode"; +// op debug mode +const std::string OP_DEBUG_AICORE = "aicore_overflow"; +const std::string OP_DEBUG_ATOMIC = "atomic_overflow"; +const std::string OP_DEBUG_ALL = "all"; + const int DEFAULT_FORMAT = static_cast(ge::FORMAT_NCHW); // Supported public property names const std::string PROP_OME_START_TIME = "ome_start_time"; // start time @@ -277,8 +282,8 @@ REGISTER_OPTYPE_DEFINE(GETSPAN, "GetSpan"); REGISTER_OPTYPE_DEFINE(STOPGRADIENT, "StopGradient"); REGISTER_OPTYPE_DEFINE(PREVENTGRADIENT, "PreventGradient"); REGISTER_OPTYPE_DEFINE(GUARANTEECONST, "GuaranteeConst"); -REGISTER_OPTYPE_DEFINE(BROADCASTGRADIENTARGS, "BroadcastGradientArgs") -REGISTER_OPTYPE_DEFINE(BROADCASTARGS, "BroadcastArgs") +REGISTER_OPTYPE_DEFINE(BROADCASTGRADIENTARGS, "BroadcastGradientArgs"); +REGISTER_OPTYPE_DEFINE(BROADCASTARGS, "BroadcastArgs"); REGISTER_OPTYPE_DEFINE(CONFUSIONMATRIX, "ConfusionMatrix"); REGISTER_OPTYPE_DEFINE(RANK, "Rank"); REGISTER_OPTYPE_DEFINE(PLACEHOLDER, "PlaceHolder"); @@ -286,6 +291,8 @@ REGISTER_OPTYPE_DEFINE(END, "End"); REGISTER_OPTYPE_DEFINE(BASICLSTMCELL, "BasicLSTMCell"); REGISTER_OPTYPE_DEFINE(GETNEXT, "GetNext"); REGISTER_OPTYPE_DEFINE(INITDATA, "InitData"); +REGISTER_OPTYPE_DEFINE(REFIDENTITY, "RefIdentity"); +REGISTER_OPTYPE_DEFINE(BITCAST, "Bitcast"); /***************Ann special operator*************************/ REGISTER_OPTYPE_DEFINE(ANN_MEAN, "AnnMean"); @@ -376,6 +383,8 @@ REGISTER_OPTYPE_DEFINE(HCOMALLREDUCE, "HcomAllReduce"); REGISTER_OPTYPE_DEFINE(HCOMREDUCESCATTER, "HcomReduceScatter"); REGISTER_OPTYPE_DEFINE(HCOMSEND, "HcomSend"); REGISTER_OPTYPE_DEFINE(HCOMRECEIVE, "HcomReceive"); +REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead"); +REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite"); REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign"); REGISTER_OPTYPE_DEFINE(VARISINITIALIZEDOP, "VarIsInitializedOp"); @@ -479,72 +488,72 @@ const uint64_t ALLOC_MEMORY_MAX_SIZE = 536870912; // Max size of 512M. #endif /// -///@brief Magic number of model file +/// @brief Magic number of model file /// const uint32_t MODEL_FILE_MAGIC_NUM = 0x444F4D49; // magic number /// -///@brief Model head length +/// @brief Model head length /// const uint32_t MODEL_FILE_HEAD_LEN = 256; /// -///@ingroup domi_omg -///@brief Input node type +/// @ingroup domi_omg +/// @brief Input node type /// const std::string INPUT_TYPE = "Input"; /// -///@ingroup domi_omg -///@brief AIPP label, label AIPP conv operator +/// @ingroup domi_omg +/// @brief AIPP label, label AIPP conv operator /// const std::string AIPP_CONV_FLAG = "Aipp_Conv_Flag"; /// -///@ingroup domi_omg -///@brief AIPP label, label aipp data operator +/// @ingroup domi_omg +/// @brief AIPP label, label aipp data operator /// const std::string AIPP_DATA_FLAG = "Aipp_Data_Flag"; /// -///@ingroup domi_omg -///@brief Record the w dimension of model input corresponding to dynamic AIPP +/// @ingroup domi_omg +/// @brief Record the w dimension of model input corresponding to dynamic AIPP /// const std::string AIPP_RELATED_DATA_DIM_W = "aipp_related_data_dim_w"; /// -///@ingroup domi_omg -///@brief Record the H dimension of model input corresponding to dynamic AIPP +/// @ingroup domi_omg +/// @brief Record the H dimension of model input corresponding to dynamic AIPP /// const std::string AIPP_RELATED_DATA_DIM_H = "aipp_related_data_dim_h"; /// -///@ingroup domi_omg -///@brief The tag of the data operator. Mark this input to the dynamic AIPP operator +/// @ingroup domi_omg +/// @brief The tag of the data operator. Mark this input to the dynamic AIPP operator /// const std::string INPUT_TO_DYNAMIC_AIPP = "input_to_dynamic_aipp"; /// -///@ingroup domi_omg -///@brief DATA node type +/// @ingroup domi_omg +/// @brief DATA node type /// const std::string DATA_TYPE = "Data"; /// -///@ingroup domi_omg -///@brief DATA node type +/// @ingroup domi_omg +/// @brief DATA node type /// const std::string AIPP_DATA_TYPE = "AippData"; /// -///@ingroup domi_omg -///@brief Frame operator type +/// @ingroup domi_omg +/// @brief Frame operator type /// const std::string FRAMEWORK_OP_TYPE = "FrameworkOp"; /// -///@ingroup domi_omg -///@brief Data node type +/// @ingroup domi_omg +/// @brief Data node type /// const std::string ANN_DATA_TYPE = "AnnData"; const std::string ANN_NETOUTPUT_TYPE = "AnnNetOutput"; @@ -552,136 +561,139 @@ const std::string ANN_DEPTHCONV_TYPE = "AnnDepthConv"; const std::string ANN_CONV_TYPE = "AnnConvolution"; const std::string ANN_FC_TYPE = "AnnFullConnection"; /// -///@ingroup domi_omg -///@brief Convolution node type +/// @ingroup domi_omg +/// @brief Convolution node type /// const std::string NODE_NAME_NET_OUTPUT = "Node_Output"; const std::string NODE_NAME_END_GRAPH = "Node_EndGraph"; +const std::string NODE_NAME_OP_DEBUG = "Node_OpDebug"; +const std::string OP_TYPE_OP_DEBUG = "Opdebug"; + /// -///@ingroup domi_omg -///@brief Convolution node type +/// @ingroup domi_omg +/// @brief Convolution node type /// const std::string OP_TYPE_CONVOLUTION = "Convolution"; /// -///@ingroup domi_omg -///@brief Add convolution node name to AIPP +/// @ingroup domi_omg +/// @brief Add convolution node name to AIPP /// const std::string AIPP_CONV_OP_NAME = "aipp_conv_op"; /// -///@ingroup domi_omg -///@brief Operator configuration item separator +/// @ingroup domi_omg +/// @brief Operator configuration item separator /// const std::string OP_CONF_DELIMITER = ":"; /// -///@ingroup domi_omg -///@brief attr value name +/// @ingroup domi_omg +/// @brief attr value name /// const std::string ATTR_NAME_VALUE1 = "value1"; /// -///@ingroup domi_omg -///@brief attr value name, 6d_2_4d C +/// @ingroup domi_omg +/// @brief attr value name, 6d_2_4d C /// const std::string ATTR_NAME_INPUT_CVALUE = "input_cvalue"; /// -///@ingroup domi_omg -///@brief alpha default value +/// @ingroup domi_omg +/// @brief alpha default value /// const float ALPHA_DEFAULT_VALUE = 1.0; /// -///@ingroup domi_omg -///@brief beta default value +/// @ingroup domi_omg +/// @brief beta default value /// const float BETA_DEFAULT_VALUE = 0.0; /// -///@ingroup domi_omg -///@brief coef default value +/// @ingroup domi_omg +/// @brief coef default value /// const float COEF_DEFAULT_VALUE = 0.0; /// -///@ingroup domi_omg -///@brief Relu6 coef value +/// @ingroup domi_omg +/// @brief Relu6 coef value /// const float RELU6_COEF = 6.0; /// -///@ingroup domi_omg -///@brief stride default value +/// @ingroup domi_omg +/// @brief stride default value /// const uint32_t STRIDE_DEFAULT_VALUE = 1; /// -///@ingroup domi_omg -///@brief pad default value +/// @ingroup domi_omg +/// @brief pad default value /// const uint32_t PAD_DEFAULT_VALUE = 0; /// -///@ingroup domi_omg -///@brief dilation default value +/// @ingroup domi_omg +/// @brief dilation default value /// const int DILATION_DEFAULT_VALUE = 1; /// -///@ingroup domi_omg -///@brief kernel default value +/// @ingroup domi_omg +/// @brief kernel default value /// const uint32_t KERNEL_DEFAULT_VALUE = 0; /// -///@ingroup domi_omg -///@brief defaule convolution group size +/// @ingroup domi_omg +/// @brief defaule convolution group size /// const uint32_t DEFAULT_CONV_GROUP = 1; /// -///@ingroup domi_omg -///@brief Default deconvolution adj +/// @ingroup domi_omg +/// @brief Default deconvolution adj /// const uint32_t DEFAULT_DECONV_ADJ = 0; /// -///@ingroup domi_omg -///@brief Represents value 1 +/// @ingroup domi_omg +/// @brief Represents value 1 /// const uint32_t NUM_ONE = 1; /// -///@ingroup domi_omg -///@brief spatial dim size default value +/// @ingroup domi_omg +/// @brief spatial dim size default value /// const int32_t SPATIAL_DIM_DEFAULT_SIZE = 2; /// -///@ingroup domi_omg -///@brief dim extended default value +/// @ingroup domi_omg +/// @brief dim extended default value /// const int32_t DIM_DEFAULT_VALUE = 1; /// -///@ingroup domi_omg -///@brief The first weight list in opdef is filter +/// @ingroup domi_omg +/// @brief The first weight list in opdef is filter /// const int32_t WEIGHT_FILTER_INDEX = 0; /// -///@ingroup domi_omg -///@brief The second weight list in opdef is bias +/// @ingroup domi_omg +/// @brief The second weight list in opdef is bias /// const int32_t WEIGHT_BIAS_INDEX = 1; const int32_t TENSOR_ND_SUPPORT_SIZE = 8; /// -///@ingroup domi_omg -///@brief NCHW index default value +/// @ingroup domi_omg +/// @brief NCHW index default value /// const uint32_t NCHW_DIM_N = 0; const uint32_t NCHW_DIM_C = 1; @@ -689,8 +701,8 @@ const uint32_t NCHW_DIM_H = 2; const uint32_t NCHW_DIM_W = 3; /// -///@ingroup domi_omg -///@brief KCHW index default value +/// @ingroup domi_omg +/// @brief KCHW index default value /// const uint32_t KCHW_DIM_K = 0; const uint32_t KCHW_DIM_C = 1; @@ -698,8 +710,8 @@ const uint32_t KCHW_DIM_H = 2; const uint32_t KCHW_DIM_W = 3; /// -///@ingroup domi_omg -///@brief HWCK index default value +/// @ingroup domi_omg +/// @brief HWCK index default value /// const uint32_t HWCK_DIM_H = 0; const uint32_t HWCK_DIM_W = 1; @@ -707,8 +719,8 @@ const uint32_t HWCK_DIM_C = 2; const uint32_t HWCK_DIM_K = 3; /// -///@ingroup domi_omg -///@brief NHWC index default value +/// @ingroup domi_omg +/// @brief NHWC index default value /// const uint32_t NHWC_DIM_N = 0; const uint32_t NHWC_DIM_H = 1; @@ -716,8 +728,8 @@ const uint32_t NHWC_DIM_W = 2; const uint32_t NHWC_DIM_C = 3; /// -///@ingroup domi_omg -///@brief CHWN index default value +/// @ingroup domi_omg +/// @brief CHWN index default value /// const uint32_t CHWN_DIM_N = 3; const uint32_t CHWN_DIM_C = 0; @@ -725,23 +737,23 @@ const uint32_t CHWN_DIM_H = 1; const uint32_t CHWN_DIM_W = 2; /// -///@ingroup domi_omg -///@brief CHW index default value +/// @ingroup domi_omg +/// @brief CHW index default value /// const uint32_t CHW_DIM_C = 0; const uint32_t CHW_DIM_H = 1; const uint32_t CHW_DIM_W = 2; /// -///@ingroup domi_omg -///@brief HWC index default value +/// @ingroup domi_omg +/// @brief HWC index default value /// const uint32_t HWC_DIM_H = 0; const uint32_t HWC_DIM_W = 1; const uint32_t HWC_DIM_C = 2; /// -///@ingroup domi_omg -///@brief Pad index default value +/// @ingroup domi_omg +/// @brief Pad index default value /// const uint32_t PAD_H_HEAD = 0; const uint32_t PAD_H_TAIL = 1; @@ -749,35 +761,35 @@ const uint32_t PAD_W_HEAD = 2; const uint32_t PAD_W_TAIL = 3; /// -///@ingroup domi_omg -///@brief window index default value +/// @ingroup domi_omg +/// @brief window index default value /// const uint32_t WINDOW_H = 0; const uint32_t WINDOW_W = 1; /// -///@ingroup domi_omg -///@brief stride index default value +/// @ingroup domi_omg +/// @brief stride index default value /// const uint32_t STRIDE_H = 0; const uint32_t STRIDE_W = 1; /// -///@ingroup domi_omg -///@brief dilation index default value +/// @ingroup domi_omg +/// @brief dilation index default value /// const uint32_t DILATION_H = 0; const uint32_t DILATION_W = 1; /// -///@ingroup domi_omg -///@brief the num of XRBG channel +/// @ingroup domi_omg +/// @brief the num of XRBG channel /// const uint32_t XRGB_CHN_NUM = 4; /// -///@ingroup domi_omg -///@brief global pooling default value +/// @ingroup domi_omg +/// @brief global pooling default value /// const bool DEFAULT_GLOBAL_POOLING = false; @@ -801,4 +813,4 @@ const uint32_t STREAM_SWITCH_INPUT_NUM = 2; const std::string NODE_NAME_GLOBAL_STEP = "ge_global_step"; const std::string NODE_NAME_GLOBAL_STEP_ASSIGNADD = "global_step_assignadd"; -}; // namespace ge +} // namespace ge diff --git a/src/ge/common/util.cc b/src/ge/common/util.cc index 50ed2f33..a52978af 100644 --- a/src/ge/common/util.cc +++ b/src/ge/common/util.cc @@ -20,12 +20,12 @@ #include #include +#include #include #include #include #include #include -#include #include "external/ge/ge_api_error_codes.h" #include "common/util/error_manager/error_manager.h" @@ -56,6 +56,8 @@ const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M /// The maximum length of the file. /// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 const int kMaxFileSizeLimit = INT_MAX; +const int kMaxBuffSize = 256; +const char *const kPathValidReason = "The path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character"; } // namespace namespace ge { @@ -77,7 +79,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); if (!fs.is_open()) { - ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"realpath"}, {file}); + ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {file, "ifstream is_open failed"}); GELOGE(ge::FAILED, "Open real path[%s] failed.", file); return false; } @@ -90,7 +92,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co fs.close(); if (!ret) { - ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"filepath"}, {file}); + ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"file"}, {file}); GELOGE(ge::FAILED, "Parse file[%s] failed.", file); return ret; } @@ -114,17 +116,18 @@ long GetFileLength(const std::string &input_file) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); unsigned long long file_length = 0; - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, - ErrorManager::GetInstance().ATCReportErrMessage("E10037", {"filepath"}, {input_file}); - return -1, "Open file[%s] failed", input_file.c_str()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, + ErrorManager::GetInstance().ATCReportErrMessage("E19001", {"file", "errmsg"}, {input_file, strerror(errno)}); + return -1, "Open file[%s] failed. %s", input_file.c_str(), strerror(errno)); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), - ErrorManager::GetInstance().ATCReportErrMessage("E10038", {"filepath"}, {input_file}); + ErrorManager::GetInstance().ATCReportErrMessage("E19015", {"filepath"}, {input_file}); return -1, "File[%s] size is 0, not valid.", input_file.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( file_length > kMaxFileSizeLimit, ErrorManager::GetInstance().ATCReportErrMessage( - "E10039", {"filepath", "filesize", "maxlen"}, + "E19016", {"filepath", "filesize", "maxlen"}, {input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); return -1, "File[%s] size %lld is out of limit: %d.", input_file.c_str(), file_length, kMaxFileSizeLimit); return static_cast(file_length); @@ -219,7 +222,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: if (ret != 0) { if (errno != EEXIST) { ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); - GELOGW("Cannot create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); + GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); return ret; } } @@ -230,7 +233,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: if (ret != 0) { if (errno != EEXIST) { ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); - GELOGW("Cannot create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); + GELOGW("Can not create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); return ret; } } @@ -258,16 +261,16 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch "incorrect parameter. nullptr == file || nullptr == message"); std::string real_path = RealPath(file); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), - ErrorManager::GetInstance().ATCReportErrMessage("E10036", {"filepath"}, {file}); - return false, "Get path[%s]'s real path failed", file); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), ErrorManager::GetInstance().ATCReportErrMessage( + "E19000", {"path", "errmsg"}, {file, strerror(errno)}); + return false, "Path[%s]'s realpath is empty, errmsg[%s]", file, strerror(errno)); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); std::ifstream fs(real_path.c_str(), std::ifstream::in); if (!fs.is_open()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10040", {"realpth", "protofile"}, {real_path, file}); + ErrorManager::GetInstance().ATCReportErrMessage("E19017", {"realpth", "protofile"}, {real_path, file}); GELOGE(ge::FAILED, "Fail to open proto file real path is '%s' when orginal file path is '%s'.", real_path.c_str(), file); return false; @@ -275,7 +278,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch google::protobuf::io::IstreamInputStream input(&fs); bool ret = google::protobuf::TextFormat::Parse(&input, message); - GE_IF_BOOL_EXEC(!ret, ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"protofile"}, {file}); + GE_IF_BOOL_EXEC(!ret, ErrorManager::GetInstance().ATCReportErrMessage("E19018", {"protofile"}, {file}); GELOGE(ret, "Parse file[%s] through [google::protobuf::TextFormat::Parse] failed, " "please check whether the file is a valid protobuf format file.", @@ -339,17 +342,13 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char strlen(path) >= PATH_MAX, ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(PATH_MAX)}); return "", "Path[%s] len is too long, it must be less than %d", path, PATH_MAX); - // PATH_MAX is the system's own macro, indicating the maximum file path length supported - std::shared_ptr resolved_path(new (std::nothrow) char[PATH_MAX](), std::default_delete()); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(resolved_path == nullptr, return "", "Path[%s] new string object len[%d] failed.", - path, PATH_MAX); // Nullptr is returned when the path does not exist or there is no permission // Return absolute path when path is accessible std::string res; - if (realpath(path, resolved_path.get()) != nullptr) { - res = resolved_path.get(); + char resolved_path[PATH_MAX] = {0}; + if (realpath(path, resolved_path) != nullptr) { + res = resolved_path; } return res; @@ -360,36 +359,34 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const // The specified path is empty std::map args_map; if (file_path.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {atc_param}); - GELOGW("Input parameter's value is empty."); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); + GELOGW("Input parameter %s is empty.", file_path.c_str()); return false; } std::string real_path = RealPath(file_path.c_str()); // Unable to get absolute path (does not exist or does not have permission to access) if (real_path.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)}); + ErrorManager::GetInstance().ATCReportErrMessage("E19000", {"path", "errmsg"}, {file_path, strerror(errno)}); GELOGW("Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno)); return false; } // A regular matching expression to verify the validity of the input file path // ^(/|./|(../)+|)([.]?[\u4e00-\u9fa5A-Za-z0-9_.-]+/)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$ - // Path section:Support upper and lower case letters, numbers dots(.) chinese and underscores - // File name section:Support upper and lower case letters, numbers, underscores chinese and dots(.) + // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores + // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) std::string mode = "^(/+|./+|(../+)+|)(../|([.]?[\u4e00-\u9fa5A-Za-z0-9_.-]+)/+)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$"; GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( !ValidateStr(real_path, mode), - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "path"}, {atc_param, real_path}); - return false, - "Input parameter[--%s]'s value[%s] is illegal. The path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' " - "and chinese character.", - atc_param.c_str(), real_path.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {atc_param, real_path, kPathValidReason}); + return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), real_path.c_str(), kPathValidReason); // The absolute path points to a file that is not readable if (access(real_path.c_str(), R_OK) != 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)}); - GELOGW("Read path[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno)); + ErrorManager::GetInstance().ATCReportErrMessage("E19003", {"file", "errmsg"}, {file_path.c_str(), strerror(errno)}); + GELOGW("Read file[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno)); return false; } @@ -400,34 +397,35 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const const std::string &atc_param) { // The specified path is empty if (file_path.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {atc_param}); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {atc_param}); GELOGW("Input parameter's value is empty."); return false; } + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + strlen(file_path.c_str()) >= PATH_MAX, ErrorManager::GetInstance().ATCReportErrMessage( + "E19002", {"filepath", "size"}, {file_path, std::to_string(PATH_MAX)}); + return "", "Path[%s] len is too long, it must be less than %d", file_path.c_str(), PATH_MAX); + + // A regular matching expression to verify the validity of the input file path + // ^(/|./|(../)+|)([.]?[\u4e00-\u9fa5A-Za-z0-9_-]+/)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$ + // Path section: Support upper and lower case letters, numbers dots(.) chinese and underscores + // File name section: Support upper and lower case letters, numbers, underscores chinese and dots(.) + std::string mode = "^(/+|./+|(../+)+|)(../|([.]?[\u4e00-\u9fa5A-Za-z0-9_.-]+)/+)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$"; + + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + !ValidateStr(file_path, mode), + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {atc_param, file_path, kPathValidReason}); + return false, "Invalid value for %s[%s], %s.", atc_param.c_str(), file_path.c_str(), kPathValidReason); + std::string real_path = RealPath(file_path.c_str()); // Can get absolute path (file exists) if (!real_path.empty()) { - // A regular matching expression to verify the validity of the input file path - // ^(/|./|(../)+|)([.]?[\u4e00-\u9fa5A-Za-z0-9_-]+/)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$ - // Path section:Support upper and lower case letters, numbers dots(.) chinese and underscores - // File name section:Support upper and lower case letters, numbers, underscores chinese and dots(.) - std::string mode = "^(/+|./+|(../+)+|)(../|([.]?[\u4e00-\u9fa5A-Za-z0-9_.-]+)/+)*[\u4e00-\u9fa5A-Za-z0-9_+.-]+$"; - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - !ValidateStr(real_path, mode), - ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "path"}, {atc_param, real_path}); - return false, - "Input parameter[--%s]'s value[%s] is illegal. The path can only contains 'a-z' 'A-Z' '0-9' '-' '.' '_' " - "and chinese character.", - atc_param.c_str(), real_path.c_str()); - // File is not readable or writable if (access(real_path.c_str(), W_OK | F_OK) != 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"realpath", "path", "errmsg"}, - {real_path, file_path, strerror(errno)}); - GELOGW("Write file[%s] failed, input path is %s, errmsg[%s]", real_path.c_str(), file_path.c_str(), - strerror(errno)); + ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"file", "errmsg"}, {real_path, strerror(errno)}); + GELOGW("Write file[%s] failed, errmsg[%s]", real_path.c_str(), strerror(errno)); return false; } } else { @@ -445,8 +443,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const std::string prefix_path = std::string(file_path).substr(0, static_cast(path_split_pos)); // Determine whether the specified path is valid by creating the path if (CreateDirectory(prefix_path) != 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"path"}, {file_path}); - GELOGW("Can not create prefix path for path[%s].", file_path.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {file_path}); + GELOGW("Can not create directory[%s].", file_path.c_str()); return false; } } @@ -456,17 +454,26 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const } FMK_FUNC_HOST_VISIBILITY bool ValidateStr(const std::string &str, const std::string &mode) { -#ifndef OS_CENTOS - std::regex reg(mode); + char ebuff[kMaxBuffSize]; + regex_t reg; + int cflags = REG_EXTENDED | REG_NOSUB; + int ret = regcomp(®, mode.c_str(), cflags); + if (ret) { + regerror(ret, ®, ebuff, kMaxBuffSize); + GELOGE(ge::PARAM_INVALID, "regcomp failed, reason: %s", ebuff); + regfree(®); + return false; + } - // Matching string part - std::smatch match; + ret = regexec(®, str.c_str(), 0, nullptr, 0); + if (ret) { + regerror(ret, ®, ebuff, kMaxBuffSize); + GELOGE(ge::PARAM_INVALID, "regexec failed, reason: %s", ebuff); + regfree(®); + return false; + } - bool res = regex_match(str, match, reg); - res = regex_search(str, std::regex("[`!@#$%^&*()|{}':;',\\[\\]<>?]")); - return !(res) && (str.size() == match.str().size()); -#else + regfree(®); return true; -#endif } } // namespace ge diff --git a/src/ge/engine_manager/dnnengine_manager.cc b/src/ge/engine_manager/dnnengine_manager.cc index c8843c09..ad36ebb5 100644 --- a/src/ge/engine_manager/dnnengine_manager.cc +++ b/src/ge/engine_manager/dnnengine_manager.cc @@ -24,6 +24,7 @@ #include "common/debug/log.h" #include "common/ge/ge_util.h" +#include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" #include "graph/ge_context.h" #include "init/gelib.h" @@ -161,6 +162,10 @@ bool DNNEngineManager::IsEngineRegistered(const std::string &name) { return false; } +void DNNEngineManager::InitPerformanceStaistic() { checksupport_cost_.clear(); } + +const map &DNNEngineManager::GetCheckSupportCost() const { return checksupport_cost_; } + std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GE_CLI_GE_NOT_INITIALIZED, "DNNEngineManager: op_desc is nullptr"); return ""); @@ -176,13 +181,12 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { GELOGI("DNNEngineManager: Can not get op info by op type %s", op_desc->GetType().c_str()); return ""; } - string ge_core_type; + std::string ge_core_type; Status ret = ge::GetContext().GetOption(ge::CORE_TYPE, ge_core_type); - if (ret != SUCCESS) { - GELOGD("get the option CORE_TYPE fail, set it to default value VECTOR_ENGINE"); - } - string exclude_core_Type = (ge_core_type == kVectorCore) ? kAIcoreEngine : kVectorEngine; + GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGD("get the option CORE_TYPE fail, set it to default value VECTOR_ENGINE")); + std::string exclude_core_Type = (ge_core_type == kVectorCore) ? kAIcoreEngine : kVectorEngine; GELOGD("engine type will exclude: %s", exclude_core_Type.c_str()); + std::map unsupported_reasons; for (const auto &it : op_infos) { if (it.engine == exclude_core_Type) { @@ -194,15 +198,20 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { if (kernel_info_store != kernel_map.end()) { std::string unsupported_reason; // It will be replaced by engine' checksupport + uint64_t start_time = GetCurrentTimestap(); if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { + checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time; op_desc->SetOpEngineName(it.engine); op_desc->SetOpKernelLibName(kernel_name); - GELOGD("DNNEngineManager:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), + GELOGD("DNNEngineManager:Set OpKernelLibName %s and engine name %s to op_desc %s", kernel_name.c_str(), it.engine.c_str(), op_desc->GetName().c_str()); return it.engine; } else { + checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time; bool is_custom_op = false; if ((ge::AttrUtils::GetBool(op_desc, kCustomOpFlag, is_custom_op)) && is_custom_op) { + ErrorManager::GetInstance().ATCReportErrMessage("E13001", {"kernelname", "optype", "opname"}, + {kernel_name, op_desc->GetType(), op_desc->GetName()}); GELOGE(FAILED, "The custom operator registered by the user does not support the logic function delivered by this " "network. Check support failed, kernel_name is %s, op type is %s, op name is %s", @@ -212,6 +221,9 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { unsupported_reasons.emplace(kernel_name, unsupported_reason); GELOGI("DNNEngineManager:Check support failed, kernel_name is %s, op type is %s, op name is %s", kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str()); + if (!op_desc->HasAttr("_is_ge_op")) { + ErrorManager::GetInstance().ATCReportErrMessage("W11001", {"opname"}, {op_desc->GetName()}); + } } } else { GELOGW( @@ -221,9 +233,13 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { } } for (const auto &it : unsupported_reasons) { + ErrorManager::GetInstance().ATCReportErrMessage("E13002", {"optype", "opskernel", "reason"}, + {op_desc->GetType(), it.first, it.second}); GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "GetDNNEngineName:Op type %s of ops kernel %s is unsupported, reason:%s", op_desc->GetType().c_str(), it.first.c_str(), it.second.c_str()); } + ErrorManager::GetInstance().ATCReportErrMessage("E13003", {"opname", "optype"}, + {op_desc->GetName(), op_desc->GetType()}); GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "Can't find any supported ops kernel and engine of %s, type is %s", op_desc->GetName().c_str(), op_desc->GetType().c_str()); return ""; @@ -357,7 +373,7 @@ Status DNNEngineManager::ParserEngineMessage(const json engines_json, const std: } Status DNNEngineManager::ReadJsonFile(const std::string &file_path, JsonHandle handle) { - GELOGI("Begin to read json file"); + GELOGD("Begin to read json file"); if (file_path.empty()) { GELOGE(FAILED, "Json path %s is not valid", file_path.c_str()); return FAILED; @@ -384,14 +400,20 @@ Status DNNEngineManager::ReadJsonFile(const std::string &file_path, JsonHandle h return FAILED; } - ifs >> *json_file; + try { + ifs >> *json_file; + } catch (const json::exception &e) { + GELOGE(FAILED, "Read json file failed"); + ifs.close(); + return FAILED; + } ifs.close(); - GELOGI("Read json file success"); + GELOGD("Read json file success"); return SUCCESS; } Status DNNEngineManager::CheckJsonFile() { - GELOGI("Begin to check json file"); + GELOGD("Begin to check json file"); for (auto &it : engines_map_) { std::string engine_name = it.first; int count = 0; @@ -411,7 +433,7 @@ Status DNNEngineManager::CheckJsonFile() { return FAILED; } } - GELOGI("Check json file success"); + GELOGD("Check json file success"); return SUCCESS; } } // namespace ge diff --git a/src/ge/engine_manager/dnnengine_manager.h b/src/ge/engine_manager/dnnengine_manager.h index ab813398..15628ecf 100644 --- a/src/ge/engine_manager/dnnengine_manager.h +++ b/src/ge/engine_manager/dnnengine_manager.h @@ -63,6 +63,8 @@ class DNNEngineManager { // If can't find appropriate engine name, return "", report error string GetDNNEngineName(const OpDescPtr &op_desc); const map &GetSchedulers() const; + const map &GetCheckSupportCost() const; + void InitPerformanceStaistic(); private: DNNEngineManager(); @@ -78,6 +80,7 @@ class DNNEngineManager { std::map engines_map_; std::map engines_attrs_map_; std::map schedulers_; + std::map checksupport_cost_; bool init_flag_; }; } // namespace ge diff --git a/src/ge/executor/CMakeLists.txt b/src/ge/executor/CMakeLists.txt index cddf25b7..1b0b8131 100755 --- a/src/ge/executor/CMakeLists.txt +++ b/src/ge/executor/CMakeLists.txt @@ -26,6 +26,7 @@ file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "ge_executor.cc" + "../common/ge/op_tiling_manager.cc" "../common/ge/plugin_manager.cc" "../common/profiling/profiling_manager.cc" "../graph/execute/graph_execute.cc" @@ -58,8 +59,8 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} "../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc" "../graph/load/new_model_manager/task_info/task_info.cc" "../graph/load/new_model_manager/tbe_handle_store.cc" + "../graph/load/new_model_manager/zero_copy_offset.cc" "../graph/load/new_model_manager/zero_copy_task.cc" - "../graph/load/output/output.cc" "../graph/manager/graph_caching_allocator.cc" "../graph/manager/graph_manager_utils.cc" "../graph/manager/graph_mem_allocator.cc" diff --git a/src/ge/executor/ge_executor.cc b/src/ge/executor/ge_executor.cc index b5a3b3cf..ee65faec 100644 --- a/src/ge/executor/ge_executor.cc +++ b/src/ge/executor/ge_executor.cc @@ -36,6 +36,9 @@ #include "mmpa/mmpa_api.h" #include "single_op/single_op_manager.h" +using std::string; +using std::vector; + namespace { const size_t kDynamicBatchSizeVecSize = 1; const size_t kStaticBatchInfoSize = 1; @@ -102,20 +105,36 @@ void SetDynamicInputDataFlag(const ge::RunModelData &input_data, const std::vect ge::InputData &inputs) { inputs.is_dynamic_batch = true; std::string batch_label; + size_t match_idx = 0; for (size_t i = 0; i < batch_info.size(); ++i) { - if (batch_info[i].size() == kDynamicBatchSizeVecSize && - batch_info[i][0] == static_cast(input_data.dynamic_batch_size)) { - batch_label = kBatchLabel + std::to_string(i); - inputs.batch_label = batch_label; + // dynamic_dims + if (input_data.dynamic_dims.size() != 0) { + bool is_match = true; + for (size_t j = 0; j < static_cast(input_data.dynamic_dims.size()); ++j) { + if (static_cast(batch_info[i][j]) != input_data.dynamic_dims[j]) { + is_match = false; + break; + } + } + if (is_match) { + match_idx = i; + break; + } + // dynamic_batch_size + } else if (batch_info[i].size() == kDynamicBatchSizeVecSize && + batch_info[i][0] == static_cast(input_data.dynamic_batch_size)) { + match_idx = i; break; + // dynamic_image_size } else if (batch_info[i].size() == kDynamicImageSizeVecSize && batch_info[i][0] == static_cast(input_data.dynamic_image_height) && batch_info[i][1] == static_cast(input_data.dynamic_image_width)) { - batch_label = kBatchLabel + std::to_string(i); - inputs.batch_label = batch_label; + match_idx = i; break; } } + batch_label = kBatchLabel + std::to_string(match_idx); + inputs.batch_label = batch_label; GELOGI("current batch label:%s", batch_label.c_str()); } @@ -225,39 +244,41 @@ Status GeExecutor::Finalize() { Status GeExecutor::SetDynamicBatchSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t batch_size) { if (dynamic_input_addr == nullptr) { - GELOGE(FAILED, "Dynamic input addr is nullptr!"); - return FAILED; + GELOGE(PARAM_INVALID, "Dynamic input addr is nullptr!"); + return PARAM_INVALID; } uint64_t size = sizeof(uint64_t); if (length < size) { - GELOGE(FAILED, "Dynamic input size [%lu] is less than [%lu]!", length, size); - return FAILED; + GELOGE(PARAM_INVALID, "Dynamic input size [%lu] is less than [%lu]!", length, size); + return PARAM_INVALID; } // Verify whether the input dynamic batch matches the model gear std::vector> batch_info; std::vector batch_num{batch_size}; - Status ret = GraphExecutor::GetDynamicBatchInfo(model_id, batch_info); + int32_t dynamic_type = static_cast(FIXED); + Status ret = GraphExecutor::GetDynamicBatchInfo(model_id, batch_info, dynamic_type); if (ret != SUCCESS) { - GELOGE(FAILED, "Get dynamic input info failed."); - return FAILED; + GELOGE(ret, "Get dynamic input info failed."); + return ret; } if (!IsDynamicBatchSizeMatchModel(batch_size, batch_info)) { - GELOGE(FAILED, "The current dynamic input does not match the gear of the model."); - return FAILED; + GELOGE(PARAM_INVALID, "The current dynamic input does not match the gear of the model."); + return PARAM_INVALID; } - ret = GraphExecutor::SetDynamicSize(model_id, batch_num); + ret = GraphExecutor::SetDynamicSize(model_id, batch_num, static_cast(DYNAMIC_BATCH)); if (ret != SUCCESS) { - GELOGE(FAILED, "Set dynamic size failed"); - return FAILED; + GELOGE(ret, "Set dynamic size failed"); + return ret; } // memcpy dynamic_batch_size from host to device - if (rtMemcpy(dynamic_input_addr, length, &batch_size, size, RT_MEMCPY_HOST_TO_DEVICE) != RT_ERROR_NONE) { - GELOGE(FAILED, "memcpy dynamic batch input data failed!"); - return FAILED; + rtError_t rt_ret = rtMemcpy(dynamic_input_addr, length, &batch_size, size, RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "memcpy dynamic batch input data failed! ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); } return SUCCESS; } @@ -265,40 +286,42 @@ Status GeExecutor::SetDynamicBatchSize(uint32_t model_id, void *dynamic_input_ad Status GeExecutor::SetDynamicImageSize(uint32_t model_id, void *dynamic_input_addr, uint64_t length, uint64_t image_height, uint64_t image_width) { if (dynamic_input_addr == nullptr) { - GELOGE(FAILED, "Dynamic input addr is nullptr!"); - return FAILED; + GELOGE(PARAM_INVALID, "Dynamic input addr is nullptr!"); + return PARAM_INVALID; } uint64_t dynamic_input_size = kDynamicImageSizeInputSize * sizeof(uint64_t); if (length < dynamic_input_size) { - GELOGE(FAILED, "Dynamic input size [%lu] is less than [%lu]!", length, dynamic_input_size); - return FAILED; + GELOGE(PARAM_INVALID, "Dynamic input size [%lu] is less than [%lu]!", length, dynamic_input_size); + return PARAM_INVALID; } // Verify whether the input dynamic resolution matches the model gear std::vector> batch_info; std::vector batch_num{image_height, image_width}; - Status ret = GraphExecutor::GetDynamicBatchInfo(model_id, batch_info); + int32_t dynamic_type = static_cast(FIXED); + Status ret = GraphExecutor::GetDynamicBatchInfo(model_id, batch_info, dynamic_type); if (ret != SUCCESS) { - GELOGE(FAILED, "Get dynamic input info failed."); - return FAILED; + GELOGE(ret, "Get dynamic input info failed."); + return ret; } if (!IsDynamicImageSizeMatchModel(image_height, image_width, batch_info)) { - GELOGE(FAILED, "The current dynamic input does not match the gear of the model."); - return FAILED; + GELOGE(PARAM_INVALID, "The current dynamic input does not match the gear of the model."); + return PARAM_INVALID; } - ret = GraphExecutor::SetDynamicSize(model_id, batch_num); + ret = GraphExecutor::SetDynamicSize(model_id, batch_num, static_cast(DYNAMIC_IMAGE)); if (ret != SUCCESS) { - GELOGE(FAILED, "Set dynamic size failed"); - return FAILED; + GELOGE(ret, "Set dynamic size failed"); + return ret; } // Memcpy dynamic resolution height from host to device - if (rtMemcpy(dynamic_input_addr, sizeof(uint64_t), &image_height, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE) != - RT_ERROR_NONE) { - GELOGE(FAILED, "memcpy dynamic resolution input data failed!"); - return FAILED; + rtError_t rt_ret = + rtMemcpy(dynamic_input_addr, sizeof(uint64_t), &image_height, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "memcpy dynamic resolution input data failed! ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); } uint64_t remain_size = length - sizeof(uint64_t); @@ -311,16 +334,109 @@ Status GeExecutor::SetDynamicImageSize(uint32_t model_id, void *dynamic_input_ad return SUCCESS; } -Status GeExecutor::GetCurShape(const uint32_t model_id, std::vector &batch_info) { +Status GeExecutor::SetDynamicDims(uint32_t model_id, void *dynamic_input_addr, uint64_t length, + const vector &dynamic_dims) { + if (dynamic_input_addr == nullptr) { + GELOGE(FAILED, "Dynamic input addr is nullptr!"); + return FAILED; + } + + Status ret = GraphExecutor::SetDynamicSize(model_id, dynamic_dims, static_cast(DYNAMIC_DIMS)); + if (ret != SUCCESS) { + GELOGE(FAILED, "Set dynamic size failed"); + return FAILED; + } + + vector cur_dynamic_dims; + if (GetCurDynamicDims(model_id, dynamic_dims, cur_dynamic_dims) != SUCCESS) { + GELOGE(FAILED, "GetCurDynamicDims failed."); + return FAILED; + } + + size_t dynamic_dim_num = cur_dynamic_dims.size(); + uint64_t dynamic_input_size = static_cast(dynamic_dim_num * sizeof(uint64_t)); + if (length < dynamic_input_size) { + GELOGE(FAILED, "Dynamic input size [%lu] is less than [%lu]!", length, dynamic_input_size); + return FAILED; + } + + for (uint32_t i = 0; i < dynamic_dim_num; ++i) { + // Memcpy dynamic dim[i] from host to device + if (rtMemcpy(reinterpret_cast(reinterpret_cast(dynamic_input_addr) + sizeof(uint64_t) * i), + length - sizeof(uint64_t) * i, &cur_dynamic_dims[i], sizeof(uint64_t), + RT_MEMCPY_HOST_TO_DEVICE) != RT_ERROR_NONE) { + GELOGE(FAILED, "memcpy dynamic resolution input data failed!"); + return FAILED; + } + } + return SUCCESS; +} + +Status GeExecutor::GetCurDynamicDims(uint32_t model_id, const vector &combined_dims, + vector &cur_dynamic_dims) { + vector> combined_batch; + if (GraphExecutor::GetCombinedDynamicDims(model_id, combined_batch) != SUCCESS) { + GELOGE(FAILED, "Get combined dynamic dims info failed."); + return FAILED; + } + if (combined_batch.empty()) { + GELOGE(FAILED, "Combined dynamic dims is empty."); + return FAILED; + } + + if (combined_dims.size() != combined_batch[0].size()) { + GELOGE(FAILED, "Input dynamic dims's dimension size[%zu] is different from model[%zu].", combined_dims.size(), + combined_batch[0].size()); + return FAILED; + } + bool matched = false; + size_t idx = 0; + for (size_t i = 0; i < combined_batch.size(); i++) { + bool is_match = true; + for (size_t j = 0; j < combined_dims.size(); j++) { + if (combined_dims[j] != static_cast(combined_batch[i][j])) { + is_match = false; + break; + } + } + if (is_match) { + idx = i; + matched = true; + break; + } + } + + if (!matched) { + GELOGE(FAILED, "Input dynamic dims can not match model."); + return FAILED; + } + + // batch_info save the dynamic info of combined_dims + vector> batch_info; + int32_t dynamic_type = static_cast(FIXED); + if (GraphExecutor::GetDynamicBatchInfo(model_id, batch_info, dynamic_type) != SUCCESS) { + GELOGE(FAILED, "Get dynamic input info failed."); + return FAILED; + } + + cur_dynamic_dims.clear(); + for (size_t i = 0; i < batch_info[idx].size(); i++) { + cur_dynamic_dims.emplace_back(static_cast(batch_info[idx][i])); + } + + return SUCCESS; +} + +Status GeExecutor::GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type) { GELOGI("Begin to get current shape"); if (!isInit_) { GELOGE(GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); return GE_EXEC_NOT_INIT; } - Status ret = GraphExecutor::GetCurShape(model_id, batch_info); + Status ret = GraphExecutor::GetCurShape(model_id, batch_info, dynamic_type); if (ret != SUCCESS) { - GELOGE(FAILED, "Get current shape failed"); - return FAILED; + GELOGE(ret, "Get current shape failed"); + return ret; } return SUCCESS; } @@ -330,12 +446,12 @@ Status GeExecutor::SetDynamicAippData(uint32_t model_id, void *dynamic_input_add const kAippDynamicPara &aippParms) { GELOGI("Enter to SetDynamicAippData."); if (dynamic_input_addr == nullptr) { - GELOGE(FAILED, "Dynamic aipp input addr is nullptr!"); - return FAILED; + GELOGE(PARAM_INVALID, "Dynamic aipp input addr is nullptr!"); + return PARAM_INVALID; } if (aippBatchPara.empty()) { - GELOGE(FAILED, "aippBatchPara is empty."); - return FAILED; + GELOGE(PARAM_INVALID, "aippBatchPara is empty."); + return PARAM_INVALID; } uint64_t batch_num = aippBatchPara.size(); uint64_t real_aippParms_size = sizeof(kAippDynamicPara) - sizeof(kAippDynamicBatchPara); @@ -345,24 +461,25 @@ Status GeExecutor::SetDynamicAippData(uint32_t model_id, void *dynamic_input_add "batch num is %lu, struct_len is %lu", model_id, length, batch_num, struct_len); if (struct_len > length) { - GELOGE(FAILED, "input dynamic aipp param len [%lu] is larger than aipp_data size [%lu]", struct_len, length); - return FAILED; + GELOGE(PARAM_INVALID, "input dynamic aipp param len [%lu] is larger than aipp_data size [%lu]", struct_len, length); + return PARAM_INVALID; } // Memcpy real kAippDynamicBatchPara from host to device - if (rtMemcpy(dynamic_input_addr, length, &aippParms, real_aippParms_size, RT_MEMCPY_HOST_TO_DEVICE) != - RT_ERROR_NONE) { - GELOGE(FAILED, "memcpy real_aippParms_size failed!"); - return FAILED; + rtError_t rt_ret = rtMemcpy(dynamic_input_addr, length, &aippParms, real_aippParms_size, RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "memcpy real_aippParms_size failed! ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); } uint64_t remain_len = length - real_aippParms_size; uint8_t *aipp_batch_para_dev = reinterpret_cast(dynamic_input_addr) + real_aippParms_size; for (uint64_t i = 0; i < batch_num; ++i) { - if (rtMemcpy(reinterpret_cast(aipp_batch_para_dev + i * sizeof(kAippDynamicBatchPara)), - (remain_len - i * sizeof(kAippDynamicBatchPara)), &(aippBatchPara[i]), sizeof(kAippDynamicBatchPara), - RT_MEMCPY_HOST_TO_DEVICE) != RT_ERROR_NONE) { - GELOGE(FAILED, "memcpy kAippDynamicBatchPara input data failed!"); - return FAILED; + rt_ret = rtMemcpy(reinterpret_cast(aipp_batch_para_dev + i * sizeof(kAippDynamicBatchPara)), + (remain_len - i * sizeof(kAippDynamicBatchPara)), &(aippBatchPara[i]), + sizeof(kAippDynamicBatchPara), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "memcpy kAippDynamicBatchPara input data failed! ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); } } return SUCCESS; @@ -429,7 +546,7 @@ Status GeExecutor::UnloadModel(uint32_t model_id) { } Status ret = GraphLoader::DestroyAicpuSessionForInfer(model_id); if (ret != SUCCESS) { - GELOGE(ret, "[GraphLoader] DestroyAicpuSessionForInfer failed."); + GELOGE(ret, "[GraphLoader] DestroyAicpuSessionForInfer failed. model id: %u", model_id); return FAILED; } return GraphLoader::UnloadModel(model_id); @@ -468,17 +585,19 @@ Status GeExecutor::GetModelDescInfo(uint32_t model_id, std::vector> &batch_info) { +Status GeExecutor::GetDynamicBatchInfo(uint32_t model_id, std::vector> &batch_info, + int32_t &dynamic_type) { GELOGI("Begin to get dynamic batch info."); if (!isInit_) { GELOGE(GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); return GE_EXEC_NOT_INIT; } - Status ret = GraphExecutor::GetDynamicBatchInfo(model_id, batch_info); + Status ret = GraphExecutor::GetDynamicBatchInfo(model_id, batch_info, dynamic_type); if (ret != SUCCESS) { GELOGE(ret, "GetDynamicBatchInfo failed."); return ret; @@ -513,6 +634,30 @@ Status GeExecutor::GetDynamicBatchInfo(uint32_t model_id, std::vector> &batch_info) { + GELOGI("Begin to get combined dynamic dims info."); + if (!isInit_) { + GELOGE(GE_EXEC_NOT_INIT, "GeExecutor has not been initialized!"); + return GE_EXEC_NOT_INIT; + } + + Status ret = GraphExecutor::GetCombinedDynamicDims(model_id, batch_info); + if (ret != SUCCESS) { + GELOGE(ret, "GetCombinedDynamicDims failed."); + return ret; + } + + GELOGI("Get combined dynamic dims succ."); + return SUCCESS; +} + /// /// @ingroup ge /// @brief Get AIPP input format @@ -628,8 +773,8 @@ Status GeExecutor::LoadDataFromFile(const std::string &path, ModelData &model_da string filePath = RealPath(path.c_str()); if (filePath.empty()) { - GELOGE(ge::FAILED, "File path is invalid. please check your text file '%s'.", path.c_str()); - return ge::FAILED; + GELOGE(GE_EXEC_MODEL_PATH_INVALID, "File path is invalid. please check your text file '%s'.", path.c_str()); + return GE_EXEC_MODEL_PATH_INVALID; } GELOGI("load modelData from file: %s.", path.c_str()); std::string key_path; @@ -710,12 +855,20 @@ Status GeExecutor::ExecModel(uint32_t model_id, void *stream, const ge::RunModel GetDomiOutputData(run_output_data, output_data); if ((run_input_data.dynamic_batch_size != 0) || (run_input_data.dynamic_image_width != 0) || - (run_input_data.dynamic_image_height != 0)) { + (run_input_data.dynamic_image_height != 0) || (run_input_data.dynamic_dims.size() != 0)) { std::vector> batch_info; - Status ret = GraphExecutor::GetDynamicBatchInfo(model_id, batch_info); + int32_t dynamic_type = static_cast(FIXED); + Status ret = GraphExecutor::GetDynamicBatchInfo(model_id, batch_info, dynamic_type); if (ret != SUCCESS) { - GELOGE(FAILED, "Get dynamic input info failed."); - return FAILED; + GELOGE(ret, "Get dynamic input info failed."); + return ret; + } + if (dynamic_type == static_cast(DYNAMIC_DIMS)) { + ret = GraphExecutor::GetCombinedDynamicDims(model_id, batch_info); + if (ret != SUCCESS) { + GELOGE(FAILED, "Get dynamic input info failed."); + return FAILED; + } } if (!batch_info.empty()) { SetDynamicInputDataFlag(run_input_data, batch_info, input_data); @@ -790,6 +943,11 @@ Status GeExecutor::LoadSingleOp(const std::string &model_name, const ge::ModelDa return SingleOpManager::GetInstance().GetOpFromModel(model_name, modelData, stream, single_op); } +Status GeExecutor::LoadDynamicSingleOp(const std::string &model_name, const ge::ModelData &modelData, void *stream, + DynamicSingleOp **single_op) { + return SingleOpManager::GetInstance().GetDynamicOpFromModel(model_name, modelData, stream, single_op); +} + Status GeExecutor::ExecuteAsync(SingleOp *executor, const std::vector &inputs, std::vector &outputs) { if (executor == nullptr) { @@ -800,13 +958,21 @@ Status GeExecutor::ExecuteAsync(SingleOp *executor, const std::vectorExecuteAsync(inputs, outputs); } +ge::Status GeExecutor::ExecuteAsync(DynamicSingleOp *executor, const vector &input_desc, + const vector &inputs, vector &output_desc, + vector &outputs) { + GE_CHECK_NOTNULL(executor); + return executor->ExecuteAsync(input_desc, inputs, output_desc, outputs); +} + Status GeExecutor::ReleaseSingleOpResource(void *stream) { return SingleOpManager::GetInstance().ReleaseResource(stream); } Status GeExecutor::GetBatchInfoSize(uint32_t model_id, size_t &shape_count) { std::vector> batch_info; - Status ret = GetDynamicBatchInfo(model_id, batch_info); + int32_t dynamic_type = static_cast(FIXED); + Status ret = GetDynamicBatchInfo(model_id, batch_info, dynamic_type); if (ret != SUCCESS) { GELOGE(ret, "Calc batch info size failed. ret = %d", ret); return ret; @@ -854,5 +1020,4 @@ Status GeExecutor::GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, GELOGI("GetAllAippInputOutputDims succ."); return SUCCESS; } - } // namespace ge diff --git a/src/ge/executor/module.mk b/src/ge/executor/module.mk index efed8854..b19f3c24 100644 --- a/src/ge/executor/module.mk +++ b/src/ge/executor/module.mk @@ -4,6 +4,7 @@ local_ge_executor_src_files := \ ge_executor.cc \ ../common/profiling/profiling_manager.cc \ ../common/ge/plugin_manager.cc \ + ../common/ge/op_tiling_manager.cc \ ../graph/load/graph_loader.cc \ ../graph/execute/graph_execute.cc \ ../omm/csa_interact.cc \ @@ -25,6 +26,7 @@ local_ge_executor_src_files := \ ../graph/load/new_model_manager/data_inputer.cc \ ../graph/load/new_model_manager/data_dumper.cc \ ../graph/load/new_model_manager/zero_copy_task.cc \ + ../graph/load/new_model_manager/zero_copy_offset.cc \ ../graph/load/new_model_manager/task_info/task_info.cc \ ../graph/load/new_model_manager/task_info/event_record_task_info.cc \ ../graph/load/new_model_manager/task_info/event_wait_task_info.cc \ @@ -44,7 +46,6 @@ local_ge_executor_src_files := \ ../graph/load/new_model_manager/task_info/end_graph_task_info.cc \ ../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ ../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ - ../graph/load/output/output.cc \ ../single_op/single_op_manager.cc \ ../single_op/single_op_model.cc \ ../single_op/single_op.cc \ @@ -53,6 +54,7 @@ local_ge_executor_src_files := \ ../single_op/task/build_task_utils.cc \ ../single_op/task/tbe_task_builder.cc \ ../single_op/task/aicpu_task_builder.cc \ + ../single_op/task/aicpu_kernel_task_builder.cc \ ../hybrid/hybrid_davinci_model_stub.cc\ local_ge_executor_c_include := \ @@ -78,6 +80,7 @@ local_ge_executor_shared_library := \ libslog \ libmmpa \ libgraph \ + libregister \ libmsprof \ local_ge_executor_ldflags := -lrt -ldl \ @@ -127,6 +130,7 @@ LOCAL_SHARED_LIBRARIES := \ libslog \ libmmpa \ libgraph \ + libregister \ libmsprof \ LOCAL_LDFLAGS += $(local_ge_executor_ldflags) @@ -152,6 +156,7 @@ LOCAL_C_INCLUDES := $(local_ge_executor_c_include) LOCAL_STATIC_LIBRARIES := \ libge_common \ libgraph \ + libregister \ libprotobuf \ LOCAL_SHARED_LIBRARIES := \ @@ -183,6 +188,7 @@ LOCAL_C_INCLUDES := $(local_ge_executor_c_include) LOCAL_STATIC_LIBRARIES := \ libge_common \ libgraph \ + libregister \ libprotobuf \ LOCAL_SHARED_LIBRARIES := \ diff --git a/src/ge/ge_inference.mk b/src/ge/ge_inference.mk index 2b26b214..e3e1e10c 100644 --- a/src/ge/ge_inference.mk +++ b/src/ge/ge_inference.mk @@ -32,6 +32,7 @@ COMMON_LOCAL_SRC_FILES := \ GRAPH_MANAGER_LOCAL_SRC_FILES := \ common/ge/plugin_manager.cc\ + common/ge/op_tiling_manager.cc\ init/gelib.cc \ session/inner_session.cc \ session/session_manager.cc \ @@ -45,6 +46,7 @@ GRAPH_MANAGER_LOCAL_SRC_FILES := \ graph/execute/graph_execute.cc \ graph/load/graph_loader.cc \ graph/optimize/graph_optimize.cc \ + graph/optimize/mem_rw_conflict_optimize.cc \ graph/optimize/summary_optimize.cc \ graph/build/graph_builder.cc \ graph/partition/engine_place.cc \ @@ -69,6 +71,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/resource_pair_remove_control_pass.cc \ graph/passes/pass_utils.cc \ graph/passes/base_pass.cc \ + graph/passes/bitcast_pass.cc \ graph/passes/constant_folding_pass.cc \ graph/passes/aicpu_constant_folding_pass.cc \ graph/passes/reshape_remove_pass.cc \ @@ -90,7 +93,10 @@ OMG_HOST_SRC_FILES := \ graph/passes/print_op_pass.cc \ graph/passes/no_use_reshape_remove_pass.cc \ graph/passes/iterator_op_pass.cc \ + graph/passes/input_output_connection_identify_pass.cc \ graph/passes/atomic_addr_clean_pass.cc \ + graph/passes/mark_same_addr_pass.cc \ + graph/passes/mark_graph_unknown_status_pass.cc \ graph/common/omg_util.cc \ graph/common/bcast.cc \ graph/passes/dimension_compute_pass.cc \ @@ -105,6 +111,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/isolated_op_remove_pass.cc \ graph/passes/permute_pass.cc \ graph/passes/ctrl_edge_transfer_pass.cc \ + graph/passes/end_of_sequence_add_control_pass.cc \ host_kernels/broadcast_gradient_args_kernel.cc \ host_kernels/greater_kernel.cc \ host_kernels/gather_v2_kernel.cc \ @@ -145,6 +152,7 @@ OMG_HOST_SRC_FILES := \ graph/passes/stop_gradient_pass.cc \ graph/passes/prevent_gradient_pass.cc \ graph/passes/identity_pass.cc \ + graph/passes/ref_identity_delete_op_pass.cc \ graph/passes/placeholder_with_default_pass.cc \ graph/passes/snapshot_pass.cc \ graph/passes/guarantee_const_pass.cc \ @@ -153,7 +161,9 @@ OMG_HOST_SRC_FILES := \ graph/passes/folding_pass.cc \ graph/passes/cast_translate_pass.cc \ graph/passes/prune_pass.cc \ - graph/passes/switch_op_pass.cc \ + graph/passes/merge_to_stream_merge_pass.cc \ + graph/passes/switch_to_stream_switch_pass.cc \ + graph/passes/attach_stream_label_pass.cc \ graph/passes/multi_batch_pass.cc \ graph/passes/next_iteration_pass.cc \ graph/passes/control_trigger_pass.cc \ @@ -173,7 +183,6 @@ OMG_HOST_SRC_FILES := \ graph/passes/variable_op_pass.cc \ graph/passes/cast_remove_pass.cc \ graph/passes/transpose_transdata_pass.cc \ - graph/passes/identify_reference_pass.cc \ graph/passes/hccl_memcpy_pass.cc \ graph/passes/flow_ctrl_pass.cc \ graph/passes/link_gen_mask_nodes_pass.cc \ @@ -181,6 +190,8 @@ OMG_HOST_SRC_FILES := \ graph/passes/hccl_group_pass.cc \ graph/passes/switch_fusion_pass.cc \ graph/passes/switch_split_pass.cc \ + graph/passes/memcpy_addr_async_pass.cc \ + graph/passes/set_input_output_offset_pass.cc \ OMG_DEVICE_SRC_FILES := $(OMG_HOST_SRC_FILES) @@ -199,7 +210,7 @@ OME_HOST_SRC_FILES := \ graph/load/new_model_manager/tbe_handle_store.cc \ graph/load/new_model_manager/cpu_queue_schedule.cc \ graph/load/new_model_manager/zero_copy_task.cc \ - graph/load/output/output.cc \ + graph/load/new_model_manager/zero_copy_offset.cc \ graph/load/new_model_manager/data_dumper.cc \ graph/load/new_model_manager/task_info/task_info.cc \ graph/load/new_model_manager/task_info/event_record_task_info.cc \ @@ -224,6 +235,7 @@ OME_HOST_SRC_FILES := \ single_op/task/build_task_utils.cc \ single_op/task/tbe_task_builder.cc \ single_op/task/aicpu_task_builder.cc \ + single_op/task/aicpu_kernel_task_builder.cc \ single_op/single_op.cc \ single_op/single_op_model.cc \ single_op/stream_resource.cc \ @@ -368,7 +380,7 @@ endif LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) -LOCAL_SRC_FILES := ../../out/atc/lib64/stub/ge_ir_build.cc +LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_ir_build.cc LOCAL_SHARED_LIBRARIES := diff --git a/src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc b/src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc index 0f33ae2a..badca5a3 100644 --- a/src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc +++ b/src/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc @@ -61,5 +61,6 @@ REGISTER_OP_CREATOR(SwitchN, GeDeletedOp); REGISTER_OP_CREATOR(RefMerge, GeDeletedOp); REGISTER_OP_CREATOR(RefSwitch, GeDeletedOp); REGISTER_OP_CREATOR(TransShape, GeDeletedOp); +REGISTER_OP_CREATOR(Bitcast, GeDeletedOp); } // namespace ge_local } // namespace ge diff --git a/src/ge/ge_runner.mk b/src/ge/ge_runner.mk index a9cfdd82..a3119b50 100644 --- a/src/ge/ge_runner.mk +++ b/src/ge/ge_runner.mk @@ -23,6 +23,7 @@ LIBGE_LOCAL_SRC_FILES := \ common/formats/utils/formats_trans_utils.cc \ common/fp16_t.cc \ common/ge/plugin_manager.cc\ + common/ge/op_tiling_manager.cc\ common/helper/model_cache_helper.cc \ common/profiling/profiling_manager.cc \ engine_manager/dnnengine_manager.cc \ @@ -77,7 +78,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/load/new_model_manager/task_info/task_info.cc \ graph/load/new_model_manager/tbe_handle_store.cc \ graph/load/new_model_manager/zero_copy_task.cc \ - graph/load/output/output.cc \ + graph/load/new_model_manager/zero_copy_offset.cc \ graph/manager/graph_context.cc \ graph/manager/graph_manager.cc \ graph/manager/graph_manager_utils.cc \ @@ -91,6 +92,7 @@ LIBGE_LOCAL_SRC_FILES := \ graph/manager/util/rt_context_util.cc \ graph/manager/util/variable_accelerate_ctrl.cc \ graph/optimize/graph_optimize.cc \ + graph/optimize/mem_rw_conflict_optimize.cc \ graph/optimize/optimizer/allreduce_fusion_pass.cc \ graph/optimize/summary_optimize.cc \ graph/partition/engine_place.cc \ @@ -98,9 +100,13 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/addn_pass.cc \ graph/passes/aicpu_constant_folding_pass.cc \ graph/passes/assert_pass.cc \ + graph/passes/input_output_connection_identify_pass.cc \ graph/passes/atomic_addr_clean_pass.cc \ + graph/passes/mark_same_addr_pass.cc \ + graph/passes/mark_graph_unknown_status_pass.cc \ graph/partition/dynamic_shape_partition.cc \ graph/passes/base_pass.cc \ + graph/passes/bitcast_pass.cc \ graph/passes/cast_remove_pass.cc \ graph/passes/cast_translate_pass.cc \ graph/passes/common_subexpression_elimination_pass.cc \ @@ -158,8 +164,8 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/get_original_format_pass.cc \ graph/passes/guarantee_const_pass.cc \ graph/passes/hccl_memcpy_pass.cc \ - graph/passes/identify_reference_pass.cc \ graph/passes/identity_pass.cc \ + graph/passes/ref_identity_delete_op_pass.cc \ graph/passes/infershape_pass.cc \ graph/passes/isolated_op_remove_pass.cc \ graph/passes/iterator_op_pass.cc \ @@ -191,7 +197,9 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/data_pass.cc \ graph/passes/switch_data_edges_bypass.cc \ graph/passes/switch_logic_remove_pass.cc \ - graph/passes/switch_op_pass.cc \ + graph/passes/merge_to_stream_merge_pass.cc \ + graph/passes/switch_to_stream_switch_pass.cc \ + graph/passes/attach_stream_label_pass.cc \ graph/passes/switch_dead_branch_elimination.cc \ graph/passes/replace_transshape_pass.cc \ graph/passes/transop_breadth_fusion_pass.cc \ @@ -211,6 +219,9 @@ LIBGE_LOCAL_SRC_FILES := \ graph/passes/variable_prepare_op_pass.cc \ graph/passes/variable_ref_delete_op_pass.cc \ graph/passes/variable_ref_useless_control_out_delete_pass.cc \ + graph/passes/end_of_sequence_add_control_pass.cc \ + graph/passes/memcpy_addr_async_pass.cc \ + graph/passes/set_input_output_offset_pass.cc \ graph/preprocess/graph_preprocess.cc \ graph/preprocess/insert_op/ge_aipp_op.cc \ graph/preprocess/insert_op/util_insert_aipp_op.cc \ @@ -230,6 +241,7 @@ LIBGE_LOCAL_SRC_FILES := \ single_op/task/op_task.cc \ single_op/task/tbe_task_builder.cc \ single_op/task/aicpu_task_builder.cc \ + single_op/task/aicpu_kernel_task_builder.cc \ hybrid/common/tensor_value.cc \ hybrid/common/npu_memory_allocator.cc \ hybrid/executor/rt_callback_manager.cc \ @@ -239,12 +251,15 @@ LIBGE_LOCAL_SRC_FILES := \ hybrid/executor/hybrid_model_executor.cc \ hybrid/executor/hybrid_model_async_executor.cc \ hybrid/executor/hybrid_execution_context.cc \ + hybrid/executor/subgraph_context.cc \ + hybrid/executor/subgraph_executor.cc \ hybrid/executor/worker/task_compile_engine.cc \ hybrid/executor/worker/shape_inference_engine.cc \ hybrid/executor/worker/execution_engine.cc \ hybrid/model/hybrid_model.cc \ hybrid/model/hybrid_model_builder.cc \ hybrid/model/node_item.cc \ + hybrid/model/graph_item.cc \ hybrid/node_executor/aicore/aicore_node_executor.cc \ hybrid/node_executor/aicore/aicore_op_task.cc \ hybrid/node_executor/aicore/aicore_task_builder.cc \ @@ -253,6 +268,9 @@ LIBGE_LOCAL_SRC_FILES := \ hybrid/node_executor/aicpu/aicpu_node_executor.cc \ hybrid/node_executor/compiledsubgraph/known_node_executor.cc \ hybrid/node_executor/hostcpu/ge_local_node_executor.cc \ + hybrid/node_executor/controlop/control_op_executor.cc \ + hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc \ + hybrid/node_executor/hccl/hccl_node_executor.cc \ hybrid/node_executor/node_executor.cc \ hybrid/node_executor/task_context.cc \ hybrid/hybrid_davinci_model.cc \ @@ -338,6 +356,28 @@ LOCAL_SHARED_LIBRARIES += \ include $(BUILD_HOST_SHARED_LIBRARY) +#compiler for GeRunner +include $(CLEAR_VARS) + +LOCAL_MODULE := stub/libge_runner + +LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DREUSE_MEMORY=1 -O2 +LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -DDAVINCI_SUPPORT_PROFILING -DDAVINCI_CLOUD +ifeq ($(DEBUG), 1) +LOCAL_CFLAGS += -g -O0 +endif + + +LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES) + +LOCAL_SRC_FILES := ../../out/ge/lib64/stub/ge_api.cc + + +LOCAL_SHARED_LIBRARIES := + +LOCAL_LDFLAGS := -lrt -ldl + +include $(BUILD_HOST_SHARED_LIBRARY) # add engine_conf.json to host include $(CLEAR_VARS) @@ -407,6 +447,7 @@ LOCAL_CFLAGS += -DFMK_SUPPORT_DUMP -DDAVINCI_SUPPORT_PROFILING -DDAVINCI_CLOUD LOCAL_CFLAGS += -g -O0 LOCAL_C_INCLUDES := $(RUNNER_LOCAL_C_INCLUDES) + LOCAL_SRC_FILES := $(LIBGE_LOCAL_SRC_FILES) LOCAL_SRC_FILES += $(LIBCLIENT_LOCAL_SRC_FILES) diff --git a/src/ge/ge_runtime/model_runner.cc b/src/ge/ge_runtime/model_runner.cc index 59952e39..9961ab4e 100644 --- a/src/ge/ge_runtime/model_runner.cc +++ b/src/ge/ge_runtime/model_runner.cc @@ -49,6 +49,24 @@ bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint return true; } +bool ModelRunner::DistributeTask(uint32_t model_id) { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); + return false; + } + return model_iter->second->DistributeTask(); +} + +bool ModelRunner::LoadModelComplete(uint32_t model_id) { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); + return false; + } + return model_iter->second->LoadComplete(); +} + const std::vector &ModelRunner::GetTaskIdList(uint32_t model_id) const { auto model_iter = runtime_models_.find(model_id); if (model_iter == runtime_models_.end()) { @@ -60,6 +78,38 @@ const std::vector &ModelRunner::GetTaskIdList(uint32_t model_id) const return model_iter->second->GetTaskIdList(); } +const std::vector &ModelRunner::GetStreamIdList(uint32_t model_id) const { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + GELOGE(PARAM_INVALID, "Model id %u not found.", model_id); + static const std::vector empty_ret; + return empty_ret; + } + + return model_iter->second->GetStreamIdList(); +} + +const std::map> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + GELOGW("Model id %u not found.", model_id); + static const std::map> empty_ret; + return empty_ret; + } + + return model_iter->second->GetRuntimeInfoMap(); +} + +void *ModelRunner::GetModelHandle(uint32_t model_id) const { + auto model_iter = runtime_models_.find(model_id); + if (model_iter == runtime_models_.end()) { + GELOGW("Model id %u not found.", model_id); + return nullptr; + } + + return model_iter->second->GetModelHandle(); +} + bool ModelRunner::UnloadModel(uint32_t model_id) { auto iter = runtime_models_.find(model_id); if (iter != runtime_models_.end()) { diff --git a/src/ge/ge_runtime/output.cc b/src/ge/ge_runtime/output.cc index 90c33bb4..5153f688 100644 --- a/src/ge/ge_runtime/output.cc +++ b/src/ge/ge_runtime/output.cc @@ -76,7 +76,7 @@ bool Output::CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_inde DataBuffer data_buf = rslt->blobs[data_begin + data_count]; bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share); if (!ret) { - GELOGE(FAILED, "Copy data to host failed. index: %lu, addr: %p", i, v_input_data_addr_[i]); + GELOGE(FAILED, "Copy data to host error. index: %lu, addr: %p", i, v_input_data_addr_[i]); return ret; } data_index = data_begin + data_count; diff --git a/src/ge/ge_runtime/runtime_model.cc b/src/ge/ge_runtime/runtime_model.cc index c89ced91..f0405056 100644 --- a/src/ge/ge_runtime/runtime_model.cc +++ b/src/ge/ge_runtime/runtime_model.cc @@ -28,7 +28,6 @@ namespace ge { namespace model_runner { - RuntimeModel::~RuntimeModel() { GELOGI("RuntimeModel destructor start"); @@ -116,23 +115,34 @@ bool RuntimeModel::InitEvent(uint32_t event_num) { return true; } -bool RuntimeModel::InitLabel(uint32_t batch_num) { - GELOGI("batch number:%u.", batch_num); - for (uint32_t i = 0; (batch_num != 0 && i <= batch_num); ++i) { - rtLabel_t rt_lLabel = nullptr; - rtError_t rt_ret = rtLabelCreate(&rt_lLabel); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, i; %u; ret: 0x%X", i, rt_ret); - return false; +bool RuntimeModel::InitLabel(std::shared_ptr &davinci_model) { + GELOGI("batch number:%u.", davinci_model->GetBatchNum()); + label_list_.resize(davinci_model->GetBatchNum()); + for (auto &task_info : davinci_model->GetTaskInfoList()) { + if (task_info == nullptr) { + GELOGE(PARAM_INVALID, "task_info is null."); + continue; } - if (rt_lLabel == nullptr) { - GELOGE(RT_FAILED, "rtLabel is nullptr!"); + if (task_info->type() != TaskInfoType::LABEL_SET) { + continue; + } + auto label_set_task_info = std::static_pointer_cast(task_info); + + if (label_set_task_info->stream_id() >= stream_list_.size()) { + GELOGE(PARAM_INVALID, "Invalid stream id."); return false; } - label_list_.emplace_back(rt_lLabel); + rtLabel_t rt_label = nullptr; + rtError_t rt_ret = rtLabelCreateEx(&rt_label, stream_list_[label_set_task_info->stream_id()]); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api rtLabelCreate failed, ret: 0x%X", rt_ret); + return false; + } + label_list_[label_set_task_info->label_id()] = rt_label; } + return true; } @@ -164,7 +174,7 @@ bool RuntimeModel::InitResource(std::shared_ptr &davinci_model) { return false; } - if (!InitLabel(davinci_model->GetBatchNum())) { + if (!InitLabel(davinci_model)) { return false; } @@ -209,20 +219,41 @@ bool RuntimeModel::LoadTask() { return false; } task_id_list_.push_back(task_id); + stream_id_list_.push_back(stream_id); + if (task->Args() != nullptr) { + std::shared_ptr runtime_tuple = nullptr; + GE_MAKE_SHARED(runtime_tuple = std::make_shared(task_id, stream_id, task->Args()), return false); + auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple); + if (!emplace_ret.second) { + GELOGW("Task name exist:%s", task->task_name().c_str()); + } + } } if (task_list_.empty()) { GELOGE(FAILED, "Task list is empty"); return false; } - GELOGI("Distribute task succ."); - auto rt_ret = rtModelLoadComplete(rt_model_handle_); + GELOGI("LoadTask succ."); + return true; +} + +bool RuntimeModel::LoadComplete() { + uint32_t task_id = 0; + uint32_t stream_id = 0; + auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtModelGetTaskId failed, ret:0x%X", rt_ret); + return RT_FAILED; + } + task_id_list_.push_back(task_id); + stream_id_list_.push_back(stream_id); + + rt_ret = rtModelLoadComplete(rt_model_handle_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api rtModelLoadComplete failed, ret: 0x%X.", rt_ret); return false; } - - GELOGI("LoadTask succ."); return true; } @@ -252,14 +283,16 @@ bool RuntimeModel::Load(uint32_t device_id, uint64_t session_id, std::shared_ptr } GenerateTask(device_id, session_id, davinci_model); + return status; +} - status = LoadTask(); +bool RuntimeModel::DistributeTask() { + bool status = LoadTask(); if (!status) { GELOGE(FAILED, "DistributeTask failed"); - return status; + return false; } - - return status; + return true; } bool RuntimeModel::Run() { @@ -270,10 +303,14 @@ bool RuntimeModel::Run() { return false; } - GELOGI("Run rtModelExecute success"); + GELOGI("Run rtModelExecute success, ret = 0x%X", ret); ret = rtStreamSynchronize(rt_model_stream_); if (ret != RT_ERROR_NONE) { + if (ret == RT_ERROR_END_OF_SEQUENCE) { + GELOGI("Model stream RT_ERROR_END_OF_SEQUENCE signal received, ret = 0x%X", ret); + return true; + } GELOGE(RT_FAILED, "Model stream sync failed, ret = 0x%X", ret); return false; } @@ -433,7 +470,7 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr &davinci_model } if (constant->output_tensors[0].size < constant->weight_data.size()) { - GELOGE(PARAM_INVALID, "Output size:%u is less than weight data size:%zu", constant->output_tensors[0].size, + GELOGE(PARAM_INVALID, "Output size:%u less than weight data size:%zu", constant->output_tensors[0].size, constant->weight_data.size()); return false; } @@ -448,11 +485,8 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr &davinci_model /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero /// and that of unknown shape is zero too. /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. - int64_t elem_num = constant->weight_tensors[0].GetShapeSize(); - if (elem_num == 0 && constant->weight_tensors[0].size == 0) { - elem_num = 1; - } - + int64_t elem_num = + (constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); if (constant->weight_data.size() < sizeof(uint64_t)) { GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); return false; @@ -495,5 +529,6 @@ void RuntimeModel::CreateOutput(uint32_t index, const OpInfo &op_info, InputOutp const std::vector &RuntimeModel::GetTaskIdList() const { return task_id_list_; } +const std::vector &RuntimeModel::GetStreamIdList() const { return stream_id_list_; } } // namespace model_runner } // namespace ge diff --git a/src/ge/ge_runtime/runtime_model.h b/src/ge/ge_runtime/runtime_model.h index e8ff4057..d0c466d4 100644 --- a/src/ge/ge_runtime/runtime_model.h +++ b/src/ge/ge_runtime/runtime_model.h @@ -27,7 +27,7 @@ namespace ge { namespace model_runner { - +using RuntimeInfo = std::tuple; class Task; class RuntimeModel { public: @@ -35,7 +35,12 @@ class RuntimeModel { ~RuntimeModel(); bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr &davinci_model); + bool DistributeTask(); + bool LoadComplete(); const std::vector &GetTaskIdList() const; + const std::vector &GetStreamIdList() const; + const std::map> &GetRuntimeInfoMap() const { return runtime_info_map_; } + rtModel_t GetModelHandle() const { return rt_model_handle_; } bool Run(); bool CopyInputData(const InputData &input_data); bool GetInputOutputDescInfo(bool zero_copy, std::vector *input_desc, @@ -48,7 +53,7 @@ class RuntimeModel { bool LoadTask(); bool InitStream(std::shared_ptr &davinci_model); bool InitEvent(uint32_t event_num); - bool InitLabel(uint32_t batch_num); + bool InitLabel(std::shared_ptr &davinci_model); bool InitDataInfo(std::shared_ptr &davinci_model); bool InitOutputInfo(std::shared_ptr &davinci_model); bool InitConstantInfo(std::shared_ptr &davinci_model); @@ -77,6 +82,8 @@ class RuntimeModel { std::vector> constant_info_list_{}; std::vector task_id_list_{}; + std::vector stream_id_list_{}; + std::map> runtime_info_map_; }; } // namespace model_runner diff --git a/src/ge/ge_runtime/task/aicpu_task.cc b/src/ge/ge_runtime/task/aicpu_task.cc index 4cb71866..9b126ec0 100644 --- a/src/ge/ge_runtime/task/aicpu_task.cc +++ b/src/ge/ge_runtime/task/aicpu_task.cc @@ -85,11 +85,15 @@ bool AicpuTask::Distribute() { return false; } - GELOGI("Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s.", args_size, - io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data()); - rt_ret = rtCpuKernelLaunch(reinterpret_cast(task_info_->so_name().data()), - reinterpret_cast(task_info_->kernel_name().data()), 1, args_, args_size, - nullptr, stream_); + input_output_addr_ = reinterpret_cast(reinterpret_cast(args_) + io_addr_offset); + + auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; + GELOGI( + "Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s, dump_flag = %d.", + args_size, io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data(), dump_flag); + rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(task_info_->so_name().data()), + reinterpret_cast(task_info_->kernel_name().data()), 1, args_, + args_size, nullptr, stream_, dump_flag); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return false; diff --git a/src/ge/ge_runtime/task/aicpu_task.h b/src/ge/ge_runtime/task/aicpu_task.h index f5cdc617..cc21af8a 100644 --- a/src/ge/ge_runtime/task/aicpu_task.h +++ b/src/ge/ge_runtime/task/aicpu_task.h @@ -18,6 +18,7 @@ #define GE_GE_RUNTIME_TASK_AICPU_TASK_H_ #include +#include #include "ge_runtime/task/task.h" namespace ge { @@ -30,12 +31,17 @@ class AicpuTask : public TaskRepeater { bool Distribute() override; + void *Args() override { return input_output_addr_; } + + std::string task_name() const override { return task_info_->op_name(); } + private: static void ReleaseRtMem(void **ptr) noexcept; std::shared_ptr task_info_; void *stream_; void *args_; + void *input_output_addr_; }; } // namespace model_runner } // namespace ge diff --git a/src/ge/ge_runtime/task/hccl_task.cc b/src/ge/ge_runtime/task/hccl_task.cc index 54ae3bf3..3d5f8504 100644 --- a/src/ge/ge_runtime/task/hccl_task.cc +++ b/src/ge/ge_runtime/task/hccl_task.cc @@ -115,7 +115,6 @@ bool HcclTask::Distribute() { rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - (void)rtStreamDestroy(stream); return false; } diff --git a/src/ge/ge_runtime/task/label_goto_task.cc b/src/ge/ge_runtime/task/label_goto_task.cc new file mode 100644 index 00000000..d357accb --- /dev/null +++ b/src/ge/ge_runtime/task/label_goto_task.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_runtime/task/label_goto_task.h" +#include "ge_runtime/task/task_factory.h" + +namespace ge { +namespace model_runner { +LabelGotoTask::LabelGotoTask(const ModelContext &model_context, const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + label_(nullptr) { + if (task_info_ == nullptr) { + GELOGW("task_info_ is null!"); + return; + } + auto stream_list = model_context.stream_list(); + auto label_list = model_context.label_list(); + uint32_t stream_id = task_info->stream_id(); + uint32_t label_id = task_info->label_id(); + GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); + GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); + if (stream_id >= stream_list.size() || label_id >= label_list.size()) { + GELOGW("Stream/Label id invalid."); + return; + } + stream_ = stream_list[stream_id]; + label_ = label_list[label_id]; +} + +LabelGotoTask::~LabelGotoTask() {} + +bool LabelGotoTask::Distribute() { + GELOGI("LabelGotoTask Distribute start."); + if (stream_ == nullptr) { + GELOGE(PARAM_INVALID, "stream is null!"); + return false; + } + if (label_ == nullptr) { + GELOGE(PARAM_INVALID, "label is null!"); + return false; + } + rtError_t rt_ret = rtLabelGotoEx(label_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + GELOGI("DistributeTask end."); + return true; +} + +REGISTER_TASK(TaskInfoType::LABEL_GOTO, LabelGotoTask, LabelGotoTaskInfo); + +} // namespace model_runner +} // namespace ge diff --git a/src/ge/ge_runtime/task/label_goto_task.h b/src/ge/ge_runtime/task/label_goto_task.h new file mode 100644 index 00000000..4fd6d1bc --- /dev/null +++ b/src/ge/ge_runtime/task/label_goto_task.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ +#define GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ + +#include +#include "ge_runtime/task/task.h" + +namespace ge { +namespace model_runner { +class LabelGotoTask : public TaskRepeater { + public: + LabelGotoTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~LabelGotoTask() override; + + bool Distribute() override; + + private: + std::shared_ptr task_info_; + void *stream_; + void *label_; +}; +} // namespace model_runner +} // namespace ge + +#endif // GE_GE_RUNTIME_TASK_LABEL_GOTO_TASK_H_ diff --git a/src/ge/ge_runtime/task/label_set_task.cc b/src/ge/ge_runtime/task/label_set_task.cc new file mode 100644 index 00000000..3ab5802c --- /dev/null +++ b/src/ge/ge_runtime/task/label_set_task.cc @@ -0,0 +1,70 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_runtime/task/label_set_task.h" +#include "ge_runtime/task/task_factory.h" + +namespace ge { +namespace model_runner { +LabelSetTask::LabelSetTask(const ModelContext &model_context, const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + label_(nullptr) { + if (task_info_ == nullptr) { + GELOGW("task_info_ is null!"); + return; + } + auto stream_list = model_context.stream_list(); + auto label_list = model_context.label_list(); + uint32_t stream_id = task_info->stream_id(); + uint32_t label_id = task_info->label_id(); + GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); + GELOGI("Label list size:%zu, label id:%u.", label_list.size(), label_id); + if (stream_id >= stream_list.size() || label_id >= label_list.size()) { + GELOGW("Stream/Label id invalid."); + return; + } + stream_ = stream_list[stream_id]; + label_ = label_list[label_id]; +} + +LabelSetTask::~LabelSetTask() {} + +bool LabelSetTask::Distribute() { + GELOGI("LabelSetTask Distribute start."); + if (stream_ == nullptr) { + GELOGE(PARAM_INVALID, "stream is null!"); + return false; + } + if (label_ == nullptr) { + GELOGE(PARAM_INVALID, "label is null!"); + return false; + } + rtError_t rt_ret = rtLabelSet(label_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + GELOGI("DistributeTask end."); + return true; +} + +REGISTER_TASK(TaskInfoType::LABEL_SET, LabelSetTask, LabelSetTaskInfo); + +} // namespace model_runner +} // namespace ge diff --git a/src/ge/ge_runtime/task/label_set_task.h b/src/ge/ge_runtime/task/label_set_task.h new file mode 100644 index 00000000..70bf1584 --- /dev/null +++ b/src/ge/ge_runtime/task/label_set_task.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ +#define GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ + +#include +#include "ge_runtime/task/task.h" + +namespace ge { +namespace model_runner { +class LabelSetTask : public TaskRepeater { + public: + LabelSetTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~LabelSetTask() override; + + bool Distribute() override; + + private: + std::shared_ptr task_info_; + void *stream_; + void *label_; +}; +} // namespace model_runner +} // namespace ge + +#endif // GE_GE_RUNTIME_TASK_LABEL_SET_TASK_H_ diff --git a/src/ge/ge_runtime/task/label_switch_task.cc b/src/ge/ge_runtime/task/label_switch_task.cc new file mode 100644 index 00000000..a3c2d41a --- /dev/null +++ b/src/ge/ge_runtime/task/label_switch_task.cc @@ -0,0 +1,131 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ge_runtime/task/label_switch_task.h" +#include "ge_runtime/task/task_factory.h" + +namespace ge { +namespace model_runner { +LabelSwitchTask::LabelSwitchTask(const ModelContext &model_context, + const std::shared_ptr &task_info) + : TaskRepeater(model_context, task_info), + task_info_(task_info), + stream_(nullptr), + all_label_resource_(), + label_info_(nullptr) { + if (task_info_ == nullptr) { + GELOGW("task_info_ is null!"); + return; + } + + all_label_resource_ = model_context.label_list(); + auto stream_list = model_context.stream_list(); + uint32_t stream_id = task_info->stream_id(); + GELOGI("Stream list size:%zu, stream id:%u.", stream_list.size(), stream_id); + if (stream_id >= stream_list.size()) { + GELOGW("Stream id invalid."); + return; + } + stream_ = stream_list[stream_id]; +} + +LabelSwitchTask::~LabelSwitchTask() { + if (label_info_ != nullptr) { + rtError_t rt_ret = rtFree(label_info_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtFree fwkOpBuf failed! ret: 0x%X.", rt_ret); + } + label_info_ = nullptr; + } +} + +bool LabelSwitchTask::Distribute() { + GELOGI("LabelSwitchTask Distribute start."); + if (!CheckParamValid()) { + return false; + } + + const std::vector &label_index_list = task_info_->label_list(); + std::vector label_list(task_info_->label_size(), nullptr); + + for (size_t i = 0; i < task_info_->label_size(); ++i) { + uint32_t label_index = label_index_list[i]; + if (label_index >= all_label_resource_.size()) { + GELOGE(PARAM_INVALID, "label %zu index is %u, but there are %zu labels in total.", i, label_index, + all_label_resource_.size()); + return false; + } + label_list[i] = all_label_resource_[label_index]; + GELOGI("Case %zu: label id %zu.", i, label_index); + } + + uint32_t label_info_size = sizeof(rtLabelDevInfo) * task_info_->label_size(); + rtError_t rt_ret = rtMalloc(&label_info_, label_info_size, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + rt_ret = rtLabelListCpy(label_list.data(), label_list.size(), label_info_, label_info_size); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + rt_ret = rtLabelSwitchByIndex(task_info_->cond(), label_list.size(), label_info_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + + GELOGI("DistributeTask end."); + return true; +} + +bool LabelSwitchTask::CheckParamValid() { + if (stream_ == nullptr) { + GELOGE(PARAM_INVALID, "stream is null!"); + return false; + } + + if (task_info_->label_list().empty()) { + GELOGE(PARAM_INVALID, "label_list is empty."); + return false; + } + + if (task_info_->label_size() != task_info_->label_list().size()) { + GELOGE(PARAM_INVALID, "label_list size %zu but label_size is %u.", task_info_->label_list().size(), + task_info_->label_size()); + return false; + } + + if (task_info_->label_size() >= UINT32_MAX / sizeof(rtLabelDevInfo)) { + GELOGE(PARAM_INVALID, "label_size %u will overflow.", task_info_->label_size()); + return false; + } + + if (label_info_ != nullptr) { + GELOGE(PARAM_INVALID, "label_info_ has dirty data."); + return false; + } + + return true; +} + +REGISTER_TASK(TaskInfoType::LABEL_SWITCH, LabelSwitchTask, LabelSwitchTaskInfo); + +} // namespace model_runner +} // namespace ge diff --git a/src/ge/ge_runtime/task/label_switch_task.h b/src/ge/ge_runtime/task/label_switch_task.h new file mode 100644 index 00000000..463faa31 --- /dev/null +++ b/src/ge/ge_runtime/task/label_switch_task.h @@ -0,0 +1,44 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ +#define GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ + +#include +#include "ge_runtime/task/task.h" + +namespace ge { +namespace model_runner { +class LabelSwitchTask : public TaskRepeater { + public: + LabelSwitchTask(const ModelContext &model_context, const std::shared_ptr &task_info); + + ~LabelSwitchTask() override; + + bool Distribute() override; + + private: + bool CheckParamValid(); + + std::shared_ptr task_info_; + void *stream_; + std::vector all_label_resource_; + void *label_info_; +}; +} // namespace model_runner +} // namespace ge + +#endif // GE_GE_RUNTIME_TASK_LABEL_SWITCH_TASK_H_ diff --git a/src/ge/ge_runtime/task/stream_switch_task.cc b/src/ge/ge_runtime/task/stream_switch_task.cc index 91141139..2adcb4bd 100644 --- a/src/ge/ge_runtime/task/stream_switch_task.cc +++ b/src/ge/ge_runtime/task/stream_switch_task.cc @@ -51,7 +51,7 @@ bool StreamSwitchTask::Distribute() { } if (static_cast(task_info_->true_stream_id()) >= stream_list_.size()) { - GELOGE(PARAM_INVALID, "true_stream_id %ld must be less than stream_list_ size %zu!", task_info_->true_stream_id(), + GELOGE(PARAM_INVALID, "true_stream_id %ld must less than stream_list_ size %zu!", task_info_->true_stream_id(), stream_list_.size()); return false; } diff --git a/src/ge/ge_runtime/task/task.h b/src/ge/ge_runtime/task/task.h index 7c748a7d..6c4df248 100644 --- a/src/ge/ge_runtime/task/task.h +++ b/src/ge/ge_runtime/task/task.h @@ -18,7 +18,9 @@ #define GE_GE_RUNTIME_TASK_TASK_H_ #include +#include #include +#include #include "runtime/rt_model.h" #include "ge_runtime/model_context.h" #include "ge_runtime/task_info.h" @@ -32,6 +34,10 @@ class Task { virtual ~Task() {} virtual bool Distribute() = 0; + + virtual void *Args() { return nullptr; } + + virtual std::string task_name() const { return ""; } }; template diff --git a/src/ge/ge_runtime/task/tbe_task.cc b/src/ge/ge_runtime/task/tbe_task.cc index 8a3c36a4..e7025ae8 100644 --- a/src/ge/ge_runtime/task/tbe_task.cc +++ b/src/ge/ge_runtime/task/tbe_task.cc @@ -95,15 +95,14 @@ bool TbeTask::Distribute() { return false; } - GELOGI("InitTbeTask end."); GELOGI("DistributeTbeTask start."); - rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_); + auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT; + rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api rtKernelLaunch failed, ret: 0x%X", rt_ret); return false; } - - GELOGI("DistributeTbeTask end."); + GELOGI("[DataDump] task name:%s, dump_flag:%d", task_info_->op_name().c_str(), dump_flag); return true; } diff --git a/src/ge/ge_runtime/task/tbe_task.h b/src/ge/ge_runtime/task/tbe_task.h index 994ba5e2..a8ce6268 100644 --- a/src/ge/ge_runtime/task/tbe_task.h +++ b/src/ge/ge_runtime/task/tbe_task.h @@ -30,6 +30,10 @@ class TbeTask : public TaskRepeater { bool Distribute() override; + void *Args() override { return args_; } + + std::string task_name() const override { return task_info_->op_name(); } + private: std::shared_ptr task_info_; void *stream_; diff --git a/src/ge/ge_train.mk b/src/ge/ge_train.mk deleted file mode 100644 index 767ce86b..00000000 --- a/src/ge/ge_train.mk +++ /dev/null @@ -1,333 +0,0 @@ -LOCAL_PATH := $(call my-dir) - -COMMON_LOCAL_SRC_FILES := \ - proto/fusion_model.proto \ - proto/optimizer_priority.proto \ - session/inner_session.cc \ - session/session_manager.cc \ - common/ge/plugin_manager.cc\ - common/fp16_t.cc \ - common/formats/utils/formats_trans_utils.cc \ - common/formats/format_transfers/datatype_transfer.cc \ - common/formats/format_transfers/format_transfer_transpose.cc \ - common/formats/format_transfers/format_transfer_nchw_nc1hwc0.cc \ - common/formats/format_transfers/format_transfer_fractal_z.cc \ - common/formats/format_transfers/format_transfer_fractal_nz.cc \ - common/formats/format_transfers/format_transfer_fractal_zz.cc \ - common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.cc \ - common/formats/format_transfers/format_transfer_nc1hwc0_nchw.cc \ - common/formats/format_transfers/format_transfer_nc1hwc0_nhwc.cc \ - common/formats/format_transfers/format_transfer_hwcn_c1hwncoc0.cc \ - common/formats/format_transfers/format_transfer_c1hwncoc0_hwcn.cc \ - common/formats/format_transfers/format_transfer_fracz_nchw.cc \ - common/formats/format_transfers/format_transfer_fracz_nhwc.cc \ - common/formats/format_transfers/format_transfer_fracz_hwcn.cc \ - common/formats/format_transfers/format_transfer_dhwcn_fracz3D.cc \ - common/formats/format_transfers/format_transfer_dhwnc_fracz3D_transpose.cc \ - common/formats/formats.cc \ - init/gelib.cc \ - engine_manager/dnnengine_manager.cc \ - opskernel_manager/ops_kernel_manager.cc \ - graph/manager/graph_manager.cc \ - graph/manager/graph_manager_utils.cc \ - graph/manager/graph_context.cc \ - graph/preprocess/graph_preprocess.cc \ - graph/preprocess/multi_batch_copy_graph.cc \ - graph/execute/graph_execute.cc \ - graph/load/graph_loader.cc \ - graph/optimize/graph_optimize.cc \ - graph/passes/folding_pass.cc \ - graph/optimize/summary_optimize.cc \ - graph/build/graph_builder.cc \ - graph/partition/engine_place.cc \ - graph/partition/graph_partition.cc \ - graph/partition/dynamic_shape_partition.cc \ - generator/ge_generator.cc \ - generator/generator_api.cc \ - common/profiling/profiling_manager.cc \ - ge_local_engine/engine/host_cpu_engine.cc \ - common/helper/model_cache_helper.cc \ - -OMG_HOST_SRC_FILES := \ - model/ge_model.cc \ - model/ge_root_model.cc \ - graph/common/transop_util.cc \ - graph/manager/graph_var_manager.cc \ - graph/manager/trans_var_data_utils.cc \ - omm/csa_interact.cc \ - graph/passes/pass_manager.cc \ - graph/passes/pass_utils.cc \ - graph/passes/base_pass.cc \ - graph/passes/resource_pair_add_control_pass.cc \ - graph/passes/resource_pair_remove_control_pass.cc \ - graph/passes/constant_folding_pass.cc \ - graph/passes/aicpu_constant_folding_pass.cc \ - graph/passes/reshape_remove_pass.cc \ - graph/passes/reshape_recovery_pass.cc \ - graph/passes/transop_breadth_fusion_pass.cc \ - graph/passes/transop_depth_fusion_pass.cc \ - graph/passes/same_transdata_breadth_fusion_pass.cc \ - graph/passes/transop_without_reshape_fusion_pass.cc \ - graph/passes/compile_nodes_pass.cc \ - graph/passes/transop_nearby_allreduce_fusion_pass.cc \ - graph/passes/variable_prepare_op_pass.cc \ - graph/passes/variable_ref_delete_op_pass.cc \ - graph/passes/variable_ref_useless_control_out_delete_pass.cc \ - graph/passes/variable_op_pass.cc \ - graph/passes/cast_remove_pass.cc \ - graph/passes/replace_transshape_pass.cc \ - graph/passes/transpose_transdata_pass.cc \ - graph/passes/identify_reference_pass.cc \ - graph/passes/variable_format_pass.cc \ - graph/passes/subgraph_pass.cc \ - graph/passes/data_pass.cc \ - graph/passes/net_output_pass.cc \ - graph/passes/constant_fuse_same_pass.cc \ - graph/passes/print_op_pass.cc \ - graph/passes/no_use_reshape_remove_pass.cc \ - graph/passes/iterator_op_pass.cc \ - graph/passes/atomic_addr_clean_pass.cc \ - graph/optimize/optimizer/allreduce_fusion_pass.cc \ - graph/common/omg_util.cc \ - graph/common/bcast.cc \ - graph/passes/dimension_compute_pass.cc \ - graph/passes/dimension_adjust_pass.cc \ - graph/passes/get_original_format_pass.cc \ - graph/passes/shape_operate_op_remove_pass.cc \ - graph/passes/unused_op_remove_pass.cc \ - graph/passes/assert_pass.cc \ - graph/passes/dropout_pass.cc \ - graph/passes/infershape_pass.cc \ - graph/passes/unused_const_pass.cc \ - graph/passes/isolated_op_remove_pass.cc \ - graph/passes/permute_pass.cc \ - graph/passes/ctrl_edge_transfer_pass.cc \ - host_kernels/broadcast_gradient_args_kernel.cc \ - host_kernels/greater_kernel.cc \ - host_kernels/gather_v2_kernel.cc \ - host_kernels/maximum_kernel.cc \ - host_kernels/floormod_kernel.cc \ - host_kernels/floordiv_kernel.cc \ - host_kernels/range_kernel.cc \ - host_kernels/shape_kernel.cc \ - host_kernels/size_kernel.cc \ - host_kernels/shape_n_kernel.cc \ - host_kernels/rank_kernel.cc \ - host_kernels/broadcast_args_kernel.cc \ - host_kernels/fill_kernel.cc \ - host_kernels/empty_kernel.cc \ - host_kernels/expanddims_kernel.cc \ - host_kernels/reshape_kernel.cc \ - host_kernels/squeeze_kernel.cc \ - host_kernels/kernel_utils.cc \ - host_kernels/cast_kernel.cc \ - host_kernels/transdata_kernel.cc \ - host_kernels/transpose_kernel.cc \ - host_kernels/permute_kernel.cc \ - host_kernels/pack_kernel.cc \ - host_kernels/concat_v2_kernel.cc \ - host_kernels/concat_offset_kernel.cc \ - host_kernels/strided_slice_kernel.cc \ - host_kernels/ssd_prior_box_kernel.cc \ - host_kernels/add_kernel.cc \ - host_kernels/unpack_kernel.cc \ - host_kernels/sub_kernel.cc \ - host_kernels/mul_kernel.cc \ - host_kernels/reduce_prod_kernel.cc \ - host_kernels/rsqrt_kernel.cc \ - host_kernels/slice_kernel.cc \ - host_kernels/slice_d_kernel.cc \ - host_kernels/dynamic_stitch_kernel.cc \ - graph/passes/stop_gradient_pass.cc \ - graph/passes/prevent_gradient_pass.cc \ - graph/passes/identity_pass.cc \ - graph/passes/placeholder_with_default_pass.cc \ - graph/passes/snapshot_pass.cc \ - graph/passes/guarantee_const_pass.cc \ - graph/passes/var_is_initialized_op_pass.cc \ - graph/passes/parallel_concat_start_op_pass.cc \ - graph/passes/cast_translate_pass.cc \ - graph/passes/addn_pass.cc \ - graph/passes/common_subexpression_elimination_pass.cc \ - graph/passes/transop_symmetry_elimination_pass.cc \ - graph/passes/save_pass.cc \ - graph/passes/switch_dead_branch_elimination.cc \ - graph/passes/merge_pass.cc \ - graph/passes/prune_pass.cc \ - graph/passes/flow_ctrl_pass.cc \ - graph/passes/control_trigger_pass.cc \ - graph/passes/switch_data_edges_bypass.cc \ - graph/passes/switch_op_pass.cc \ - graph/passes/multi_batch_pass.cc \ - graph/passes/switch_logic_remove_pass.cc \ - graph/passes/next_iteration_pass.cc \ - graph/passes/cond_pass.cc \ - graph/passes/cond_remove_pass.cc \ - graph/passes/for_pass.cc \ - graph/passes/enter_pass.cc \ - graph/passes/hccl_memcpy_pass.cc \ - graph/passes/link_gen_mask_nodes_pass.cc \ - graph/passes/replace_with_empty_const_pass.cc \ - graph/passes/hccl_group_pass.cc \ - -OME_SRC_FILES := \ - graph/manager/graph_mem_allocator.cc \ - graph/manager/graph_caching_allocator.cc \ - graph/manager/model_manager/event_manager.cc \ - graph/manager/util/debug.cc \ - graph/manager/util/rt_context_util.cc \ - graph/manager/util/variable_accelerate_ctrl.cc \ - graph/manager/util/hcom_util.cc \ - graph/load/new_model_manager/model_manager.cc \ - graph/load/new_model_manager/data_inputer.cc \ - graph/load/new_model_manager/davinci_model.cc \ - graph/load/new_model_manager/davinci_model_parser.cc \ - graph/load/new_model_manager/model_utils.cc \ - graph/load/new_model_manager/tbe_handle_store.cc \ - graph/load/new_model_manager/cpu_queue_schedule.cc \ - graph/load/new_model_manager/zero_copy_task.cc \ - graph/load/output/output.cc \ - graph/load/new_model_manager/data_dumper.cc \ - graph/load/new_model_manager/task_info/task_info.cc \ - graph/load/new_model_manager/task_info/event_record_task_info.cc \ - graph/load/new_model_manager/task_info/event_wait_task_info.cc \ - graph/load/new_model_manager/task_info/fusion_start_task_info.cc \ - graph/load/new_model_manager/task_info/fusion_stop_task_info.cc \ - graph/load/new_model_manager/task_info/hccl_task_info.cc \ - graph/load/new_model_manager/task_info/kernel_ex_task_info.cc \ - graph/load/new_model_manager/task_info/kernel_task_info.cc \ - graph/load/new_model_manager/task_info/label_set_task_info.cc \ - graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc \ - graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc \ - graph/load/new_model_manager/task_info/memcpy_async_task_info.cc \ - graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc \ - graph/load/new_model_manager/task_info/profiler_trace_task_info.cc \ - graph/load/new_model_manager/task_info/stream_active_task_info.cc \ - graph/load/new_model_manager/task_info/stream_switch_task_info.cc \ - graph/load/new_model_manager/task_info/stream_switchn_task_info.cc \ - graph/load/new_model_manager/task_info/end_graph_task_info.cc \ - graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ - graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ - single_op/task/op_task.cc \ - single_op/task/build_task_utils.cc \ - single_op/task/tbe_task_builder.cc \ - single_op/task/aicpu_task_builder.cc \ - single_op/single_op.cc \ - single_op/single_op_model.cc \ - single_op/stream_resource.cc \ - single_op/single_op_manager.cc \ - hybrid/hybrid_davinci_model_stub.cc \ - - -COMMON_LOCAL_C_INCLUDES := \ - proto/om.proto \ - proto/task.proto \ - proto/insert_op.proto \ - proto/ge_ir.proto \ - proto/fwk_adapter.proto \ - proto/op_mapping_info.proto \ - proto/tensorflow/attr_value.proto \ - proto/tensorflow/function.proto \ - proto/tensorflow/graph.proto \ - proto/tensorflow/node_def.proto \ - proto/tensorflow/op_def.proto \ - proto/tensorflow/resource_handle.proto \ - proto/tensorflow/tensor.proto \ - proto/tensorflow/tensor_shape.proto \ - proto/tensorflow/types.proto \ - proto/tensorflow/versions.proto \ - $(LOCAL_PATH) ./ \ - $(TOPDIR)inc \ - $(TOPDIR)inc/external \ - $(TOPDIR)inc/external/graph \ - $(TOPDIR)inc/framework \ - $(TOPDIR)inc/framework/common \ - $(TOPDIR)inc/runtime \ - $(TOPDIR)libc_sec/include \ - $(TOPDIR)ops/built-in/op_proto/inc \ - third_party/json/include \ - third_party/protobuf/include \ - third_party/opencv/include \ - -NEW_OMG_HOST_SRC_FILES := \ - graph/preprocess/insert_op/util_insert_aipp_op.cc \ - graph/preprocess/insert_op/ge_aipp_op.cc \ - graph/build/model_builder.cc \ - graph/build/task_generator.cc \ - graph/build/stream_allocator.cc \ - graph/build/logical_stream_allocator.cc \ - graph/build/stream_graph_optimizer.cc \ - graph/build/run_context.cc \ - graph/build/label_allocator.cc \ - graph/label/label_maker.cc \ - graph/label/if_label_maker.cc \ - graph/label/case_label_maker.cc \ - graph/label/while_label_maker.cc \ - graph/label/partitioned_call_label_maker.cc \ - - - -#compiler for host train -include $(CLEAR_VARS) - -LOCAL_MODULE := libge_train - -LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DREUSE_MEMORY=1 -O2 -LOCAL_CFLAGS += -DDAVINCI_CLOUD -DDAVINCI_TRAIN -DFMK_SUPPORT_DUMP -DDAVINCI_SUPPORT_PROFILING -LOCAL_CFLAGS += -DFMK_SUPPORT_DEBUG -ifeq ($(DEBUG), 1) -LOCAL_CFLAGS += -g -O0 -endif - -LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) - -LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) -LOCAL_SRC_FILES += $(OMG_HOST_SRC_FILES) -LOCAL_SRC_FILES += $(OME_SRC_FILES) -LOCAL_SRC_FILES += $(NEW_OMG_HOST_SRC_FILES) - -LOCAL_STATIC_LIBRARIES := libge_memory \ - -LOCAL_SHARED_LIBRARIES := \ - libc_sec \ - libprotobuf \ - libslog \ - libmmpa \ - libgraph \ - libregister \ - libge_common \ - libhccl \ - libmsprof \ - - -LOCAL_LDFLAGS := -lrt -ldl - -LOCAL_SHARED_LIBRARIES += \ - libruntime \ - libresource \ - -include $(BUILD_HOST_SHARED_LIBRARY) - -# add engine_conf.json to host -include $(CLEAR_VARS) - -LOCAL_MODULE := engine_conf.json - -LOCAL_SRC_FILES := engine_manager/engine_conf.json - -LOCAL_MODULE_CLASS := ETC - -LOCAL_INSTALLED_PATH := $(HOST_OUT_ROOT)/engine_conf.json -include $(BUILD_HOST_PREBUILT) - -# add optimizer_priority.pbtxt to host -include $(CLEAR_VARS) - -LOCAL_MODULE := optimizer_priority.pbtxt - -LOCAL_SRC_FILES := opskernel_manager/optimizer_priority.pbtxt - -LOCAL_MODULE_CLASS := ETC - -LOCAL_INSTALLED_PATH := $(HOST_OUT_ROOT)/optimizer_priority.pbtxt -include $(BUILD_HOST_PREBUILT) diff --git a/src/ge/generator/ge_generator.cc b/src/ge/generator/ge_generator.cc index b01f7591..bc1e78c1 100644 --- a/src/ge/generator/ge_generator.cc +++ b/src/ge/generator/ge_generator.cc @@ -23,15 +23,15 @@ #include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" #include "ge/ge_api.h" -#include "graph/ge_context.h" #include "graph/debug/ge_attr_define.h" +#include "graph/ge_context.h" #include "graph/manager/graph_manager.h" #include "graph/manager/util/rt_context_util.h" #include "graph/opsproto_manager.h" #include "graph/utils/graph_utils.h" #include "graph/utils/type_utils.h" -#include "model/ge_model.h" #include "init/gelib.h" +#include "model/ge_model.h" using std::map; using std::string; @@ -46,6 +46,16 @@ const char *const kFileNameSuffix = "online"; std::map engine_type_map{ {ge::ENGINE_SYS, kEngineNameDefault}, {ge::ENGINE_AICORE, kAIcoreEngine}, {ge::ENGINE_VECTOR, kVectorEngine}}; + +bool ContainsDynamicInpus(const ge::OpDesc &op_desc) { + for (auto &tensor_desc : op_desc.GetAllInputsDescPtr()) { + if (tensor_desc->MutableShape().IsUnknownShape()) { + GELOGI("Contains unknown shape input. set is_dynamic_input to true."); + return true; + } + } + return false; +} } // namespace namespace ge { @@ -55,6 +65,7 @@ static Status CheckEngineTypeSupport(const OpDescPtr &op_desc, OpEngineType engi GELOGI("CheckEngineType: use default engine."); return SUCCESS; } + // get op engine name string op_engine_name; auto iter = engine_type_map.find(engine_type); @@ -65,6 +76,12 @@ static Status CheckEngineTypeSupport(const OpDescPtr &op_desc, OpEngineType engi GELOGE(FAILED, "CheckEngineType: engine type: %d not support", static_cast(engine_type)); return FAILED; } + + if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) { + op_desc->SetOpEngineName(op_engine_name); + op_desc->SetOpKernelLibName(op_engine_name); + return SUCCESS; + } // set op engine name and opkernelLib. when engine support std::shared_ptr instance_ptr = ge::GELib::GetInstance(); if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { @@ -195,18 +212,26 @@ static void GetOpsProtoPath(string &opsproto_path) { class GeGenerator::Impl { public: - Status BuildModel(const Graph &graph, const vector &inputs, GraphId &graph_id, GeRootModelPtr &ge_models); + Status BuildModel(const Graph &graph, const vector &inputs, GeRootModelPtr &ge_models); Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model); Status SaveParams(GeModelPtr &ge_model, const string &type, const map &attrs, const vector &inputs, const vector &outputs); - Status GenerateInfershapeGraph(const Graph &graph, GraphId &graph_id); + Status GenerateInfershapeGraph(const Graph &graph); GraphManager graph_manager_; SaveParam save_param_; bool is_offline_ = true; + bool is_singleop_unregistered_ = false; + + private: + static std::string Trim(const std::string &str); + bool ParseVersion(const std::string &line, std::string &version); + bool GetVersionFromPath(const std::string &file_path, std::string &version); + bool SetAtcVersionInfo(AttrHolder &obj); + bool SetOppVersionInfo(AttrHolder &obj); }; Status GeGenerator::Initialize(const map &options) { @@ -273,10 +298,9 @@ Status GeGenerator::GenerateOnlineModel(const Graph &graph, const vectorGenerateInfershapeGraph(graph, graph_id); + Status ret = impl_->GenerateInfershapeGraph(graph); if (ret != SUCCESS) { GELOGE(ret, "Dump infershape json failed"); if (impl_->graph_manager_.Finalize() != SUCCESS) { @@ -288,6 +312,124 @@ Status GeGenerator::GenerateInfershapeGraph(const Graph &graph) { return SUCCESS; } +// Remove the space and tab before and after the string +std::string GeGenerator::Impl::Trim(const std::string &str) { + if (str.empty()) { + return str; + } + + std::string::size_type start = str.find_first_not_of(" \t\r\n"); + if (start == std::string::npos) { + return str; + } + + std::string::size_type end = str.find_last_not_of(" \t\r\n") + 1; + return str.substr(start, end); +} + +// Parsing the command line +bool GeGenerator::Impl::ParseVersion(const std::string &line, std::string &version) { + std::string flag = "Version="; + std::string temp = Trim(line); + + if (temp.empty()) { + GELOGW("line is empty."); + return false; + } + + std::string::size_type pos = temp.find(flag); + if (pos == std::string::npos) { + GELOGW("Incorrect line [%s], it must include [%s].", line.c_str(), flag.c_str()); + return false; + } + + if (temp.size() == flag.size()) { + GELOGW("version information is empty. %s", line.c_str()); + return false; + } + + version = temp.substr(pos + flag.size()); + GELOGI("Version=%s", version.c_str()); + + return true; +} + +bool GeGenerator::Impl::GetVersionFromPath(const std::string &file_path, std::string &version) { + // Normalize the path + string resolved_file_path = RealPath(file_path.c_str()); + if (resolved_file_path.empty()) { + GELOGW("Invalid input file path [%s], make sure that the file path is correct.", file_path.c_str()); + return false; + } + std::ifstream fs(resolved_file_path, std::ifstream::in); + if (!fs.is_open()) { + GELOGW("Open %s failed.", file_path.c_str()); + return false; + } + + std::string line; + if (getline(fs, line)) { + if (!ParseVersion(line, version)) { + GELOGW("Parse version failed. content is [%s].", line.c_str()); + fs.close(); + return false; + } + } else { + GELOGW("No version information found in the file path:%s", file_path.c_str()); + fs.close(); + return false; + } + + fs.close(); // close the file + return true; +} + +// Set package version information in the model +bool GeGenerator::Impl::SetAtcVersionInfo(AttrHolder &obj) { + std::string path_base = ge::GELib::GetPath(); + path_base = path_base.substr(0, path_base.rfind('/')); + path_base = path_base.substr(0, path_base.rfind('/') + 1); + + std::string version_path = path_base + "version.info"; + GELOGI("version_path is %s", version_path.c_str()); + std::string version; + if (!GetVersionFromPath(version_path, version)) { + GELOGW("Get atc version information failed!"); + return false; + } + // set version info + if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_ATC_VERSION, version)) { + GELOGW("Ge model set atc version failed!"); + return false; + } + GELOGI("Ge model set atc version information success."); + return true; +} + +// Set package version information in the model +bool GeGenerator::Impl::SetOppVersionInfo(AttrHolder &obj) { + const char *path_env = std::getenv("ASCEND_OPP_PATH"); + if (path_env == nullptr) { + GELOGW("Get environment variable ASCEND_OPP_PATH failed!"); + return false; + } + std::string version_path = path_env; + version_path += "/version.info"; + GELOGI("version_path is %s", version_path.c_str()); + std::string version; + if (!GetVersionFromPath(version_path, version)) { + GELOGW("Get opp version information failed!"); + return false; + } + // set version info + if (!ge::AttrUtils::SetStr(obj, ATTR_MODEL_OPP_VERSION, version)) { + GELOGW("Ge model set opp version failed!"); + return false; + } + GELOGI("Ge Model set opp version information success."); + return true; +} + Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector &inputs, ModelBufferData &model, bool is_offline) { rtContext_t ctx = nullptr; @@ -297,11 +439,11 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr } else { ge::RtContextUtil::GetInstance().SetNormalModeContext(ctx); } - GraphId graph_id; + GeRootModelPtr ge_root_model = nullptr; GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); impl_->is_offline_ = is_offline; - Status ret = impl_->BuildModel(graph, inputs, graph_id, ge_root_model); + Status ret = impl_->BuildModel(graph, inputs, ge_root_model); if (ret != SUCCESS) { GELOGE(ret, "Build model failed."); if (impl_->graph_manager_.Finalize() != SUCCESS) { @@ -315,6 +457,7 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr string model_name = ""; Status name_ret = model_helper.GetModelNameFromMergedGraphName(ge_root_model->GetRootGraph()->GetName(), model_name); if (name_ret != SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); GELOGE(FAILED, "Get model_name failed. Param --output is invalid"); return PARAM_INVALID; } @@ -352,6 +495,12 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in return PARAM_INVALID; } + domi::GetContext().is_dynamic_input = ContainsDynamicInpus(*op_desc); + + if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) { + impl_->is_singleop_unregistered_ = true; + } + // 0. Save original attributes. OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc); GE_CHECK_NOTNULL(op_desc_tmp); @@ -368,9 +517,6 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in // 2. Create ComputeGraph. string name = ge::CurrentTimeInStr() + "_" + model_file_name; ge::ComputeGraphPtr compute_graph = MakeShared(name); - if (compute_graph == nullptr) { - return INTERNAL_ERROR; - } GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR); // 3. Add Node to ComputeGraph. @@ -403,16 +549,19 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector &in Graph graph = ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph); GELOGI("ATC parser success in single op build."); - GraphId graph_id; GeRootModelPtr ge_root_model = nullptr; GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); impl_->is_offline_ = is_offline; - GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, inputs, graph_id, ge_root_model)); + GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, inputs, ge_root_model)); map op_attrs = op_desc_tmp->GetAllAttrs(); GE_CHECK_NOTNULL(ge_root_model); GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); map name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); - GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; + if (name_to_ge_model.empty()) { + GELOGE(PARAM_INVALID, "GetSubgraphInstanceNameToModel is empty."); + return PARAM_INVALID; + } + GeModelPtr &ge_model = name_to_ge_model.begin()->second; GELOGD("The opType in op_desc_tmp is [%s]", op_desc_tmp->GetType().c_str()); GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs)); GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_model, model_buff)); @@ -464,6 +613,14 @@ Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, c } Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &model, ModelBufferData &model_buff) { + // set atc version + if (!SetAtcVersionInfo(*(model.get()))) { + GELOGW("SetPackageVersionInfo of atc failed!"); + } + // set opp version + if (!SetOppVersionInfo(*(model.get()))) { + GELOGW("SetPackageVersionInfo of ops failed!"); + } ModelHelper model_helper; model_helper.SetSaveMode(is_offline_); Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff); @@ -474,7 +631,7 @@ Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr & return SUCCESS; } -Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector &inputs, GraphId &graph_id, +Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector &inputs, GeRootModelPtr &ge_root_model) { static GraphId id = 0; const std::map options; @@ -493,19 +650,22 @@ Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector return INTERNAL_ERROR; } uint64_t session_id = static_cast(tv.tv_sec * 1000000 + tv.tv_usec); // 1000000us - ret = graph_manager_.BuildGraph(id, inputs, ge_root_model, session_id); + if (is_singleop_unregistered_) { + ret = graph_manager_.BuildGraphForUnregisteredOp(id, inputs, ge_root_model, session_id); + } else { + ret = graph_manager_.BuildGraph(id, inputs, ge_root_model, session_id); + } + if (ret != SUCCESS) { GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager build graph fail, graph id: %u", id); return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; } - - graph_id = id; id += 1; return SUCCESS; } -Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph, GraphId &graph_id) { +Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph) { static GraphId id = 0; const std::map options; Status ret = graph_manager_.AddGraph(id, graph, options); @@ -520,11 +680,8 @@ Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph, GraphId &g GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager generate graph failed"); return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; } - - graph_id = id; id += 1; return SUCCESS; } - } // namespace ge diff --git a/src/ge/graph/build/graph_builder.cc b/src/ge/graph/build/graph_builder.cc index f2fa4ada..51519023 100644 --- a/src/ge/graph/build/graph_builder.cc +++ b/src/ge/graph/build/graph_builder.cc @@ -18,11 +18,14 @@ #include "common/ge/ge_util.h" #include "common/helper/model_helper.h" #include "common/opskernel/ops_kernel_info_types.h" +#include "graph/build/logical_stream_allocator.h" #include "graph/build/run_context.h" #include "graph/build/stream_graph_optimizer.h" #include "graph/manager/graph_var_manager.h" +#include "graph/passes/mark_same_addr_pass.h" #include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" +#include "graph/common/ge_call_wrapper.h" #include "init/gelib.h" #include "model/ge_model.h" @@ -54,7 +57,7 @@ Status GraphBuilder::CalcOpParam(const ge::ComputeGraphPtr &graph) { return GE_CLI_GE_NOT_INITIALIZED; } - for (const auto &node_ptr : graph->GetAllNodes()) { + for (const auto &node_ptr : graph->GetNodes(graph->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); std::string kernel_lib_name = node_ptr->GetOpDesc()->GetOpKernelLibName(); if (kernel_lib_name.empty()) { @@ -102,11 +105,7 @@ Status GraphBuilder::UpdateParentNodeOutputSize(const ge::ComputeGraphPtr &graph graph->GetName().c_str()); auto parent_op_desc = parent_node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(parent_op_desc); - bool is_unknown_shape = false; - if (!AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape)) { - GELOGE(PARAM_INVALID, "Get op %s unknown shape attr failed.", parent_op_desc->GetName().c_str()); - return PARAM_INVALID; - } + bool is_unknown_shape = graph->GetGraphUnknownFlag(); if (is_unknown_shape) { GELOGI("Current graph[%s] is unknown, no need to update parent node[%s] output size.", graph->GetName().c_str(), parent_node_ptr->GetName().c_str()); @@ -121,14 +120,14 @@ Status GraphBuilder::UpdateParentNodeOutputSize(const ge::ComputeGraphPtr &graph for (const auto &in_data_anchor : node_ptr->GetAllInDataAnchors()) { auto index = in_data_anchor->GetIdx(); ge::GeTensorDesc desc_temp = op_desc->GetInputDesc(index); - int64_t size = 0; - GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc_temp, size) != SUCCESS, GELOGI("Get size failed!")); uint32_t parent_index = 0; if (!AttrUtils::GetInt(desc_temp, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { - GELOGE(INTERNAL_ERROR, "NetOutput input tensor %d, attr %s not found.", index, - ATTR_NAME_PARENT_NODE_INDEX.c_str()); - return INTERNAL_ERROR; + GELOGI("NetOutput input tensor %d, attr %s not found.", index, ATTR_NAME_PARENT_NODE_INDEX.c_str()); + continue; } + + int64_t size = 0; + GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc_temp, size) != SUCCESS, GELOGI("Get size failed!")); ge::GeTensorDesc parent_desc_temp = parent_op_desc->GetOutputDesc(parent_index); ge::TensorUtils::SetSize(parent_desc_temp, size); GE_CHK_STATUS_RET(parent_op_desc->UpdateOutputDesc(parent_index, parent_desc_temp)); @@ -176,7 +175,7 @@ Status GraphBuilder::BuildForKnownShapeGraph(ComputeGraphPtr &comp_graph, auto subgraph_map = graph_partitioner_.GetSubGraphMap(); GE_TIMESTAMP_START(BuildSubgraph); - ge::ModelBuilder builder(comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); + ge::ModelBuilder builder(session_id, comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); GE_DUMP(comp_graph, "BeforePreBuildModel"); GE_TIMESTAMP_START(PreBuildModel); GE_CHK_STATUS_RET(builder.PreBuildModel(), "Graph[%s] builder PreBuildModel() return fail.", @@ -229,7 +228,7 @@ Status GraphBuilder::BuildForUnknownShapeGraph(ComputeGraphPtr &comp_graph, GeMo GE_TIMESTAMP_END(CalcOpParam, "GraphBuilder::CalcOpParam"); GE_DUMP(comp_graph, "AfterCalcOpParam"); Graph2SubGraphInfoList subgraph_map; - ge::ModelBuilder builder(comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); + ge::ModelBuilder builder(session_id, comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); ModelPtr model_ptr = MakeShared(); if (model_ptr == nullptr) { return MEMALLOC_FAILED; @@ -263,51 +262,41 @@ Status GraphBuilder::BuildForDynamicShapeGraph(ComputeGraphPtr &comp_graph, GeRootModelPtr &ge_root_model_ptr, GeModelPtr &ge_model_ptr, uint64_t session_id) { GELOGI("Start to build BuildForDynamicShape for dynamic shape."); - for (const auto &node : comp_graph->GetDirectNode()) { + // Update Root Graph Data size + for (auto &node : comp_graph->GetDirectNode()) { auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); + op_desc->SetStreamId(kInvalidStream); if (node->GetType() == DATA) { GE_CHK_STATUS_RET(CalcDynShapeRootGraphDataSize(op_desc), "Calc dynamic shape root graph data[%s] size failed.", op_desc->GetName().c_str()); } - - // ATTR_NAME_IS_UNKNOWN_SHAPE is set on "graph partion" stage, but afer fusion , the graph may - // be changed so here need to renew. For example , the scene followed: - // (known)partioncall(known) (known)partioncall(known) - // After fusion - // | --> - // (known)Unique(unknown)--->(unknow)Shape(unknown) (known)FuncDef(known) - // if scene like this , it should be process as known shape graph - bool is_unknown_shape = false; - GE_CHK_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), - "Get node[%s] shape status failed!", node->GetName().c_str()); - if (!is_unknown_shape) { - GE_CHK_BOOL_EXEC(ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape), return FAILED, - "Renew node [%s] attr[%s] failed!", node->GetName().c_str(), ATTR_NAME_IS_UNKNOWN_SHAPE.c_str()); - GELOGD("renew node [%s] attr[%s] success! value is %d", node->GetName().c_str(), - ATTR_NAME_IS_UNKNOWN_SHAPE.c_str(), is_unknown_shape); + } + // + for (auto &sub_graph : comp_graph->GetAllSubgraphs()) { + // exclude functional subgraph in known subgraph + if (sub_graph->GetParentGraph() != comp_graph && !sub_graph->GetParentGraph()->GetGraphUnknownFlag()) { + continue; } - - vector subgraph_names = op_desc->GetSubgraphInstanceNames(); - for (auto subgraph_name : subgraph_names) { - ComputeGraphPtr subgraph = comp_graph->GetSubgraph(subgraph_name); - bool is_unknown_shape = false; - if (!AttrUtils::GetBool(op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape)) { - GELOGE(PARAM_INVALID, "Get op %s unknown shape attr failed.", op_desc->GetName().c_str()); - return PARAM_INVALID; - } - if (is_unknown_shape) { - // unknown shape build flow - GE_CHK_STATUS_RET(BuildForUnknownShapeGraph(subgraph, ge_model_ptr, session_id), - "Build for unknown shape graph failed."); - } else { - // known shape build flow - GE_CHK_STATUS_RET(BuildForKnownShapeGraph(subgraph, subgraph_ptr_list, ge_model_ptr, session_id), - "Build for known shape graph failed."); + if (sub_graph->GetGraphUnknownFlag()) { + // unknown shape build flow + GE_CHK_STATUS_RET(BuildForUnknownShapeGraph(sub_graph, ge_model_ptr, session_id), + "Build for unknown shape graph failed."); + } else { + // reset functional subgraph parent graph as known subgraph + for (const auto &node : sub_graph->GetDirectNode()) { + for (const auto &sub_graph_name : node->GetOpDesc()->GetSubgraphInstanceNames()) { + auto sub_sub_graph = comp_graph->GetSubgraph(sub_graph_name); + GE_CHK_STATUS_RET(sub_graph->AddSubgraph(sub_sub_graph), "Failed add subgraph to known graph."); + } } - ge_root_model_ptr->SetSubgraphInstanceNameToModel(subgraph_name, ge_model_ptr); + // known shape build flow + GE_CHK_STATUS_RET(BuildForKnownShapeGraph(sub_graph, subgraph_ptr_list, ge_model_ptr, session_id), + "Build for known shape graph failed."); } + ge_root_model_ptr->SetSubgraphInstanceNameToModel(sub_graph->GetName(), ge_model_ptr); } + return SUCCESS; } @@ -327,8 +316,9 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr GELOGE(INTERNAL_ERROR, "Get weight memory size fail."); return INTERNAL_ERROR; } - auto *get_mem_base = - reinterpret_cast(reinterpret_cast(ge::VarManager::Instance(0)->GetVarMemMaxSize())); + + auto var_manager = VarManager::Instance(session_id); + auto *get_mem_base = reinterpret_cast(reinterpret_cast(var_manager->GetVarMemMaxSize())); uint8_t *get_weight_mem_base = get_mem_base; if (weight_size > 0) { get_weight_mem_base = get_mem_base + memory_size; @@ -354,11 +344,8 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr return ret; } GE_DUMP(comp_graph, "AfterOptimizeStreamedSubGraph"); - auto *get_var_mem_base = - reinterpret_cast(reinterpret_cast(ge::VarManager::Instance(0)->GetVarMemLogicBase())); - uint64_t var_size = (ge::VarManager::Instance(session_id)->GetVarMemSize(RT_MEMORY_HBM) > 0) - ? ge::VarManager::Instance(0)->GetVarMemMaxSize() - : 0; + auto *get_var_mem_base = reinterpret_cast(reinterpret_cast(var_manager->GetVarMemLogicBase())); + uint64_t var_size = (var_manager->GetVarMemSize(RT_MEMORY_HBM) > 0) ? var_manager->GetVarMemMaxSize() : 0; TaskGenerator task_generator(get_var_mem_base, var_size); ret = task_generator.GetTaskInfo(*model_ptr, comp_graph, session_id, run_context.GetRunContext()); @@ -368,6 +355,13 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) { // set input_desc.size = src_node.output_desc.size if (node_ptr->GetType() == DATA) { + bool is_unknown_shape = false; + GE_CHK_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node_ptr, is_unknown_shape), + "Get data node[%s] shape status failed!", node_ptr->GetName().c_str()); + if (is_unknown_shape) { + GELOGD("data node: %s is unknown shape, do not set input size!", node_ptr->GetName().c_str()); + return SUCCESS; + } if (UpdateDataInputSize(node_ptr) != SUCCESS) { GELOGE(FAILED, "Update data input size failed."); return FAILED; @@ -398,7 +392,7 @@ Status GraphBuilder::SetInputSize(const ge::NodePtr &node_ptr) { GE_CHECK_NOTNULL(input_desc); ge::TensorUtils::SetSize(const_cast(*input_desc), size); GE_CHK_STATUS_RET(node_op_desc->UpdateInputDesc(in_data_anchor->GetIdx(), *input_desc)); - GELOGD("%s input desc, dim_size: %zu, mem_size: %u, format: %s, type: %s.", node_ptr->GetName().c_str(), + GELOGD("%s input desc, dim_size: %zu, mem_size: %ld, format: %s, type: %s.", node_ptr->GetName().c_str(), input_desc->GetShape().GetDimNum(), size, TypeUtils::FormatToSerialString(input_desc->GetFormat()).c_str(), TypeUtils::DataTypeToSerialString(input_desc->GetDataType()).c_str()); } @@ -444,6 +438,11 @@ Status GraphBuilder::CalcDynShapeRootGraphDataSize(const ge::OpDescPtr &op_desc) GELOGI("Begin to calc dynamic shape graph data[%s] size.", op_desc->GetName().c_str()); // data op only has one output anchor ge::GeTensorDesc output_desc = op_desc->GetOutputDesc(0); + if (output_desc.MutableShape().IsUnknownShape()) { + GELOGI("No need to update dynamic shape graph data output size for unknown shape data."); + return SUCCESS; + } + int64_t output_size = 0; if (ge::TensorUtils::GetSize(output_desc, output_size) != SUCCESS) { GELOGW("Get size failed!"); diff --git a/src/ge/graph/build/label_allocator.cc b/src/ge/graph/build/label_allocator.cc index 46c092f5..f8fbe28b 100644 --- a/src/ge/graph/build/label_allocator.cc +++ b/src/ge/graph/build/label_allocator.cc @@ -24,7 +24,6 @@ #include "graph/label/label_maker.h" namespace ge { - LabelAllocator::LabelAllocator(const ComputeGraphPtr &graph) : compute_graph_(graph) {} Status LabelAllocator::AssignFunctionalLabels(uint32_t &label_index) { @@ -76,5 +75,4 @@ bool LabelAllocator::CollectFunctionalNode(ComputeGraphPtr &graph, std::setGetOpDesc(), kAttrNameParentOpType, parent_op_type)) { - if ((parent_op_type != CONSTANT) && (parent_op_type != CONSTANTOP)) { - return true; - } - } - } - } - - return false; -} - Status AssignByLabelPass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { bool changed = false; int64_t &next_stream = context.next_stream; @@ -133,21 +110,6 @@ Status IndependentStreamPass::Run(ComputeGraphPtr graph, const vector &subgraphs, Context &context) { bool changed = false; - if (IsHeadNodeExceeded(subgraphs)) { - int64_t &next_stream = context.next_stream; - for (const SubgraphPtr &subgraph : subgraphs) { - if (!HasAssignedStream(*subgraph)) { - subgraph->stream_id = next_stream; - changed = true; - } - } - if (changed) { - ++next_stream; - return SUCCESS; - } - return NOT_CHANGED; - } - map end_subgraph_map; map pld_subgraph_map; InitEndSubgraphMap(subgraphs, end_subgraph_map); @@ -190,24 +152,6 @@ Status AssignByDependencyPass::Run(ComputeGraphPtr graph, const vector &subgraphs) const { - size_t aicpu_node_num = 0; - for (const SubgraphPtr &subgraph : subgraphs) { - if (subgraph->engine_conf.id == kAICPUEngineName && !HasNonConstInputNode(*subgraph)) { - const SubGraphInfo &subgraph_info = subgraph->subgraph_info; - auto compute_graph = subgraph_info.GetSubGraph(); - aicpu_node_num += compute_graph->GetDirectNode().size() - subgraph_info.GetPld2EndMap().size() - - subgraph_info.GetEnd2PldMap().size(); - if (aicpu_node_num > kHeadNodeMaxNum) { - GELOGI("aicpu_node_num, %zu", aicpu_node_num); - return true; - } - } - } - - return false; -} - void AssignByDependencyPass::InitEndSubgraphMap(const vector &subgraphs, map &end_subgraph_map) { for (const auto &subgraph : subgraphs) { @@ -727,7 +671,7 @@ void LogicalStreamAllocator::RefreshContinuousStreams(const ComputeGraphPtr &gra int64_t stream_num = context_.next_stream; vector stream_has_node(stream_num); - for (const NodePtr &node : graph->GetAllNodes()) { + for (const NodePtr &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { if (node != nullptr) { auto op_desc = node->GetOpDesc(); if (op_desc != nullptr) { @@ -748,7 +692,7 @@ void LogicalStreamAllocator::RefreshContinuousStreams(const ComputeGraphPtr &gra } } - for (const NodePtr &node : graph->GetAllNodes()) { + for (const NodePtr &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { auto op_desc = node->GetOpDesc(); if (op_desc != nullptr) { int64_t stream_id = op_desc->GetStreamId(); diff --git a/src/ge/graph/build/logical_stream_allocator.h b/src/ge/graph/build/logical_stream_allocator.h index 71946630..280a4104 100644 --- a/src/ge/graph/build/logical_stream_allocator.h +++ b/src/ge/graph/build/logical_stream_allocator.h @@ -81,9 +81,6 @@ class LogicalStreamPass { bool HasStreamLabel(const Subgraph &subgraph) const; bool HasAssignedStream(const Subgraph &subgraph) const; - // Determine if the input of the subgraph is a constant. - bool HasNonConstInputNode(const Subgraph &subgraph) const; - private: std::string name_; }; @@ -121,7 +118,6 @@ class AssignByDependencyPass : public LogicalStreamPass { void UpdateAssignedSubgraphs(Context &context); void UpdateReusedSubgraphs(); - bool IsHeadNodeExceeded(const std::vector &subgraphs) const; bool CouldReuse(const SubgraphPtr &subgraph, const SubgraphPtr &pred_subgraph, const std::map &pld_subgraph_map); diff --git a/src/ge/graph/build/memory/block_mem_assigner.cc b/src/ge/graph/build/memory/block_mem_assigner.cc index df7912fa..99b2fd7d 100644 --- a/src/ge/graph/build/memory/block_mem_assigner.cc +++ b/src/ge/graph/build/memory/block_mem_assigner.cc @@ -18,6 +18,7 @@ #include #include +#include "external/ge/ge_api_types.h" #include "framework/common/debug/ge_log.h" #include "graph/anchor.h" #include "graph/buffer.h" @@ -35,11 +36,19 @@ #include "omg/omg_inner_types.h" #include "runtime/mem.h" +using std::map; +using std::pair; +using std::set; +using std::string; +using std::stringstream; +using std::unordered_map; +using std::unordered_set; +using std::vector; + namespace { const char *const kAttrNameWorkspaceReuseFlag = "workspace_reuse_flag"; const char *const kL2FusionDynamicConvergeOp = "l2fusion_dynamic_converge_op"; const char *const kOpNoReuseMem = "no_reuse_mem_flag"; -const char *const kDisableReuseMemory = "ge.exec.disableReuseMemory"; const char *const OP_NO_REUSE_MEM = "OP_NO_REUSE_MEM"; const int kReuseMaxCount = 10; const int kReuseMaxOpNum = 10; @@ -47,13 +56,12 @@ const int kReuseMaxCharNum = 2000; } // namespace namespace ge { -using std::map; -using std::pair; -using std::string; -using std::stringstream; -using std::unordered_map; -using std::unordered_set; -using std::vector; +void AlignMemOffset(size_t &mem_align_size) { + if (mem_align_size <= 0) { + return; + } + mem_align_size = (mem_align_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE; +} void MemoryBlock::SetHeadOffset(size_t offset) { head_offset_ = offset; @@ -92,7 +100,7 @@ void MemoryBlock::Resize() { } else { size_t block_size = (child_block_size > *iter) ? child_block_size : *iter; if ((block_size > 0) && (block_size % MEM_ALIGN_SIZE != 0)) { - block_size = (block_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE; + AlignMemOffset(block_size); } block_size_ = block_size; if (last_continuous_block_) { @@ -101,6 +109,20 @@ void MemoryBlock::Resize() { } } +size_t MemoryBlock::AlignSize() const { + size_t align_block_size = 0; + auto iter = std::max_element(real_size_list_.begin(), real_size_list_.end()); + if (iter == real_size_list_.end()) { + GELOGW("real_size_list_ is empty"); + } else { + align_block_size = *iter; + if ((align_block_size > 0) && (align_block_size % MEM_ALIGN_SIZE != 0)) { + AlignMemOffset(align_block_size); + } + } + return align_block_size; +} + bool MemoryBlock::IsSameLabel(std::string &first_batch_label) { if (node_type_index_list_.empty()) { return false; @@ -133,32 +155,69 @@ bool MemoryBlock::IsSameLabel(std::string &first_batch_label) { } bool CanNotLifeReuse(MemoryBlock *block) { - if (block == nullptr || !block->reuse_mem_ || block->deleted_block_ || block->continuous_block_ || - block->GetLifeEnd() == kMaxLifeTime) { + if ((block == nullptr) || !block->reuse_mem_ || block->deleted_block_) { return true; } return false; } -void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block) { +void MemoryBlock::AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life) { + // continuous memory case:only real_size is maximum can be reused and only one continuous memory in one block + auto it_block = std::max_element(std::begin(block->NoAlignSizeList()), std::end(block->NoAlignSizeList())); + auto it_this = std::max_element(std::begin(NoAlignSizeList()), std::end(NoAlignSizeList())); + if (it_block != std::end(block->NoAlignSizeList()) && it_this != std::end(NoAlignSizeList())) { + if ((continuous_block_ && block->continuous_block_) || (continuous_block_ && (*it_this < *it_block)) || + (block->continuous_block_ && (*it_this > *it_block))) { + GELOGD("Conflict current block size:%zu continuous:%d, reuse block max size:%zu continuous:%d", *it_this, + continuous_block_, *it_block, block->continuous_block_); + return; + } + } + + MemoryBlock *parent = nullptr; + MemoryBlock *child = nullptr; + // merge small block to large block + if (block->GetDependLifeBegin(stream_id_, total_node_depend_stream_life) > GetLifeEnd()) { + if ((block->child_offset_ + AlignSize()) <= *it_block) { + parent = block; + child = this; + } + } + if ((parent != nullptr) && (child != nullptr) && child->child_blocks_.empty()) { + parent->child_blocks_.emplace_back(child); + parent->child_offset_ += child->AlignSize(); + child->deleted_block_ = true; + GELOGI( + "Add continuous block[%p size:%zu, stream id:%ld life time[begin:%zu, end:%zu]] to" + " block[%p size:%zu, stream id:%ld, life time[begin:%zu, end:%zu]]", + child, child->block_size_, child->stream_id_, child->GetLifeBegin(), child->GetLifeEnd(), parent, + parent->block_size_, parent->stream_id_, parent->GetLifeBegin(), parent->GetLifeEnd()); + } +} + +void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life) { if (CanNotLifeReuse(this) || CanNotLifeReuse(block)) { return; } + if (block->continuous_block_) { + AddContinuousLifeReuseBlock(block, total_node_depend_stream_life); + return; + } MemoryBlock *parent = nullptr; MemoryBlock *child = nullptr; // merge small block to large block - if ((block->GetLifeBegin() > GetLifeEnd()) && (block->stream_id_ == stream_id_)) { - if ((child_offset_ + block->block_size_) <= block_size_) { + if (block->GetDependLifeBegin(stream_id_, total_node_depend_stream_life) > GetLifeEnd()) { + if ((child_offset_ + block->AlignSize()) <= AlignSize()) { parent = this; child = block; - } else if ((block->child_offset_ + block_size_) <= block->block_size_) { + } else if ((block->child_offset_ + AlignSize()) <= block->AlignSize()) { parent = block; child = this; } } if ((parent != nullptr) && (child != nullptr) && child->child_blocks_.empty()) { parent->child_blocks_.emplace_back(child); - parent->child_offset_ += child->block_size_; + parent->child_offset_ += child->AlignSize(); child->deleted_block_ = true; GELOGI( "Add block[%p size:%zu, stream id:%ld life time[begin:%zu, end:%zu]] to" @@ -181,6 +240,87 @@ size_t MemoryBlock::GetLifeBegin() { return life_time; } +/// |-stream 1-| |-stream 2-| +/// |--block1--| |--block---| +/// |--block2--| |--block---| +/// |--block3--|\ |--block---| +/// |--block---| \ |--block---| +/// |--block---| \|--block---| +/// |--block---| |--block7--| +/// |--block---| |--block---| +/// block7's first node's input node's life begin > block2's life end, block7 can reuse block1~block2 +size_t MemoryBlock::GetDependLifeBegin(int64_t stream_id, DependStreamLife &total_node_depend_stream_life) { + AddDependLifeBegin(total_node_depend_stream_life); + auto it = depend_stream_life_.find(stream_id); + if (it == depend_stream_life_.end()) { + return 0; + } + return it->second; +} + +void AddDependLife(const ge::NodePtr &org_node, const ge::NodePtr &node, int64_t stream_id, + std::map &depend_stream_life, DependStreamLife &total_node_depend_stream_life) { + GE_CHECK_NOTNULL_EXEC(node, return ); + auto node_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(node_desc, return ); + auto node_id = node_desc->GetId(); + auto stream_life = total_node_depend_stream_life.find(node_id); + if (stream_life != total_node_depend_stream_life.end()) { + for (auto &it : stream_life->second) { + if (depend_stream_life.find(it.first) == depend_stream_life.end()) { + depend_stream_life[it.first] = it.second; + } + } + return; + } + + for (const auto &in_anchor : node->GetAllInAnchors()) { + GE_CHECK_NOTNULL_EXEC(in_anchor, continue); + for (auto peer_out_anchor : in_anchor->GetPeerAnchors()) { + GE_CHECK_NOTNULL_EXEC(peer_out_anchor, continue); + auto peer_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL_EXEC(peer_node, continue); + auto peer_node_desc = peer_node->GetOpDesc(); + GE_CHECK_NOTNULL_EXEC(peer_node_desc, continue); + auto peer_node_stream_id = peer_node_desc->GetStreamId(); + if (peer_node_stream_id < 0) { + continue; + } + size_t peer_node_life_time = peer_node_desc->GetId(); + auto it = depend_stream_life.find(peer_node_stream_id); + if (it == depend_stream_life.end() || peer_node_life_time > it->second) { + depend_stream_life[peer_node_stream_id] = peer_node_life_time; + if (peer_node_stream_id != stream_id) { + GELOGI("Node:%s stream id:%ld depend node:%s stream id:%ld index[%d] life time[%zu].", + org_node->GetName().c_str(), stream_id, peer_node_desc->GetName().c_str(), peer_node_stream_id, + peer_out_anchor->GetIdx(), peer_node_life_time); + } + AddDependLife(org_node, peer_node, stream_id, depend_stream_life, total_node_depend_stream_life); + } + } + } + + // save on node to save next calculation + for (auto &it : depend_stream_life) { + if (total_node_depend_stream_life[node_id].find(it.first) == total_node_depend_stream_life[node_id].end()) { + total_node_depend_stream_life[node_id][it.first] = it.second; + } + } +} + +void MemoryBlock::AddDependLifeBegin(DependStreamLife &total_node_depend_stream_life) { + if (!depend_stream_life_.empty()) { + return; + } + if (!node_type_index_list_.empty()) { + auto node = node_type_index_list_.front().node; + if (node != nullptr) { + AddDependLife(node, node, stream_id_, depend_stream_life_, total_node_depend_stream_life); + } + } + depend_stream_life_[stream_id_] = GetLifeBegin(); +} + size_t MemoryBlock::GetLifeEnd() { if (!node_type_index_list_.empty()) { return node_type_index_list_.back().life_time_end; @@ -249,15 +389,15 @@ string ToString(ge::NodeTypeIndex &x) { string MemoryBlock::String() { stringstream ss; - ss << "Block size: " << Size() << " from " << HeadOffset() << " to " << TailOffset() << ""; - ss << "real_size_list: " << ToString(real_size_list_) << ""; - ss << "ref_count: " << ref_count_ << ""; + ss << "Block size: " << Size() << " from " << HeadOffset() << " to " << TailOffset() << " "; + ss << "real_size_list: " << ToString(real_size_list_) << " "; + ss << "ref_count: " << ref_count_ << " "; ss << "members: "; for (auto x : NodeTypeIndexList()) { - ss << "__node: " << ToString(x) << ""; + ss << "__node: " << ToString(x) << " "; } for (const auto &symbol : SymbolList()) { - ss << "__symbol: " << symbol << ""; + ss << "__symbol: " << symbol << " "; } return ss.str(); } @@ -302,7 +442,7 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector &all_memory_size) { if (iter1 == anchor_to_symbol_.end()) { continue; } - std::string symbol = iter1->second; + const std::string &symbol = iter1->second; auto iter2 = symbol_size_.find(symbol); if (iter2 == symbol_size_.end()) { symbol_size_[symbol] = size; @@ -317,7 +457,7 @@ void BlockMemAssigner::GetOutAndWorkSpaceMem(vector &all_memory_size) { all_memory_size.insert(all_memory_size.end(), temp.begin(), temp.end()); } GELOGI("The last atomic_addr_clean node id: %ld", atomic_addr_clean_id_); - for (auto &pair : symbol_size_) { + for (const auto &pair : symbol_size_) { all_memory_size.emplace_back(pair.second); } sort(all_memory_size.begin(), all_memory_size.end()); @@ -351,7 +491,7 @@ size_t GetBlockSize(size_t size, const vector &ranges) { } GELOGW("Memory needed size:%zu is beyond the biggest block in memory ranges.", size); - return 0; + return size; } bool IsDirectOutputNode(const NodePtr &node, int idx) { @@ -385,34 +525,8 @@ void ReduceReusableBlockCount(const MemoryBlock &mem_block, map &reusable_block_counts, const MemoryBlock &reusable_block, - size_t block_size, size_t real_size, bool continuous, int64_t atomic_addr_clean_id) { + size_t block_size, size_t real_size, bool continuous) { bool can_reuse = false; - - // If node is before atomic_addr_clean node, the continus memory can't be reused. - if (!reusable_block.NodeTypeIndexList().empty()) { - auto node = reusable_block.NodeTypeIndexList()[0].node; - if (node != nullptr) { - auto op_desc = node->GetOpDesc(); - if (op_desc != nullptr) { - if ((op_desc->GetId() < atomic_addr_clean_id) && continuous) { - return false; - } - } - } - } - - // continuous memory case:only real_size is maximum can be reused and only one continuous memory in one block - if (continuous || reusable_block.continuous_block_) { - auto it = - std::max_element(std::begin(reusable_block.NoAlignSizeList()), std::end(reusable_block.NoAlignSizeList())); - if (it != std::end(reusable_block.NoAlignSizeList())) { - GE_IF_BOOL_EXEC((continuous && reusable_block.continuous_block_) || (continuous && (real_size < *it)) || - (reusable_block.continuous_block_ && (real_size > *it)), - GELOGD("Conflict current block size:%zu continuous:%d, reuse block max size:%zu continuous:%d", - real_size, continuous, *it, reusable_block.continuous_block_); - return false;); - } - } if (reusable_block.Size() == block_size) { can_reuse = true; } else { @@ -427,14 +541,6 @@ bool CanReuseBySize(const map &reusable_block_counts, const Me return can_reuse; } -bool CanReuseByStream(const std::unordered_set &reuse_stream, MemoryBlock &reusable_block) { - bool can_reuse = false; - if (reuse_stream.find(reusable_block.stream_id_) != reuse_stream.cend()) { - can_reuse = true; - } - return can_reuse; -} - bool BlockMemAssigner::IsOutNodeSetContinuousInput(const NodePtr &n, uint32_t out_index, std::string &peer_name, uint32_t &peer_input_index) { if (n == nullptr || n->GetAllOutDataAnchors().size() <= 0) { @@ -495,11 +601,11 @@ void BlockMemAssigner::InitReuseFlag() { ge::CONSTANT, ge::CONSTANTOP}; static const std::set kPostReuseTypes = {ge::DATA_TYPE, ge::AIPP_DATA_TYPE, ge::ENTER, ge::REFENTER, ge::NEXTITERATION, ge::REFNEXTITERATION}; - for (auto &pair : symbol_to_anchors_) { + for (const auto &pair : symbol_to_anchors_) { std::string symbol = pair.first; bool pre_reuse_flag = true; bool post_reuse_flag = true; - for (auto &node_index_io : pair.second) { + for (const auto &node_index_io : pair.second) { if (node_index_io.io_type_ == kIn) { continue; } @@ -513,13 +619,13 @@ void BlockMemAssigner::InitReuseFlag() { if (node_index_io.node_->GetOutDataNodes().empty()) { out_flg = true; } - for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + for (const auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { if (IsDirectOutputNode(in_anchor->GetOwnerNode(), in_anchor->GetIdx())) { out_flg = true; break; } } - std::string type = out_anchor->GetOwnerNode()->GetType(); + const std::string &type = out_anchor->GetOwnerNode()->GetType(); pre_reuse_flag = pre_reuse_flag && !out_flg && (kPreReuseTypes.count(type) == 0); post_reuse_flag = post_reuse_flag && (kPostReuseTypes.count(type) == 0); if (!pre_reuse_flag && !post_reuse_flag) { @@ -552,7 +658,7 @@ bool BlockMemAssigner::IsPreReuse(const NodePtr &node, uint32_t out_index) const return false; } - std::string symbol = iter1->second; + const std::string &symbol = iter1->second; auto iter2 = pre_reuse_flag_.find(symbol); if (iter2 == pre_reuse_flag_.end()) { return false; @@ -570,7 +676,7 @@ bool BlockMemAssigner::IsPostReuse(const MemoryBlock *mem_block) const { if (mem_block == nullptr) { return false; } - for (auto &symbol : mem_block->SymbolList()) { + for (const auto &symbol : mem_block->SymbolList()) { auto iter = post_reuse_flag_.find(symbol); if (iter == post_reuse_flag_.end()) { continue; @@ -593,8 +699,7 @@ bool BlockMemAssigner::IsSymbolExist(const NodeIndexIO &node_index_io) { if (iter == anchor_to_symbol_.end()) { return false; } - std::string symbol = iter->second; - return symbol_blocks_.find(symbol) != symbol_blocks_.end(); + return symbol_blocks_.find(iter->second) != symbol_blocks_.end(); } /// @@ -603,15 +708,43 @@ bool BlockMemAssigner::IsSymbolExist(const NodeIndexIO &node_index_io) { /// @return void /// void BlockMemAssigner::PrintSymbolMap() { - for (auto &pair : symbol_to_anchors_) { + for (const auto &pair : symbol_to_anchors_) { GELOGD("symbol=%s, max_size=%zu, pre_reuse=%s, post_reuse=%s", pair.first.c_str(), symbol_size_[pair.first], pre_reuse_flag_[pair.first] ? "true" : "false", post_reuse_flag_[pair.first] ? "true" : "false"); - for (auto &node_index_io : pair.second) { + for (const auto &node_index_io : pair.second) { GELOGD("anchor:%s", node_index_io.ToString().c_str()); } } } +bool BlockMemAssigner::IsContinuousOutput(const NodePtr &n) { + if (n == nullptr) { + GELOGE(FAILED, "Node is null."); + return false; + } + + // Get the continuous output type of the node, default is false + bool is_output_continuous = false; + auto node_desc = n->GetOpDesc(); + if (node_desc == nullptr) { + GELOGE(FAILED, "Node[%s] nodedesc is null.", n->GetName().c_str()); + return false; + } + + // If GetBool fail, is_output_continuous is false. + (void)ge::AttrUtils::GetBool(node_desc, ATTR_NAME_CONTINUOUS_OUTPUT, is_output_continuous); + if (is_output_continuous) { + if (n->GetOwnerComputeGraph() != nullptr) { + string graph_name = n->GetOwnerComputeGraph()->GetName(); + GELOGI("%s name[%s] set continuous, output size[%u].", graph_name.c_str(), n->GetName().c_str(), + n->GetAllOutDataAnchorsSize()); + return true; + } + } + + return false; +} + MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, size_t no_align_size, MemoryType mem_type, const NodePtr &n, uint32_t out_index, const vector &workspace_reuse_flag, const bool is_op_reuse_mem, @@ -622,15 +755,14 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, bool is_reuse_memory = false; string ge_disable_reuse_mem_env = "0"; - (void)ge::GetContext().GetOption(kDisableReuseMemory, ge_disable_reuse_mem_env); + (void)ge::GetContext().GetOption(OPTION_EXEC_DISABLE_REUSED_MEMORY, ge_disable_reuse_mem_env); if (ge_disable_reuse_mem_env != "1") { bool reuse_mem_flag = !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]); is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && !node_op_desc->HasAttr(kOpNoReuseMem) && reuse_mem_flag && is_op_reuse_mem && (IsPreReuse(n, out_index)); auto stream_id = node_op_desc->GetStreamId(); - auto map_iter = reusable_streams_map_.find(stream_id); - if (is_reuse_memory && map_iter != reusable_streams_map_.end()) { - for (auto it = reusable_blocks_.begin(); it != reusable_blocks_.end(); ++it) { + if (is_reuse_memory && !continuous) { + for (auto it = reusable_blocks_[stream_id].begin(); it != reusable_blocks_[stream_id].end(); ++it) { MemoryBlock *reusable_block = *it; if (!IsPostReuse(reusable_block)) { reusable_block->reuse_mem_ = false; @@ -639,11 +771,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, } // A node can reuse blocks of the same stream and preorder streams - auto id = GetAtomicAddrCleanId(); - if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous, id) && - CanReuseByStream(map_iter->second, *reusable_block)) { - GELOGD("Cross stream mem reuse, target stream:%ld, current stream:%ld", reusable_block->stream_id_, - stream_id); + if (CanReuseBySize(reusable_block_counts_, *reusable_block, block_size, real_size, continuous)) { reusable_block->AddNodeTypeIndex({n, mem_type, out_index, false}, real_size, no_align_size); if (mem_type == kOutput) { auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString()); @@ -654,7 +782,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, reusable_block->continuous_block_ = continuous; reusable_block->ref_count_++; ReduceReusableBlockCount(*reusable_block, reusable_block_counts_); - reusable_blocks_.erase(it); + reusable_blocks_[stream_id].erase(it); return reusable_block; } } @@ -683,6 +811,47 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, return block; } +MemoryBlock *BlockMemAssigner::ApplyContinuousMemory(const NodePtr &n, const vector &ranges, + const bool is_op_reuse_mem) { + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "input node is null."); + auto node_op_desc = n->GetOpDesc(); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(node_op_desc == nullptr, return nullptr, "node_op_desc is null."); + MemoryBlock *block = nullptr; + int64_t total_size = 0; + for (uint32_t index = 0; index < static_cast(node_op_desc->GetOutputsSize()); index++) { + auto output_op_desc = node_op_desc->GetOutputDescPtr(index); + if (output_op_desc == nullptr) { + return nullptr; + } + int64_t size = 0; + if (ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS) { + GELOGI("Get size failed"); + return nullptr; + } + size_t align_size = static_cast(size); + AlignMemOffset(align_size); + total_size += align_size; + + // only apply total size in first block + if (index != 0) { + zero_memory_list_.emplace_back(n, kOutput, index); + } + } + + auto block_size = GetBlockSize(total_size, ranges); + GELOGI("Node[%s] continuous out memory size[%ld] block size[%zu]", node_op_desc->GetName().c_str(), total_size, + block_size); + + vector workspace_reuse_flag; + block = ApplyMemory(block_size, total_size, total_size, kOutput, n, 0, workspace_reuse_flag, is_op_reuse_mem, true); + if (block != nullptr) { + // hccl task need align header and tail + block->first_continuous_block_ = true; + block->last_continuous_block_ = true; + } + return block; +} + MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, const vector &ranges, const bool is_op_reuse_mem, const bool continuous) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "input node is null."); @@ -700,7 +869,7 @@ MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, "Get no align size failed"); if (IsSymbolExist(node_index_io)) { - std::string symbol = anchor_to_symbol_[node_index_io.ToString()]; + const std::string &symbol = anchor_to_symbol_[node_index_io.ToString()]; block = symbol_blocks_[symbol]; block->AddNodeTypeIndex({n, kOutput, index, true}, size, no_align_size); block->ref_count_++; @@ -923,7 +1092,11 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector (void)ge::AttrUtils::GetBool(op_desc, ATOMIC_ATTR_IS_ATOMIC_NODE, is_atomic); // Allocate memory for the current node and release node memory of the same size in the workspace GE_IF_BOOL_EXEC(ge_disable_reuse_mem_env_ != "1", - ReleaseMemorys(stream_workspace_blocks_[stream_id], reusable_blocks_);) + ReleaseMemorys(stream_workspace_blocks_[stream_id], reusable_blocks_[stream_id]);) + if (IsContinuousOutput(node)) { + (void)ApplyContinuousMemory(node, ranges, is_op_reuse_mem_); + return SUCCESS; + } for (uint32_t i = 0; i < static_cast(op_desc->GetOutputsSize()); i++) { int64_t size = 0; auto output_op_desc = op_desc->GetOutputDescPtr(i); @@ -950,7 +1123,8 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector continue; } // atomic can't be reused - if (is_op_reuse_mem_ && out_node_set_continuous_input && is_atomic) { + bool need_change = is_op_reuse_mem_ && out_node_set_continuous_input && is_atomic; + if (need_change) { is_op_reuse_mem_ = false; } MemoryBlock *mem_block = ApplyOutMemory(node, i, ranges, is_op_reuse_mem_, out_node_set_continuous_input); @@ -977,10 +1151,7 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector /// @return Status result /// void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { - // Init reusable streams map - InitReusableStreamMap(); - - (void)ge::GetContext().GetOption(kDisableReuseMemory, ge_disable_reuse_mem_env_); + (void)ge::GetContext().GetOption(OPTION_EXEC_DISABLE_REUSED_MEMORY, ge_disable_reuse_mem_env_); GEEVENT("Reuse memory %s", ge_disable_reuse_mem_env_ == "1" ? "close" : "open"); string op_no_reuse_mem_str; const char *op_no_reuse_mem = std::getenv(OP_NO_REUSE_MEM); @@ -1033,7 +1204,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mem_block == nullptr, continue, "failed to apply memory block."); CheckWorkspaceReuse(workspace_reuse_flag, i, stream_id, mem_block); } - ReleaseInputNodeOutMemory(node_out_blocks_, reusable_blocks_, n); + ReleaseInputNodeOutMemory(node_out_blocks_, reusable_blocks_[stream_id], n); } GELOGD("Assigned memory blocks:"); @@ -1043,8 +1214,8 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector &ranges) { } bool merge_dynamic_batch = false; - GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), merge_dynamic_batch = MergeDynamicBatchBlocks();) - GE_IF_BOOL_EXEC(!merge_dynamic_batch, ReuseBlocksByLifeTime();) + GE_IF_BOOL_EXEC(!(ge_disable_reuse_mem_env_ == "1"), merge_dynamic_batch = MergeDynamicBatchBlocks()); + GE_IF_BOOL_EXEC((!(ge_disable_reuse_mem_env_ == "1") && !merge_dynamic_batch), ReuseBlocksByLifeTime(ranges.size())); AssignContinuousBlocks(); ResizeMemoryBlocks(); @@ -1161,10 +1332,12 @@ static bool CompareBlockIndex(MemoryBlock *left, MemoryBlock *right) { /// @param [in] input blocks need continuous /// @param [out] blocks after continuous order /// @param [in/out] blocks ordered +/// @param [in] input or output /// void ReAssignContinuousBlocks(const std::vector &org_blocks, const std::map block_map, - std::vector &dest_blocks, std::vector &continuous_blocks) { + std::vector &dest_blocks, std::vector &continuous_blocks, + const std::string &type) { for (auto &memory_block : org_blocks) { if (memory_block == nullptr || memory_block->deleted_block_) { continue; @@ -1181,7 +1354,7 @@ void ReAssignContinuousBlocks(const std::vector &org_blocks, for (auto &memory_block : continuous_blocks) { GE_IF_BOOL_EXEC(memory_block == nullptr, continue); - GELOGI("Block continuous input index:%d", memory_block->input_index_); + GELOGI("Block continuous %s index:%d", type.c_str(), memory_block->input_index_); count++; if (count == 1) { memory_block->first_continuous_block_ = true; @@ -1216,22 +1389,37 @@ void BlockMemAssigner::AssignContinuousBlocks() { continuous_block_map.size(), continuous_blocks.size()); continue; } - ReAssignContinuousBlocks(memory_blocks_, continuous_block_map, dest_memory_blocks, continuous_blocks); + ReAssignContinuousBlocks(memory_blocks_, continuous_block_map, dest_memory_blocks, continuous_blocks, "input"); memory_blocks_.swap(dest_memory_blocks); } } -void BlockMemAssigner::ReuseBlocksByLifeTime() { +void BlockMemAssigner::ReuseBlocksByLifeTime(size_t range_size) { + // 1 means block size is same so no need to do this + if (range_size <= 1) { + return; + } for (size_t i = 0; i < memory_blocks_.size(); ++i) { auto parent = memory_blocks_[i]; - if (parent == nullptr || parent->deleted_block_) { + if (parent == nullptr || parent->deleted_block_ || parent->continuous_block_) { continue; } if (parent->reuse_mem_ && !IsPostReuse(parent)) { parent->reuse_mem_ = false; } for (size_t j = i + 1; j < memory_blocks_.size(); ++j) { - parent->AddLifeReuseBlock(memory_blocks_[j]); + auto child = memory_blocks_[j]; + if (child == nullptr) { + continue; + } + // If node is before atomic_addr_clean node, the continus memory can't be reused. + if (!parent->NodeTypeIndexList().empty() && child->continuous_block_) { + auto node = parent->NodeTypeIndexList()[0].node; + if (node == nullptr || node->GetOpDesc() == nullptr || (node->GetOpDesc()->GetId() < GetAtomicAddrCleanId())) { + continue; + } + } + parent->AddLifeReuseBlock(child, total_node_depend_stream_life_); } } } @@ -1318,10 +1506,10 @@ void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, siz } GELOGI( "[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu]" - " noalignsize[%zu] life time begin[%zu] life time end[%zu] child[%d] isref[%d].", + " noalignsize[%zu] life time begin[%zu] life time end[%zu] child[%d:%d:%d:%d] isref[%d].", graph_name.c_str(), op_desc->GetName().c_str(), node_type.GetMemType().c_str(), node_type.index, offset, op_desc->GetStreamId(), block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block, - node_type.ref_input); + block->reuse_mem_, block->continuous_block_, block->deleted_block_, node_type.ref_input); } void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) { @@ -1380,143 +1568,10 @@ Status BlockMemAssigner::Assign() { return SUCCESS; } -void BlockMemAssigner::InitReusableStreamMap() { - // save a stream's id and its first Node and last node. - map> stream_head_tail_node_map; - // save a stream's id and its directly child stream. - map> stream_dependency_map; - // save a stream's id and its occupied memory. - unordered_map stream_mem_map; - - // Find streams's first and last node. - FindHeadAndTailNodesForStream(stream_head_tail_node_map, stream_mem_map); - - // If streamB's first node is the output of streamA's last node, then B depends on A. - FindDependentStream(stream_head_tail_node_map, stream_dependency_map); - - // If a stream has more than one child stream, select the one that occupies the closest memory - for (const auto &iter : stream_dependency_map) { - if (iter.second.empty()) { - continue; - } - int64_t target_size = stream_mem_map[iter.first]; - int64_t min_size_gap = LONG_MAX; - int64_t target_reuse_stream_id = 0; - for (auto id : iter.second) { - if (labs(stream_mem_map[id] - target_size) < min_size_gap) { - target_reuse_stream_id = id; - min_size_gap = labs(stream_mem_map[id] - target_size); - } - } - // If b can reuse a, then b should also be able to reuse all blocks that a can reuse. - reusable_streams_map_[target_reuse_stream_id].insert(reusable_streams_map_[iter.first].begin(), - reusable_streams_map_[iter.first].end()); - } -} - -void BlockMemAssigner::FindHeadAndTailNodesForStream(map> &stream_head_tail_node_map, - unordered_map &stream_mem_map) { - for (const auto &n : compute_graph_->GetAllNodes()) { - GE_IF_BOOL_EXEC(n->GetOpDesc() == nullptr, GELOGW("Op desc is nullptr"); continue); - auto stream_id = n->GetOpDesc()->GetStreamId(); - // traverse to find streams's first and last node. - if (stream_head_tail_node_map.find(stream_id) == stream_head_tail_node_map.end()) { - stream_head_tail_node_map[stream_id] = std::make_pair(n, n); - reusable_streams_map_[stream_id].insert(stream_id); // a node can reuse blocks from same stream. - } else { - stream_head_tail_node_map[stream_id].second = n; - } - - // Accumulate the output size of the node in the stream. - for (size_t i = 0; i < n->GetOpDesc()->GetOutputsSize(); i++) { - int64_t size = 0; - if (ge::TensorUtils::GetSize(*n->GetOpDesc()->GetOutputDescPtr(static_cast(i)), size) != SUCCESS) { - GELOGW("Get output size failed!"); - continue; - } - stream_mem_map[stream_id] += size; - } - // Accumulate the workspace size of the node in the stream. - for (auto size : n->GetOpDesc()->GetWorkspaceBytes()) { - stream_mem_map[stream_id] += size; - } - } -} - -void BlockMemAssigner::FindDependentStream(map> &stream_head_tail_node_map, - map> &stream_dependency_map) { - for (const auto &it1 : stream_head_tail_node_map) { - for (const auto &it2 : stream_head_tail_node_map) { - if (it1 == it2) { - continue; - } - NodePtr pre_node = it1.second.second; - NodePtr post_node = it2.second.first; - std::vector out_nodes; - // Direct link out_node - for (const auto &out_node : pre_node->GetOutNodes()) { - if ((out_node->GetOpDesc() == nullptr) || (post_node->GetOpDesc() == nullptr) || - (pre_node->GetOpDesc() == nullptr)) { - continue; - } - out_nodes.emplace_back(out_node); - } - - FindDependentStreamBetweenGraphs(pre_node, out_nodes); - - for (auto &out_node : out_nodes) { - if (out_node->GetOpDesc()->GetId() == post_node->GetOpDesc()->GetId()) { - stream_dependency_map[pre_node->GetOpDesc()->GetStreamId()].insert(post_node->GetOpDesc()->GetStreamId()); - } - } - } - } -} - -/// -/// @ingroup GE -/// @brief Find dependent link between parent_graph and sub_graph -/// @param [in] pre_node -/// @param [out] out_nodes -/// @return void -/// @author -/// -void BlockMemAssigner::FindDependentStreamBetweenGraphs(const NodePtr &pre_node, std::vector &out_nodes) { - if ((pre_node == nullptr) || (pre_node->GetOpDesc() == nullptr)) { - return; - } - - // FunctionOp & subgraph input - std::vector subgraph_names = pre_node->GetOpDesc()->GetSubgraphInstanceNames(); - for (auto &subgraph_name : subgraph_names) { - ComputeGraphPtr subgraph = compute_graph_->GetSubgraph(subgraph_name); - if (subgraph == nullptr) { - continue; - } - for (auto &node : subgraph->GetDirectNode()) { - OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - continue; - } - if (op_desc->HasAttr(ATTR_NAME_PARENT_NODE_INDEX)) { - out_nodes.emplace_back(node); - } - } - } - - // subgraph output & parent_node output - if (NodeUtils::IsSubgraphOutput(pre_node)) { - NodePtr parent_node = pre_node->GetOwnerComputeGraph()->GetParentNode(); - for (const auto &out_node : parent_node->GetOutNodes()) { - out_nodes.emplace_back(out_node); - } - } -} - bool BlockMemAssigner::CheckIsZeroMemNodeType(const string &node_type) const { return (node_type == VARIABLE) || (node_type == CONSTANT) || (node_type == MULTISHAPE) || - (node_type == HCOMBROADCAST) || (node_type == HCOMALLREDUCE) || (node_type == CONSTANTOP) || - (node_type == ASSIGNADD) || (node_type == ASSIGNSUB) || (node_type == ASSIGN) || (node_type == HVDWAIT) || - (node_type == HVDCALLBACKBROADCAST) || (node_type == HVDCALLBACKALLREDUCE); + (node_type == HCOMBROADCAST) || (node_type == CONSTANTOP) || (node_type == ASSIGNADD) || + (node_type == ASSIGNSUB) || (node_type == ASSIGN) || (node_type == HVDWAIT) || + (node_type == HVDCALLBACKBROADCAST); } } // namespace ge diff --git a/src/ge/graph/build/memory/block_mem_assigner.h b/src/ge/graph/build/memory/block_mem_assigner.h index 8ee4506e..3dfba4c5 100644 --- a/src/ge/graph/build/memory/block_mem_assigner.h +++ b/src/ge/graph/build/memory/block_mem_assigner.h @@ -34,6 +34,8 @@ namespace ge { const size_t kMaxLifeTime = 0xffffffff; +using DependStreamLife = std::map>; + enum MemoryType { kOutput, kWorkspace }; struct NodeTypeIndex { @@ -88,6 +90,8 @@ class MemoryBlock { } size_t Size() const { return block_size_; } + size_t AlignSize() const; + void SetHeadOffset(size_t offset); void SetTailOffset(size_t offset); @@ -116,7 +120,9 @@ class MemoryBlock { bool IsSameLabel(std::string &first_batch_label); - void AddLifeReuseBlock(MemoryBlock *block); + void AddContinuousLifeReuseBlock(MemoryBlock *block, DependStreamLife &total_node_depend_stream_life); + + void AddLifeReuseBlock(MemoryBlock *block, DependStreamLife &node_depend_stream_life); void SetLifeTimeEnd(size_t time); @@ -124,6 +130,10 @@ class MemoryBlock { size_t GetLifeEnd(); + void AddDependLifeBegin(DependStreamLife &node_depend_stream_life); + + size_t GetDependLifeBegin(int64_t stream_id, DependStreamLife &node_depend_stream_life); + int ref_count_; int64_t stream_id_; bool deleted_block_; @@ -194,47 +204,6 @@ class BlockMemAssigner : public MemAssigner { void GetNodeWorkSpaceSize(const ge::NodePtr &node, std::vector &workspace_memory); - /// - /// @ingroup GE - /// @brief Traversing the compute_graph_ to find the reuse relationship between streams - /// @param [in] reusable_stream_map map to save stream_id and its reusable stream_ids - /// @return void - /// @author - /// - void InitReusableStreamMap(); - - /// - /// @ingroup GE - /// @brief Traversing the compute_graph_ to find the first and last nodeptr of a stream. - /// @param [in] stream_head_tail_node_map map to save stream_id and its first and last nodeptr. - /// @param [in] stream_mem_map map to save stream_id and its memory capacity. - /// @return void - /// @author - /// - void FindHeadAndTailNodesForStream(std::map> &stream_head_tail_node_map, - std::unordered_map &stream_mem_map); - - /// - /// @ingroup GE - /// @brief Traversing the compute_graph_ to find the reuse relationship between streams. - /// @param [in] stream_head_tail_node_map map to save stream_id and its first and last nodeptr. - /// @param [in] stream_dependency_map map to save stream_id and stream_ids depends on it. - /// @return void - /// @author - /// - void FindDependentStream(std::map> &stream_head_tail_node_map, - std::map> &stream_dependency_map); - - /// - /// @ingroup GE - /// @brief Find dependent link between parent_graph and sub_graph - /// @param [in] pre_node - /// @param [out] out_nodes - /// @return void - /// @author - /// - void FindDependentStreamBetweenGraphs(const NodePtr &pre_node, std::vector &out_nodes); - /// /// @ingroup GE /// @brief Determine whether it is the type of zero memory node. @@ -395,9 +364,13 @@ class BlockMemAssigner : public MemAssigner { /// @return void /// @author /// - void ReuseBlocksByLifeTime(); + void ReuseBlocksByLifeTime(size_t range_size); - std::vector reusable_blocks_; + bool IsContinuousOutput(const NodePtr &n); + + MemoryBlock *ApplyContinuousMemory(const NodePtr &n, const vector &ranges, const bool is_op_reuse_mem); + + std::unordered_map> reusable_blocks_; std::map reusable_block_counts_; @@ -411,9 +384,6 @@ class BlockMemAssigner : public MemAssigner { std::unordered_map node_continuous_input_counts_; - // save stream_id and reusable stream_ids - std::unordered_map> reusable_streams_map_; - // reuse memory vector op_no_reuse_mem_vec_; @@ -426,6 +396,8 @@ class BlockMemAssigner : public MemAssigner { size_t life_time_; int64_t atomic_addr_clean_id_ = 0; + + DependStreamLife total_node_depend_stream_life_; }; } // namespace ge #endif // GE_GRAPH_BUILD_MEMORY_BLOCK_MEM_ASSIGNER_H_ diff --git a/src/ge/graph/build/memory/graph_mem_assigner.cc b/src/ge/graph/build/memory/graph_mem_assigner.cc index c4aca639..c5060dbd 100644 --- a/src/ge/graph/build/memory/graph_mem_assigner.cc +++ b/src/ge/graph/build/memory/graph_mem_assigner.cc @@ -222,9 +222,10 @@ Status GraphMemoryAssigner::ReAssignMemory(bool is_loop_graph, size_t &mem_offse mem_offset = memory_offset_[0].mem_offset_; - if (mem_offset > VarManager::Instance(0)->GetGraphMemoryMaxSize()) { + auto session_id = compute_graph_->GetSessionID(); + if (mem_offset > VarManager::Instance(session_id)->GetGraphMemoryMaxSize()) { GELOGE(ge::FAILED, "Current memoffset %zu is greater than memory manager malloc max size %zu", mem_offset, - VarManager::Instance(0)->GetGraphMemoryMaxSize()); + VarManager::Instance(session_id)->GetGraphMemoryMaxSize()); return ge::FAILED; } return SUCCESS; @@ -292,7 +293,8 @@ Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) { } else if (is_loop_graph) { GE_CHK_STATUS_RET(SetLoopGraphAtomicAttr(node, mem_clean_start)); } else { - GE_CHK_STATUS_RET(SetAtomicCleanAttr(nullptr, mem_clean_start, mem_clean_size), "SetAtomicCleanAttr failed."); + GE_CHK_STATUS_RET(SetAtomicCleanAttr(nullptr, {mem_clean_start}, {mem_clean_size}), + "SetAtomicCleanAttr failed."); } } } @@ -440,35 +442,33 @@ Status GraphMemoryAssigner::AssignContinuousOutputMemory(const ge::NodePtr &node GE_IF_BOOL_EXEC(out_op_desc == nullptr, GELOGE(ge::FAILED, "out_op_desc is null."); return ge::FAILED); vector output_list = out_op_desc->GetOutputOffset(); - if (out_op_desc->GetOutputsSize() > output_list.size()) { + if ((out_op_desc->GetOutputsSize() > output_list.size()) || (output_list.size() == 0)) { GELOGE(ge::FAILED, "The size %zu of node output desc is more than output_list's size %zu.", out_op_desc->GetOutputsSize(), output_list.size()); return ge::FAILED; } - memory_offset_[0].mem_offset_ += MEM_ALIGN_SIZE; + size_t mem_offset = output_list[0]; for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { - output_list[out_data_anchor->GetIdx()] = memory_offset_[0].mem_offset_; - size_t pre_mem_offset = memory_offset_[0].mem_offset_; - + output_list[out_data_anchor->GetIdx()] = mem_offset; int64_t tensor_desc_size = 0; if (ge::TensorUtils::GetSize(*(out_op_desc->GetOutputDescPtr(out_data_anchor->GetIdx())), tensor_desc_size) != ge::SUCCESS) { GELOGE(FAILED, "GetSize failed."); return FAILED; } - memory_offset_[0].mem_offset_ += tensor_desc_size; - - AlignMemOffset(MEM_ALIGN_SIZE); + mem_offset += tensor_desc_size; + if (mem_offset <= 0) { + return FAILED; + } + mem_offset = (mem_offset + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE; GELOGI( - "[IMAS]Continuous output : Set %s name[%s] output[%d] offset to [%zu] stream_id[%ld] size[%zu] " + "[IMAS]Continuous output : Set %s name[%s] output[%d] offset to [%zu] stream_id[%ld] size[%ld] " "real_size[%ld].", node->GetOwnerComputeGraph()->GetName().c_str(), out_op_desc->GetName().c_str(), out_data_anchor->GetIdx(), - pre_mem_offset, out_op_desc->GetStreamId(), (memory_offset_[0].mem_offset_ - pre_mem_offset), tensor_desc_size); + output_list[out_data_anchor->GetIdx()], out_op_desc->GetStreamId(), tensor_desc_size, tensor_desc_size); } - out_op_desc->SetOutputOffset(output_list); - memory_offset_[0].mem_offset_ += MEM_ALIGN_SIZE; return ge::SUCCESS; } @@ -808,14 +808,12 @@ Status GraphMemoryAssigner::ReAssignVirtualNodesMemory(map(memory_offset_[0].mem_offset_); GELOGI("Begin to reAssign atomic memory, atomic initial address mem_offset = %zu!", memory_offset_[0].mem_offset_); + vector connect_netoutput_nodes; for (auto &node : compute_graph_->GetAllNodes()) { auto node_op_desc = node->GetOpDesc(); if (node_op_desc == nullptr) { @@ -838,36 +836,20 @@ Status GraphMemoryAssigner::ReAssignAtomicMemory(bool is_loop_graph) { return ge::PARAM_INVALID; } - // Atomic op memory start addr of loop graph - int64_t loop_graph_atomic_mem_start = static_cast(memory_offset_[0].mem_offset_); - - // Reassign atomic node output memory - Status ret = AssignAtomicOutputMemory(node); - if (ret != SUCCESS) { - GELOGE(ret, "Assign atomic output memory failed, node is %s.", node_op_desc->GetName().c_str()); - return ret; + vector is_connect_netoutput; + // If GetBool fail, attr is_connect_netoutput is an empty vector. + (void)ge::AttrUtils::GetListInt(node_op_desc, ATTR_NAME_NODE_CONNECT_OUTPUT, is_connect_netoutput); + if (!is_connect_netoutput.empty()) { + connect_netoutput_nodes.emplace_back(node); + continue; } - // Check atomic workspace - map> sub_node_workspace_info; - sub_node_workspace_info = node_op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, sub_node_workspace_info); - if (!sub_node_workspace_info.empty()) { - bool is_fusion_node = false; - // If GetBool fail, is_fusion_node is false. - (void)ge::AttrUtils::GetBool(node_op_desc, ATOMIC_ATTR_IS_FUSION_NODE, is_fusion_node); - - if (is_fusion_node) { - // Assign fusion atomic node workspace memory - ret = AssignFusionAtomicWorkspaceMemory(node_op_desc, sub_node_workspace_info); - } else { - // Assign single ordinary atomic node workspace memory, not include fusion node - ret = AssignOrdinaryAtomicWorkspaceMemory(node_op_desc, sub_node_workspace_info); - } - - if (ret != SUCCESS) { - GELOGE(ret, "Assign atomic workspace memory failed, node is %s.", node_op_desc->GetName().c_str()); - return ret; - } + // Atomic op memory start addr of loop graph + int64_t loop_graph_atomic_mem_start = static_cast(memory_offset_[0].mem_offset_); + vector mem_offset_end; + if (AssignAtomicOutputAndWorkspaceMemory(node, mem_offset_end) != SUCCESS) { + GELOGE(FAILED, "Assign atomic output and workspace memory failed, node is %s.", node->GetName().c_str()); + return FAILED; } /// In networks with loop op, atomic op uses atomic_addr_clean op independently, @@ -882,10 +864,77 @@ Status GraphMemoryAssigner::ReAssignAtomicMemory(bool is_loop_graph) { // Set the address attr of atomic clean operator int64_t atomic_mem_size = memory_offset_[0].mem_offset_ - atomic_mem_start; if (atomic_mem_size != 0) { - GE_CHK_STATUS_RET(SetAtomicCleanAttr(nullptr, atomic_mem_start, atomic_mem_size), "SetAtomicCleanAttr failed."); + GE_CHK_STATUS_RET(SetAtomicCleanAttr(nullptr, {atomic_mem_start}, {atomic_mem_size}), + "SetAtomicCleanAttr failed."); } } + if (AssignConnectNetOutputAtomicMemory(connect_netoutput_nodes) != SUCCESS) { + GELOGE(FAILED, "Failed to assign memory of nodes that connect to netoutput."); + return FAILED; + } + + return SUCCESS; +} + +Status GraphMemoryAssigner::AssignAtomicOutputAndWorkspaceMemory(const ge::NodePtr &node, + vector &mem_offset_end) { + auto node_op_desc = node->GetOpDesc(); + // Assign atomic node output memory + Status ret = AssignAtomicOutputMemory(node, mem_offset_end); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to assign atomic output memory, node is %s.", node_op_desc->GetName().c_str()); + return ret; + } + + // Check and assign atomic node workspace memory + map> atomic_workspace_info; + atomic_workspace_info = node_op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, atomic_workspace_info); + if (!atomic_workspace_info.empty()) { + bool is_fusion_node = false; + // If GetBool fail, is_fusion_node is false. + (void)ge::AttrUtils::GetBool(node_op_desc, ATOMIC_ATTR_IS_FUSION_NODE, is_fusion_node); + + if (is_fusion_node) { + // Assign fusion atomic node workspace memory + ret = AssignFusionAtomicWorkspaceMemory(node_op_desc, atomic_workspace_info, mem_offset_end); + } else { + // Assign single ordinary atomic node workspace memory, not include fusion node + ret = AssignOrdinaryAtomicWorkspaceMemory(node_op_desc, atomic_workspace_info, mem_offset_end); + } + if (ret != SUCCESS) { + GELOGE(ret, "Assign atomic workspace memory failed, node is %s.", node_op_desc->GetName().c_str()); + return ret; + } + } + + return SUCCESS; +} + +Status GraphMemoryAssigner::AssignConnectNetOutputAtomicMemory(vector &connect_netoutput_nodes) { + for (auto &node : connect_netoutput_nodes) { + GE_CHECK_NOTNULL(node); + if (node->GetOpDesc() == nullptr) { + GELOGW("Current node %s op desc is nullptr, memory assignment is skipped.", node->GetName().c_str()); + continue; + } + + // Atomic memory start addr + int64_t original_atomic_mem_start = static_cast(memory_offset_[0].mem_offset_); + GELOGD("Start to assign memory of atomic node, node name: %s, node type: %s, mem_offset: %ld.", + node->GetName().c_str(), node->GetOpDesc()->GetType().c_str(), original_atomic_mem_start); + vector mem_offset_end; + if (AssignAtomicOutputAndWorkspaceMemory(node, mem_offset_end) != SUCCESS) { + GELOGE(FAILED, "Assign atomic output and workspace memory failed, node is %s.", node->GetName().c_str()); + return FAILED; + } + + // All atomic nodes use atomic_addr_clean op independently, so we need to set the attr separately. + if (SetIndependentAtomicAttr(node, original_atomic_mem_start, mem_offset_end) != SUCCESS) { + GELOGE(FAILED, "Failed to set atomic attr separately."); + return FAILED; + } + } return SUCCESS; } @@ -970,9 +1019,10 @@ bool GraphMemoryAssigner::CheckInputIsSupportAtomic(const ge::NodePtr &node) { return true; } -Status GraphMemoryAssigner::AssignAtomicOutputMemory(const ge::NodePtr &node) { +Status GraphMemoryAssigner::AssignAtomicOutputMemory(const ge::NodePtr &node, vector &mem_offset_end) { auto op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(ge::FAILED, "op_desc is null."); return ge::FAILED); + mem_offset_end.clear(); GELOGD("Begin to assign atomic output memory, node = %s.", op_desc->GetName().c_str()); vector atomic_output_index; @@ -995,24 +1045,9 @@ Status GraphMemoryAssigner::AssignAtomicOutputMemory(const ge::NodePtr &node) { // If the input of the cascade op needs to clear the atomic addr, there is no need to clear it separately here bool is_assigned_mem = false; - if (static_cast(output_index) >= node->GetAllOutDataAnchors().size()) { - GELOGE(ge::PARAM_INVALID, "Output index %ld is more than the size of node's AllOutDataAnchors.", output_index); - return ge::PARAM_INVALID; - } - auto out_data_anchor = node->GetAllOutDataAnchors().at(output_index); - GE_CHECK_NOTNULL(out_data_anchor); - auto input_anchors = out_data_anchor->GetPeerInDataAnchors(); - for (auto &input_anchor : input_anchors) { - auto output_node = input_anchor->GetOwnerNode(); - - /// Get input atomic attr of peer output op, if atomic_input_index[0] = -1, indicates that the atomic address - /// has been assigned - vector atomic_input_index; - (void)ge::AttrUtils::GetListInt(output_node->GetOpDesc(), ATOMIC_ATTR_INPUT_INDEX, atomic_input_index); - if (!atomic_input_index.empty() && (atomic_input_index[0] == kAllInputAddrIsAtomic)) { - is_assigned_mem = true; - break; - } + if (GetMemoryAssignmentStatus(node, output_index, is_assigned_mem) != SUCCESS) { + GELOGE(ge::FAILED, "Failed to get memory assignment of node %s.", node->GetName().c_str()); + return ge::FAILED; } // If you have already assigned an atomic address, skip it, and you don't need to reassign it. @@ -1037,6 +1072,7 @@ Status GraphMemoryAssigner::AssignAtomicOutputMemory(const ge::NodePtr &node) { memory_offset_[0].mem_offset_ += size; AlignMemOffset(MEM_ALIGN_SIZE); + mem_offset_end.emplace_back(memory_offset_[0].mem_offset_); } op_desc->SetOutputOffset(output_list); @@ -1044,8 +1080,33 @@ Status GraphMemoryAssigner::AssignAtomicOutputMemory(const ge::NodePtr &node) { return ge::SUCCESS; } +Status GraphMemoryAssigner::GetMemoryAssignmentStatus(const ge::NodePtr &node, int64_t output_index, + bool &is_mem_assigned) { + if (static_cast(output_index) >= node->GetAllOutDataAnchors().size()) { + GELOGE(ge::PARAM_INVALID, "Output index %ld is more than the size of node's AllOutDataAnchors.", output_index); + return ge::PARAM_INVALID; + } + auto out_data_anchor = node->GetAllOutDataAnchors().at(output_index); + GE_CHECK_NOTNULL(out_data_anchor); + auto input_anchors = out_data_anchor->GetPeerInDataAnchors(); + for (auto &input_anchor : input_anchors) { + auto output_node = input_anchor->GetOwnerNode(); + + /// Get input atomic attr of peer output op, if atomic_input_index[0] = -1, indicates that the atomic address + /// has been assigned + vector atomic_input_index; + (void)ge::AttrUtils::GetListInt(output_node->GetOpDesc(), ATOMIC_ATTR_INPUT_INDEX, atomic_input_index); + if (!atomic_input_index.empty() && (atomic_input_index[0] == kAllInputAddrIsAtomic)) { + is_mem_assigned = true; + break; + } + } + return SUCCESS; +} + Status GraphMemoryAssigner::AssignOrdinaryAtomicWorkspaceMemory(const ge::OpDescPtr &op_desc, - map> &workspace_info) { + map> &workspace_info, + vector &mem_offset_end) { GELOGI("Begin to reassign normal atomic memory, node = %s.", op_desc->GetName().c_str()); vector workspace_vector = op_desc->GetWorkspace(); @@ -1077,6 +1138,7 @@ Status GraphMemoryAssigner::AssignOrdinaryAtomicWorkspaceMemory(const ge::OpDesc op_desc->GetStreamId(), workspace_size, workspace_size); memory_offset_[0].mem_offset_ += workspace_size; + mem_offset_end.emplace_back(memory_offset_[0].mem_offset_); } } op_desc->SetWorkspace(workspace_vector); @@ -1085,7 +1147,8 @@ Status GraphMemoryAssigner::AssignOrdinaryAtomicWorkspaceMemory(const ge::OpDesc } Status GraphMemoryAssigner::AssignFusionAtomicWorkspaceMemory(const ge::OpDescPtr &op_desc, - map> &workspace_info) { + map> &workspace_info, + vector &mem_offset_end) { GELOGI("Begin to reassign fusion atomic memory, node = %s.", op_desc->GetName().c_str()); map> sub_node_workspace_offset; @@ -1107,6 +1170,7 @@ Status GraphMemoryAssigner::AssignFusionAtomicWorkspaceMemory(const ge::OpDescPt op_desc->GetStreamId(), workspace_size, workspace_size); memory_offset_[0].mem_offset_ += workspace_size; + mem_offset_end.emplace_back(memory_offset_[0].mem_offset_); index_offset.insert(std::make_pair(workspace_index, workspace_offset)); } sub_node_workspace_offset.insert(std::make_pair(iter.first, index_offset)); @@ -1222,10 +1286,16 @@ ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node, vector< peer_out_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_out_anchor->GetIdx(), input_list.back()); } else { + int64_t output_offset = output_list.at(peer_out_anchor->GetIdx()); + if (peer_out_anchor->GetOwnerNode()->GetType() == CONSTANT) { + GeTensorDesc tensor_desc = tmp_op_desc->GetInputDesc(input_index); + GE_CHK_STATUS(TensorUtils::GetDataOffset(tensor_desc, output_offset)); + } + GELOGI("node[%s] input[%d] is set from node[%s] out index[%d] offset[%ld]", tmp_op_desc->GetName().c_str(), input_index, peer_out_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), peer_out_anchor->GetIdx(), - output_list.at(peer_out_anchor->GetIdx())); - input_list.emplace_back(output_list.at(peer_out_anchor->GetIdx())); + output_offset); + input_list.emplace_back(output_offset); } } } @@ -1264,7 +1334,7 @@ ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node) const { } } } - } else if (node->GetType() == DATA) { + } else if (node->GetType() == DATA_TYPE) { if (UpdateConstArgsOffset(node, input_list) != SUCCESS) { GELOGE(FAILED, "Update data: %s args offset failed.", node->GetName().c_str()); return FAILED; @@ -1280,6 +1350,47 @@ ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node) const { return SUCCESS; } +Status GraphMemoryAssigner::SetIndependentAtomicAttr(const ge::NodePtr &node, int64_t atomic_mem_start, + const vector &mem_offset_end) { + GELOGD("Start to set independent atomic attr, atomic_addr_clean memory offset start is %ld", atomic_mem_start); + + // Parsing offset and size vectors + vector memory_offset_start; + vector memory_offset_size; + memory_offset_start.emplace_back(atomic_mem_start); + for (size_t i = 0; i < mem_offset_end.size(); ++i) { + memory_offset_start.emplace_back(mem_offset_end[i]); + // Number 1 means element index + auto size = memory_offset_start[i + 1] - memory_offset_start[i]; + memory_offset_size.emplace_back(size); + } + memory_offset_start.pop_back(); + + const auto &in_control_anchor = node->GetInControlAnchor(); + if (!memory_offset_size.empty() && in_control_anchor != nullptr) { + for (auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { + if (peer_out_control_anchor == nullptr) { + continue; + } + auto peer_out_node = peer_out_control_anchor->GetOwnerNode(); + auto peer_out_node_desc = peer_out_node->GetOpDesc(); + if (peer_out_node_desc == nullptr) { + continue; + } + + GELOGD("Current node memory_offset vector size is %zu, node name %s, node type is %s.", memory_offset_size.size(), + peer_out_node_desc->GetName().c_str(), peer_out_node_desc->GetType().c_str()); + if (peer_out_node_desc->GetType() == ATOMICADDRCLEAN) { + if (SetAtomicCleanAttr(peer_out_node, memory_offset_start, memory_offset_size) != SUCCESS) { + GELOGE(FAILED, "Set atomic clean attr failed."); + return FAILED; + } + } + } + } + return SUCCESS; +} + Status GraphMemoryAssigner::SetLoopGraphAtomicAttr(const ge::NodePtr &node, int64_t atomic_mem_start) { // set the address attr of atomic clean operator for loop graph int64_t atomic_mem_size = memory_offset_[0].mem_offset_ - atomic_mem_start; @@ -1301,7 +1412,7 @@ Status GraphMemoryAssigner::SetLoopGraphAtomicAttr(const ge::NodePtr &node, int6 peer_out_node_desc->GetType().c_str()); if (peer_out_node_desc->GetType() == ATOMICADDRCLEAN) { - GE_CHK_STATUS_EXEC(SetAtomicCleanAttr(peer_out_node, atomic_mem_start, atomic_mem_size), + GE_CHK_STATUS_EXEC(SetAtomicCleanAttr(peer_out_node, {atomic_mem_start}, {atomic_mem_size}), GELOGE(FAILED, "SetAtomicCleanAttr failed."); return FAILED); } @@ -1310,8 +1421,8 @@ Status GraphMemoryAssigner::SetLoopGraphAtomicAttr(const ge::NodePtr &node, int6 return SUCCESS; } -ge::Status GraphMemoryAssigner::SetAtomicCleanAttr(const NodePtr &n, int64_t atomic_mem_start, - int64_t atomic_mem_size) { +ge::Status GraphMemoryAssigner::SetAtomicCleanAttr(const NodePtr &n, const vector &atomic_mem_start, + const vector &atomic_mem_size) { for (ge::NodePtr &node : compute_graph_->GetAllNodes()) { auto node_op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); @@ -1320,15 +1431,15 @@ ge::Status GraphMemoryAssigner::SetAtomicCleanAttr(const NodePtr &n, int64_t ato ((n == nullptr) && (node_op_desc->GetType() == ATOMICADDRCLEAN))) { vector workspace_vector = node_op_desc->GetWorkspace(); vector workspace_byte_vector = node_op_desc->GetWorkspaceBytes(); - workspace_vector.emplace_back(atomic_mem_start); - workspace_byte_vector.emplace_back(atomic_mem_size); + workspace_vector.insert(workspace_vector.end(), atomic_mem_start.begin(), atomic_mem_start.end()); + workspace_byte_vector.insert(workspace_byte_vector.end(), atomic_mem_size.begin(), atomic_mem_size.end()); node_op_desc->SetWorkspace(workspace_vector); node_op_desc->SetWorkspaceBytes(workspace_byte_vector); std::vector mem_start_vector; // If GetListInt fail, mem_start_vector is empty. (void)ge::AttrUtils::GetListInt(node_op_desc, ATTR_NAME_AUTOMIC_ADD_START, mem_start_vector); - mem_start_vector.emplace_back(atomic_mem_start); + mem_start_vector.insert(mem_start_vector.end(), atomic_mem_start.begin(), atomic_mem_start.end()); GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(node_op_desc, ATTR_NAME_AUTOMIC_ADD_START, mem_start_vector), GELOGE(FAILED, "SetListInt failed."); return FAILED); @@ -1336,16 +1447,26 @@ ge::Status GraphMemoryAssigner::SetAtomicCleanAttr(const NodePtr &n, int64_t ato std::vector mem_size_vector; // If GetListInt fail, mem_size_vector is empty. (void)ge::AttrUtils::GetListInt(node_op_desc, ATTR_NAME_AUTOMIC_ADD_MEM_SIZE, mem_size_vector); - mem_size_vector.emplace_back(atomic_mem_size); + mem_size_vector.insert(mem_size_vector.end(), atomic_mem_size.begin(), atomic_mem_size.end()); GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(node_op_desc, ATTR_NAME_AUTOMIC_ADD_MEM_SIZE, mem_size_vector), GELOGE(FAILED, "SetListInt failed."); return FAILED); - GELOGI( - "[IMAS]SetAtomicCleanAttr : Set %s name[%s] output[%d] offset to [%ld] streamid[%ld] size[%ld] " - "realsize[%ld].", - node->GetOwnerComputeGraph()->GetName().c_str(), node_op_desc->GetName().c_str(), 0, atomic_mem_start, - node->GetOpDesc()->GetStreamId(), atomic_mem_size, atomic_mem_size); + std::stringstream ss; + for (auto iter : atomic_mem_start) { + ss << iter << " "; + } + string atomic_mem_start_str = ss.str(); + ss.clear(); + ss.str(""); + for (auto iter : atomic_mem_size) { + ss << iter << " "; + } + string atomic_mem_size_str = ss.str(); + + GELOGI("[IMAS]SetAtomicCleanAttr : Set graph[%s] atomic_node[%s] output offset [%s] size[%s] streamid[%ld]", + node->GetOwnerComputeGraph()->GetName().c_str(), node_op_desc->GetName().c_str(), + atomic_mem_start_str.c_str(), atomic_mem_size_str.c_str(), node->GetOpDesc()->GetStreamId()); } } return SUCCESS; diff --git a/src/ge/graph/build/memory/graph_mem_assigner.h b/src/ge/graph/build/memory/graph_mem_assigner.h index 67008918..afe9a4fa 100644 --- a/src/ge/graph/build/memory/graph_mem_assigner.h +++ b/src/ge/graph/build/memory/graph_mem_assigner.h @@ -147,22 +147,33 @@ class GraphMemoryAssigner { /// bool CheckInputIsSupportAtomic(const ge::NodePtr &node); - ge::Status AssignAtomicOutputMemory(const ge::NodePtr &node); + ge::Status GetMemoryAssignmentStatus(const ge::NodePtr &node, int64_t output_index, bool &is_mem_assigned); + + ge::Status AssignAtomicOutputMemory(const ge::NodePtr &node, std::vector &mem_offset_end); ge::Status AssignOrdinaryAtomicWorkspaceMemory(const ge::OpDescPtr &op_desc, - std::map> &workspace_info); + std::map> &workspace_info, + std::vector &mem_offset_end); ge::Status AssignFusionAtomicWorkspaceMemory(const ge::OpDescPtr &op_desc, - std::map> &workspace_info); + std::map> &workspace_info, + std::vector &mem_offset_end); + + ge::Status AssignAtomicOutputAndWorkspaceMemory(const ge::NodePtr &node, std::vector &mem_offset_end); + ge::Status AssignConnectNetOutputAtomicMemory(vector &connect_netoutput_nodes); + + ge::Status SetIndependentAtomicAttr(const ge::NodePtr &node, int64_t atomic_mem_start, + const std::vector &mem_offset_end); /// /// @brief set loop graph atomic attr - /// @param node + /// @param node, atomic memory assignment start offset /// @param atomic_mem_start: atomic op memory start address /// ge::Status SetLoopGraphAtomicAttr(const ge::NodePtr &node, int64_t atomic_mem_start); - ge::Status SetAtomicCleanAttr(const ge::NodePtr &n, int64_t atomic_mem_start, int64_t atomic_mem_size); + ge::Status SetAtomicCleanAttr(const ge::NodePtr &n, const std::vector &atomic_mem_start, + const std::vector &atomic_mem_size); void AlignMemOffset(const int64_t &mem_align_size); diff --git a/src/ge/graph/build/memory/var_mem_assign_util.cc b/src/ge/graph/build/memory/var_mem_assign_util.cc index 111adc7a..a352cf65 100644 --- a/src/ge/graph/build/memory/var_mem_assign_util.cc +++ b/src/ge/graph/build/memory/var_mem_assign_util.cc @@ -299,21 +299,33 @@ Status VarMemAssignUtil::SetOutTransNodeToAssign(const ge::NodePtr &node, const Status VarMemAssignUtil::AssignMemory2HasRefAttrNode(ge::ComputeGraphPtr &compute_graph) { for (const ge::NodePtr &n : compute_graph->GetAllNodes()) { string ref_var_src_var_name; - GE_CHECK_NOTNULL(n->GetOpDesc()); - bool is_ref = ge::AttrUtils::GetStr(n->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); - GE_IF_BOOL_EXEC(is_ref, - GE_CHK_STATUS_RET(AssignData2VarRef(n, ref_var_src_var_name, compute_graph->GetSessionID()))); + auto op_desc = n->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (uint32_t idx = 0; idx < op_desc->GetOutputsSize(); idx += 1) { + const auto out_desc = op_desc->MutableOutputDesc(idx); + if (ge::AttrUtils::GetStr(out_desc, REF_VAR_SRC_VAR_NAME, ref_var_src_var_name)) { + GE_CHK_STATUS_RET(AssignData2VarRef(n, ref_var_src_var_name, compute_graph->GetSessionID(), idx)); + } + } } return SUCCESS; } Status VarMemAssignUtil::AssignData2VarRef(const ge::NodePtr &has_ref_attr_node, const string &src_var_name, - uint64_t session_id) { - if (!TransOpUtil::IsTransOp(has_ref_attr_node)) { - return SUCCESS; - } + uint64_t session_id, uint32_t out_index) { // Get ref_var_src_var address - ge::NodePtr var_ref_src_var = has_ref_attr_node->GetOwnerComputeGraph()->FindNode(src_var_name); + auto root_graph = GraphUtils::FindRootGraph(has_ref_attr_node->GetOwnerComputeGraph()); + GE_CHECK_NOTNULL(root_graph); + ge::NodePtr var_ref_src_var = root_graph->FindNode(src_var_name); + if (var_ref_src_var == nullptr) { + for (auto sub_graph : root_graph->GetAllSubgraphs()) { + auto node_ptr = sub_graph->FindNode(src_var_name); + if (node_ptr != nullptr) { + var_ref_src_var = node_ptr; + break; + } + } + } GE_IF_BOOL_EXEC(var_ref_src_var == nullptr || var_ref_src_var->GetOpDesc() == nullptr, return FAILED); GeTensorDesc src_tensor_desc = var_ref_src_var->GetOpDesc()->GetOutputDesc(0); uint8_t *dev_ptr = nullptr; @@ -322,14 +334,8 @@ Status VarMemAssignUtil::AssignData2VarRef(const ge::NodePtr &has_ref_attr_node, vector ref_attr_node_output_list = has_ref_attr_node->GetOpDesc()->GetOutputOffset(); GE_CHECK_SIZE(ref_attr_node_output_list.size()); - int out_index = 0; - bool is_get = ge::AttrUtils::GetInt(var_ref_src_var->GetOpDesc(), REF_VAR_PRE_PEER_OUT_INDEX, out_index); - if (!is_get) { - GELOGI("%s failed to get attr [REF_VAR_PRE_PEER_OUT_INDEX]", var_ref_src_var->GetName().c_str()); - } - - GE_CHK_BOOL_RET_STATUS(static_cast(out_index) < ref_attr_node_output_list.size(), FAILED, - "out_index %d >= ref_attr_node_output_list.size() %zu", out_index, + GE_CHK_BOOL_RET_STATUS(out_index < ref_attr_node_output_list.size(), FAILED, + "out_index %u >= ref_attr_node_output_list.size() %zu", out_index, ref_attr_node_output_list.size()); ref_attr_node_output_list[out_index] = static_cast(reinterpret_cast(dev_ptr)); diff --git a/src/ge/graph/build/memory/var_mem_assign_util.h b/src/ge/graph/build/memory/var_mem_assign_util.h index 036fed06..cb38af29 100644 --- a/src/ge/graph/build/memory/var_mem_assign_util.h +++ b/src/ge/graph/build/memory/var_mem_assign_util.h @@ -46,8 +46,8 @@ class VarMemAssignUtil { static Status DealTransNode(const ge::NodePtr &final_trans_node); static Status DealExportTransNode(const ge::NodePtr &node, const ge::NodePtr &final_trans_node); - static Status AssignData2VarRef(const ge::NodePtr &variable_ref, const std::string &src_var_name, - uint64_t session_id); + static Status AssignData2VarRef(const ge::NodePtr &variable_ref, const std::string &src_var_name, uint64_t session_id, + uint32_t out_index); static Status SetOutTransNodeToAssign(const ge::NodePtr &node, const ge::NodePtr &final_trans_node, size_t index); }; diff --git a/src/ge/graph/build/model_builder.cc b/src/ge/graph/build/model_builder.cc index 62abd4ab..5435eb7b 100644 --- a/src/ge/graph/build/model_builder.cc +++ b/src/ge/graph/build/model_builder.cc @@ -15,10 +15,10 @@ */ #include "graph/build/model_builder.h" +#include #include #include #include -#include #include "common/ge/ge_util.h" #include "framework/common/debug/ge_log.h" #include "graph/anchor.h" @@ -27,6 +27,7 @@ #include "graph/build/label_allocator.h" #include "graph/build/stream_allocator.h" #include "graph/common/omg_util.h" +#include "graph/common/ge_call_wrapper.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_attr_value.h" #include "graph/ge_context.h" @@ -41,10 +42,12 @@ #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" +#include "graph/passes/memcpy_addr_async_pass.h" #include "init/gelib.h" #include "memory/memory_assigner.h" #include "omg/version.h" #include "register/op_registry.h" +#include "graph/passes/set_input_output_offset_pass.h" using std::map; using std::set; @@ -85,9 +88,11 @@ bool IsGeLocalOp(const ge::ConstOpDescPtr &op_desc) { } // namespace namespace ge { -ModelBuilder::ModelBuilder(ge::ComputeGraphPtr compute_graph, const Graph2SubGraphInfoList &subgraphs, - const map &stream_max_parallel_num, bool hcom_parallel, int mode) - : mem_offset_(0), +ModelBuilder::ModelBuilder(uint64_t session_id, ge::ComputeGraphPtr compute_graph, + const Graph2SubGraphInfoList &subgraphs, const map &stream_max_parallel_num, + bool hcom_parallel, int mode) + : session_id_(session_id), + mem_offset_(0), weight_offset_(kWeightsStartOffset), compute_graph_(std::move(compute_graph)), subgraphs_(subgraphs), @@ -242,7 +247,7 @@ Status ModelBuilder::SetInputOutputDesc() { Status ret; GELOGI("Start to SetInputOutputDesc."); - for (const ge::NodePtr &n : compute_graph_->GetAllNodes()) { + for (const ge::NodePtr &n : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto node_op_desc = n->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); @@ -291,7 +296,7 @@ Status ModelBuilder::SetInputOutputDesc() { } void ModelBuilder::AddNodeInputProperty() { - for (const ge::NodePtr &node : compute_graph_->GetAllNodes()) { + for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto node_op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return ); vector src_name_list; @@ -318,7 +323,7 @@ void ModelBuilder::AddNodeInputProperty() { node_op_desc->SetSrcIndex(src_index_list); } - for (const ge::NodePtr &node : compute_graph_->GetAllNodes()) { + for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto node_op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, GELOGW("node_op_desc is nullptr!"); return ); GE_IF_BOOL_EXEC(node_op_desc->GetType() == NETOUTPUT, continue); @@ -356,7 +361,7 @@ void ModelBuilder::AddNodeInputProperty() { Status ModelBuilder::AdjustInputTensorFlag() { GELOGI("Start to AdjustInputTensorFlag."); - for (const ge::NodePtr &n : compute_graph_->GetAllNodes()) { + for (const ge::NodePtr &n : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { if ((n->GetType() == DATA_TYPE) || (n->GetType() == AIPP_DATA_TYPE)) { GELOGD("Data node: %s.", n->GetName().c_str()); for (const auto &anchor : n->GetAllOutDataAnchors()) { @@ -432,6 +437,21 @@ Status ModelBuilder::BuildModelDef(ge::Model &model) { GE_CHK_BOOL_EXEC(ge::AttrUtils::SetBool(&model, ATTR_NAME_SWITCH_FOR_L1_FUSION, is_l1_fusion_enable_), GELOGE(FAILED, "SetBool of ATTR_NAME_SWITCH_FOR_L1_FUSION failed."); return FAILED); + const DumpProperties &dump_properties = PropertiesManager::Instance().GetDumpProperties(session_id_); + bool is_op_debug = dump_properties.IsOpDebugOpen(); + GELOGI("Get op debug:%d", is_op_debug); + if (is_op_debug) { + if (!ge::AttrUtils::SetBool(&model, ATTR_OP_DEBUG_FLAG, is_op_debug)) { + GELOGE(FAILED, "SetBool of ATTR_OP_DEBUG_FLAG failed."); + return FAILED; + } + uint32_t op_debug_mode = dump_properties.GetOpDebugMode(); + GELOGI("Get op debug mode:%d", op_debug_mode); + if (!ge::AttrUtils::SetInt(&model, ATTR_OP_DEBUG_MODE, op_debug_mode)) { + GELOGE(FAILED, "SetBool of ATTR_OP_DEBUG_MODE failed."); + return FAILED; + } + } model.SetName(compute_graph_->GetName()); model.SetGraph(ge::GraphUtils::CreateGraphFromComputeGraph(compute_graph_)); @@ -448,7 +468,7 @@ Status ModelBuilder::BuildModelDef(ge::Model &model) { } void ModelBuilder::ClearOriginalFormat() { - for (const ge::NodePtr &n : compute_graph_->GetAllNodes()) { + for (const ge::NodePtr &n : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto node_op_desc = n->GetOpDesc(); if (node_op_desc != nullptr) { if (node_op_desc->HasAttr(ATTR_NAME_FORMAT)) { @@ -487,7 +507,7 @@ Status ModelBuilder::MergeWeights() { weight_buffer_ = buffer; auto base_addr = weight_buffer_.GetData(); - for (const ge::NodePtr &node : compute_graph_->GetAllNodes()) { + for (const ge::NodePtr &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(op_desc == nullptr, continue); if (node->GetType() != CONSTANT) { @@ -514,7 +534,7 @@ Status ModelBuilder::MergeWeights() { auto weight_data = weight->MutableData(); // copy const op weight data to buffer - GELOGI("Move weight data to buffer, name: %s offset: %ld", node->GetName().c_str(), offset); + GELOGI("Move to buffer, name: %s offset: %ld size: %zu", node->GetName().c_str(), offset, weight_data.size()); ge::TensorUtils::SetWeightSize(weight->MutableTensorDesc(), static_cast(weight_data.size())); if ((offset == 0) || (weight_data.size() == 0)) { GELOGI("Size or offset is 0. size: %lu offset: %ld", weight_data.size(), offset); @@ -527,8 +547,8 @@ Status ModelBuilder::MergeWeights() { weight_data.size()); return FAILED; } - uintptr_t dst_ptr = (uintptr_t)base_addr + offset; - uintptr_t src_ptr = (uintptr_t)weight_data.data(); + uintptr_t dst_ptr = reinterpret_cast(base_addr) + offset; + uintptr_t src_ptr = reinterpret_cast(weight_data.data()); size_t left_size = weight_data.size(); while (left_size > SECUREC_MEM_MAX_LEN) { auto err = memcpy_s(reinterpret_cast(dst_ptr), SECUREC_MEM_MAX_LEN, reinterpret_cast(src_ptr), @@ -565,7 +585,7 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { // Add TBE Kernels std::set name_set; - for (const ge::NodePtr &n : compute_graph_->GetAllNodes()) { + for (const ge::NodePtr &n : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto node_op_desc = n->GetOpDesc(); GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); @@ -650,16 +670,40 @@ Status ModelBuilder::BuildModelForGetTask(ge::Model &model) { GE_CHK_STATUS_RET(label_allocator.AssignFunctionalLabels(label_num_), "Assign label failed."); GE_TIMESTAMP_END(AssignFunctionalLabels, "ModelBuilder::AssignFunctionalLabels"); + // Add memcpy_addr_async node. + rtFeatureType_t feature_type = FEATURE_TYPE_MEMCPY; + int32_t feature_info = MEMCPY_INFO_SUPPORT_ZEROCOPY; + int64_t value = 0; + rtError_t rt_ret = rtGetRtCapability(feature_type, feature_info, &value); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtGetRtCapability failed."); + return RT_FAILED; + } else { + if (value == RT_CAPABILITY_SUPPORT) { + GE_TIMESTAMP_START(AddMemcpyAddrAsyncNode); + MemcpyAddrAsyncPass memcpy_addr; + GE_CHK_STATUS_RET(memcpy_addr.Run(compute_graph_), "Add memcpy_addr_async node failed."); + GE_TIMESTAMP_END(AddMemcpyAddrAsyncNode, "MemcpyAddrAsyncPass::Run."); + } else { + GELOGW("rtGetRtCapability not support memcpy_addr_async."); + } + } + GE_TIMESTAMP_START(AssignMemory); MemoryAssigner mem_assigner(compute_graph_); GE_CHK_STATUS_RET(mem_assigner.AssignMemory(is_loop_graph_, mem_offset_, zero_copy_mem_size_), "Assign Memory Failed!"); GE_TIMESTAMP_END(AssignMemory, "GraphBuilder::AssignMemory"); + GE_TIMESTAMP_START(SetInputOutputOffset); + SetInputOutputOffsetPass input_output_offset; + GE_CHK_STATUS_RET(input_output_offset.Run(compute_graph_), "Set input output offset failed."); + GE_TIMESTAMP_END(SetInputOutputOffset, "SetInputOutputOffsetPass::Run."); + // Compile single op in graph build stage GE_TIMESTAMP_START(CompileSingleOp); GE_CHK_STATUS_RET(CompileSingleOp(), "ATC builder CompileSingleOp() return fail."); - GE_TIMESTAMP_END(CompileSingleOp, "GraphBuilder::CompileSingleOp"); + GE_TIMESTAMP_EVENT_END(CompileSingleOp, "GraphBuilder::CompileSingleOp"); // Refresh real streams and insert event nodes. GE_TIMESTAMP_START(RefreshRealStream); @@ -700,7 +744,7 @@ Status ModelBuilder::CompileSingleOp() { GE_TIMESTAMP_CALLNUM_START(BatchCompileOp); std::unordered_map> node_vector_map; - for (auto &node : compute_graph_->GetAllNodes()) { + for (auto &node : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { auto op_desc = node->GetOpDesc(); if (op_desc == nullptr) { continue; @@ -737,7 +781,7 @@ Status ModelBuilder::CompileSingleOp() { GE_CHECK_NOTNULL(kernel_info); GE_TIMESTAMP_RESTART(BatchCompileOp); auto ret = kernel_info->CompileOp(node_vector); - GEEVENT("[GEPERFTRACE] The node size of compile op of %s is %zu", kernel_lib_name.c_str(), node_vector.size()); + GELOGI("[GEPERFTRACE] The node size of compile op of %s is %zu", kernel_lib_name.c_str(), node_vector.size()); GE_TIMESTAMP_ADD(BatchCompileOp); if (ret != ge::SUCCESS) { GELOGE(ret, "Compile op failed, kernel lib name is %s", kernel_lib_name.c_str()); diff --git a/src/ge/graph/build/model_builder.h b/src/ge/graph/build/model_builder.h index 21e611ee..86b34c6d 100644 --- a/src/ge/graph/build/model_builder.h +++ b/src/ge/graph/build/model_builder.h @@ -37,7 +37,7 @@ namespace ge { class ModelBuilder { public: - ModelBuilder(ge::ComputeGraphPtr whole_graph, const Graph2SubGraphInfoList &subgraphs, + ModelBuilder(uint64_t session_id, ge::ComputeGraphPtr whole_graph, const Graph2SubGraphInfoList &subgraphs, const std::map &stream_max_parallel_num, bool hcom_parallel, int mode = static_cast(domi::BuildMode::GEN_TASK_WITHOUT_FUSION)); @@ -82,6 +82,8 @@ class ModelBuilder { Status CompileSingleOp(); + uint64_t session_id_; + size_t mem_offset_; size_t weight_offset_; diff --git a/src/ge/graph/build/run_context.cc b/src/ge/graph/build/run_context.cc index f2a41271..cece31ea 100644 --- a/src/ge/graph/build/run_context.cc +++ b/src/ge/graph/build/run_context.cc @@ -173,5 +173,4 @@ Status RunContextUtil::CreateRunContext(Model &model, const ComputeGraphPtr &gra } RunContext &RunContextUtil::GetRunContext() { return run_context_; } - } // namespace ge diff --git a/src/ge/graph/build/stream_allocator.cc b/src/ge/graph/build/stream_allocator.cc index f6323434..5c82f461 100644 --- a/src/ge/graph/build/stream_allocator.cc +++ b/src/ge/graph/build/stream_allocator.cc @@ -146,12 +146,6 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu return status; } - status = AddActiveEntryStream(); - if (status != SUCCESS) { - GELOGE(status, "AddActiveEntryStream failed!"); - return status; - } - status = RefreshContinuousEvents(); if (status != SUCCESS) { GELOGE(status, "RefreshContinuousEvents failed!"); @@ -167,7 +161,7 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu DumpEvents(); GE_DUMP(whole_graph_, "AfterRefreshRealStream"); - for (const NodePtr &node : whole_graph_->GetAllNodes()) { + for (const NodePtr &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(node->GetOpDesc()); auto stream_id = node->GetOpDesc()->GetStreamId(); if (stream_id == kInvalidStream) { @@ -199,7 +193,7 @@ Status StreamAllocator::AssignSingleStream() { } int64_t task_count = 0; - for (const NodePtr &node : whole_graph_->GetAllNodes()) { + for (const NodePtr &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { string op_type = node->GetType(); if (IsHcclOp(op_type)) { task_count += kTaskNumPerHcclNode; @@ -236,7 +230,7 @@ Status StreamAllocator::AssignSingleStream() { } Status StreamAllocator::SetActiveStreamsByLabel() { - for (const auto &node : whole_graph_->GetAllNodes()) { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); string stream_label; @@ -248,7 +242,7 @@ Status StreamAllocator::SetActiveStreamsByLabel() { } } - for (const auto &node : whole_graph_->GetAllNodes()) { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(node->GetOpDesc()); vector activated_label_list; if (!AttrUtils::GetListStr(node->GetOpDesc(), ATTR_NAME_ACTIVE_LABEL_LIST, activated_label_list) || @@ -326,7 +320,7 @@ Status StreamAllocator::SetActiveStreamsForSubgraphs() { // Insert the send/recv event id to the graph Status StreamAllocator::InsertSyncEvents() { - for (const auto &cur_node : whole_graph_->GetAllNodes()) { + for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { // Take the adjacent points, then judge whether need to insert the event for (const OutDataAnchorPtr &anchor : cur_node->GetAllOutDataAnchors()) { for (const InDataAnchorPtr &peer_in_anchor : anchor->GetPeerInDataAnchors()) { @@ -380,6 +374,11 @@ Status StreamAllocator::InsertOneEventInTwoNodes(const NodePtr &cur_node, const return SUCCESS; } + if ((cur_node->GetType() == ENTER) || (cur_node->GetType() == REFENTER)) { + GELOGD("No need to insert event after enter_node %s.", cur_node->GetName().c_str()); + return SUCCESS; + } + if (next_stream_id == kInvalidStream) { GELOGE(FAILED, "Stream id of next_node %s should not be %ld", next_node->GetName().c_str(), kInvalidStream); return FAILED; @@ -446,7 +445,7 @@ Status StreamAllocator::InsertEventsForSubgraph() { Status StreamAllocator::OptimizeSyncEvents() { map> stream_nodes; - for (const auto &node : whole_graph_->GetAllNodes()) { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(node->GetOpDesc()); int64_t stream_id = node->GetOpDesc()->GetStreamId(); stream_nodes[stream_id].emplace_back(node); @@ -613,6 +612,33 @@ bool StreamAllocator::IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr AttrUtils::HasAttr(activate_stream_node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE)) { return false; } + + /// + /// stream_0 --> stream_2 --> stream_3 --> stream_4 + /// /\ | + /// | \/ + /// | stream_1 --> stream_5 --> stream_6 --> stream_7 + /// | /\ | | + /// | | \/ | + /// | |---------- stream_8 | + /// | | + /// |-----------------------------------------------------------| + /// + /// Exit1(S7) Exit2(S7) Exit3(S7) + /// \ / | + /// AddN(S1) NextIteration(S7) + /// | | + /// NextIteration(S1) / + /// | / + /// | / + /// StreamActive(S7) + /// + /// Event between Exit1/Exit2 and AddN should not be optimized + /// + if (IsActiveAfterNextIteration(activate_stream_node)) { + continue; + } + visited_nodes.insert(activate_stream_node); // nodes in stream link to streamActivate no need to add event before activated node for (const auto &pre_activate_stream_node : activate_stream_node->GetInNodes()) { @@ -640,6 +666,18 @@ bool StreamAllocator::IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr return false; } +bool StreamAllocator::IsActiveAfterNextIteration(const NodePtr &active_node_ptr) const { + if ((active_node_ptr == nullptr) || active_node_ptr->GetInControlNodes().empty()) { + return false; + } + for (const auto &in_node : active_node_ptr->GetInControlNodes()) { + if ((in_node->GetType() != NEXTITERATION) && (in_node->GetType() != REFNEXTITERATION)) { + return false; + } + } + return true; +} + // Split the stream according to the maximum number of nodes in the stream. Status StreamAllocator::SplitStreams(vector> &split_streams) { if (enable_single_stream_ || stream_num_ == 0) { @@ -671,7 +709,7 @@ Status StreamAllocator::SplitStreams(vector> &split_streams) { GE_CHK_STATUS_RET(GetMaxStreamAndTask(false, max_stream_count, max_task_count), "Get max stream and task count failed."); - for (const auto &cur_node : whole_graph_->GetAllNodes()) { + for (const auto &cur_node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(cur_node); auto op_desc = cur_node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -774,42 +812,23 @@ bool StreamAllocator::NeedSpiltNewStream(int64_t stream_node_num, int64_t max_no Status StreamAllocator::UpdateActiveStreams(const vector> &split_streams) { UpdateLabelStreams(split_streams); - for (auto &node : whole_graph_->GetAllNodes()) { + for (auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { if ((node->GetType() == STREAMSWITCH) || (node->GetType() == STREAMSWITCHN)) { - if (InsertActiveNodesAfterSwitch(node) != SUCCESS) { - GELOGE(FAILED, "Insert active nodes after switch node failed."); + if (UpdateActiveStreamsForSwitchNode(node) != SUCCESS) { + GELOGE(FAILED, "Update active streams for switch node: %s failed.", node->GetName().c_str()); return FAILED; } } else { - vector active_streams; - GE_CHECK_NOTNULL(node->GetOpDesc()); - if (AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { - vector new_active_streams = active_streams; - for (const uint32_t logical_stream : active_streams) { - if (static_cast(logical_stream) >= split_streams.size()) { - GELOGE(FAILED, "logical stream is out of range."); - return FAILED; - } - const set &new_split_streams = split_streams[logical_stream]; - if (!new_split_streams.empty()) { - for (int64_t split_stream : new_split_streams) { - new_active_streams.emplace_back(static_cast(split_stream)); - GELOGI("Add stream %ld to active_stream_list of node %s of graph %s", split_stream, - node->GetName().c_str(), node->GetOwnerComputeGraph()->GetName().c_str()); - } - } - } - if (!AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, new_active_streams)) { - GELOGE(FAILED, "Set active streams for node %s failed.", node->GetName().c_str()); - return FAILED; - } + if (UpdateActiveStreamsForActiveNode(split_streams, node) != SUCCESS) { + GELOGE(FAILED, "Update active streams for active node: %s failed.", node->GetName().c_str()); + return FAILED; } } } Status status = UpdateActiveStreamsForSubgraphs(); if (status != SUCCESS) { - GELOGE(status, "SetActiveStreamsForSubgraph failed!"); + GELOGE(status, "Update active streams for subgraphs failed!"); return status; } @@ -840,7 +859,7 @@ void StreamAllocator::UpdateLabelStreams(const vector> &split_strea } } -Status StreamAllocator::InsertActiveNodesAfterSwitch(NodePtr &switch_node) { +Status StreamAllocator::UpdateActiveStreamsForSwitchNode(NodePtr &switch_node) { vector active_nodes; if (InsertActiveNodesAfterSwitch(switch_node, active_nodes) != SUCCESS) { GELOGE(FAILED, "Insert active nodes after node %s failed.", switch_node->GetName().c_str()); @@ -906,6 +925,38 @@ Status StreamAllocator::InsertActiveNodesAfterSwitch(NodePtr &switch_node, vecto return SUCCESS; } +Status StreamAllocator::UpdateActiveStreamsForActiveNode(const vector> &split_streams, NodePtr &node) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + vector active_streams; + if (AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { + vector new_active_streams = active_streams; + for (uint32_t logical_stream : active_streams) { + if (static_cast(logical_stream) >= split_streams.size()) { + GELOGE(FAILED, "logical stream is out of range."); + return FAILED; + } + const set &new_split_streams = split_streams[logical_stream]; + for (int64_t split_stream : new_split_streams) { + for (const auto &node_stream : node_split_stream_map_) { + if (split_stream == node_stream.second) { + if (node_stream.first->GetOwnerComputeGraph() == node->GetOwnerComputeGraph()) { + new_active_streams.emplace_back(static_cast(split_stream)); + GELOGI("Add stream %ld to active_stream_list of node %s of graph %s", split_stream, + node->GetName().c_str(), node->GetOwnerComputeGraph()->GetName().c_str()); + } + break; + } + } + } + } + if (!AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, new_active_streams)) { + GELOGE(FAILED, "Set active streams for node %s failed.", node->GetName().c_str()); + return FAILED; + } + } + return SUCCESS; +} + Status StreamAllocator::UpdateActiveStreamsForSubgraphs() const { // Update active stream list for active nodes for (auto &node_stream_pair : node_split_stream_map_) { @@ -926,14 +977,19 @@ Status StreamAllocator::UpdateActiveStreamsForSubgraphs() const { } const auto &active_node = it->second; GE_CHECK_NOTNULL(active_node); - auto op_desc = active_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); + auto active_op = active_node->GetOpDesc(); + GE_CHECK_NOTNULL(active_op); vector active_streams; - (void)AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams); + (void)AttrUtils::GetListInt(active_op, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams); set new_active_streams(active_streams.begin(), active_streams.end()); - new_active_streams.emplace(static_cast(node_stream_pair.second)); + // specific_activated_streams_ has already contained new split activated stream + int64_t new_split_stream = node_stream_pair.second; + if (IsActivated(new_split_stream)) { + continue; + } + new_active_streams.emplace(static_cast(new_split_stream)); active_streams.assign(new_active_streams.begin(), new_active_streams.end()); - if (!AttrUtils::SetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { + if (!AttrUtils::SetListInt(active_op, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { GELOGE(FAILED, "Set active streams for node %s failed.", active_node->GetName().c_str()); return FAILED; } @@ -942,6 +998,20 @@ Status StreamAllocator::UpdateActiveStreamsForSubgraphs() const { return SUCCESS; } +bool StreamAllocator::IsActivated(int64_t stream_id) const { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { + auto op_desc = node->GetOpDesc(); + vector active_streams; + if (op_desc == nullptr || !AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { + continue; + } + if (std::find(active_streams.begin(), active_streams.end(), stream_id) != active_streams.end()) { + return true; + } + } + return false; +} + Status StreamAllocator::SetActiveStreamsForLoop() { vector loop_active_streams; for (int64_t stream_id = 0; stream_id < stream_num_; stream_id++) { @@ -950,7 +1020,7 @@ Status StreamAllocator::SetActiveStreamsForLoop() { } } // Set the stream that needs to be activated - for (const auto &node : whole_graph_->GetAllNodes()) { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(node->GetOpDesc()); bool is_loop_active = false; if (AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_IS_LOOP_ACTIVE, is_loop_active) && is_loop_active) { @@ -973,7 +1043,7 @@ Status StreamAllocator::SetActiveStreamsForLoop() { } Status StreamAllocator::CheckStreamActived() const { - for (const auto &node : whole_graph_->GetAllNodes()) { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_CHECK_NOTNULL(node->GetOpDesc()); vector active_streams; if (AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { @@ -989,108 +1059,6 @@ Status StreamAllocator::CheckStreamActived() const { return SUCCESS; } -// Add active entry stream for special env. -Status StreamAllocator::AddActiveEntryStream() { - auto gelib = GELib::GetInstance(); - bool head_stream = (gelib == nullptr) ? false : gelib->HeadStream(); - GELOGI("Configured head stream: %u", head_stream); - if (!head_stream) { - return SUCCESS; - } - - // Collect streams active by StreamSwitch/StreamActive node. - std::set deactive_stream; - for (ge::NodePtr &node : whole_graph_->GetAllNodes()) { - GE_CHECK_NOTNULL(node->GetOpDesc()); - Status ret = CollectDeactiveStream(node->GetOpDesc(), deactive_stream); - if (ret != SUCCESS) { - return ret; - } - } - - // Collect default active stream, Add to active entry stream. - std::vector active_stream_list; - for (int64_t stream_id = 0; stream_id < stream_num_; ++stream_id) { - if (deactive_stream.count(stream_id) == 0) { - active_stream_list.push_back(stream_id); - } - } - - int64_t new_stream_id = stream_num_; - stream_num_++; - return InsertActiveEntryStream(active_stream_list, new_stream_id); -} - -// Collect deactive stream from flowctrl op. -Status StreamAllocator::CollectDeactiveStream(const OpDescPtr &op_desc, std::set &deactive_streams) const { - GE_CHECK_NOTNULL(op_desc); - std::string op_type = op_desc->GetType(); - if (op_type == STREAMSWITCH) { - std::vector active_stream_list; - // If GetListInt fail, active_stream_list is empty. - (void)ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list); - if (active_stream_list.size() != kMaxSwitchStreamNum) { - GELOGE(INTERNAL_ERROR, "Stream num of switch true branch must be %u.", kMaxSwitchStreamNum); - return INTERNAL_ERROR; - } - - deactive_streams.insert(active_stream_list[0]); - GELOGI("Flowctrl_op node:%s, flowctrl stream id:%u.", op_desc->GetName().c_str(), active_stream_list[0]); - } else if (op_type == STREAMACTIVE) { - if (op_desc->HasAttr(ATTR_NAME_SWITCH_BRANCH_NODE_LABEL)) { - std::vector active_stream_list; - if (!AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_stream_list)) { - GELOGE(INTERNAL_ERROR, "StreamActiveOp get attr ACTIVE_STREAM fail."); - return INTERNAL_ERROR; - } - - for (uint32_t deactive_stream : active_stream_list) { - deactive_streams.insert(deactive_stream); - GELOGI("Flowctrl_op node:%s, flowctrl stream id:%u.", op_desc->GetName().c_str(), deactive_stream); - } - } - } - - return SUCCESS; -} - -// Insert StreamActive Op for Entry Stream. -Status StreamAllocator::InsertActiveEntryStream(const std::vector &active_streams, int64_t stream_id) { - string node_name = whole_graph_->GetName() + "_ActiveEntryStream_" + string(STREAMACTIVE); - OpDescPtr op_desc = ge::MakeShared(node_name, STREAMACTIVE); - if (op_desc == nullptr) { - GELOGE(FAILED, "Failed to new opdesc."); - return FAILED; - } - GELOGI("Create StreamActive op:%s.", op_desc->GetName().c_str()); - - GE_CHK_BOOL_EXEC( - AttrUtils::SetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, std::move(std::vector())), - GELOGE(FAILED, "SetListStr failed."); - return FAILED); - - NodePtr active_node = whole_graph_->AddNodeFront(op_desc); - GE_IF_BOOL_EXEC(active_node == nullptr, - GELOGE(FAILED, "Create StreamActive op: %s failed.", op_desc->GetName().c_str()); - return INTERNAL_ERROR); - GE_CHECK_NOTNULL(active_node->GetOpDesc()); - // Add one stream for ActiveEntryStream Task. - active_node->GetOpDesc()->SetStreamId(stream_id); - - GE_CHK_BOOL_EXEC(AttrUtils::SetBool(op_desc, "is_aicpu_stream", true), GELOGE(FAILED, "SetBool failed."); - return FAILED); - GE_CHK_BOOL_EXEC(AttrUtils::SetListInt(active_node->GetOpDesc(), ATTR_NAME_ACTIVE_STREAM_LIST, active_streams), - GELOGE(FAILED, "SetListInt failed."); - return FAILED); - - std::vector group_names; - GE_CHK_BOOL_EXEC(AttrUtils::SetListStr(active_node->GetOpDesc(), ATTR_NAME_SWITCH_BRANCH_NODE_LABEL, group_names), - GELOGE(FAILED, "SetLisStr failed."); - return FAILED); - - return SUCCESS; -} - // Refresh events to continuous events Status StreamAllocator::RefreshContinuousEvents() { // Establish a mapping relationship from old to new event id @@ -1136,7 +1104,7 @@ Status StreamAllocator::RefreshContinuousEvents() { // Insert the real send/recv node in the graph Status StreamAllocator::InsertSyncEventNodes() { - for (const auto &node : whole_graph_->GetAllNodes()) { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { // Add the node corresponding to the recv event vector recv_event_id_list; GetRecvEventIdList(node, recv_event_id_list); @@ -1223,7 +1191,7 @@ Status StreamAllocator::ReorderEventNodes() const { void StreamAllocator::DumpEvents() { map> after_refresh_stream_nodes; - for (const auto &node : whole_graph_->GetAllNodes()) { + for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); int64_t stream_id = node->GetOpDesc()->GetStreamId(); after_refresh_stream_nodes[stream_id].emplace_back(node); diff --git a/src/ge/graph/build/stream_allocator.h b/src/ge/graph/build/stream_allocator.h index ae79430a..a5326a39 100644 --- a/src/ge/graph/build/stream_allocator.h +++ b/src/ge/graph/build/stream_allocator.h @@ -55,22 +55,21 @@ class StreamAllocator { Status OptimizeByStreamActivate(); // Determine if the successor node of RecvNode is directly or indirectly activated by the SendNode precursor node bool IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr, const NodePtr &recv_node_ptr) const; + bool IsActiveAfterNextIteration(const NodePtr &active_node_ptr) const; Status SplitStreams(std::vector> &split_streams); bool NeedSpiltNewStream(int64_t stream_node_num, int64_t max_node_num_one_stream, const OpDescPtr &op_desc) const; - Status UpdateActiveStreams(const std::vector> &splited_streams); + Status UpdateActiveStreams(const std::vector> &split_streams); void UpdateLabelStreams(const std::vector> &split_streams); - Status InsertActiveNodesAfterSwitch(NodePtr &switch_node); + Status UpdateActiveStreamsForSwitchNode(NodePtr &switch_node); Status InsertActiveNodesAfterSwitch(NodePtr &switch_nodes, std::vector &switch_active_nodes); + Status UpdateActiveStreamsForActiveNode(const std::vector> &split_streams, NodePtr &node); Status UpdateActiveStreamsForSubgraphs() const; + bool IsActivated(int64_t stream_id) const; Status SetActiveStreamsForLoop(); Status CheckStreamActived() const; - Status AddActiveEntryStream(); - Status CollectDeactiveStream(const OpDescPtr &op_desc, std::set &deactive_streams) const; - Status InsertActiveEntryStream(const std::vector &active_streams, int64_t stream_id); - Status RefreshContinuousEvents(); Status InsertSyncEventNodes(); diff --git a/src/ge/graph/build/task_generator.cc b/src/ge/graph/build/task_generator.cc index 2ce4e89d..41a845a2 100644 --- a/src/ge/graph/build/task_generator.cc +++ b/src/ge/graph/build/task_generator.cc @@ -29,6 +29,7 @@ #include "graph/utils/node_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" +#include "graph/common/ge_call_wrapper.h" #include "init/gelib.h" using domi::LogTimeStampDef; @@ -47,7 +48,6 @@ const char *const kIsOutputVar = "OUTPUT_IS_VAR"; const char *const kProfilingMode = "PROFILING_MODE"; const char *const kProfilingFpPoint = "FP_POINT"; const char *const kProfilingBpPoint = "BP_POINT"; -const char *const kOffOptimize = "off_optimize"; const uint32_t kProfilingArStep = 2; const uint64_t kProfilingFpStartLogid = 1; const uint64_t kProfilingBpEndLogid = 2; @@ -75,21 +75,7 @@ Status TaskGenerator::GetTaskInfo(Model &model, ComputeGraphPtr &graph, uint64_t std::vector task_def_list; std::map op_name_map; GE_DUMP(graph, "GenerateTaskBefore"); - bool is_unknown_shape = false; - NodePtr parent_node = graph->GetParentNode(); - if (parent_node != nullptr) { - auto op_desc = parent_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - (void)AttrUtils::GetBool(op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape); - } - Status ret = SUCCESS; - if (is_unknown_shape) { - GELOGI("Beign to generate unknown shape task. Graph name is %s.", graph->GetName().c_str()); - ret = GenerateUnknownShapeTask(run_context, graph, task_def_list, op_name_map); - } else { - GELOGI("Beign to generate known shape task. Graph name is %s.", graph->GetName().c_str()); - ret = GenerateTask(run_context, graph, task_def_list, op_name_map); - } + Status ret = GenerateTask(run_context, graph, task_def_list, op_name_map); GE_DUMP(graph, "GenerateTaskAfter"); if (ret != SUCCESS) { @@ -109,7 +95,7 @@ Status TaskGenerator::GetTaskInfo(Model &model, ComputeGraphPtr &graph, uint64_t GELOGE(FAILED, "SetListStr failed."); return FAILED); - GELOGI("Generate task success, task_def_list.size:%zu, op_name_map.size:%zu", task_def_list.size(), + GELOGI("Call GenerateTask Success, task_def_list.size:%zu, op_name_map.size:%zu", task_def_list.size(), op_name_map.size()); // Init and serialize model_task_def @@ -131,7 +117,7 @@ Status TaskGenerator::GetTaskInfo(Model &model, ComputeGraphPtr &graph, uint64_t return ret; } - GELOGI("Get TaskInfo success. session id is %lu", session_id); + GELOGI("Get TaskInfo success. session_id=%lu", session_id); return SUCCESS; } @@ -198,7 +184,7 @@ Status TaskGenerator::UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t sessi Status TaskGenerator::SaveFusionNodes(map> &fusion_nodes, ComputeGraphPtr &graph) { std::map nodes_with_group_attr; - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); int64_t group_id = kInvalidGroupId; @@ -249,12 +235,13 @@ Status TaskGenerator::SaveFusionNodes(map> &fusion Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &graph, vector &task_def_list, map &op_name_map) { + GELOGD("Beign to generate task, graph name is %s.", graph->GetName().c_str()); std::shared_ptr ge_lib = GELib::GetInstance(); if ((ge_lib == nullptr) || !ge_lib->InitFlag()) { GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GenerateTask failed."); return GE_CLI_GE_NOT_INITIALIZED; } - GE_CHK_STATUS_RET(MarkNodeAndSetIndex(graph), "Mark node and set index failed."); + GE_CHK_STATUS_RET(MarkNodeAndSetIndex(graph), "MarkNodeAndSetIndex failed."); ProfilingPoint profiling_point; vector all_reduce_nodes; GE_CHK_STATUS_RET(FindProfilingTaskIndex(graph, profiling_point, all_reduce_nodes)); @@ -264,15 +251,21 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra GE_TIMESTAMP_CALLNUM_START(GenerateTask); // map store fusion nodes map> fusion_nodes; - string buffer_optimize = kOffOptimize; + string buffer_optimize = "off_optimize"; (void)ge::GetContext().GetOption(BUFFER_OPTIMIZE, buffer_optimize); - if (buffer_optimize != kOffOptimize) { + if (buffer_optimize != "off_optimize") { GE_CHK_STATUS_RET(SaveFusionNodes(fusion_nodes, graph)); } std::unordered_set fusion_nodes_seen; int64_t group_key; uint32_t node_index = 0; - for (auto &node : graph->GetAllNodes()) { + rtStream_t stream = nullptr; + bool is_unknown_shape = graph->GetGraphUnknownFlag(); + if (is_unknown_shape) { + GE_CHK_STATUS_RET(SetUnknownShapeStream(run_context, stream), "Set unknown shape stream failed."); + } + + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); node_index++; @@ -302,7 +295,6 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); continue; } - OpsKernelInfoStorePtr kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); if (kernel_info_store == nullptr) { GELOGE(INTERNAL_ERROR, "No ops kernel store found. node:%s(%s), op_kernel_lib_name=%s.", name.c_str(), @@ -311,18 +303,17 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra } GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "Call UpdateAnchorStatus node:%s(%s) failed", name.c_str(), type.c_str()); - int64_t op_id = op_desc->GetId(); - int64_t stream_id = op_desc->GetStreamId(); - if (stream_id < 0 || stream_id >= static_cast(run_context.graphStreamList.size())) { - GELOGE(INTERNAL_ERROR, "node[name:%s(%s), id:%ld] stream id is invalid, stream list size=%zu", name.c_str(), - type.c_str(), op_id, run_context.graphStreamList.size()); - return INTERNAL_ERROR; - } - // Profiling task size_t task_list_size_before = task_def_list.size(); GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); - run_context.stream = run_context.graphStreamList[stream_id]; + int64_t op_id = op_desc->GetId(); + // Compatible with dynamic shape scenes, the default is 0 + int64_t stream_id = 0; + if (!is_unknown_shape) { + stream_id = op_desc->GetStreamId(); + GE_CHK_STATUS_RET(SetKnownShapeStream(run_context, stream_id), "node[name:%s(%s), id:%ld] stream id is invalid.", + name.c_str(), type.c_str(), op_id); + } GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task.", op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id); GE_TIMESTAMP_RESTART(GenerateTask); @@ -355,131 +346,14 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra GE_CHECK_NOTNULL(task_def_ptr); task_def_ptr->set_ops_kernel_store_ptr(reinterpret_cast(ops_kernel_info_store_ptr)); } - GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task finished, generate %zu task(s).", op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, task_list_size_after - task_list_size_before); } - GE_TIMESTAMP_CALLNUM_END(GenerateTask, "GraphBuild::GenerateTask"); - return SUCCESS; -} - -Status TaskGenerator::GenerateUnknownShapeTask(RunContext &run_context, ComputeGraphPtr &graph, - vector &task_def_list, - map &op_name_map) { - std::shared_ptr ge_lib = GELib::GetInstance(); - if ((ge_lib == nullptr) || !ge_lib->InitFlag()) { - GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GenerateTask failed."); - return GE_CLI_GE_NOT_INITIALIZED; - } - GE_CHK_STATUS_RET(MarkNodeAndSetIndex(graph), "Mark node and set index failed."); - ProfilingPoint profiling_point; - vector all_reduce_nodes; - GE_CHK_STATUS_RET(FindProfilingTaskIndex(graph, profiling_point, all_reduce_nodes)); - - const OpsKernelManager &ops_kernel_manager = ge_lib->OpsKernelManagerObj(); - - GE_TIMESTAMP_CALLNUM_START(GenerateTask); - // map store fusion nodes - map> fusion_nodes; - string buffer_optimize = kOffOptimize; - (void)ge::GetContext().GetOption(BUFFER_OPTIMIZE, buffer_optimize); - if (buffer_optimize != kOffOptimize) { - GE_CHK_STATUS_RET(SaveFusionNodes(fusion_nodes, graph)); - } - std::unordered_set fusion_nodes_seen; - int64_t group_key; - uint32_t node_index = 0; - rtStream_t stream = nullptr; - GE_CHK_RT_RET(rtStreamCreate(&stream, 0)); - run_context.stream = stream; - if (rtModelBindStream(run_context.model, stream, 0) != RT_ERROR_NONE) { - GELOGE(FAILED, "Call rt api failed."); - GE_CHK_RT(rtStreamDestroy(stream)); - return FAILED; - } - for (auto &node : graph->GetAllNodes()) { - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - node_index++; - string name = node->GetName(); - string type = node->GetType(); - bool attr_notask = false; - bool get_attr_notask_flag = ge::AttrUtils::GetBool(op_desc, ATTR_NAME_NOTASK, attr_notask); - GE_IF_BOOL_EXEC(get_attr_notask_flag && attr_notask, - GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); - continue); - - GE_CHK_STATUS_RET(UpdateOpIsVarAttr(op_desc, graph->GetSessionID())); - string op_kernel_lib_name = op_desc->GetOpKernelLibName(); - // For fusion ddb pass, task def must be continuous. - // Part2: Call - auto fusion_task_info = - FusionTaskInfo{run_context, graph, node, op_desc, node_index, ge_lib, - ops_kernel_manager, task_def_list, op_name_map, profiling_point, all_reduce_nodes}; - GE_CHK_STATUS_RET(GenerateTaskForFusionNode(fusion_task_info, fusion_nodes, fusion_nodes_seen), - "Call GenerateTaskForFusionNode node:%s(%s) failed", name.c_str(), type.c_str()); - // continue directly - if (ge::AttrUtils::GetInt(op_desc, ATTR_NAME_FUSION_GROUP_KEY, group_key)) { - GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str()); - continue; - } - if (op_kernel_lib_name.empty()) { - GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); - continue; - } - OpsKernelInfoStorePtr kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); - if (kernel_info_store == nullptr) { - GELOGE(INTERNAL_ERROR, "No ops kernel store found. node:%s(%s), op_kernel_lib_name=%s.", name.c_str(), - type.c_str(), op_kernel_lib_name.c_str()); - return INTERNAL_ERROR; - } - GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "Call UpdateAnchorStatus node:%s(%s) failed", name.c_str(), - type.c_str()); - int64_t op_id = op_desc->GetId(); - int64_t stream_id = op_desc->GetStreamId(); - // Profiling task - size_t task_list_size_before = task_def_list.size(); - GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); - - GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task.", op_kernel_lib_name.c_str(), - name.c_str(), type.c_str(), op_id, stream_id); - GE_TIMESTAMP_RESTART(GenerateTask); - auto ret = kernel_info_store->GenerateTask(*node, run_context, task_def_list); - GE_TIMESTAMP_ADD(GenerateTask); - if (ret != SUCCESS) { - GELOGE(ret, "Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task failed.", - op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id); - return ret; - } - // Profiling task - GE_CHK_STATUS_RET(InsertProfilingTaskAfter(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); - size_t task_list_size_after = task_def_list.size(); - // If tasks is reduced - if (task_list_size_after < task_list_size_before) { - GELOGE(FAILED, "Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task. but task num from %zu to %zu.", - op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, task_list_size_before, - task_list_size_after); - return FAILED; - } - - // Reset stream id to ge stream id, as graph load must use ge stream to reassign stream - void *ops_kernel_info_store_ptr = kernel_info_store.get(); - for (size_t idx = task_list_size_before; idx < task_list_size_after; ++idx) { - op_name_map[idx] = name; - // Set opsKernelInfoStorePtr and op_index, the two fields be use in DistributeTask and InitTaskInfo - TaskDef *task_def_ptr = &task_def_list[idx]; - GE_CHECK_NOTNULL(task_def_ptr); - task_def_ptr->set_ops_kernel_store_ptr(reinterpret_cast(ops_kernel_info_store_ptr)); - } - - GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task finished, generate %zu task(s).", - op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, - task_list_size_after - task_list_size_before); + if (is_unknown_shape) { + GE_CHK_STATUS_RET(DestroyUnknownShapeStream(run_context, stream), "Destory unknown shape stream failed."); } - GE_CHK_RT(rtModelUnbindStream(run_context.model, stream)); - GE_CHK_RT(rtStreamDestroy(stream)); - GE_TIMESTAMP_CALLNUM_END(GenerateTask, "GraphBuild::GenerateTask"); + GE_TIMESTAMP_CALLNUM_EVENT_END(GenerateTask, "GraphBuild::GenerateTask"); return SUCCESS; } @@ -628,7 +502,11 @@ Status TaskGenerator::MarkNodeAndSetIndex(ComputeGraphPtr &graph) { return GE_CLI_GE_NOT_INITIALIZED; } - const auto all_nodes = graph->GetAllNodes(); + const auto all_nodes = graph->GetNodes(graph->GetGraphUnknownFlag()); + if (all_nodes.empty()) { + GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "Graph's node is empty"); + return GE_GRAPH_GRAPH_NODE_NULL; + } int64_t node_index = 0; for (auto &node : all_nodes) { @@ -715,7 +593,7 @@ Status TaskGenerator::AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingP OpDescPtr fp_op_desc = nullptr; uint32_t current_idx = 0; uint32_t first_fp = 0; - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); string op_kernel_lib_name = op_desc->GetOpKernelLibName(); @@ -742,7 +620,7 @@ Status TaskGenerator::AutoFindFpOpIndex(const ComputeGraphPtr &graph, ProfilingP return SUCCESS; } GELOGI("Find fp_op_desc is %s, id is %ld", fp_op_desc->GetName().c_str(), fp_op_desc->GetId()); - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); current_idx++; @@ -763,7 +641,7 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP uint32_t last_bp = 0; uint32_t iter_end = 0; uint32_t current_idx = 0; - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); current_idx++; @@ -807,7 +685,7 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP GE_CHECK_NOTNULL(bp_op_desc); current_idx = 0; - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); current_idx++; @@ -826,7 +704,7 @@ Status TaskGenerator::FindFpOfEnv(const ComputeGraphPtr &graph, const std::strin GELOGI("Start FindFpOfEnv"); uint32_t current_idx = 0; uint32_t first_fp = 0; - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(node->GetOpDesc()); current_idx++; @@ -851,7 +729,7 @@ Status TaskGenerator::FindBpOfEnv(const ComputeGraphPtr &graph, const std::strin uint32_t current_idx = 0; uint32_t iter_end = 0; uint32_t last_bp = 0; - for (auto &node : graph->GetAllNodes()) { + for (auto &node : graph->GetNodes(graph->GetGraphUnknownFlag())) { OpDescPtr op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(node->GetOpDesc()); current_idx++; @@ -927,10 +805,10 @@ Status TaskGenerator::FindProfilingTaskIndex(const ComputeGraphPtr &graph, Profi bool train_graph = graph->GetNeedIteration(); if (profiling_point.fp_index == 0 && train_graph) { - GELOGE(FAILED, "First forward op name can't be found in graph for training trace."); + GELOGW("First forward op name can't be found in graph for training trace."); } if (profiling_point.bp_index == 0 && train_graph) { - GELOGE(FAILED, "Last backward op name can't be found in graph for training trace."); + GELOGW("Last backward op name can't be found in graph for training trace."); } return SUCCESS; } @@ -1068,4 +946,31 @@ bool TaskGenerator::IsProfPoint(const OpDescPtr &op, const std::string &name) { return false; } +Status TaskGenerator::SetUnknownShapeStream(RunContext &run_context, rtStream_t &stream) { + GE_CHK_RT_RET(rtStreamCreate(&stream, 0)); + run_context.stream = stream; + rtError_t rt_ret = rtModelBindStream(run_context.model, stream, 0); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + GE_CHK_RT_RET(rtStreamDestroy(stream)); + return FAILED; + } + return SUCCESS; +} + +Status TaskGenerator::DestroyUnknownShapeStream(RunContext &run_context, rtStream_t &stream) { + GE_CHK_RT(rtModelUnbindStream(run_context.model, stream)); + GE_CHK_RT_RET(rtStreamDestroy(stream)); + return SUCCESS; +} + +Status TaskGenerator::SetKnownShapeStream(RunContext &run_context, int64_t stream_id) { + if (stream_id < 0 || stream_id >= static_cast(run_context.graphStreamList.size())) { + GELOGE(INTERNAL_ERROR, "Stream id[%ld] is invalid, stream list size=%zu", stream_id, + run_context.graphStreamList.size()); + return INTERNAL_ERROR; + } + run_context.stream = run_context.graphStreamList[stream_id]; + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/build/task_generator.h b/src/ge/graph/build/task_generator.h index 02721e00..b2ca4470 100644 --- a/src/ge/graph/build/task_generator.h +++ b/src/ge/graph/build/task_generator.h @@ -93,18 +93,6 @@ class TaskGenerator { Status GenerateTask(RunContext &run_context, ComputeGraphPtr &graph, std::vector &task_def_list, std::map &op_name_map); - /// - /// call engine to generate unknown shape task. - /// @param run_context run context - /// @param graph compute graph - /// @param task_def_list task def list generate by engine - /// @param op_name_map relation of task index and op - /// @return SUCCESS:seccess - /// Other: failed - /// - Status GenerateUnknownShapeTask(RunContext &run_context, ComputeGraphPtr &graph, - std::vector &task_def_list, std::map &op_name_map); - /// /// AddModelTaskToModel /// @param model_task_def model task @@ -154,6 +142,12 @@ class TaskGenerator { Status SaveFusionNodes(map> &fusion_nodes, ComputeGraphPtr &graph); + Status SetUnknownShapeStream(RunContext &run_context, rtStream_t &stream); + + Status DestroyUnknownShapeStream(RunContext &run_context, rtStream_t &stream); + + Status SetKnownShapeStream(RunContext &run_context, int64_t stream_id); + uint8_t *var_mem_base_ = nullptr; uint64_t var_mem_size_ = 0; }; diff --git a/src/ge/graph/common/ge_call_wrapper.h b/src/ge/graph/common/ge_call_wrapper.h index a21d642e..305c6c15 100644 --- a/src/ge/graph/common/ge_call_wrapper.h +++ b/src/ge/graph/common/ge_call_wrapper.h @@ -18,6 +18,43 @@ #define GE_GE_CALL_WRAPPER_H_ #include "framework/common/debug/ge_log.h" +/*lint --emacro((773),GE_TIMESTAMP_START)*/ +/*lint -esym(773,GE_TIMESTAMP_START)*/ +#define GE_TIMESTAMP_START(stage) uint64_t startUsec_##stage = ge::GetCurrentTimestap() + +#define GE_TIMESTAMP_END(stage, stage_name) \ + do { \ + uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ + GELOGI("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ + (endUsec_##stage - startUsec_##stage)); \ + } while (0); + +#define GE_TIMESTAMP_EVENT_END(stage, stage_name) \ + do { \ + uint64_t endUsec_##stage = ge::GetCurrentTimestap(); \ + GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second.", (stage_name), \ + (endUsec_##stage - startUsec_##stage)); \ + } while (0); + +#define GE_TIMESTAMP_CALLNUM_START(stage) \ + uint64_t startUsec_##stage = ge::GetCurrentTimestap(); \ + uint64_t call_num_of##stage = 0; \ + uint64_t time_of##stage = 0 + +#define GE_TIMESTAMP_RESTART(stage) (startUsec_##stage = ge::GetCurrentTimestap()) + +#define GE_TIMESTAMP_ADD(stage) \ + time_of##stage += ge::GetCurrentTimestap() - startUsec_##stage; \ + call_num_of##stage++ + +#define GE_TIMESTAMP_CALLNUM_END(stage, stage_name) \ + GELOGI("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ + call_num_of##stage) + +#define GE_TIMESTAMP_CALLNUM_EVENT_END(stage, stage_name) \ + GEEVENT("[GEPERFTRACE] The time cost of %s is [%lu] micro second, call num is %lu", (stage_name), time_of##stage, \ + call_num_of##stage) + #define RUN_WITH_TIMESTAMP_NAME(var_name, prefix, func, ...) \ do { \ GE_TIMESTAMP_START(var_name); \ @@ -29,10 +66,23 @@ } \ } while (0) +#define RUN_WITH_PERF_TIMESTAMP_NAME(var_name, prefix, func, ...) \ + do { \ + GE_TIMESTAMP_START(var_name); \ + auto ret_inner_macro = func(__VA_ARGS__); \ + GE_TIMESTAMP_EVENT_END(var_name, #prefix "::" #func) \ + if (ret_inner_macro != ge::SUCCESS) { \ + GELOGE(ret_inner_macro, "Failed to process " #prefix "_" #func); \ + return ret_inner_macro; \ + } \ + } while (0) + #define JOIN_NAME_INNER(a, b) a##b #define JOIN_NAME(a, b) JOIN_NAME_INNER(a, b) #define COUNTER_NAME(a) JOIN_NAME(a, __COUNTER__) #define GE_RUN(prefix, func, ...) \ RUN_WITH_TIMESTAMP_NAME(COUNTER_NAME(ge_timestamp_##prefix), prefix, func, __VA_ARGS__) +#define GE_RUN_PERF(prefix, func, ...) \ + RUN_WITH_PERF_TIMESTAMP_NAME(COUNTER_NAME(ge_timestamp_##prefix), prefix, func, __VA_ARGS__) #endif // GE_GE_CALL_WRAPPER_H_ diff --git a/src/ge/graph/execute/graph_execute.cc b/src/ge/graph/execute/graph_execute.cc index b021ce55..1bebd382 100644 --- a/src/ge/graph/execute/graph_execute.cc +++ b/src/ge/graph/execute/graph_execute.cc @@ -86,10 +86,10 @@ Status GraphExecutor::SetGraphContext(GraphContextPtr graph_context_ptr) { return SUCCESS; } -Status GraphExecutor::SetDynamicSize(uint32_t model_id, const std::vector &batch_num) { +Status GraphExecutor::SetDynamicSize(uint32_t model_id, const std::vector &batch_num, int32_t dynamic_type) { auto model_manager = ge::ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); - Status ret = model_manager->SetDynamicSize(model_id, batch_num); + Status ret = model_manager->SetDynamicSize(model_id, batch_num, dynamic_type); if (ret != SUCCESS) { GELOGE(FAILED, "SetDynamicSize failed"); return ret; @@ -120,7 +120,7 @@ Status GraphExecutor::FreeInOutBuffer() { } } -Status GraphExecutor::MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr) { +Status GraphExecutor::MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr) { if (malloc_flag_) { auto all_size_same = true; if (buffer_size.size() == buffer_size_.size()) { @@ -169,7 +169,7 @@ Status GraphExecutor::PrepareInputData(const std::vector &input_tensor graph_input_data.timestamp = 0; std::size_t inputSize = input_tensor.size(); std::size_t output_size = output_desc.size(); - std::vector bufferSizeVec; + std::vector bufferSizeVec; std::vector addrVec; for (std::size_t i = 0; i < inputSize; ++i) { @@ -211,7 +211,7 @@ Status GraphExecutor::PrepareInputData(const std::vector &input_tensor for (std::size_t j = 0; j < output_size; j++) { auto desc = output_desc[j]; - uint32_t buffer_size = desc.size; + uint64_t buffer_size = desc.size; DataBuffer out_data_buf; out_data_buf.data = reinterpret_cast(addrVec[inputSize + j]); @@ -225,6 +225,13 @@ Status GraphExecutor::PrepareInputData(const std::vector &input_tensor Status GraphExecutor::SyncExecuteModel(uint32_t model_id, const std::vector &input_tensor, std::vector &output_tensor) { + auto model_manager = ge::ModelManager::GetInstance(); + GE_CHECK_NOTNULL(model_manager); + if (model_manager->IsDynamicShape(model_id)) { + GELOGI("[ExecuteGraph] GetInputOutputDescInfo via dynamic shape model executor, modelId=%u", model_id); + return model_manager->SyncExecuteModel(model_id, input_tensor, output_tensor); + } + // Prepare input and output std::vector inputs_desc; std::vector output_desc; @@ -479,12 +486,14 @@ Status GraphExecutor::GetInputOutputDescInfo(const uint32_t model_id, vector> &batch_info) { +Status GraphExecutor::GetDynamicBatchInfo(uint32_t model_id, std::vector> &batch_info, + int32_t &dynamic_type) { auto model_manager = ge::ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); - Status ret = model_manager->GetDynamicBatchInfo(model_id, batch_info); + Status ret = model_manager->GetDynamicBatchInfo(model_id, batch_info, dynamic_type); if (ret != SUCCESS) { GELOGE(ret, "GetDynamicBatchInfo failed."); return ret; @@ -492,12 +501,30 @@ Status GraphExecutor::GetDynamicBatchInfo(uint32_t model_id, std::vector &batch_info) { +/// +/// @ingroup ge +/// @brief Get combined dynamic dims info +/// @param [in] model_id +/// @param [out] batch_info +/// @return execute result +/// +Status GraphExecutor::GetCombinedDynamicDims(uint32_t model_id, std::vector> &batch_info) { auto model_manager = ge::ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); - Status ret = model_manager->GetCurShape(model_id, batch_info); + Status ret = model_manager->GetCombinedDynamicDims(model_id, batch_info); if (ret != SUCCESS) { - GELOGE(FAILED, "GetCurShape failed"); + GELOGE(ret, "GetCombinedDynamicDims failed."); + return ret; + } + return SUCCESS; +} + +Status GraphExecutor::GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type) { + auto model_manager = ge::ModelManager::GetInstance(); + GE_CHECK_NOTNULL(model_manager); + Status ret = model_manager->GetCurShape(model_id, batch_info, dynamic_type); + if (ret != SUCCESS) { + GELOGE(ret, "GetCurShape failed"); return ret; } return SUCCESS; @@ -575,5 +602,4 @@ Status GraphExecutor::GetAllAippInputOutputDims(uint32_t model_id, uint32_t inde return SUCCESS; } - } // namespace ge diff --git a/src/ge/graph/execute/graph_execute.h b/src/ge/graph/execute/graph_execute.h index 0518cf11..f79a2e29 100644 --- a/src/ge/graph/execute/graph_execute.h +++ b/src/ge/graph/execute/graph_execute.h @@ -56,7 +56,7 @@ class GraphExecutor { Status SetGraphContext(GraphContextPtr graph_context_ptr); - static Status SetDynamicSize(uint32_t model_id, const std::vector &batch_num); + static Status SetDynamicSize(uint32_t model_id, const std::vector &batch_num, int32_t dynamic_type); void SetTrainFlag(bool is_train_graph); @@ -80,11 +80,22 @@ class GraphExecutor { /// @brief Get dynamic batch_info /// @param [in] model_id /// @param [out] batch_info + /// @param [out] dynamic_type /// @return execute result /// - static Status GetDynamicBatchInfo(uint32_t model_id, std::vector> &batch_info); + static Status GetDynamicBatchInfo(uint32_t model_id, std::vector> &batch_info, + int32_t &dynamic_type); - static Status GetCurShape(const uint32_t model_id, std::vector &batch_info); + /// + /// @ingroup ge + /// @brief Get combined dynamic dims info + /// @param [in] model_id + /// @param [out] batch_info + /// @return execute result + /// + static Status GetCombinedDynamicDims(uint32_t model_id, std::vector> &batch_info); + + static Status GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type); static Status GetModelAttr(uint32_t model_id, std::vector &dynamic_output_shape_info); @@ -110,7 +121,7 @@ class GraphExecutor { Status FreeInOutBuffer(); - Status MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr); + Status MallocInOutBuffer(const std::vector &buffer_size, std::vector &data_addr); bool init_flag_; @@ -129,7 +140,7 @@ class GraphExecutor { bool malloc_flag_; std::vector buffer_addr_; - std::vector buffer_size_; + std::vector buffer_size_; }; } // namespace ge diff --git a/src/ge/graph/label/while_label_maker.cc b/src/ge/graph/label/while_label_maker.cc index 6601abd1..c5e0abb7 100644 --- a/src/ge/graph/label/while_label_maker.cc +++ b/src/ge/graph/label/while_label_maker.cc @@ -104,12 +104,11 @@ Status WhileOpLabelMaker::Run(uint32_t &label_index) { GE_CHECK_NOTNULL(cond_out_desc); GeTensorDesc pred_desc = cond_out_desc->GetInputDesc(kCondOutputIndex); - GeTensorDesc cond_desc(GeShape(pred_desc.GetShape().GetDims()), pred_desc.GetFormat(), DT_INT32); // false ==> 0 ==> switch_labels[0] ==> body_leave_index // true ==> 1 ==> switch_labels[1] ==> body_enter_name const std::vector switch_labels = {body_leave_index, body_enter_index}; - NodePtr switch_node = AddLabelSwitchLeave(cond_graph, cond_leave_name, cond_desc, switch_labels); + NodePtr switch_node = AddLabelSwitchLeave(cond_graph, cond_leave_name, pred_desc, switch_labels); if (switch_node == nullptr) { GELOGE(INTERNAL_ERROR, "Subgraph: %s add label switch failed.", cond_graph->GetName().c_str()); return FAILED; diff --git a/src/ge/graph/load/graph_loader.cc b/src/ge/graph/load/graph_loader.cc index 1f4cbcf9..d181f3a5 100644 --- a/src/ge/graph/load/graph_loader.cc +++ b/src/ge/graph/load/graph_loader.cc @@ -36,20 +36,20 @@ GraphLoader::~GraphLoader() = default; Status GraphLoader::UnloadModel(uint32_t model_id) { auto model_manager = ModelManager::GetInstance(); GE_CHECK_NOTNULL(model_manager); - GELOGI("UnLoad model begin, model_id:%u.", model_id); + GELOGI("UnLoad model begin, model id:%u.", model_id); Status ret = model_manager->Stop(model_id); if (ret != SUCCESS) { - GELOGE(ret, "UnloadModel: Stop failed."); + GELOGE(ret, "UnloadModel: Stop failed. model id:%u", model_id); } ret = model_manager->Unload(model_id); if (ret != SUCCESS) { - GELOGE(ret, "UnloadModel: Unload failed."); + GELOGE(ret, "UnloadModel: Unload failed. model id:%u", model_id); CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_UNLOAD); return ret; } - GELOGI("UnLoad model success, model_id:%u.", model_id); + GELOGI("UnLoad model success, model id:%u.", model_id); return SUCCESS; } @@ -123,14 +123,14 @@ Status GraphLoader::LoadDataFromFile(const std::string &path, const std::string Status ret; try { if (!CheckInputPathValid(path)) { - GELOGE(PARAM_INVALID, "model path is invalid: %s", path.c_str()); - return PARAM_INVALID; + GELOGE(GE_EXEC_MODEL_PATH_INVALID, "model path is invalid: %s", path.c_str()); + return GE_EXEC_MODEL_PATH_INVALID; } GELOGI("Load model begin, model path is: %s", path.c_str()); if (!key_path.empty() && !CheckInputPathValid(key_path)) { - GELOGE(PARAM_INVALID, "decrypt_key path is invalid: %s", key_path.c_str()); - return PARAM_INVALID; + GELOGE(GE_EXEC_MODEL_KEY_PATH_INVALID, "decrypt_key path is invalid: %s", key_path.c_str()); + return GE_EXEC_MODEL_KEY_PATH_INVALID; } ret = DavinciModelParser::LoadFromFile(path.c_str(), key_path.c_str(), priority, model_data); @@ -350,7 +350,8 @@ Status GraphLoader::GetMemoryInfo(int64_t &free) { return RT_FAILED; } // Add small page memory size - free = static_cast(free_mem + VarManager::Instance(0)->GetUseMaxMemorySize() - total_mem); + free = + static_cast(free_mem + VarManager::Instance(GetContext().SessionId())->GetUseMaxMemorySize() - total_mem); GELOGI("GetMemoryInfo free[%zu], total[%zu], return free[%ld]", free_mem, total_mem, free); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc index 06111015..01e1cfa8 100644 --- a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc +++ b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.cc @@ -16,6 +16,7 @@ #include "graph/load/new_model_manager/cpu_queue_schedule.h" #include "common/debug/ge_log.h" +#include "common/debug/log.h" namespace { const uint32_t kCoreDim = 1; // for rtCpuKernelLaunch @@ -58,7 +59,7 @@ Status CpuTaskModelDequeue::Init(uint32_t queue_id, uintptr_t &in_mbuf) { rtError_t status = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } in_mbuf = reinterpret_cast(args_) + sizeof(MbufQueueInfo); GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "args data.", args_size_) @@ -69,7 +70,7 @@ Status CpuTaskModelDequeue::Init(uint32_t queue_id, uintptr_t &in_mbuf) { status = rtMemcpy(args_, args_size_, &queue_info, sizeof(MbufQueueInfo), RT_MEMCPY_HOST_TO_DEVICE); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } return SUCCESS; @@ -84,7 +85,7 @@ Status CpuTaskModelDequeue::Distribute() { rtError_t status = rtCpuKernelLaunch(nullptr, kCpuTaskModelDequeue, kCoreDim, args_, args_size_, nullptr, stream_); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt CpuKernelLaunch ModelDequeue failed, status: 0x%X", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } GELOGI("Cpu kernel launch model dequeue task success."); @@ -98,24 +99,24 @@ Status CpuTaskModelDequeue::Distribute() { /// @param [in] outside_addrs: model input/output memory addr /// @return: 0 for success / others for failed /// -Status CpuTaskZeroCopy::Init(std::vector &mbuf_list, - std::map> &outside_addrs) { +Status CpuTaskZeroCopy::Init(std::vector &mbuf_list, std::map &outside_addrs) { if ((args_ != nullptr) || (args_size_ > 0)) { GELOGE(FAILED, "Task already initialized, size: %u", args_size_); return FAILED; } args_size_ = sizeof(AddrMapInfo); - rtError_t status = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); - if (status != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); - return RT_FAILED; - } + GE_CHK_RT_RET(rtMalloc(&args_, args_size_, RT_MEMORY_HBM)); GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "args data.", args_size_) AddrMapInfo addr_map_info; - for (const auto &addrs : outside_addrs) { - addr_map_info.addr_num += addrs.second.size(); + for (auto &addrs : outside_addrs) { + auto &addrs_mapping_list = addrs.second.GetOutsideAddrs(); + GE_CHK_BOOL_EXEC(!addrs_mapping_list.empty(), return PARAM_INVALID, "not set outside_addrs"); + std::map> virtual_args_addrs = addrs_mapping_list[0]; + for (const auto &virtual_args_addr : virtual_args_addrs) { + addr_map_info.addr_num += virtual_args_addr.second.size(); + } } GELOGI("addr_map_info.addr_num is %u", addr_map_info.addr_num); @@ -123,38 +124,31 @@ Status CpuTaskZeroCopy::Init(std::vector &mbuf_list, size_t index = 0; vector src_addrs; vector dst_addrs; - for (const auto &addrs : outside_addrs) { - for (size_t i = 0; i < addrs.second.size(); ++i) { - src_addrs.push_back(mbuf_list.at(index)); - dst_addrs.push_back(reinterpret_cast(reinterpret_cast(addrs.second.at(i)))); + for (auto &addrs : outside_addrs) { + auto &addrs_mapping_list = addrs.second.GetOutsideAddrs(); + GE_CHK_BOOL_EXEC(!addrs_mapping_list.empty(), return PARAM_INVALID, "not set outside_addrs"); + std::map> virtual_args_addrs = addrs_mapping_list[0]; + for (const auto &virtual_args_addr : virtual_args_addrs) { + for (size_t i = 0; i < virtual_args_addr.second.size(); ++i) { + src_addrs.push_back(mbuf_list.at(index)); + dst_addrs.push_back(reinterpret_cast(reinterpret_cast(virtual_args_addr.second.at(i)))); + } } index++; } // malloc mem for src_addrs/dst_addrs, and copy data of src_addrs/dst_addrs - status = rtMalloc(&src_addr_, src_addrs.size() * sizeof(uint64_t), RT_MEMORY_HBM); - if (status != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); - return RT_FAILED; - } - status = rtMemcpy(src_addr_, src_addrs.size() * sizeof(uint64_t), src_addrs.data(), - src_addrs.size() * sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); - if (status != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); - return RT_FAILED; - } + GE_CHK_RT_RET(rtMalloc(&src_addr_, src_addrs.size() * sizeof(uint64_t), RT_MEMORY_HBM)); + rtError_t status = rtMemcpy(src_addr_, src_addrs.size() * sizeof(uint64_t), src_addrs.data(), + src_addrs.size() * sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); + GE_IF_BOOL_EXEC(status != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy error, ret: Ox%X", status); + return RT_ERROR_TO_GE_STATUS(status);) - status = rtMalloc(&dst_addr_, dst_addrs.size() * sizeof(uint64_t), RT_MEMORY_HBM); - if (status != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); - return RT_FAILED; - } + GE_CHK_RT_RET(rtMalloc(&dst_addr_, dst_addrs.size() * sizeof(uint64_t), RT_MEMORY_HBM)); status = rtMemcpy(dst_addr_, dst_addrs.size() * sizeof(uint64_t), dst_addrs.data(), dst_addrs.size() * sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); - if (status != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); - return RT_FAILED; - } + GE_IF_BOOL_EXEC(status != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy error, ret: Ox%X", status); + return RT_ERROR_TO_GE_STATUS(status);) // src_addr_list is init to src_addr, which is the point to src_addrs if (!src_addrs.empty() && !dst_addrs.empty()) { @@ -164,10 +158,8 @@ Status CpuTaskZeroCopy::Init(std::vector &mbuf_list, } status = rtMemcpy(args_, args_size_, &addr_map_info, sizeof(AddrMapInfo), RT_MEMCPY_HOST_TO_DEVICE); - if (status != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); - return RT_FAILED; - } + GE_IF_BOOL_EXEC(status != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy error, ret: Ox%X", status); + return RT_ERROR_TO_GE_STATUS(status);) return SUCCESS; } @@ -180,7 +172,7 @@ Status CpuTaskZeroCopy::Distribute() { rtError_t status = rtCpuKernelLaunch(nullptr, kCpuTaskZeroCopy, kCoreDim, args_, args_size_, nullptr, stream_); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt CpuKernelLaunch ZeroCopy failed, status: 0x%X", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } GELOGI("Cpu kernel launch zero copy task success."); @@ -225,7 +217,7 @@ Status CpuTaskPrepareOutput::Init(uintptr_t addr, uint32_t size, uintptr_t in_mb rtError_t status = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } out_mbuf = reinterpret_cast(args_) + sizeof(PrepareOutputInfo); GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "args data.", args_size_) @@ -239,7 +231,7 @@ Status CpuTaskPrepareOutput::Init(uintptr_t addr, uint32_t size, uintptr_t in_mb status = rtMemcpy(args_, args_size_, &prepare, sizeof(PrepareOutputInfo), RT_MEMCPY_HOST_TO_DEVICE); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } return SUCCESS; @@ -254,7 +246,7 @@ Status CpuTaskPrepareOutput::Distribute() { rtError_t status = rtCpuKernelLaunch(nullptr, kCpuTaskPrepareOutput, kCoreDim, args_, args_size_, nullptr, stream_); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt CpuKernelLaunch PrepareOutput failed, status: 0x%X", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } GELOGI("Cpu kernel launch prepare output task success."); @@ -279,7 +271,7 @@ Status CpuTaskModelEnqueue::Init(uint32_t queue_id, uintptr_t out_mbuf) { rtError_t status = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "args data.", args_size_) @@ -289,7 +281,7 @@ Status CpuTaskModelEnqueue::Init(uint32_t queue_id, uintptr_t out_mbuf) { status = rtMemcpy(args_, args_size_, &queue_info, args_size_, RT_MEMCPY_HOST_TO_DEVICE); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } return SUCCESS; @@ -304,7 +296,7 @@ Status CpuTaskModelEnqueue::Distribute() { rtError_t status = rtCpuKernelLaunch(nullptr, kCpuTaskModelEnqueue, kCoreDim, args_, args_size_, nullptr, stream_); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt CpuKernelLaunch ModelEnqueue failed, status: 0x%X", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } GELOGI("Cpu kernel launch model enqueue task success."); @@ -336,10 +328,10 @@ Status CpuTaskActiveEntry::Distribute() { rtError_t ret = rtStreamActive(active_stream_, stream_); if (ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt StreamActive failed, ret: 0x%X", ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(ret); } - GELOGI("Cpu kernel launch wait end task success."); + GELOGI("Cpu kernel launch active entry task success."); return SUCCESS; } @@ -359,14 +351,14 @@ Status CpuTaskWaitEndGraph::Init(uint32_t model_id) { rtError_t status = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "args data.", args_size_) status = rtMemcpy(args_, args_size_, &model_id, args_size_, RT_MEMCPY_HOST_TO_DEVICE); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } return SUCCESS; @@ -381,7 +373,7 @@ Status CpuTaskWaitEndGraph::Distribute() { rtError_t status = rtCpuKernelLaunch(nullptr, kCpuTaskWaitEndGraph, kCoreDim, args_, args_size_, nullptr, stream_); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt CpuKernelLaunch WaitEndGraph failed, status: 0x%X", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } GELOGI("Cpu kernel launch wait end task success."); @@ -404,14 +396,14 @@ Status CpuTaskModelRepeat::Init(uint32_t model_id) { rtError_t status = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "args data.", args_size_) status = rtMemcpy(args_, args_size_, &model_id, args_size_, RT_MEMCPY_HOST_TO_DEVICE); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } return SUCCESS; @@ -426,7 +418,7 @@ Status CpuTaskModelRepeat::Distribute() { rtError_t status = rtCpuKernelLaunch(nullptr, kCpuTaskModelRepeat, kCoreDim, args_, args_size_, nullptr, stream_); if (status != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt CpuKernelLaunch ModelRepeat failed, status: 0x%x", status); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(status); } GELOGI("Cpu kernel launch repeat task success."); diff --git a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.h b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.h index c4ae4df5..cea00613 100644 --- a/src/ge/graph/load/new_model_manager/cpu_queue_schedule.h +++ b/src/ge/graph/load/new_model_manager/cpu_queue_schedule.h @@ -22,6 +22,7 @@ #include "common/ge_inner_error_codes.h" #include "graph/load/new_model_manager/task_info/task_info.h" +#include "graph/load/new_model_manager/zero_copy_offset.h" #include "runtime/kernel.h" namespace ge { @@ -93,7 +94,7 @@ class CpuTaskZeroCopy : public CpuTaskInfo { ~CpuTaskZeroCopy() override; Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override { return SUCCESS; } - Status Init(std::vector &mbuf_list, std::map> &outside_addrs); + Status Init(std::vector &mbuf_list, std::map &outside_addrs); Status Distribute() override; diff --git a/src/ge/graph/load/new_model_manager/data_dumper.cc b/src/ge/graph/load/new_model_manager/data_dumper.cc index 653a3fa1..7194264d 100644 --- a/src/ge/graph/load/new_model_manager/data_dumper.cc +++ b/src/ge/graph/load/new_model_manager/data_dumper.cc @@ -21,7 +21,6 @@ #include #include -#include "common/debug/log.h" #include "common/properties_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" @@ -37,9 +36,36 @@ namespace { const uint32_t kAicpuLoadFlag = 1; const uint32_t kAicpuUnloadFlag = 0; +const int64_t kOpDebugSize = 2048; +const int64_t kOpDebugShape = 2048; +const int8_t kDecimal = 10; +const uint32_t kAddrLen = sizeof(void *); const char *const kDumpOutput = "output"; const char *const kDumpInput = "input"; const char *const kDumpAll = "all"; + +// parse for format like nodename:input:index +static bool ParseNameIndex(const std::string &node_name_index, std::string &node_name, std::string &input_or_output, + size_t &index) { + auto sep = node_name_index.rfind(':'); + if (sep == std::string::npos) { + return false; + } + auto index_str = node_name_index.substr(sep + 1); + index = static_cast(std::strtol(index_str.c_str(), nullptr, kDecimal)); + auto node_name_without_index = node_name_index.substr(0, sep); + sep = node_name_without_index.rfind(':'); + if (sep == std::string::npos) { + return false; + } + node_name = node_name_without_index.substr(0, sep); + input_or_output = node_name_without_index.substr(sep + 1); + return !(input_or_output != kDumpInput && input_or_output != kDumpOutput); +} + +static bool IsTensorDescWithSkipDumpAddrType(bool has_mem_type_attr, vector v_memory_type, size_t i) { + return has_mem_type_attr && (v_memory_type[i] == RT_MEMORY_L1); +} } // namespace static int32_t GetIrDataType(ge::DataType data_type) { @@ -138,6 +164,13 @@ void DataDumper::SaveEndGraphId(uint32_t task_id, uint32_t stream_id) { end_graph_stream_id_ = stream_id; } +void DataDumper::SaveOpDebugId(uint32_t task_id, uint32_t stream_id, void *op_debug_addr, bool is_op_debug) { + op_debug_task_id_ = task_id; + op_debug_stream_id_ = stream_id; + op_debug_addr_ = op_debug_addr; + is_op_debug_ = is_op_debug; +} + void DataDumper::SaveDumpTask(uint32_t task_id, uint32_t stream_id, const std::shared_ptr &op_desc, uintptr_t args) { if (op_desc == nullptr) { @@ -202,56 +235,121 @@ static void SetOpMappingLoopAddr(uintptr_t step_id, uintptr_t loop_per_iter, uin } } -Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { - GELOGI("Start dump output"); - if (inner_dump_info.is_task) { - // tbe or aicpu op - const auto &output_descs = inner_dump_info.op->GetAllOutputsDesc(); - const auto input_size = inner_dump_info.op->GetAllInputsDesc().size(); - const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op, false); - if (output_descs.size() != output_addrs.size()) { - GELOGE(PARAM_INVALID, "Invalid output desc addrs size %zu, op %s has %zu output desc.", output_addrs.size(), - inner_dump_info.op->GetName().c_str(), output_descs.size()); - return PARAM_INVALID; - } +Status DataDumper::GenerateOutput(aicpu::dump::Output &output, const OpDesc::Vistor &tensor_descs, + const uintptr_t &addr, size_t index) { + output.set_data_type(static_cast(GetIrDataType(tensor_descs.at(index).GetDataType()))); + output.set_format(static_cast(tensor_descs.at(index).GetFormat())); - for (size_t i = 0; i < output_descs.size(); ++i) { - aicpu::dump::Output output; - output.set_data_type(static_cast(GetIrDataType(output_descs.at(i).GetDataType()))); - output.set_format(static_cast(output_descs.at(i).GetFormat())); + for (auto dim : tensor_descs.at(index).GetShape().GetDims()) { + output.mutable_shape()->add_dim(dim); + } + int64_t output_size = 0; + if (TensorUtils::GetTensorSizeInBytes(tensor_descs.at(index), output_size) != SUCCESS) { + GELOGE(PARAM_INVALID, "Get output size filed"); + return PARAM_INVALID; + } + GELOGD("Get output size in dump is %ld", output_size); + std::string origin_name; + int32_t origin_output_index = -1; + (void)AttrUtils::GetStr(&tensor_descs.at(index), ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); + (void)AttrUtils::GetInt(&tensor_descs.at(index), ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index); + output.set_size(output_size); + output.set_original_name(origin_name); + output.set_original_output_index(origin_output_index); + output.set_original_output_format(static_cast(tensor_descs.at(index).GetOriginFormat())); + output.set_original_output_data_type(static_cast(tensor_descs.at(index).GetOriginDataType())); + output.set_address(static_cast(addr)); + return SUCCESS; +} - for (auto dim : output_descs.at(i).GetShape().GetDims()) { - output.mutable_shape()->add_dim(dim); - } +Status DataDumper::DumpRefOutput(const DataDumper::InnerDumpInfo &inner_dump_info, aicpu::dump::Output &output, + size_t i, const std::string &node_name_index) { + std::string dump_op_name; + std::string input_or_output; + size_t index; + // parser and find which node's input or output tensor desc is chosen for dump info + if (!ParseNameIndex(node_name_index, dump_op_name, input_or_output, index)) { + GELOGE(PARAM_INVALID, "Op [%s] output desc[%zu] with invalid ATTR_DATA_DUMP_REF attr[%s].", + inner_dump_info.op->GetName().c_str(), i, node_name_index.c_str()); + return PARAM_INVALID; + } + GE_CHECK_NOTNULL(compute_graph_); + auto replace_node = compute_graph_->FindNode(dump_op_name); + GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(replace_node == nullptr, + "Op [%s] output desc[%zu] with invalid ATTR_DATA_DUMP_REF attr[%s]," + " cannot find redirect node[%s].", + inner_dump_info.op->GetName().c_str(), i, node_name_index.c_str(), + dump_op_name.c_str()); + auto replace_opdesc = replace_node->GetOpDesc(); + GE_CHECK_NOTNULL(replace_opdesc); + auto iter = ref_info_.find(replace_opdesc); + GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(iter == ref_info_.end(), + "Op [%s] output desc[%zu] cannot find any saved redirect node[%s]'s info.", + inner_dump_info.op->GetName().c_str(), i, replace_opdesc->GetName().c_str()); + GE_CHECK_NOTNULL(iter->second); + auto addr = reinterpret_cast(iter->second); + if (input_or_output == kDumpInput) { + const auto &replace_input_descs = replace_opdesc->GetAllInputsDesc(); + addr += kAddrLen * index; + GE_CHK_STATUS_RET(GenerateOutput(output, replace_input_descs, addr, index), "Generate output failed"); + } else if (input_or_output == kDumpOutput) { + const auto &replace_output_descs = replace_opdesc->GetAllOutputsDesc(); + const auto replace_input_size = replace_opdesc->GetAllInputsDesc().size(); + addr += (index + replace_input_size) * kAddrLen; + GE_CHK_STATUS_RET(GenerateOutput(output, replace_output_descs, addr, index), "Generate output failed"); + } + GELOGD("Op [%s] output desc[%zu] dump info is replaced by node[%s] [%s] tensor_desc [%zu]", + inner_dump_info.op->GetName().c_str(), i, dump_op_name.c_str(), input_or_output.c_str(), index); + return SUCCESS; +} - int64_t output_size = 0; - if (TensorUtils::GetTensorSizeInBytes(output_descs.at(i), output_size) != SUCCESS) { - GELOGE(PARAM_INVALID, "Get output size filed"); - return PARAM_INVALID; - } - GELOGI("Get output size in dump is %ld", output_size); - std::string origin_name; - int32_t origin_output_index = -1; - (void)AttrUtils::GetStr(&output_descs.at(i), ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); - (void)AttrUtils::GetInt(&output_descs.at(i), ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index); - GE_IF_BOOL_EXEC(output_size <= 0, GELOGE(PARAM_INVALID, "Output size %ld is less than zero", output_size); - return PARAM_INVALID) - output.set_size(output_size); - output.set_original_name(origin_name); - output.set_original_output_index(origin_output_index); - output.set_original_output_format(static_cast(output_descs.at(i).GetOriginFormat())); - output.set_original_output_data_type(static_cast(output_descs.at(i).GetOriginDataType())); - output.set_address(static_cast(inner_dump_info.args + (i + input_size) * sizeof(void *))); - - task.mutable_output()->Add(std::move(output)); +Status DataDumper::DumpOutputWithTask(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { + const auto &output_descs = inner_dump_info.op->GetAllOutputsDesc(); + const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op); + if (output_descs.size() != output_addrs.size()) { + GELOGE(PARAM_INVALID, "Invalid output desc addrs size %zu, op %s has %zu output desc.", output_addrs.size(), + inner_dump_info.op->GetName().c_str(), output_descs.size()); + return PARAM_INVALID; + } + std::vector v_memory_type; + bool has_mem_type_attr = ge::AttrUtils::GetListInt(inner_dump_info.op, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, v_memory_type); + GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(has_mem_type_attr && (v_memory_type.size() != output_descs.size()), + "DumpOutputWithTask[%s], output size[%zu], output memory type size[%zu]", + inner_dump_info.op->GetName().c_str(), output_descs.size(), + v_memory_type.size()); + + for (size_t i = 0; i < output_descs.size(); ++i) { + aicpu::dump::Output output; + std::string node_name_index; + const auto &output_desc = output_descs.at(i); + // check dump output tensor desc is redirected by attr ATTR_DATA_DUMP_REF + if (AttrUtils::GetStr(&output_desc, ATTR_DATA_DUMP_REF, node_name_index)) { + GE_CHK_STATUS_RET(DumpRefOutput(inner_dump_info, output, i, node_name_index), "DumpRefOutput failed"); + } else { + GE_IF_BOOL_EXEC( + IsTensorDescWithSkipDumpAddrType(has_mem_type_attr, v_memory_type, i), + GELOGD("DumpOutputWithTask[%s] output[%zu] is l1 addr, skip it", inner_dump_info.op->GetName().c_str(), i); + continue;); + + const auto input_size = inner_dump_info.op->GetInputsSize(); + auto addr = inner_dump_info.args + (i + input_size) * kAddrLen; + GE_CHK_STATUS_RET(GenerateOutput(output, output_descs, addr, i), "Generate output failed"); } - return SUCCESS; + task.mutable_output()->Add(std::move(output)); } + return SUCCESS; +} +Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { + GELOGI("Start dump output"); + if (inner_dump_info.is_task) { + // tbe or aicpu op, these ops are with task + return DumpOutputWithTask(inner_dump_info, task); + } // else data, const or variable op aicpu::dump::Output output; auto output_tensor = inner_dump_info.op->GetOutputDescPtr(inner_dump_info.output_anchor_index); - const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op, false); + const std::vector output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op); if (output_tensor == nullptr) { GELOGE(PARAM_INVALID, "output_tensor is null, index: %d, size: %zu.", inner_dump_info.output_anchor_index, inner_dump_info.op->GetOutputsSize()); @@ -269,9 +367,6 @@ Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump: int32_t origin_output_index = -1; (void)AttrUtils::GetStr(output_tensor, ATTR_NAME_DATA_DUMP_ORIGIN_NAME, origin_name); (void)AttrUtils::GetInt(output_tensor, ATTR_NAME_DATA_DUMP_ORIGIN_OUTPUT_INDEX, origin_output_index); - GE_IF_BOOL_EXEC(inner_dump_info.data_size <= 0, - GELOGE(PARAM_INVALID, "The size of data %ld is less than zero", inner_dump_info.data_size); - return PARAM_INVALID) output.set_size(inner_dump_info.data_size); output.set_original_name(origin_name); output.set_original_output_index(origin_output_index); @@ -282,7 +377,7 @@ Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump: GELOGE(FAILED, "Index is out of range."); return FAILED; } - auto data_addr = inner_dump_info.args + sizeof(void *) * static_cast(inner_dump_info.input_anchor_index); + auto data_addr = inner_dump_info.args + kAddrLen * static_cast(inner_dump_info.input_anchor_index); output.set_address(static_cast(data_addr)); task.mutable_output()->Add(std::move(output)); @@ -290,37 +385,98 @@ Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump: return SUCCESS; } +Status DataDumper::GenerateInput(aicpu::dump::Input &input, const OpDesc::Vistor &tensor_descs, + const uintptr_t &addr, size_t index) { + input.set_data_type(static_cast(GetIrDataType(tensor_descs.at(index).GetDataType()))); + input.set_format(static_cast(tensor_descs.at(index).GetFormat())); + + for (auto dim : tensor_descs.at(index).GetShape().GetDims()) { + input.mutable_shape()->add_dim(dim); + } + int64_t input_size = 0; + if (AttrUtils::GetInt(tensor_descs.at(index), ATTR_NAME_INPUT_ORIGIN_SIZE, input_size)) { + GELOGI("Get aipp input size according to attr is %ld", input_size); + } else if (TensorUtils::GetTensorSizeInBytes(tensor_descs.at(index), input_size) != SUCCESS) { + GELOGE(PARAM_INVALID, "Get input size filed"); + return PARAM_INVALID; + } + GELOGD("Get input size in dump is %ld", input_size); + input.set_size(input_size); + input.set_address(static_cast(addr)); + return SUCCESS; +} + +Status DataDumper::DumpRefInput(const DataDumper::InnerDumpInfo &inner_dump_info, aicpu::dump::Input &input, size_t i, + const std::string &node_name_index) { + std::string dump_op_name; + std::string input_or_output; + size_t index; + // parser and find which node's input or output tensor desc is chosen for dump info + if (!ParseNameIndex(node_name_index, dump_op_name, input_or_output, index)) { + GELOGE(PARAM_INVALID, "Op [%s] input desc[%zu] with invalid ATTR_DATA_DUMP_REF attr[%s].", + inner_dump_info.op->GetName().c_str(), i, node_name_index.c_str()); + return PARAM_INVALID; + } + GE_CHECK_NOTNULL(compute_graph_); + auto replace_node = compute_graph_->FindNode(dump_op_name); + GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(replace_node == nullptr, + "Op [%s] input desc[%zu] with invalid ATTR_DATA_DUMP_REF attr[%s]," + " cannot find redirect node[%s].", + inner_dump_info.op->GetName().c_str(), i, node_name_index.c_str(), + dump_op_name.c_str()); + auto replace_opdesc = replace_node->GetOpDesc(); + GE_CHECK_NOTNULL(replace_opdesc); + auto iter = ref_info_.find(replace_opdesc); + GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(iter == ref_info_.end(), + "Op [%s] input desc[%zu] cannot find any saved redirect node[%s]'s info.", + inner_dump_info.op->GetName().c_str(), i, replace_opdesc->GetName().c_str()); + GE_CHECK_NOTNULL(iter->second); + auto addr = reinterpret_cast(iter->second); + if (input_or_output == kDumpInput) { + const auto &replace_input_descs = replace_opdesc->GetAllInputsDesc(); + addr += kAddrLen * index; + GE_CHK_STATUS_RET(GenerateInput(input, replace_input_descs, addr, index), "Generate input failed"); + } else if (input_or_output == kDumpOutput) { + const auto &replace_output_descs = replace_opdesc->GetAllOutputsDesc(); + const auto replace_input_size = replace_opdesc->GetAllInputsDesc().size(); + addr += (index + replace_input_size) * kAddrLen; + GE_CHK_STATUS_RET(GenerateInput(input, replace_output_descs, addr, index), "Generate input failed"); + } + GELOGD("Op [%s] input desc[%zu] dump info is replaced by node[%s] [%s] tensor_desc [%zu]", + inner_dump_info.op->GetName().c_str(), i, dump_op_name.c_str(), input_or_output.c_str(), index); + return SUCCESS; +} + Status DataDumper::DumpInput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { GELOGI("Start dump input"); const auto &input_descs = inner_dump_info.op->GetAllInputsDesc(); - const std::vector input_addrs = ModelUtils::GetInputDataAddrs(runtime_param_, inner_dump_info.op, false); + const std::vector input_addrs = ModelUtils::GetInputDataAddrs(runtime_param_, inner_dump_info.op); if (input_descs.size() != input_addrs.size()) { GELOGE(PARAM_INVALID, "Invalid input desc addrs size %zu, op %s has %zu input desc.", input_addrs.size(), inner_dump_info.op->GetName().c_str(), input_descs.size()); return PARAM_INVALID; } + std::vector v_memory_type; + bool has_mem_type_attr = ge::AttrUtils::GetListInt(inner_dump_info.op, ATTR_NAME_INPUT_MEM_TYPE_LIST, v_memory_type); + GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(has_mem_type_attr && (v_memory_type.size() != input_descs.size()), + "DumpInput[%s], input size[%zu], input memory type size[%zu]", + inner_dump_info.op->GetName().c_str(), input_descs.size(), v_memory_type.size()); for (size_t i = 0; i < input_descs.size(); ++i) { aicpu::dump::Input input; - input.set_data_type(static_cast(GetIrDataType(input_descs.at(i).GetDataType()))); - input.set_format(static_cast(input_descs.at(i).GetFormat())); - - for (auto dim : input_descs.at(i).GetShape().GetDims()) { - input.mutable_shape()->add_dim(dim); + std::string node_name_index; + // check dump input tensor desc is redirected by attr ATTR_DATA_DUMP_REF + if (AttrUtils::GetStr(&input_descs.at(i), ATTR_DATA_DUMP_REF, node_name_index)) { + GE_CHK_STATUS_RET(DumpRefInput(inner_dump_info, input, i, node_name_index), "DumpRefInput failed"); + // normal dump without attr + } else { + GE_IF_BOOL_EXEC(IsTensorDescWithSkipDumpAddrType(has_mem_type_attr, v_memory_type, i), + GELOGD("DumpInput[%s] input[%zu] is l1 addr, skip it", inner_dump_info.op->GetName().c_str(), i); + continue;); + + auto addr = inner_dump_info.args + kAddrLen * i; + GE_CHK_STATUS_RET(GenerateInput(input, input_descs, addr, i), "Generate input failed"); } - - int64_t input_size = 0; - if (AttrUtils::GetInt(&input_descs.at(i), ATTR_NAME_INPUT_ORIGIN_SIZE, input_size)) { - GELOGI("Get aipp input size according to attr is %ld", input_size); - } else if (TensorUtils::GetTensorSizeInBytes(input_descs.at(i), input_size) != SUCCESS) { - GELOGE(PARAM_INVALID, "Get input size filed"); - return PARAM_INVALID; - } - GELOGI("Get input size in dump is %ld", input_size); - GE_IF_BOOL_EXEC(input_size <= 0, GELOGE(PARAM_INVALID, "Input size %ld is less than zero", input_size); - return PARAM_INVALID;) - input.set_size(input_size); - input.set_address(static_cast(inner_dump_info.args + sizeof(void *) * i)); task.mutable_input()->Add(std::move(input)); } return SUCCESS; @@ -331,8 +487,8 @@ Status DataDumper::ExecuteLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_in size_t proto_size = op_mapping_info.ByteSizeLong(); bool ret = op_mapping_info.SerializeToString(&proto_str); if (!ret || proto_size == 0) { - GELOGE(FAILED, "Protobuf SerializeToString failed, proto size %zu.", proto_size); - return FAILED; + GELOGE(PARAM_INVALID, "Protobuf SerializeToString failed, proto size %zu.", proto_size); + return PARAM_INVALID; } if (dev_mem_load_ != nullptr) { @@ -343,20 +499,20 @@ Status DataDumper::ExecuteLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_in rtError_t rt_ret = rtMalloc(&dev_mem_load_, proto_size, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "load dump information.", proto_size) rt_ret = rtMemcpy(dev_mem_load_, proto_size, proto_str.c_str(), proto_size, RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rtMemcpy failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } rt_ret = rtDatadumpInfoLoad(dev_mem_load_, proto_size); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rtDatadumpInfoLoad failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } load_flag_ = true; @@ -369,8 +525,8 @@ Status DataDumper::ExecuteUnLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_ size_t proto_size = op_mapping_info.ByteSizeLong(); bool ret = op_mapping_info.SerializeToString(&proto_str); if (!ret || proto_size == 0) { - GELOGE(FAILED, "Protobuf SerializeToString failed, proto size %zu.", proto_size); - return FAILED; + GELOGE(PARAM_INVALID, "Protobuf SerializeToString failed, proto size %zu.", proto_size); + return PARAM_INVALID; } if (dev_mem_unload_ != nullptr) { @@ -381,83 +537,87 @@ Status DataDumper::ExecuteUnLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_ rtError_t rt_ret = rtMalloc(&dev_mem_unload_, proto_size, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "unload dump information.", proto_size) rt_ret = rtMemcpy(dev_mem_unload_, proto_size, proto_str.c_str(), proto_size, RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rtMemcpy failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } rt_ret = rtDatadumpInfoLoad(dev_mem_unload_, proto_size); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rtDatadumpInfoLoad failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } load_flag_ = false; GELOGI("UnloadDumpInfo success, proto size is: %zu.", proto_size); return SUCCESS; } + Status DataDumper::LoadDumpInfo() { std::string dump_list_key; PrintCheckLog(dump_list_key); if (op_list_.empty()) { - return SUCCESS; + GELOGW("op_list_ is empty"); } aicpu::dump::OpMappingInfo op_mapping_info; - auto dump_path = PropertiesManager::Instance().GetDumpOutputPath(); - op_mapping_info.set_dump_path(PropertiesManager::Instance().GetDumpOutputPath() + std::to_string(device_id_) + "/"); + auto dump_path = dump_properties_.GetDumpPath() + std::to_string(device_id_) + "/"; + op_mapping_info.set_dump_path(dump_path); op_mapping_info.set_model_name(dump_list_key); op_mapping_info.set_model_id(model_id_); op_mapping_info.set_flag(kAicpuLoadFlag); - op_mapping_info.set_dump_step(PropertiesManager::Instance().GetDumpStep()); + op_mapping_info.set_dump_step(dump_properties_.GetDumpStep()); SetOpMappingLoopAddr(global_step_, loop_per_iter_, loop_cond_, op_mapping_info); - GELOGI("Dump step is %s and dump path is %s in load dump info", PropertiesManager::Instance().GetDumpStep().c_str(), + GELOGI("Dump step is %s and dump path is %s in load dump info", dump_properties_.GetDumpStep().c_str(), dump_path.c_str()); for (const auto &op_iter : op_list_) { - aicpu::dump::Task task; auto op_desc = op_iter.op; + GELOGD("Op %s in model %s begin to add task in op_mapping_info", op_desc->GetName().c_str(), dump_list_key.c_str()); + aicpu::dump::Task task; task.set_end_graph(false); task.set_task_id(op_iter.task_id); task.set_stream_id(op_iter.stream_id); task.mutable_op()->set_op_name(op_desc->GetName()); task.mutable_op()->set_op_type(op_desc->GetType()); - if (PropertiesManager::Instance().GetDumpMode() == kDumpOutput) { - if (DumpOutput(op_iter, task) != SUCCESS) { - GELOGE(FAILED, "Dump output failed"); - return FAILED; + if (dump_properties_.GetDumpMode() == kDumpOutput) { + Status ret = DumpOutput(op_iter, task); + if (ret != SUCCESS) { + GELOGE(ret, "Dump output failed"); + return ret; } op_mapping_info.mutable_task()->Add(std::move(task)); continue; } - if (PropertiesManager::Instance().GetDumpMode() == kDumpInput) { + if (dump_properties_.GetDumpMode() == kDumpInput) { if (op_iter.is_task) { - if (DumpInput(op_iter, task) != SUCCESS) { - GELOGE(FAILED, "Dump input failed"); - return FAILED; + Status ret = DumpInput(op_iter, task); + if (ret != SUCCESS) { + GELOGE(ret, "Dump input failed"); + return ret; } } op_mapping_info.mutable_task()->Add(std::move(task)); continue; } - if (PropertiesManager::Instance().GetDumpMode() == kDumpAll) { + if (dump_properties_.GetDumpMode() == kDumpAll) { auto ret = DumpOutput(op_iter, task); if (ret != SUCCESS) { - GELOGE(FAILED, "Dump output failed when in dumping all"); - return FAILED; + GELOGE(ret, "Dump output failed when in dumping all"); + return ret; } if (op_iter.is_task) { ret = DumpInput(op_iter, task); if (ret != SUCCESS) { - GELOGE(FAILED, "Dump input failed when in dumping all"); - return FAILED; + GELOGE(ret, "Dump input failed when in dumping all"); + return ret; } } op_mapping_info.mutable_task()->Add(std::move(task)); @@ -467,19 +627,22 @@ Status DataDumper::LoadDumpInfo() { SetEndGraphIdToAicpu(end_graph_task_id_, end_graph_stream_id_, op_mapping_info); - auto ret = ExecuteLoadDumpInfo(op_mapping_info); - if (ret != SUCCESS) { - GELOGE(FAILED, "Execute load dump info failed"); - return FAILED; + SetOpDebugIdToAicpu(op_debug_task_id_, op_debug_stream_id_, op_debug_addr_, op_mapping_info); + + if (!op_list_.empty() || is_op_debug_) { + auto ret = ExecuteLoadDumpInfo(op_mapping_info); + if (ret != SUCCESS) { + GELOGE(ret, "Execute load dump info failed"); + return ret; + } } return SUCCESS; } void DataDumper::SetEndGraphIdToAicpu(uint32_t task_id, uint32_t stream_id, aicpu::dump::OpMappingInfo &op_mapping_info) { - if (PropertiesManager::Instance().GetDumpMode() == kDumpOutput || - PropertiesManager::Instance().GetDumpMode() == kDumpInput || - PropertiesManager::Instance().GetDumpMode() == kDumpAll) { + if (dump_properties_.GetDumpMode() == kDumpOutput || dump_properties_.GetDumpMode() == kDumpInput || + dump_properties_.GetDumpMode() == kDumpAll) { GELOGI("Add end_graph_info to aicpu, task_id is %u, stream_id is %u", end_graph_task_id_, end_graph_stream_id_); aicpu::dump::Task task; task.set_end_graph(true); @@ -491,6 +654,37 @@ void DataDumper::SetEndGraphIdToAicpu(uint32_t task_id, uint32_t stream_id, } } +void DataDumper::SetOpDebugIdToAicpu(uint32_t task_id, uint32_t stream_id, void *op_debug_addr, + aicpu::dump::OpMappingInfo &op_mapping_info) { + if (is_op_debug_) { + GELOGI("add op_debug_info to aicpu, task_id is %u, stream_id is %u", task_id, stream_id); + aicpu::dump::Task task; + task.set_end_graph(false); + task.set_task_id(task_id); + task.set_stream_id(stream_id); + task.mutable_op()->set_op_name(NODE_NAME_OP_DEBUG); + task.mutable_op()->set_op_type(OP_TYPE_OP_DEBUG); + + // set output + aicpu::dump::Output output; + output.set_data_type(DT_UINT8); + output.set_format(FORMAT_ND); + + output.mutable_shape()->add_dim(kOpDebugShape); + + output.set_original_name(NODE_NAME_OP_DEBUG); + output.set_original_output_index(0); + output.set_original_output_format(FORMAT_ND); + output.set_original_output_data_type(DT_UINT8); + // due to lhisi virtual addr bug, cannot use args now + output.set_address(static_cast(reinterpret_cast(op_debug_addr))); + output.set_size(kOpDebugSize); + + task.mutable_output()->Add(std::move(output)); + op_mapping_info.mutable_task()->Add(std::move(task)); + } +} + Status DataDumper::UnloadDumpInfo() { if (!load_flag_) { GELOGI("No need to UnloadDumpInfo."); @@ -510,22 +704,24 @@ Status DataDumper::UnloadDumpInfo() { } auto ret = ExecuteUnLoadDumpInfo(op_mapping_info); if (ret != SUCCESS) { - GELOGE(FAILED, "Execute unload dump info failed"); - return FAILED; + GELOGE(ret, "Execute unload dump info failed"); + return ret; } return SUCCESS; } void DataDumper::PrintCheckLog(string &dump_list_key) { - std::set model_list = PropertiesManager::Instance().GetAllDumpModel(); + std::set model_list = dump_properties_.GetAllDumpModel(); if (model_list.empty()) { GELOGI("No model need dump."); return; } - GELOGI("%zu op need dump in %s.", op_list_.size(), model_name_.c_str()); bool not_find_by_omname = model_list.find(om_name_) == model_list.end(); bool not_find_by_modelname = model_list.find(model_name_) == model_list.end(); + dump_list_key = not_find_by_omname ? model_name_ : om_name_; + GELOGI("%zu op need dump in %s.", op_list_.size(), dump_list_key.c_str()); + if (model_list.find(DUMP_ALL_MODEL) == model_list.end()) { if (not_find_by_omname && not_find_by_modelname) { std::string model_list_str; @@ -533,12 +729,12 @@ void DataDumper::PrintCheckLog(string &dump_list_key) { model_list_str += "[" + model + "]."; } - GELOGW("Model %s will not be set to dump, dump list: %s", model_name_.c_str(), model_list_str.c_str()); + GELOGW("Model %s will not be set to dump, dump list: %s", dump_list_key.c_str(), model_list_str.c_str()); return; } } - dump_list_key = not_find_by_omname ? model_name_ : om_name_; - std::set config_dump_op_list = PropertiesManager::Instance().GetDumpPropertyValue(dump_list_key); + + std::set config_dump_op_list = dump_properties_.GetPropertyValue(dump_list_key); std::set dump_op_list; for (auto &inner_dump_info : op_list_) { // oplist value OpDescPtr is not nullptr diff --git a/src/ge/graph/load/new_model_manager/data_dumper.h b/src/ge/graph/load/new_model_manager/data_dumper.h index ee5b3241..0648a8ce 100644 --- a/src/ge/graph/load/new_model_manager/data_dumper.h +++ b/src/ge/graph/load/new_model_manager/data_dumper.h @@ -23,7 +23,9 @@ #include #include "framework/common/ge_inner_error_codes.h" +#include "common/properties_manager.h" #include "graph/node.h" +#include "graph/compute_graph.h" #include "proto/ge_ir.pb.h" #include "proto/op_mapping_info.pb.h" #include "runtime/mem.h" @@ -44,7 +46,9 @@ class DataDumper { device_id_(0), global_step_(0), loop_per_iter_(0), - loop_cond_(0) {} + loop_cond_(0), + compute_graph_(nullptr), + ref_info_() {} ~DataDumper(); @@ -56,6 +60,10 @@ class DataDumper { void SetDeviceId(uint32_t device_id) { device_id_ = device_id; } + void SetComputeGraph(const ComputeGraphPtr &compute_graph) { compute_graph_ = compute_graph; }; + + void SetRefInfo(const std::map &ref_info) { ref_info_ = ref_info; }; + void SetLoopAddr(void *global_step, void *loop_per_iter, void *loop_cond); void SaveDumpInput(const std::shared_ptr &node); @@ -65,11 +73,15 @@ class DataDumper { void SaveEndGraphId(uint32_t task_id, uint32_t stream_id); void SetOmName(const std::string &om_name) { om_name_ = om_name; } + void SaveOpDebugId(uint32_t task_id, uint32_t stream_id, void *op_debug_addr, bool is_op_debug); Status LoadDumpInfo(); Status UnloadDumpInfo(); + void SetDumpProperties(const DumpProperties &dump_properties) { dump_properties_ = dump_properties; } + const DumpProperties &GetDumpProperties() const { return dump_properties_; } + private: void ReleaseDevMem(void **ptr) noexcept; @@ -97,12 +109,32 @@ class DataDumper { uintptr_t global_step_; uintptr_t loop_per_iter_; uintptr_t loop_cond_; + ComputeGraphPtr compute_graph_; + std::map ref_info_; + + uint32_t op_debug_task_id_ = 0; + uint32_t op_debug_stream_id_ = 0; + void *op_debug_addr_ = nullptr; + bool is_op_debug_ = false; + + DumpProperties dump_properties_; Status DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task); + Status DumpRefOutput(const DataDumper::InnerDumpInfo &inner_dump_info, aicpu::dump::Output &output, size_t i, + const std::string &node_name_index); + Status DumpOutputWithTask(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task); Status DumpInput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task); + Status DumpRefInput(const DataDumper::InnerDumpInfo &inner_dump_info, aicpu::dump::Input &input, size_t i, + const std::string &node_name_index); Status ExecuteLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_info); void SetEndGraphIdToAicpu(uint32_t task_id, uint32_t stream_id, aicpu::dump::OpMappingInfo &op_mapping_info); + void SetOpDebugIdToAicpu(uint32_t task_id, uint32_t stream_id, void *op_debug_addr, + aicpu::dump::OpMappingInfo &op_mapping_info); Status ExecuteUnLoadDumpInfo(aicpu::dump::OpMappingInfo &op_mapping_info); + Status GenerateInput(aicpu::dump::Input &input, const OpDesc::Vistor &tensor_descs, + const uintptr_t &addr, size_t index); + Status GenerateOutput(aicpu::dump::Output &output, const OpDesc::Vistor &tensor_descs, + const uintptr_t &addr, size_t index); }; struct DataDumper::InnerDumpInfo { uint32_t task_id; diff --git a/src/ge/graph/load/new_model_manager/davinci_model.cc b/src/ge/graph/load/new_model_manager/davinci_model.cc index a8a11fd9..5af366a5 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model.cc +++ b/src/ge/graph/load/new_model_manager/davinci_model.cc @@ -36,13 +36,13 @@ #include "common/scope_guard.h" #include "common/thread_pool.h" #include "framework/common/debug/ge_log.h" +#include "graph/common/ge_call_wrapper.h" #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" #include "graph/graph.h" #include "graph/load/new_model_manager/cpu_queue_schedule.h" #include "graph/load/new_model_manager/tbe_handle_store.h" -#include "graph/load/output/output.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_var_manager.h" #include "graph/manager/trans_var_data_utils.h" @@ -58,6 +58,7 @@ #include "runtime/dev.h" #include "runtime/event.h" #include "runtime/mem.h" +#include "runtime/rt_model.h" #include "runtime/stream.h" #include "securec.h" @@ -78,9 +79,8 @@ namespace { const uint32_t kDataIndex = 0; const uint32_t kOutputNum = 1; const uint32_t kTrueBranchStreamNum = 1; -const uint32_t kThreadNum = 1; +const uint32_t kThreadNum = 16; const uint32_t kAddrLen = sizeof(void *); -const char *const kNeedDestroySpecifiedAicpuKernel = "need_destroy_specified_aicpu_kernel"; const int kDecimal = 10; const int kBytes = 8; const uint32_t kDataMemAlignSizeCompare = 64; @@ -89,10 +89,10 @@ const char *const kDefaultBatchLable = "Batch_default"; inline bool IsDataOp(const std::string &node_type) { return node_type == DATA_TYPE || node_type == AIPP_DATA_TYPE || node_type == ANN_DATA_TYPE; } -inline bool IsCallDumpInputOp(const OpDescPtr &op_desc) { - bool skip_task_generate = false; - (void)ge::AttrUtils::GetBool(op_desc, ATTR_NO_TASK_AND_DUMP_NEEDED, skip_task_generate); - return skip_task_generate; +inline bool IsNoTaskAndDumpNeeded(const OpDescPtr &op_desc) { + bool save_dump_info = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NO_TASK_AND_DUMP_NEEDED, save_dump_info); + return save_dump_info; } } // namespace @@ -125,10 +125,10 @@ DavinciModel::DavinciModel(int32_t priority, const std::shared_ptrGetModelTaskDefPtr(); return SUCCESS; } +/// +/// @ingroup ge +/// @brief Reduce memory usage after task sink. +/// @return: void +/// +void DavinciModel::Shrink() { + ge_model_.reset(); // delete object. + + // Old dump need op list, clear when closed. + char *ge_dump_env = std::getenv("DUMP_OP"); + int dump_op_switch = (ge_dump_env != nullptr) ? std::strtol(ge_dump_env, nullptr, kDecimal) : 0; + if (dump_op_switch == 0) { + op_list_.clear(); + } +} + Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { if (is_model_has_inited_) { - GELOGI("call InitModelMem more than once ."); + GELOGE(FAILED, "call InitModelMem more than once ."); return FAILED; } is_model_has_inited_ = true; - std::size_t data_size = TotalMemSize(); - ge::Buffer weights = ge_model_->GetWeight(); - uint8_t *weights_addr = weights.GetData(); + std::size_t data_size = TotalMemSize(); + const Buffer &weights = ge_model_->GetWeight(); std::size_t weights_size = weights.GetSize(); - GE_CHECK_LE(weights_size, ALLOC_MEMORY_MAX_SIZE); if ((dev_ptr != nullptr) && (mem_size < TotalMemSize())) { @@ -257,7 +287,8 @@ Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_p if (TotalMemSize() && mem_base_ == nullptr) { mem_base_ = MallocFeatureMapMem(data_size); if (mem_base_ == nullptr) { - return FAILED; + GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc feature map memory failed. size: %zu", data_size); + return GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED; } GELOGI("[IMAS]InitModelMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, mem_base_, data_size); @@ -274,17 +305,18 @@ Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_p if (weight_ptr == nullptr) { weights_mem_base_ = MallocWeightsMem(weights_size); if (weights_mem_base_ == nullptr) { - return FAILED; + GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc weight memory failed. size: %zu", weights_size); + return GE_EXEC_ALLOC_WEIGHT_MEM_FAILED; } is_inner_weight_base_ = true; } GELOGI("[IMAS]InitModelMem graph_%u MallocMemory type[W] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, weights_mem_base_, weights_size); - GE_CHK_RT_RET(rtMemcpy(weights_mem_base_, weights_size, weights_addr, weights_size, RT_MEMCPY_HOST_TO_DEVICE)) + GE_CHK_RT_RET(rtMemcpy(weights_mem_base_, weights_size, weights.GetData(), weights_size, RT_MEMCPY_HOST_TO_DEVICE)); GELOGI("copy weights data to device"); } - GE_CHK_STATUS_RET(InitVariableMem(), "init variable mem failed."); + GE_CHK_STATUS_RET(InitVariableMem(), "Init variable memory failed."); runtime_param_.mem_base = mem_base_; runtime_param_.weight_base = weights_mem_base_; return SUCCESS; @@ -296,7 +328,7 @@ Status DavinciModel::InitVariableMem() { if (TotalVarMemSize() && var_mem_base_ == nullptr) { Status ret = VarManager::Instance(session_id_)->MallocVarMemory(TotalVarMemSize()); if (ret != SUCCESS) { - GELOGE(ret, "Malloc Var Memory Fail."); + GELOGE(ret, "Malloc variable memory failed."); return ret; } var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM); @@ -335,19 +367,15 @@ void DavinciModel::InitRuntimeParams() { session_id_ = runtime_param_.session_id; GELOGI( - "InitRuntimeParams(), memory_size:%lu, weight_size:%lu, session_id:%u, var_size:%lu, logic_var_base:%lu, " - "logic_mem_base:%lu.", - runtime_param_.mem_size, runtime_param_.weight_size, runtime_param_.session_id, runtime_param_.var_size, - runtime_param_.logic_var_base, runtime_param_.logic_mem_base); - - GELOGI("InitRuntimeParams(), stream_num:%lu, event_num:%u, label_num:%u", runtime_param_.stream_num, - runtime_param_.event_num, runtime_param_.label_num); + "InitRuntimeParams(), session_id:%u, stream_num:%lu, event_num:%u, label_num:%u, " + "logic_mem_base:0x%lx, logic_weight_base:0x%lx, logic_var_base:0x%lx, " + "memory_size:%lu, weight_size:%lu, var_size:%lu", + runtime_param_.session_id, runtime_param_.stream_num, runtime_param_.event_num, runtime_param_.label_num, + runtime_param_.logic_mem_base, runtime_param_.logic_weight_base, runtime_param_.logic_var_base, + runtime_param_.mem_size, runtime_param_.weight_size, runtime_param_.var_size); } void DavinciModel::CheckHasHcomOp() { - // definiteness queue schedule, all stream by TS. - GE_IF_BOOL_EXEC(!input_queue_ids_.empty() || !output_queue_ids_.empty(), return ); - Graph graph = ge_model_->GetGraph(); auto compute_graph = GraphUtils::GetComputeGraph(graph); if (compute_graph == nullptr) { @@ -363,11 +391,6 @@ void DavinciModel::CheckHasHcomOp() { (op_desc->GetType() == HVDCALLBACKBROADCAST) || (op_desc->GetType() == HVDWAIT)), uint32_t stream_id = static_cast(op_desc->GetStreamId()); (void)hcom_streams_.emplace(stream_id); GELOGD("hcom stream: %u.", stream_id); continue); - - bool is_aicpu_stream = false; - GE_IF_BOOL_EXEC(AttrUtils::GetBool(op_desc, "is_aicpu_stream", is_aicpu_stream) && is_aicpu_stream, - uint32_t stream_id = static_cast(op_desc->GetStreamId()); - (void)aicpu_streams_.emplace(stream_id); GELOGD("aicpu stream: %u.", stream_id); continue); } } @@ -378,20 +401,13 @@ void DavinciModel::CheckHasHcomOp() { /// Status DavinciModel::BindModelStream() { // Stream not in active_stream_indication_ is active stream. - if (!input_queue_ids_.empty() || !output_queue_ids_.empty()) { - // Asynchronous Queue, need add S0, deactive all model stream. + if ((!input_queue_ids_.empty() || !output_queue_ids_.empty()) || (deploy_type_ == AICPU_DEPLOY_CROSS_THREAD)) { for (size_t i = 0; i < stream_list_.size(); ++i) { if (active_stream_indication_.count(i) == 0) { active_stream_list_.push_back(stream_list_[i]); active_stream_indication_.insert(i); // deactive all model stream. } } - } else { - for (size_t i = 0; i < stream_list_.size(); ++i) { - if (active_stream_indication_.count(i) == 0) { - active_stream_list_.push_back(stream_list_[i]); - } - } } for (size_t i = 0; i < stream_list_.size(); ++i) { @@ -409,23 +425,29 @@ Status DavinciModel::BindModelStream() { Status DavinciModel::DoTaskSink() { // task sink is supported as model_task_def is set - if (model_task_def_) { - GELOGI("do task_sink."); - GE_CHK_STATUS_RET(BindModelStream(), "Bind model stream failed."); + const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); + if (model_task_def == nullptr) { + return SUCCESS; + } - if (known_node_) { - GE_CHK_STATUS_RET(MallocKnownArgs(), "Mallloc known node args failed."); - } + GE_CHK_RT_RET(rtGetAicpuDeploy(&deploy_type_)); + GELOGI("do task_sink. AiCpu deploy type is: %x.", deploy_type_); - GE_CHK_STATUS_RET(InitTaskInfo(*model_task_def_.get()), "InitTaskInfo failed."); + GE_CHK_STATUS_RET(BindModelStream(), "Bind model stream failed."); - GE_CHK_STATUS_RET(LoadWithQueue(), "LoadWithQueue failed."); + if (known_node_) { + GE_CHK_STATUS_RET(MallocKnownArgs(), "Mallloc known node args failed."); + } - GE_CHK_STATUS_RET(DistributeTask(), "Distribute failed."); + GE_CHK_STATUS_RET(InitTaskInfo(*model_task_def.get()), "InitTaskInfo failed."); - GE_CHK_RT_RET(rtModelLoadComplete(rt_model_handle_)); - } + GE_CHK_STATUS_RET(InitEntryTask(), "InitEntryTask failed."); + + GE_CHK_STATUS_RET(DistributeTask(), "Distribute failed."); + GE_CHK_RT_RET(rtModelLoadComplete(rt_model_handle_)); + + SetCopyOnlyOutput(); return SUCCESS; } @@ -438,17 +460,98 @@ Status DavinciModel::SetTSDevice() { rtError_t rt_ret = rtSetTSDevice(core_type); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "SetTSDevice failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } return SUCCESS; } +Status DavinciModel::OpDebugRegister() { + bool is_op_debug = false; + (void)ge::AttrUtils::GetBool(ge_model_, ATTR_OP_DEBUG_FLAG, is_op_debug); + GELOGD("The value of op_debug in ge_model_ is %d.", is_op_debug); + if (is_op_debug) { + debug_reg_mutex_.lock(); + rtError_t rt_ret = rtMalloc(&op_debug_addr_, kOpDebugMemorySize, RT_MEMORY_DDR); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtMalloc error, ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + + uint64_t debug_addrs_tmp = static_cast(reinterpret_cast(op_debug_addr_)); + + // For data dump, aicpu needs the pointer to pointer that save the real debug address. + rt_ret = rtMalloc(&p2p_debug_addr_, kDebugP2pSize, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtMalloc error, ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + rt_ret = rtMemcpy(p2p_debug_addr_, sizeof(uint64_t), &debug_addrs_tmp, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtMemcpy to p2p_addr error: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + + uint32_t op_debug_mode = 0; + (void)ge::AttrUtils::GetInt(ge_model_, ATTR_OP_DEBUG_MODE, op_debug_mode); + GELOGD("The value of op_debug_mode in ge_model_ is %u.", op_debug_mode); + uint32_t debug_task_id = 0; + uint32_t debug_stream_id = 0; + rt_ret = rtDebugRegister(rt_model_handle_, op_debug_mode, op_debug_addr_, &debug_stream_id, &debug_task_id); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtDebugRegister error, ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + GELOGI("debug_task_id:%d, debug_stream_id:%u", debug_task_id, debug_stream_id); + is_op_debug_reg_ = true; + + data_dumper_.SaveOpDebugId(debug_task_id, debug_stream_id, p2p_debug_addr_, is_op_debug); + } + + return SUCCESS; +} + +void DavinciModel::OpDebugUnRegister() { + GELOGI("OpDebugUnRegister, is_op_debug_reg_ = %d", is_op_debug_reg_); + if (is_op_debug_reg_) { + debug_reg_mutex_.unlock(); + rtError_t rt_ret = RT_ERROR_NONE; + if (rt_model_handle_ != nullptr) { + rt_ret = rtDebugUnRegister(rt_model_handle_); + if (rt_ret != RT_ERROR_NONE) { + GELOGW("rtDebugUnRegister failed, ret: 0x%X", rt_ret); + } + } + + if (op_debug_addr_ != nullptr) { + rt_ret = rtFree(op_debug_addr_); + if (rt_ret != RT_ERROR_NONE) { + GELOGW("rtFree failed, ret: 0x%X", rt_ret); + } + op_debug_addr_ = nullptr; + } + + if (p2p_debug_addr_ != nullptr) { + rt_ret = rtFree(p2p_debug_addr_); + if (rt_ret != RT_ERROR_NONE) { + GELOGW("rtFree failed, ret: 0x%X", rt_ret); + } + p2p_debug_addr_ = nullptr; + } + is_op_debug_reg_ = false; + } + return; +} + // initialize op sequence and call initialization function of each op respectively Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { // validating params GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(priority_ < 0 || priority_ > 7, return PARAM_INVALID, "Priority must between 0-7, now is %d", priority_); GE_CHK_BOOL_RET_STATUS(ge_model_ != nullptr, PARAM_INVALID, "GeModel is null."); + Graph graph = ge_model_->GetGraph(); + ComputeGraphPtr compute_graph = GraphUtils::GetComputeGraph(graph); + GE_CHK_BOOL_RET_STATUS(compute_graph != nullptr, INTERNAL_ERROR, "Get compute graph is nullptr."); + // Initializing runtime_param_ InitRuntimeParams(); @@ -477,8 +580,6 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size if (hcom_streams_.find(i) != hcom_streams_.end()) { GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, stream_flags | RT_STREAM_FORCE_COPY)); - } else if (aicpu_streams_.find(i) != aicpu_streams_.end()) { - GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, stream_flags | RT_STREAM_AICPU)); } else { GE_CHK_RT_RET(rtStreamCreateWithFlags(&stream, priority_, stream_flags)); } @@ -499,35 +600,36 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size // create model_handle to load model GE_CHK_RT_RET(rtModelCreate(&rt_model_handle_, 0)); GE_CHK_RT_RET(rtModelGetId(rt_model_handle_, &runtime_model_id_)); + // inference will use default graph_id 0; + runtime_param_.graph_id = compute_graph->GetGraphID(); - Graph graph = ge_model_->GetGraph(); - compute_graph_ = GraphUtils::GetComputeGraph(graph); - GE_CHK_BOOL_RET_STATUS(compute_graph_ != nullptr, INTERNAL_ERROR, "Get compute graph is nullptr."); - - runtime_param_.graph_id = compute_graph_->GetGraphID(); + // op debug register + GE_CHK_STATUS_RET(OpDebugRegister(), "OpDebugRegister failed"); GE_TIMESTAMP_START(TransAllVarData); - GE_CHK_STATUS_RET(TransAllVarData(compute_graph_, runtime_param_.graph_id), "TransAllVarData failed."); + GE_CHK_STATUS_RET(TransAllVarData(compute_graph, runtime_param_.graph_id), "TransAllVarData failed."); GE_TIMESTAMP_END(TransAllVarData, "GraphLoader::TransAllVarData"); - GE_CHK_STATUS_RET(CopyVarData(compute_graph_), "copy var data failed."); + GE_CHK_STATUS_RET(TransVarDataUtils::CopyVarData(compute_graph, session_id_, device_id_), "copy var data failed."); GE_TIMESTAMP_START(InitModelMem); - GELOGI("known_node is %d", known_node_); + GELOGI("Known node is %d", known_node_); if (!known_node_) { GE_CHK_STATUS_RET_NOLOG(InitModelMem(dev_ptr, mem_size, weight_ptr, weight_size)); data_inputer_ = new (std::nothrow) DataInputer(); - GE_CHK_BOOL_RET_STATUS(data_inputer_ != nullptr, INTERNAL_ERROR, "data_inputer_ is nullptr."); + GE_CHK_BOOL_RET_STATUS(data_inputer_ != nullptr, MEMALLOC_FAILED, "data_inputer_ is nullptr."); } GE_TIMESTAMP_END(InitModelMem, "GraphLoader::InitModelMem"); - for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { - GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); - GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); + for (const ge::NodePtr &node : compute_graph->GetDirectNode()) { + auto op_desc = node->GetOpDesc(); + GE_IF_BOOL_EXEC(op_desc == nullptr, continue); + GetFixedAddrAttr(op_desc); + GE_IF_BOOL_EXEC(op_desc->GetType() != VARIABLE, continue); GE_IF_BOOL_EXEC(IsBroadCastOpData(node), - (void)ge::AttrUtils::SetStr(node->GetOpDesc(), VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore");); + (void)ge::AttrUtils::SetStr(op_desc, VAR_ATTR_VAR_IS_BROADCAST, "var_is_restore");); } // for profiling - op_name_map_ = compute_graph_->GetGraphOpName(); + op_name_map_ = compute_graph->GetGraphOpName(); vector op_name; GE_IF_BOOL_EXEC(ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_TASK_INDEX_OP_NAME, op_name), @@ -536,14 +638,12 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size for (size_t idx = 0; idx < op_name.size(); idx++) { op_name_map_[idx] = op_name[idx]; } - GELOGI("infer profiling: op_name_size(%zu)", op_name.size()); + GELOGI("Infer profiling: op_name_size(%zu)", op_name.size()); } - if (InitNodes(compute_graph_) != SUCCESS) { - return FAILED; - } + GE_CHK_STATUS_RET(InitNodes(compute_graph), "Init nodes failed"); - SetDataDumperArgs(); + SetDataDumperArgs(compute_graph); GE_TIMESTAMP_START(DoTaskSink); auto ret = DoTaskSink(); GE_TIMESTAMP_END(DoTaskSink, "GraphLoader::DoTaskSink"); @@ -551,22 +651,23 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size /// In zero copy model, if a aicpu operator is connected to the first or last layer, before model execution, /// the aicpu opertor needs to destroy history record, and update operator memory address. /// The model with specified aicpu operators is only marked here, and destruction is in ModelManager::ExecuteModel(). - if (MarkSpecifiedAicpuKernel() != SUCCESS) { - GELOGE(FAILED, "Mark model with specified aicpu operators failed."); - return FAILED; - } + need_destroy_aicpu_kernel_ = IsAicpuKernelConnectSpecifiedLayer(); + (void)ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_OUT_NODES_NAME, out_node_name_); // collect profiling for ge if (ProfilingManager::Instance().ProfilingOn()) { std::vector compute_graph_desc_info; - Status ret1 = GetComputeGraphInfo(compute_graph_desc_info); + Status ret1 = GetComputeGraphInfo(compute_graph, compute_graph_desc_info); if (ret1 != SUCCESS) { GELOGE(ret1, "GetComputeGraphInfo failed."); return ret1; } ProfilingManager::Instance().ReportProfilingData(GetTaskDescInfo(), compute_graph_desc_info); + GE_CHK_STATUS(SinkModelProfile(), "Sink model profile failed."); } - GELOGI("davinci model init success."); + + Shrink(); + GELOGI("Davinci model init success."); return ret; } @@ -623,26 +724,14 @@ bool DavinciModel::IsAicpuKernelConnectSpecifiedLayer() { return false; } -/// -/// @ingroup ge -/// @brief mark ge model with specified aicpu operators . -/// @return Status -/// -Status DavinciModel::MarkSpecifiedAicpuKernel() { - bool result = IsAicpuKernelConnectSpecifiedLayer(); - if (!result) { - // No aicpu operator needing destroy. - GELOGD("No specified aicpu operator that connects to data or netoutput."); - return SUCCESS; - } - bool ret = ge::AttrUtils::SetBool(ge_model_, kNeedDestroySpecifiedAicpuKernel, result); - if (!ret) { - GELOGW("Add attr[%s] in ge model failed, and may lead to specified aicpu operators destruction failure.", - kNeedDestroySpecifiedAicpuKernel); +Status DavinciModel::UpdateSessionId(uint64_t session_id) { + GE_CHECK_NOTNULL(ge_model_); + if (!AttrUtils::SetInt(ge_model_, MODEL_ATTR_SESSION_ID, static_cast(session_id))) { + GELOGW("Set attr[%s] failed in updating session_id.", MODEL_ATTR_SESSION_ID.c_str()); } - GELOGI("Mark ge model success, the model has specified aicpu operators, ge model name: %s.", - ge_model_->GetName().c_str()); + + GELOGD("Update session id: %lu.", session_id); return SUCCESS; } @@ -689,12 +778,6 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { continue; } - if (IsCallDumpInputOp(op_desc)) { - GELOGI("node[%s] is no task op , call SaveDumpInput to save it's output node info", op_desc->GetName().c_str()); - data_dumper_.SaveDumpInput(node); - continue; - } - if (op_desc->GetType() == NETOUTPUT) { if (InitNetOutput(node) != SUCCESS) { GELOGE(PARAM_INVALID, "NetOutput init failed, Name: %s", op_desc->GetName().c_str()); @@ -712,6 +795,29 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { continue; } + if (IsNoTaskAndDumpNeeded(op_desc)) { + GELOGD("node[%s] without task, and save op_desc and addr for dump", op_desc->GetName().c_str()); + const RuntimeParam &rts_param = GetRuntimeParam(); + const vector input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); + const vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); + const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); + vector tensor_device_addrs; + tensor_device_addrs.insert(tensor_device_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + tensor_device_addrs.insert(tensor_device_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); + tensor_device_addrs.insert(tensor_device_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); + void *addr = nullptr; + auto size = kAddrLen * tensor_device_addrs.size(); + GE_CHK_RT_RET(rtMalloc(&addr, size, RT_MEMORY_HBM)); + + rtError_t rt_ret = rtMemcpy(addr, size, tensor_device_addrs.data(), size, RT_MEMCPY_HOST_TO_DEVICE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtMemcpy error, ret: 0x%X", rt_ret); + GE_CHK_RT(rtFree(addr)); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } + saved_task_addrs_.emplace(op_desc, addr); + } + GE_TIMESTAMP_RESTART(InitTbeHandle); uint32_t run_mode = static_cast(domi::ImplyType::INVALID); if (AttrUtils::GetInt(op_desc, ATTR_NAME_IMPLY_TYPE, run_mode) && @@ -724,9 +830,10 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { op_desc->GetName().c_str(), op_desc->GetType().c_str()); continue;); - if (InitTbeHandle(op_desc) != SUCCESS) { - GELOGE(PARAM_INVALID, "TBE init failed. %s", op_desc->GetName().c_str()); - return PARAM_INVALID; + Status status = InitTbeHandle(op_desc); + if (status != SUCCESS) { + GELOGE(status, "TBE init failed. %s", op_desc->GetName().c_str()); + return status; } } GE_TIMESTAMP_ADD(InitTbeHandle); @@ -741,7 +848,6 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { /// @brief Data Op Initialize. /// @param [in] NodePtr: Data Op. /// @param [in/out] data_op_index: NetOutput addr size info. -/// @param [in/out] input_data_info: Data index and addr info {index, {size, addr}}. /// @return Status Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { // op_desc Checked by Init: Data, valid. @@ -757,31 +863,39 @@ Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { } data_op_list_.push_back(op_desc); - ConstGeTensorDescPtr input_desc = op_desc->GetInputDescPtr(kDataIndex); - if (input_desc != nullptr && input_desc->GetFormat() != FORMAT_FILTER_HWCK) { - data_op_input_tensor_desc_map_[op_desc->GetName()] = input_desc; - } - - ConstGeTensorDescPtr output_desc = op_desc->GetOutputDescPtr(kDataIndex); - if (output_desc != nullptr && output_desc->GetFormat() != FORMAT_FRACTAL_Z) { - data_op_output_tensor_desc_map_[op_desc->GetName()] = output_desc; - } // Make information for copy input data. const vector output_size_list = ModelUtils::GetOutputSize(op_desc); - const vector virtual_addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc, false); - if (output_size_list.empty() || virtual_addr_list.empty() || (output_size_list.size() != virtual_addr_list.size())) { - GELOGE(PARAM_INVALID, "Data[%s] init failed: Output size is %zu, Output addr is %zu", op_desc->GetName().c_str(), - output_size_list.size(), virtual_addr_list.size()); + const vector virtual_addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, op_desc); + const vector output_offset_list = op_desc->GetOutputOffset(); + if (output_offset_list.size() != virtual_addr_list.size()) { + GELOGE(PARAM_INVALID, "virtual_addr size:%zu should be equal to offset size:%zu.", virtual_addr_list.size(), + output_offset_list.size()); return PARAM_INVALID; } - auto data_index = data_op_index; if (AttrUtils::GetInt(op_desc, ATTR_NAME_INDEX, data_index)) { GELOGI("ge_train: get new index %u, old %u", data_index, data_op_index); } - input_data_info_[data_index] = {output_size_list[kDataIndex], virtual_addr_list[kDataIndex]}; - SetInputOutsideAddr(virtual_addr_list); + bool fusion_flag = false; + ZeroCopyOffset zero_copy_offset; + Status ret = zero_copy_offset.InitInputDataInfo(output_size_list, virtual_addr_list, op_desc, fusion_flag); + if (ret != SUCCESS) { + GELOGE(PARAM_INVALID, "InitDataInfo of input_info %s failed.", op_desc->GetName().c_str()); + return PARAM_INVALID; + } + new_input_data_info_[data_index] = zero_copy_offset; + + for (size_t index = 0; index < virtual_addr_list.size(); ++index) { + void *addr = virtual_addr_list.at(index); + if (new_input_outside_addrs_.find(addr) != new_input_outside_addrs_.end()) { + continue; + } + zero_copy_offset.SetInputOutsideAddrs(output_offset_list, addr, index, fusion_flag, real_virtual_addrs_); + new_input_outside_addrs_[addr] = zero_copy_offset; + } + + GELOGI("SetInputOutsideAddr success."); data_op_index++; if (InitInputZeroCopyNodes(node) != SUCCESS) { GELOGE(PARAM_INVALID, "Input zero copy nodes init failed!"); @@ -830,6 +944,7 @@ Status DavinciModel::InitInputZeroCopyNodes(const NodePtr &node) { Status DavinciModel::InitNetOutput(const NodePtr &node) { // node->GetOpDesc Checked by Init: NetOutput, valid. auto op_desc = node->GetOpDesc(); + // excludes the function op sub graph, e.g. case,if if (known_node_) { output_op_list_.push_back(op_desc); return SUCCESS; @@ -845,7 +960,12 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { output_op_list_.push_back(op_desc); // Make information for copy output data. const vector input_size_list = ModelUtils::GetInputSize(op_desc); - const vector virtual_addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, op_desc, false); + const vector virtual_addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, op_desc); + const vector input_offset_list = op_desc->GetInputOffset(); + if (input_offset_list.size() != virtual_addr_list.size()) { + GELOGE(PARAM_INVALID, "virtual_addr size should be equal to offset size."); + return PARAM_INVALID; + } if (input_size_list.empty() && virtual_addr_list.empty()) { GELOGI("NetOutput[%s] is empty.", op_desc->GetName().c_str()); return SUCCESS; @@ -856,12 +976,38 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { return PARAM_INVALID; } - size_t num = output_data_info_.size(); + size_t num = new_output_data_info_.size(); + bool fusion_flag = false; + for (size_t idx = 0; idx < input_size_list.size(); ++idx) { - output_data_info_[num + idx] = {input_size_list[idx], virtual_addr_list[idx]}; + ZeroCopyOffset zero_copy_offset; + Status ret = zero_copy_offset.InitOutputDataInfo(input_size_list, virtual_addr_list, op_desc, idx, fusion_flag); + if (ret != SUCCESS) { + GELOGE(PARAM_INVALID, "InitDataInfo of input_info %s failed.", op_desc->GetName().c_str()); + return PARAM_INVALID; + } + new_output_data_info_[num + idx] = zero_copy_offset; + void *addr = virtual_addr_list.at(idx); + int64_t input_offset = input_offset_list.at(idx); + if (new_output_outside_addrs_.find(addr) != new_output_outside_addrs_.end()) { + continue; + } + vector tensor_addrs; + zero_copy_offset.SetOutputOutsideAddrs(input_offset, fusion_flag, addr, tensor_addrs); + auto rslt = new_output_outside_addrs_.insert(std::pair(addr, zero_copy_offset)); + if (!rslt.second) { + GELOGI("same output_tensor_addr %p to different input_tensor of %s", addr, op_desc->GetName().c_str()); + DisableZeroCopy(addr); + } + + for (size_t i = 0; i < tensor_addrs.size(); ++i) { + void *real_addr = tensor_addrs.at(i); + DisableZeroCopy(real_addr); + real_virtual_addrs_.emplace_back(real_addr); + } + GELOGI("SetOutputOutsideAddr success."); } - SetOutputOutsideAddr(virtual_addr_list); if (InitOutputZeroCopyNodes(node) != SUCCESS) { GELOGE(PARAM_INVALID, "Output zero copy nodes init failed!"); return PARAM_INVALID; @@ -968,8 +1114,8 @@ Status DavinciModel::InitVariable(const OpDescPtr &op_desc) { Status DavinciModel::SetQueIds(const std::vector &input_queue_ids, const std::vector &output_queue_ids) { if (input_queue_ids.empty() && output_queue_ids.empty()) { - GELOGE(PARAM_INVALID, "Para is empty"); - return PARAM_INVALID; + GELOGE(GE_EXEC_MODEL_QUEUE_ID_INVALID, "Param is empty"); + return GE_EXEC_MODEL_QUEUE_ID_INVALID; } input_queue_ids_ = input_queue_ids; @@ -989,32 +1135,28 @@ Status DavinciModel::LoadWithQueue() { return SUCCESS; } - if (input_queue_ids_.size() != input_data_info_.size()) { - GELOGE(PARAM_INVALID, "Input queue ids not match model: input_queue=%zu input_data=%zu", input_queue_ids_.size(), - input_data_info_.size()); - return PARAM_INVALID; + if (input_queue_ids_.size() != new_input_data_info_.size()) { + GELOGE(GE_EXEC_MODEL_QUEUE_ID_INVALID, "Input queue ids not match model: input_queue=%zu input_data=%zu", + input_queue_ids_.size(), new_input_data_info_.size()); + return GE_EXEC_MODEL_QUEUE_ID_INVALID; } - if (output_queue_ids_.size() != output_data_info_.size()) { - GELOGE(PARAM_INVALID, "Output queue ids not match model: output_queue=%zu output_data=%zu", - output_queue_ids_.size(), output_data_info_.size()); - return PARAM_INVALID; + if (output_queue_ids_.size() != new_output_data_info_.size()) { + GELOGE(GE_EXEC_MODEL_QUEUE_ID_INVALID, "Output queue ids not match model: output_queue=%zu output_data=%zu", + output_queue_ids_.size(), new_output_data_info_.size()); + return GE_EXEC_MODEL_QUEUE_ID_INVALID; } - // create stream instance which rt_model_handel is running on, this is S0. - GE_CHK_RT_RET(rtStreamCreateWithFlags(&rt_model_stream_, priority_, RT_STREAM_AICPU)); - is_inner_model_stream_ = true; - GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, rt_model_stream_, RT_HEAD_STREAM)); - + GE_CHK_STATUS_RET(AddHeadStream(), "Add head stream failed."); // Binding input_queue and Data Op. GE_CHK_STATUS_RET(BindInputQueue(), "Launch bind input queue failed."); - GE_CHK_STATUS_RET(CpuTaskModelZeroCopy(input_mbuf_list_, input_outside_addrs_), "Launch zero copy failed."); + GE_CHK_STATUS_RET(CpuTaskModelZeroCopy(input_mbuf_list_, new_input_outside_addrs_), "Launch zero copy failed."); // Binding output_queue and NetOutput Op. GE_CHK_STATUS_RET(BindOutputQueue(), "Launch bind output queue failed."); - GE_CHK_STATUS_RET(CpuTaskModelZeroCopy(output_mbuf_list_, output_outside_addrs_), "Launch zero copy failed."); + GE_CHK_STATUS_RET(CpuTaskModelZeroCopy(output_mbuf_list_, new_output_outside_addrs_), "Launch zero copy failed."); - GE_CHK_STATUS_RET(CpuActiveStream(active_stream_list_), "Launch active entry stream failed."); + GE_CHK_STATUS_RET(CpuActiveStream(), "Launch active entry stream failed."); GE_CHK_STATUS_RET(CpuWaitEndGraph(), "Launch wait end graph failed."); GE_CHK_STATUS_RET(BindEnqueue(), "Launch enqueue failed."); GE_CHK_STATUS_RET(CpuModelRepeat(), "Launch model repeat failed."); @@ -1028,20 +1170,26 @@ Status DavinciModel::LoadWithQueue() { Status DavinciModel::BindInputQueue() { // Caller checked: input_queue_ids_.size() == input_size_list_.size() != input_addr_list_.size() for (size_t i = 0; i < input_queue_ids_.size(); ++i) { - auto it = input_data_info_.find(i); - if (it == input_data_info_.end()) { - GELOGE(FAILED, "Input not match: tensor num=%zu, Queue id index=%zu", input_data_info_.size(), i); + auto it = new_input_data_info_.find(i); + if (it == new_input_data_info_.end()) { + GELOGE(FAILED, "Input not match: tensor num=%zu, Queue id index=%zu", new_input_data_info_.size(), i); return FAILED; } uint32_t queue_id = input_queue_ids_[i]; - uint32_t data_size = static_cast(it->second.first); - uintptr_t data_addr = reinterpret_cast(it->second.second); + if (it->second.GetDataInfo().empty()) { + GELOGE(INTERNAL_ERROR, "the %zu input_queue not set data_info.", i); + return INTERNAL_ERROR; + } + uint32_t data_size = static_cast(it->second.GetDataInfo().at(0).first); + uintptr_t data_addr = reinterpret_cast(it->second.GetDataInfo().at(0).second); GELOGI("BindInputToQueue: graph_%u index[%zu] queue id[%u] output addr[0x%lx] output size[%u]", runtime_param_.graph_id, i, queue_id, data_addr, data_size); - if (rtModelBindQueue(rt_model_handle_, queue_id, RT_MODEL_INPUT_QUEUE) != RT_ERROR_NONE) { - return INTERNAL_ERROR; + rtError_t rt_ret = rtModelBindQueue(rt_model_handle_, queue_id, RT_MODEL_INPUT_QUEUE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtModelBindQueue failed, ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); } if (CpuModelDequeue(queue_id) != SUCCESS) { @@ -1058,16 +1206,17 @@ Status DavinciModel::BindInputQueue() { /// @return: 0 for success / others for failed Status DavinciModel::CpuModelDequeue(uint32_t queue_id) { GELOGI("Set CpuKernel model dequeue task enter."); - std::shared_ptr dequeue_task = MakeShared(rt_model_stream_); + std::shared_ptr dequeue_task = MakeShared(rt_entry_stream_); if (dequeue_task == nullptr) { - GELOGE(FAILED, "Make CpuTaskModelDequeue task failed."); - return FAILED; + GELOGE(MEMALLOC_FAILED, "Make CpuTaskModelDequeue task failed."); + return MEMALLOC_FAILED; } // Get DataOp Output address and bind to queue. uintptr_t in_mbuf = 0; - if (dequeue_task->Init(queue_id, in_mbuf) != SUCCESS) { - return FAILED; + Status status = dequeue_task->Init(queue_id, in_mbuf); + if (status != SUCCESS) { + return status; } cpu_task_list_.push_back(dequeue_task); @@ -1077,16 +1226,18 @@ Status DavinciModel::CpuModelDequeue(uint32_t queue_id) { } Status DavinciModel::CpuTaskModelZeroCopy(std::vector &mbuf_list, - std::map> &outside_addrs) { + std::map &outside_addrs) { GELOGI("Set CpuKernel model zero_copy task enter."); - std::shared_ptr zero_copy = MakeShared(rt_model_stream_); + std::shared_ptr zero_copy = MakeShared(rt_entry_stream_); if (zero_copy == nullptr) { - GELOGE(FAILED, "Make CpuTaskZeroCopy task failed."); - return FAILED; + GELOGE(MEMALLOC_FAILED, "Make CpuTaskZeroCopy task failed."); + return MEMALLOC_FAILED; } - if (zero_copy->Init(mbuf_list, outside_addrs) != SUCCESS) { - return FAILED; + // mdc zero_copy not support l2 fusion + Status status = zero_copy->Init(mbuf_list, outside_addrs); + if (status != SUCCESS) { + return status; } cpu_task_list_.push_back(zero_copy); GELOGI("Set CpuKernel model zero_copy task success."); @@ -1099,23 +1250,31 @@ Status DavinciModel::CpuTaskModelZeroCopy(std::vector &mbuf_list, Status DavinciModel::BindOutputQueue() { // Caller checked: input_queue_ids_.size() == input_size_list_.size() != input_addr_list_.size() for (size_t i = 0; i < output_queue_ids_.size(); ++i) { - auto it = output_data_info_.find(i); - if (it == output_data_info_.end()) { - GELOGE(FAILED, "Output not match: tensor num=%zu, Queue id index=%zu", output_data_info_.size(), i); + auto it = new_output_data_info_.find(i); + if (it == new_output_data_info_.end()) { + GELOGE(FAILED, "Output not match: tensor num=%zu, Queue id index=%zu", new_output_data_info_.size(), i); return FAILED; } uint32_t queue_id = output_queue_ids_[i]; - uint32_t data_size = static_cast(it->second.first); - uintptr_t data_addr = reinterpret_cast(it->second.second); + if (it->second.GetDataInfo().empty()) { + GELOGE(INTERNAL_ERROR, "the %zu output_queue not set data_info.", i); + return INTERNAL_ERROR; + } + uint32_t data_size = static_cast(it->second.GetDataInfo().at(0).first); + uintptr_t data_addr = reinterpret_cast(it->second.GetDataInfo().at(0).second); GELOGI("BindOutputToQueue: graph_%u index[%zu] queue id[%u] input addr[0x%lx] input size[%u]", runtime_param_.graph_id, i, queue_id, data_addr, data_size); - if (rtModelBindQueue(rt_model_handle_, queue_id, RT_MODEL_OUTPUT_QUEUE) != RT_ERROR_NONE) { - return INTERNAL_ERROR; + rtError_t rt_ret = rtModelBindQueue(rt_model_handle_, queue_id, RT_MODEL_OUTPUT_QUEUE); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtModelBindQueue failed, ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); } - if (CpuModelPrepareOutput(data_addr, data_size) != SUCCESS) { - return INTERNAL_ERROR; + + Status status = CpuModelPrepareOutput(data_addr, data_size); + if (status != SUCCESS) { + return status; } } @@ -1124,7 +1283,6 @@ Status DavinciModel::BindOutputQueue() { /// @ingroup ge /// @brief definiteness queue schedule, bind output queue to task. -/// @param [in] queue_id: output queue id from user. /// @param [in] addr: NetOutput Op input tensor address. /// @param [in] size: NetOutput Op input tensor size. /// @return: 0 for success / others for failed @@ -1135,10 +1293,10 @@ Status DavinciModel::CpuModelPrepareOutput(uintptr_t addr, uint32_t size) { return FAILED; } - std::shared_ptr prepare_output = MakeShared(rt_model_stream_); + std::shared_ptr prepare_output = MakeShared(rt_entry_stream_); if (prepare_output == nullptr) { - GELOGE(FAILED, "Make CpuTaskPrepareOutput task failed."); - return FAILED; + GELOGE(MEMALLOC_FAILED, "Make CpuTaskPrepareOutput task failed."); + return MEMALLOC_FAILED; } uintptr_t out_mbuf = 0; @@ -1155,25 +1313,22 @@ Status DavinciModel::CpuModelPrepareOutput(uintptr_t addr, uint32_t size) { /// /// @ingroup ge /// @brief definiteness queue schedule, active original model stream. -/// @param [in] streams: streams will active by S0. /// @return: 0 for success / others for failed /// -Status DavinciModel::CpuActiveStream(const std::vector &stream_list) { - GELOGI("Set CpuKernel active stream task:%zu enter.", stream_list.size()); - for (auto s : stream_list) { - std::shared_ptr active_entry = MakeShared(rt_model_stream_); - if (active_entry == nullptr) { - GELOGE(FAILED, "Make CpuTaskActiveEntry task failed."); - return FAILED; - } - - if (active_entry->Init(s) != SUCCESS) { - return FAILED; - } +Status DavinciModel::CpuActiveStream() { + GELOGI("Set CpuKernel active stream task enter."); + std::shared_ptr active_entry = MakeShared(rt_entry_stream_); + if (active_entry == nullptr) { + GELOGE(MEMALLOC_FAILED, "Make CpuTaskActiveEntry task failed."); + return MEMALLOC_FAILED; + } - cpu_task_list_.push_back(active_entry); + Status status = active_entry->Init(rt_head_stream_); + if (status != SUCCESS) { + return status; } + cpu_task_list_.push_back(active_entry); GELOGI("Set CpuKernel active stream task success."); return SUCCESS; } @@ -1183,14 +1338,15 @@ Status DavinciModel::CpuActiveStream(const std::vector &stream_list) /// @return: 0 for success / others for failed Status DavinciModel::CpuWaitEndGraph() { GELOGI("Set CpuKernel wait end graph task enter."); - std::shared_ptr wait_endgraph = MakeShared(rt_model_stream_); + std::shared_ptr wait_endgraph = MakeShared(rt_entry_stream_); if (wait_endgraph == nullptr) { - GELOGE(FAILED, "Make CpuTaskWaitEndGraph task failed."); - return FAILED; + GELOGE(MEMALLOC_FAILED, "Make CpuTaskWaitEndGraph task failed."); + return MEMALLOC_FAILED; } - if (wait_endgraph->Init(runtime_model_id_) != SUCCESS) { - return FAILED; + Status status = wait_endgraph->Init(runtime_model_id_); + if (status != SUCCESS) { + return status; } cpu_task_list_.push_back(wait_endgraph); @@ -1200,9 +1356,9 @@ Status DavinciModel::CpuWaitEndGraph() { Status DavinciModel::BindEnqueue() { for (size_t i = 0; i < output_queue_ids_.size(); ++i) { - auto it = output_data_info_.find(i); - if (it == output_data_info_.end()) { - GELOGE(FAILED, "Output not match: tensor num=%zu, Queue id index=%zu", output_data_info_.size(), i); + auto it = new_output_data_info_.find(i); + if (it == new_output_data_info_.end()) { + GELOGE(FAILED, "Output not match: tensor num=%zu, Queue id index=%zu", new_output_data_info_.size(), i); return FAILED; } @@ -1216,14 +1372,15 @@ Status DavinciModel::BindEnqueue() { Status DavinciModel::CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf) { GELOGI("Set CpuKernel model enqueue task enter."); - std::shared_ptr model_enqueue = MakeShared(rt_model_stream_); + std::shared_ptr model_enqueue = MakeShared(rt_entry_stream_); if (model_enqueue == nullptr) { - GELOGE(FAILED, "Make CpuTaskModelEnqueue task failed."); - return FAILED; + GELOGE(MEMALLOC_FAILED, "Make CpuTaskModelEnqueue task failed."); + return MEMALLOC_FAILED; } - if (model_enqueue->Init(queue_id, out_mbuf) != SUCCESS) { - return FAILED; + Status status = model_enqueue->Init(queue_id, out_mbuf); + if (status != SUCCESS) { + return status; } cpu_task_list_.push_back(model_enqueue); GELOGI("Set CpuKernel model enqueue task enter."); @@ -1235,14 +1392,15 @@ Status DavinciModel::CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf) { /// @return: 0 for success / others for failed Status DavinciModel::CpuModelRepeat() { GELOGI("Set CpuKernel repeat task enter."); - std::shared_ptr model_repeat = MakeShared(rt_model_stream_); + std::shared_ptr model_repeat = MakeShared(rt_entry_stream_); if (model_repeat == nullptr) { - GELOGE(FAILED, "Make CpuTaskModelRepeat task failed."); - return FAILED; + GELOGE(MEMALLOC_FAILED, "Make CpuTaskModelRepeat task failed."); + return MEMALLOC_FAILED; } - if (model_repeat->Init(runtime_model_id_) != SUCCESS) { - return FAILED; + Status status = model_repeat->Init(runtime_model_id_); + if (status != SUCCESS) { + return status; } cpu_task_list_.push_back(model_repeat); @@ -1285,41 +1443,27 @@ Status DavinciModel::GetInputOutputDescInfo(vector &input_d /// @ingroup ge /// @brief Get dynamic batch_info /// @param [out] batch_info +/// @param [out] dynamic_type /// @return execute result /// -Status DavinciModel::GetDynamicBatchInfo(std::vector> &batch_info) { - for (auto &iter : op_list_) { - OpDescPtr op_desc = iter.second; - if (op_desc == nullptr) { - GELOGE(FAILED, "op_desc is null, index=%u.", iter.first); - return FAILED; - } - - if (op_desc->GetType() != STREAMSWITCHN) { - continue; - } +Status DavinciModel::GetDynamicBatchInfo(std::vector> &batch_info, int32_t &dynamic_type) const { + dynamic_type = dynamic_type_; + batch_info = batch_info_; - batch_info.clear(); - uint32_t batch_num = 0; - if (!AttrUtils::GetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) { - GELOGE(FAILED, "Failed to get attr ATTR_NAME_BATCH_NUM, StreamSwitchN: %s.", op_desc->GetName().c_str()); - return FAILED; - } - std::vector batch_shape; - for (uint32_t i = 0; i < batch_num; i++) { - batch_shape.clear(); - const std::string attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); - if (!AttrUtils::GetListInt(op_desc, attr_name, batch_shape)) { - GELOGE(FAILED, "Failed to get attr ATTR_NAME_PRED_VALUE, StreamSwitchN: %s.", op_desc->GetName().c_str()); - return FAILED; - } - batch_info.emplace_back(batch_shape); - } - break; - } return SUCCESS; } +/// +/// @ingroup ge +/// @brief Get combined dynamic dims info +/// @param [out] batch_info +/// @return None +/// +void DavinciModel::GetCombinedDynamicDims(std::vector> &batch_info) const { + batch_info.clear(); + batch_info = combined_batch_info_; +} + /// /// @ingroup ge /// @brief Get AIPP input info @@ -1355,7 +1499,7 @@ Status DavinciModel::GetAIPPInfo(uint32_t index, AippConfigInfo &aipp_info) { return SUCCESS; } -void DavinciModel::SetDynamicSize(const std::vector &batch_num) { +void DavinciModel::SetDynamicSize(const std::vector &batch_num, int32_t dynamic_type) { batch_size_.clear(); if (batch_num.empty()) { GELOGD("User has not set dynammic data"); @@ -1363,9 +1507,11 @@ void DavinciModel::SetDynamicSize(const std::vector &batch_num) { for (size_t i = 0; i < batch_num.size(); i++) { batch_size_.emplace_back(batch_num[i]); } + + dynamic_type_ = dynamic_type; } -void DavinciModel::GetCurShape(std::vector &batch_info) { +void DavinciModel::GetCurShape(std::vector &batch_info, int32_t &dynamic_type) { if (batch_size_.empty()) { GELOGD("User does not set dynamic size"); } @@ -1373,6 +1519,8 @@ void DavinciModel::GetCurShape(std::vector &batch_info) { GELOGI("Start to get current shape"); batch_info.emplace_back(batch_size_[i]); } + + dynamic_type = dynamic_type_; } void DavinciModel::GetModelAttr(std::vector &dynamic_output_shape_info) { @@ -1529,7 +1677,7 @@ void DavinciModel::CreateOutput(uint32_t index, OpDescPtr &op_desc, InputOutputD int64_t tensor_size = 0; (void)TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size); // no need to check value - output.size = static_cast(tensor_size); + output.size = static_cast(tensor_size); output.data_type = op_desc->GetInputDescPtr(index)->GetDataType(); } @@ -1538,9 +1686,6 @@ Status DavinciModel::GetOutputDescInfo(vector &output_desc, for (size_t i = 0; i < output_op_list_.size(); i++) { auto &op_desc = output_op_list_[i]; uint32_t out_size = static_cast(op_desc->GetInputsSize()); - // get real out nodes from model - vector out_node_name; - (void)ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_OUT_NODES_NAME, out_node_name); for (uint32_t index = 0; index < out_size; index++) { string output_name; InputOutputDescInfo output; @@ -1552,11 +1697,11 @@ Status DavinciModel::GetOutputDescInfo(vector &output_desc, GE_CHK_BOOL_RET_STATUS(src_name.size() > index && src_index.size() > index, INTERNAL_ERROR, "construct output_name failed."); // forward compatbility, if old om has no out_node_name, need to return output follow origin way - if (out_size == out_node_name.size()) { + if (out_size == out_node_name_.size()) { // neweast plan, the index will add to name during generate model. - bool contains_colon = out_node_name[index].find(":") != std::string::npos; + bool contains_colon = out_node_name_[index].find(":") != std::string::npos; output_name = - contains_colon ? out_node_name[index] : out_node_name[index] + ":" + std::to_string(src_index[index]); + contains_colon ? out_node_name_[index] : out_node_name_[index] + ":" + std::to_string(src_index[index]); } else { output_name = std::string("output_") + std::to_string(index) + "_" + src_name[index] + "_" + std::to_string(src_index[index]); @@ -1581,27 +1726,28 @@ ge::Format DavinciModel::GetFormat() { Status DavinciModel::CopyInputData(const InputData &input_data, bool device_data) { rtMemcpyKind_t kind = device_data ? RT_MEMCPY_DEVICE_TO_DEVICE : RT_MEMCPY_HOST_TO_DEVICE; const std::vector &blobs = input_data.blobs; - for (const auto &data : input_data_info_) { + for (const auto &data : new_input_data_info_) { if (data.first >= blobs.size()) { GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld", blobs.size(), - input_data_info_.size(), data.first, data.second.first); + new_input_data_info_.size(), data.first, data.second.GetDataInfo().at(0).first); return FAILED; } const DataBuffer &data_buf = blobs[data.first]; - void *mem_addr = data.second.second; - uint32_t mem_size = static_cast(data.second.first); - GE_CHK_BOOL_RET_STATUS(mem_size >= data_buf.length, PARAM_INVALID, - "input data size(%u) does not match model required size(%u), ret failed.", data_buf.length, - mem_size); - - GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] input[%u] dst[%p] src[%p] mem_size[%u] datasize[%u]", - runtime_param_.graph_id, data.first, mem_addr, data_buf.data, mem_size, data_buf.length); if (data_buf.length == 0) { GELOGW("No data need to memcpy!"); return SUCCESS; } - GE_CHK_RT_RET(rtMemcpy(mem_addr, mem_size, data_buf.data, data_buf.length, kind)); + uint64_t data_size = data.second.GetDataSize(); + GE_CHK_BOOL_RET_STATUS(data_size >= data_buf.length, PARAM_INVALID, + "input data size(%lu) does not match model required size(%lu), ret failed.", data_buf.length, + data_size); + void *mem_addr = data.second.GetBasicAddr(); + void *data_buf_addr = reinterpret_cast(reinterpret_cast(data_buf.data)); + uint64_t data_buf_length = data_buf.length; + GELOGI("[IMAS]CopyPlainData memcpy graph_%lu type[F] input[%lu] dst[%p] src[%p] mem_size[%lu] datasize[%lu]", + runtime_param_.graph_id, data.first, mem_addr, data_buf_addr, data_size, data_buf_length); + GE_CHK_RT_RET(rtMemcpy(mem_addr, data_size, data_buf_addr, data_buf_length, kind)); } return SUCCESS; @@ -1643,15 +1789,9 @@ inline int64_t SumSize(const vector &size_list) { } Status DavinciModel::SinkModelProfile() { - // not support non-sink model - GE_CHK_BOOL_EXEC(this->model_task_def_ != nullptr, return SUCCESS); - // profiling plugin must be registered Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); - if (reporter == nullptr) { - GELOGI("Profiling report is nullptr!"); - return SUCCESS; - } + GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return SUCCESS); GELOGI("Start collect model load profiling data."); @@ -1663,15 +1803,19 @@ Status DavinciModel::SinkModelProfile() { return FAILED, "Sink model tag memcpy error."); // Model Header - string name = this->Name(); - int32_t name_len = name.size(); + string name; + if (!om_name_.empty()) { + name = om_name_; + } else { + name = name_; + } + size_t name_len = name.size(); // phy device id uint32_t phy_device_id = 0; rtError_t rt_ret = rtGetDevicePhyIdByIndex(device_id_, &phy_device_id); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "runtime get phy_device_id failed, current phy_device_id:%d", phy_device_id); - return FAILED; - } + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, + GELOGE(rt_ret, "runtime get phy_device_id failed, current phy_device_id:%u", phy_device_id); + return FAILED); reporter_data.deviceId = phy_device_id; reporter_data.data = (unsigned char *)&name_len; reporter_data.dataLen = sizeof(int32_t); @@ -1708,7 +1852,6 @@ Status DavinciModel::SinkModelProfile() { for (int32_t i = 0; i < task_num; i++) { auto task = task_list_[i]; auto fusion_op_info = task->GetFusionOpInfo(); - // when type is RT_MODEL_TASK_KERNEL, ctx is not null if (fusion_op_info != nullptr) { uint32_t op_num = fusion_op_info->original_op_names.size(); @@ -1827,15 +1970,9 @@ Status DavinciModel::SinkModelProfile() { } Status DavinciModel::SinkTimeProfile(const InputData ¤t_data) { - // not support non-sink model - GE_CHK_BOOL_EXEC(this->model_task_def_ != nullptr, return SUCCESS); - // profiling plugin must be registered Msprof::Engine::Reporter *reporter = PluginImpl::GetPluginReporter(); - if (reporter == nullptr) { - GELOGI("Profiling report is nullptr!"); - return SUCCESS; - } + GE_IF_BOOL_EXEC(reporter == nullptr, GELOGI("Profiling report is nullptr!"); return SUCCESS); Msprof::Engine::ReporterData reporter_data{}; // report model data tag name @@ -1850,15 +1987,19 @@ Status DavinciModel::SinkTimeProfile(const InputData ¤t_data) { // device id uint32_t phy_device_id = 0; rtError_t rt_ret = rtGetDevicePhyIdByIndex(device_id_, &phy_device_id); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "runtime get phy_device_id failed, current phy_device_id:%d", phy_device_id); - return FAILED; - } + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, + GELOGE(rt_ret, "runtime get phy_device_id failed, current phy_device_id:%u", phy_device_id); + return FAILED); reporter_data.deviceId = phy_device_id; // Model Header - string name = this->Name(); - int32_t name_len = name.size(); + string name; + if (!om_name_.empty()) { + name = om_name_; + } else { + name = name_; + } + size_t name_len = name.size(); reporter_data.data = (unsigned char *)&name_len; reporter_data.dataLen = sizeof(int32_t); GE_CHK_BOOL_EXEC(reporter->Report(&reporter_data) == SUCCESS, return FAILED, "Reporter data fail, model id:%u.", @@ -1936,79 +2077,62 @@ void DavinciModel::SetProfileTime(ModelProcStage stage, int64_t endTime) { /// @ingroup ge /// @brief send Output Op result to upper layer /// @already malloced in ModelLoad, no need to malloc again -/// @param [in] sink_op Sink Op +/// @param [in] data_id: the index of output_data +/// @param [in/out] output_data: real user output_data +/// @param [in] kind: the kind of rtMemcpy /// @return Status result /// @author /// -Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data) { - Status ret = SUCCESS; +Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data, rtMemcpyKind_t kind) { if (output_op_list_.empty()) { - ret = SyncVarData(); - } else { - output_data.index = data_id; - output_data.model_id = model_id_; - GE_CHK_BOOL_RET_STATUS(output_data.blobs.size() == output_data_info_.size(), INTERNAL_ERROR, - "output buffer size[%zu] not equal output_size_list[%zu] size!", output_data.blobs.size(), - output_data_info_.size()); - - // index of data in output_data - uint32_t output_data_index = 0; - for (auto &op_desc : output_op_list_) { - ret = CopyOutputDataToUser(op_desc, output_data.blobs, output_data_index); - GE_CHK_BOOL_EXEC(ret == SUCCESS, break, "Copy output data to model ret failed, index:%u, model id:%u", - output_data.index, output_data.model_id); - } + Status ret = SyncVarData(); + DumpOpInputOutput(); + return ret; } - (void)DumpOpInputOutput(); // dump, not care result. - return ret; -} - -Status DavinciModel::CopyOutputDataToUser(OpDescPtr &op_desc, std::vector &blobs, uint32_t &data_index) { - Output model_output(op_desc, this); - - GE_CHK_BOOL_RET_STATUS(model_output.Init() == SUCCESS, PARAM_INVALID, "make shared model_output failed"); - - vector v_output_size; - vector v_output_data_addr; - model_output.GetOutputData(v_output_data_addr, v_output_size); - - // for all output tensor, copy output data from op to designated position - for (size_t i = 0; i < v_output_size.size(); ++i) { - GE_CHK_BOOL_RET_STATUS(data_index < blobs.size(), PARAM_INVALID, - "The blobs size:%zu, data_op size:%zu, curr output size:%zu", blobs.size(), - data_op_list_.size(), v_output_size.size()); + output_data.index = data_id; + output_data.model_id = model_id_; + if (output_data.blobs.size() != new_output_data_info_.size()) { + GELOGE(FAILED, "Output data buffer num=%zu not equal model data num=%zu", output_data.blobs.size(), + new_output_data_info_.size()); + return FAILED; + } - DataBuffer &data_buf = blobs[data_index]; - data_index++; + std::vector &blobs = output_data.blobs; + for (const auto &output : new_output_data_info_) { + if (output.first >= blobs.size()) { + GELOGE(FAILED, "Blobs not match: blobs=%zu, tensor=%zu, index=%u, size=%ld", blobs.size(), + new_input_data_info_.size(), output.first, output.second.GetDataInfo().at(0).first); + return FAILED; + } - uint32_t size = data_buf.length; - GE_CHK_BOOL_RET_STATUS(size <= v_output_size[i], PARAM_INVALID, - "Model output data size(%u) does not match required size(%u).", v_output_size[i], - data_buf.length); + DataBuffer &buffer = blobs[output.first]; + uint64_t mem_size = static_cast(output.second.GetDataSize()); + if ((buffer.length == 0) || (mem_size == 0)) { + GELOGI("Length of data is zero, No need copy. output tensor index=%u", output.first); + continue; + } + if (buffer.length < mem_size) { + GELOGE(FAILED, "Tensor data size=%lu, buffer size=%u", mem_size, buffer.length); + return FAILED; + } else if (buffer.length > mem_size) { + GELOGW("Tensor data size=%lu, buffer size=%u", mem_size, buffer.length); + } - if (copy_only_addrs_.count(v_output_data_addr[i]) == 0) { - GELOGI("[ZCPY] This addr[%p] has already feed by zero copy.", v_output_data_addr[i]); + if ((kind == RT_MEMCPY_DEVICE_TO_DEVICE) && (copy_only_addrs_.count(output.second.GetBasicAddr()) == 0)) { continue; // Skip: Feed by zero copy. } - GELOGI( - "CopyOutputDataToUser memcpy graph_%u type[F] name[%s] output[%lu] dst[%p] src[%p] mem_size[%u] datasize[%u]", - runtime_param_.graph_id, op_desc->GetName().c_str(), i, data_buf.data, v_output_data_addr[i], data_buf.length, - v_output_size[i]); - GE_CHK_RT_RET(rtMemcpy(data_buf.data, size, v_output_data_addr[i], size, RT_MEMCPY_DEVICE_TO_DEVICE)); - } + uint64_t data_size = output.second.GetDataSize(); + uint64_t buffer_length = buffer.length; + void *buffer_addr = reinterpret_cast(reinterpret_cast(buffer.data)); - return SUCCESS; -} - -Status DavinciModel::SyncDataAndDump() { - Status ret = SUCCESS; - if (output_op_list_.empty()) { - ret = SyncVarData(); + GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] output[%u] memaddr[%p] mem_size[%lu] datasize[%u]", + runtime_param_.graph_id, output.first, output.second.GetBasicAddr(), data_size, buffer_length); + GE_CHK_RT_RET(rtMemcpy(buffer_addr, buffer_length, output.second.GetBasicAddr(), data_size, kind)); } - (void)DumpOpInputOutput(); // dump, not care result. - return ret; + DumpOpInputOutput(); + return SUCCESS; } Status DavinciModel::GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data_index, OutputData *output_data, @@ -2042,13 +2166,13 @@ Status DavinciModel::GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data GELOGE(GE_GRAPH_MALLOC_FAILED, "Malloc buffer failed."); return GE_GRAPH_MALLOC_FAILED; } - output_data->blobs.push_back({data_buf.get(), static_cast(out_buffer_size_vec[i]), false}); + output_data->blobs.push_back({data_buf.get(), static_cast(out_buffer_size_vec[i]), false}); ge::OutputTensorInfo output; output.dims = shape_info_vec[i]; output.data = std::move(data_buf); output.length = out_buffer_size_vec[i]; outputs.emplace_back(std::move(output)); - GELOGI("Output index:%zu, data_length:%u.", i, output.length); + GELOGI("Output index:%zu, data_length:%lu.", i, output.length); } return SUCCESS; } @@ -2057,7 +2181,10 @@ Status DavinciModel::GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data /// @ingroup ge /// @brief send Output Op result to upper layer /// @already malloced in ModelLoad, no need to malloc again -/// @param [in] sink_op Sink Op +/// @param [in] data_id: the index of output_data +/// @param [in] rslt_flg: result flag +/// @param [in] seq_end_flag: sequence end flag +/// @param [out] output_data: real user output_data /// @return Status result /// @author /// @@ -2088,20 +2215,17 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b // copy output data from op to designated position for (auto &op_desc : output_op_list_) { - Output model_output(op_desc, this); - if (model_output.Init() != SUCCESS || GenOutputTensorInfo(op_desc, data_index, output_data, outputs) != SUCCESS) { + if (GenOutputTensorInfo(op_desc, data_index, output_data, outputs) != SUCCESS) { return INTERNAL_ERROR; } + data_index += op_desc->GetInputsSize(); + } - Status ret = model_output.CopyResult(*output_data, data_index, data_index, false); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "CopyResult failed, op name: %s", op_desc->GetName().c_str()); - GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "OnComputeDone failed"); - return INTERNAL_ERROR; - } + if (CopyOutputData(data_id, *output_data, RT_MEMCPY_DEVICE_TO_HOST) != SUCCESS) { + GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "OnComputeDone failed"); + return INTERNAL_ERROR; } - GE_IF_BOOL_EXEC((DumpOpInputOutput() != SUCCESS), GELOGW("dump op failed, model_id: %u", model_id_);); if (seq_end_flag) { GELOGW("End of sequence, model id: %u", model_id_); GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, END_OF_SEQUENCE, outputs), "OnCompute Done failed."); @@ -2114,6 +2238,7 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b /// /// @ingroup ge /// @brief return not output to upper layer for cloud case +/// @param [in] data_id /// @return Status result /// Status DavinciModel::ReturnNoOutput(uint32_t data_id) { @@ -2125,7 +2250,7 @@ Status DavinciModel::ReturnNoOutput(uint32_t data_id) { op_desc->GetName().c_str()); } - GE_IF_BOOL_EXEC((DumpOpInputOutput() != SUCCESS), GELOGW("dump op failed, model_id: %u", model_id_);); + DumpOpInputOutput(); GE_CHK_BOOL_EXEC(listener_ != nullptr, return PARAM_INVALID, "listener_ is null!"); std::vector outputs; GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, SUCCESS, outputs), "OnComputeDone failed."); @@ -2135,41 +2260,40 @@ Status DavinciModel::ReturnNoOutput(uint32_t data_id) { /// /// @ingroup ge /// @brief dump all op input and output information -/// @param [in] op_list model_id -/// @return Status result +/// @return void /// -Status DavinciModel::DumpOpInputOutput() { +void DavinciModel::DumpOpInputOutput() { + char *ge_dump_env = std::getenv("DUMP_OP"); + int dump_op_switch = (ge_dump_env != nullptr) ? std::strtol(ge_dump_env, nullptr, kDecimal) : 0; + if (dump_op_switch == 0) { + GELOGI("need to set DUMP_OP for dump op input and output"); + return; + } + if (op_list_.empty()) { - GELOGW("op_list is empty."); - return FAILED; + GELOGW("op list is empty"); + return; } - char *ge_dump_env = getenv("DUMP_OP"); - int dump_op_switch = - (ge_dump_env != nullptr) ? std::strtol(ge_dump_env, nullptr, kDecimal) : 0; // 10 for decimal number - if (dump_op_switch != 0) { - int64_t cnt = 1; - for (auto it : op_list_) { - if (maxDumpOpNum_ != 0 && cnt > maxDumpOpNum_) { - GELOGW("dump op cnt > maxDumpOpNum, maxDumpOpNum: %ld.", maxDumpOpNum_); - return SUCCESS; - } - Status ret = DumpSingleOpInputOutput(it.second); - cnt++; - if (ret != SUCCESS) { - GELOGE(FAILED, "dump single op failed, model_id: %u", model_id_); - return FAILED; - } + + int64_t cnt = 1; + for (auto it : op_list_) { + if (maxDumpOpNum_ != 0 && cnt > maxDumpOpNum_) { + GELOGW("dump op cnt > maxDumpOpNum, maxDumpOpNum: %ld", maxDumpOpNum_); + return; + } + + cnt++; + if (DumpSingleOpInputOutput(it.second) != SUCCESS) { + GELOGW("dump single op failed, model_id: %u", model_id_); + return; } - } else { - GELOGW("need to set DUMP_OP for dump op input and output."); } - return SUCCESS; } /// /// @ingroup ge /// @brief dump single op input and output information -/// @param [in] dump_op model_id +/// @param [in] op_def: the op_desc which will be dump /// @return Status result /// Status DavinciModel::DumpSingleOpInputOutput(const OpDescPtr &op_def) { @@ -2185,7 +2309,7 @@ Status DavinciModel::DumpSingleOpInputOutput(const OpDescPtr &op_def) { } } const vector input_size_vec = ModelUtils::GetInputSize(op_def); - const vector input_addr_vec = ModelUtils::GetInputDataAddrs(runtime_param_, op_def, false); + const vector input_addr_vec = ModelUtils::GetInputDataAddrs(runtime_param_, op_def); vector v_memory_type; bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_def, ATTR_NAME_INPUT_MEM_TYPE_LIST, v_memory_type); GELOGD("DumpSingleOp[%s], input size[%zu], input memory type size[%zu]", op_def->GetName().c_str(), @@ -2208,7 +2332,7 @@ Status DavinciModel::DumpSingleOpInputOutput(const OpDescPtr &op_def) { } const vector output_size_vec = ModelUtils::GetOutputSize(op_def); - const vector output_addr_vec = ModelUtils::GetOutputDataAddrs(runtime_param_, op_def, false); + const vector output_addr_vec = ModelUtils::GetOutputDataAddrs(runtime_param_, op_def); v_memory_type.clear(); has_mem_type_attr = ge::AttrUtils::GetListInt(op_def, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, v_memory_type); GELOGD("DumpSingleOp[%s], output size[%zu], output memory type size[%zu]", op_def->GetName().c_str(), @@ -2278,7 +2402,7 @@ void *DavinciModel::Run(DavinciModel *model) { ret != SUCCESS, (void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput()); CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); continue, "Copy input data to model failed."); // [No need to check value] - GE_TIMESTAMP_END(Model_SyncVarData, "Model Run SyncVarData"); + GE_IF_BOOL_EXEC(model->is_first_execute_, GE_TIMESTAMP_EVENT_END(Model_SyncVarData, "Model Run SyncVarData")); GELOGI("Copy input data, model id:%u", model_id); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), model->SetProfileTime(MODEL_PRE_PROC_START)); @@ -2324,7 +2448,7 @@ void *DavinciModel::Run(DavinciModel *model) { CsaInteract::GetInstance().WriteErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); continue); GELOGI("rtModelExecute end"); - GE_TIMESTAMP_END(rtModelExecute, "GraphExcute::rtModelExecute"); + GE_IF_BOOL_EXEC(model->is_first_execute_, GE_TIMESTAMP_EVENT_END(rtModelExecute, "GraphExcute::rtModelExecute")); GE_TIMESTAMP_START(rtStreamSynchronize); GELOGI("rtStreamSynchronize start."); @@ -2339,7 +2463,8 @@ void *DavinciModel::Run(DavinciModel *model) { CsaInteract::GetInstance().StoreInternalErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); continue); GELOGI("rtStreamSynchronize end."); - GE_TIMESTAMP_END(rtStreamSynchronize, "GraphExcute::Wait for rtStreamSynchronize"); + GE_IF_BOOL_EXEC(model->is_first_execute_, + GE_TIMESTAMP_EVENT_END(rtStreamSynchronize, "GraphExcute::Wait for rtStreamSynchronize")); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), model->SetProfileTime(MODEL_INFER_END)); } @@ -2350,11 +2475,13 @@ void *DavinciModel::Run(DavinciModel *model) { (void)model->ReturnResult(current_data.index, rslt_flg, false, data_wrapper->GetOutput())) // copy output data from device to host for variable graph GE_IF_BOOL_EXEC(model->output_op_list_.empty(), (void)model->ReturnNoOutput(current_data.index)); - GE_TIMESTAMP_END(ReturnResult3, "GraphExcute::CopyDataFromDeviceToHost"); + GE_IF_BOOL_EXEC(model->is_first_execute_, + GE_TIMESTAMP_EVENT_END(ReturnResult3, "GraphExcute::CopyDataFromDeviceToHost")); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), model->SetProfileTime(MODEL_AFTER_PROC_END)); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), (void)model->SinkTimeProfile(current_data)); model->iterator_count_++; + model->is_first_execute_ = false; GELOGI("run iterator count is %lu", model->iterator_count_); } @@ -2407,7 +2534,7 @@ Status DavinciModel::ModelRunStart() { is_inner_model_stream_ = true; string opt = "0"; - (void)ge::GetContext().GetOption("ge.maxDumpOpNum", opt); // option may not be set up, no need to check value + (void)ge::GetContext().GetOption(OPTION_GE_MAX_DUMP_OP_NUM, opt); // option may not be set up, no need to check value int64_t maxDumpOpNum = std::strtol(opt.c_str(), nullptr, kDecimal); maxDumpOpNum_ = maxDumpOpNum; @@ -2450,27 +2577,41 @@ void DavinciModel::UnbindTaskSinkStream() { // destroy stream that is bound with rt_model GE_LOGW_IF(rtStreamDestroy(rt_model_stream_) != RT_ERROR_NONE, "Destroy stream for rt_model failed.") } - return; + + if (is_pure_head_stream_ && rt_head_stream_ != nullptr) { + GE_LOGW_IF(rtModelUnbindStream(rt_model_handle_, rt_head_stream_) != RT_ERROR_NONE, "Unbind stream failed!"); + GE_LOGW_IF(rtStreamDestroy(rt_head_stream_) != RT_ERROR_NONE, "Destroy stream for rt_model failed."); + rt_head_stream_ = nullptr; + } + + if (rt_entry_stream_ != nullptr) { + GE_LOGW_IF(rtModelUnbindStream(rt_model_handle_, rt_entry_stream_) != RT_ERROR_NONE, "Unbind stream failed!"); + GE_LOGW_IF(rtStreamDestroy(rt_entry_stream_) != RT_ERROR_NONE, "Destroy stream for rt_model failed."); + rt_entry_stream_ = nullptr; + } } Status DavinciModel::CreateKnownZeroCopyMap(const vector &inputs, const vector &outputs) { GELOGI("DavinciModel::CreateKnownZeroCopyMap in."); - if (inputs.size() != data_op_list_.size()) { - GELOGE(FAILED, "input data addr %u is not equal to input op number %u.", inputs.size(), data_op_list_.size()); + if (inputs.size() > data_op_list_.size()) { + GELOGE(FAILED, "input data addr %u should less than input op number %u.", inputs.size(), data_op_list_.size()); return FAILED; } - for (size_t i = 0; i < data_op_list_.size(); ++i) { + // remove zero copy addr in last iteration + knonw_input_data_info_.clear(); + knonw_output_data_info_.clear(); + for (size_t i = 0; i < inputs.size(); ++i) { const vector addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, data_op_list_[i]); knonw_input_data_info_[addr_list[kDataIndex]] = inputs[i]; GELOGI("DavinciModel::CreateKnownZeroCopyMap input %d,v addr %p,p addr %p .", i, addr_list[kDataIndex], inputs[i]); } - if (output_op_list_.size() != kOutputNum) { - GELOGE(FAILED, "output op num is %u, not equal %u.", outputs.size(), kOutputNum); - return FAILED; + if (output_op_list_.size() < kOutputNum) { + GELOGW("output op num in graph is %u.", output_op_list_.size()); + return SUCCESS; } const vector addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, output_op_list_[kDataIndex]); - if (outputs.size() != addr_list.size()) { - GELOGE(FAILED, "output data addr %u is not equal to output op number %u.", outputs.size(), addr_list.size()); + if (outputs.size() > addr_list.size()) { + GELOGE(FAILED, "output data addr %u should less than output op number %u.", outputs.size(), addr_list.size()); return FAILED; } for (size_t i = 0; i < addr_list.size(); ++i) { @@ -2481,30 +2622,20 @@ Status DavinciModel::CreateKnownZeroCopyMap(const vector &inputs, const return SUCCESS; } -Status DavinciModel::UpdateKnownZeroCopyAddr(vector &io_addrs, uint32_t args_offset) { - for (size_t i = 0; i < io_addrs.size(); ++i) { - auto it_in = knonw_input_data_info_.find(io_addrs[i]); +Status DavinciModel::UpdateKnownZeroCopyAddr() { + for (size_t i = 0; i < total_io_addrs_.size(); ++i) { + auto it_in = knonw_input_data_info_.find(total_io_addrs_[i]); if (it_in != knonw_input_data_info_.end()) { - GELOGI("DavinciModel::UpdateKnownZeroCopyAddr input %d,v addr %p,p addr %p .", i, io_addrs[i], - knonw_input_data_info_.at(io_addrs[i])); - io_addrs[i] = knonw_input_data_info_.at(io_addrs[i]); + GELOGI("DavinciModel::UpdateKnownZeroCopyAddr input %d,v addr %p,p addr %p .", i, total_io_addrs_[i], + knonw_input_data_info_.at(total_io_addrs_[i])); + total_io_addrs_[i] = knonw_input_data_info_.at(total_io_addrs_[i]); } - auto it_out = knonw_output_data_info_.find(io_addrs[i]); + auto it_out = knonw_output_data_info_.find(total_io_addrs_[i]); if (it_out != knonw_output_data_info_.end()) { - GELOGI("DavinciModel::UpdateKnownZeroCopyAddr output %d,v addr %p,p addr %p .", i, io_addrs[i], - knonw_output_data_info_.at(io_addrs[i])); - io_addrs[i] = knonw_output_data_info_.at(io_addrs[i]); - } - } - // may args_size is equal to src_args_size? - uint32_t src_args_size = io_addrs.size() * sizeof(uint64_t); - GELOGI("DavinciModel::UpdateKnownZeroCopyAddr args host %p, src_args_size %u, args_offset %u", args_host_, - src_args_size, args_offset); - errno_t sec_ret = - memcpy_s(static_cast(args_host_) + args_offset, src_args_size, io_addrs.data(), src_args_size); - if (sec_ret != EOK) { - GELOGE(FAILED, "Call memcpy_s failed, ret: %d", sec_ret); - return FAILED; + GELOGI("DavinciModel::UpdateKnownZeroCopyAddr output %d,v addr %p,p addr %p .", i, total_io_addrs_[i], + knonw_output_data_info_.at(total_io_addrs_[i])); + total_io_addrs_[i] = knonw_output_data_info_.at(total_io_addrs_[i]); + } } GELOGI("DavinciModel::UpdateKnownZeroCopyAddr success."); return SUCCESS; @@ -2514,20 +2645,31 @@ Status DavinciModel::UpdateKnownNodeArgs(const vector &inputs, const vec GELOGI("DavinciModel::UpdateKnownNodeArgs in"); GE_CHK_STATUS_RET(CreateKnownZeroCopyMap(inputs, outputs), "DavinciModel::UpdateKnownNodeArgs create map for input/output zero copy."); - for (size_t task_index = 0; task_index < task_list_.size(); ++task_index) { - auto &task = task_list_[task_index]; - if (task != nullptr) { - Status ret = task->UpdateArgs(); - if (ret != SUCCESS) { - GELOGE(FAILED, "task %d created by davinci model is nullptr.", task_index); - return FAILED; + if (!base_addr_not_changed_) { + total_io_addrs_.clear(); + orig_total_io_addrs_.clear(); + for (size_t task_index = 0; task_index < task_list_.size(); ++task_index) { + auto &task = task_list_[task_index]; + if (task != nullptr) { + Status ret = task->UpdateArgs(); + if (ret != SUCCESS) { + GELOGE(FAILED, "task %d created by davinci model is nullptr.", task_index); + return FAILED; + } } } + // cache latest iterator io addr + orig_total_io_addrs_ = total_io_addrs_; + } else { + total_io_addrs_ = orig_total_io_addrs_; } - GELOGI("DavinciModel::UpdateKnownNodeArgs device args %p, size %u, host args %p, size %u", args_, total_args_size_, - args_host_, total_args_size_); - // copy continuous args from host to device - Status rt_ret = rtMemcpy(args_, total_args_size_, args_host_, total_args_size_, RT_MEMCPY_HOST_TO_DEVICE); + GE_CHK_STATUS_RET(UpdateKnownZeroCopyAddr(), "DavinciModel::UpdateKnownZeroCopyAddr failed."); + + uint32_t total_addr_size = total_io_addrs_.size() * sizeof(uint64_t); + GELOGI("DavinciModel::UpdateKnownNodeArgs device args %p, dst size %u, src size %u", args_, total_args_size_, + total_addr_size); + + Status rt_ret = rtMemcpy(args_, total_args_size_, total_io_addrs_.data(), total_addr_size, RT_MEMCPY_HOST_TO_DEVICE); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) GELOGI("DavinciModel::UpdateKnownNodeArgs success"); @@ -2535,12 +2677,14 @@ Status DavinciModel::UpdateKnownNodeArgs(const vector &inputs, const vec } Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { - GELOGI("InitTaskInfo in,task size %d", model_task_def.task().size()); + GELOGI("InitTaskInfo in, task size %zu", model_task_def.task().size()); task_list_.resize(model_task_def.task_size()); for (int i = 0; i < model_task_def.task_size(); ++i) { // dynamic shape will create task_list_ before const domi::TaskDef &task = model_task_def.task(i); - task_list_[i] = TaskInfoFactory::Instance().Create(static_cast(task.type())); + if (this->task_list_[i] == nullptr) { + task_list_[i] = TaskInfoFactory::Instance().Create(static_cast(task.type())); + } GE_CHECK_NOTNULL(task_list_[i]); Status ret = task_list_[i]->Init(task, this); if (ret != SUCCESS) { @@ -2554,13 +2698,14 @@ Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { Status DavinciModel::MallocKnownArgs() { GELOGI("DavinciModel::MallocKnownArgs in"); - if (model_task_def_->task_size() == 0) { + const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); + if (model_task_def->task_size() == 0) { GELOGW("DavinciModel::MallocKnownArgs davincimodel has no task info."); return SUCCESS; } - task_list_.resize(model_task_def_->task_size()); - for (int32_t i = 0; i < model_task_def_->task_size(); ++i) { - const domi::TaskDef &taskdef = model_task_def_->task(i); + task_list_.resize(model_task_def->task_size()); + for (int32_t i = 0; i < model_task_def->task_size(); ++i) { + const domi::TaskDef &taskdef = model_task_def->task(i); task_list_[i] = TaskInfoFactory::Instance().Create(static_cast(taskdef.type())); GE_CHECK_NOTNULL(task_list_[i]); Status ret = task_list_[i]->CalculateArgs(taskdef, this); @@ -2573,15 +2718,21 @@ Status DavinciModel::MallocKnownArgs() { rtError_t rt_ret = rtMalloc(&args_, total_args_size_, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } - // malloc args host memory - rt_ret = rtMallocHost(&args_host_, total_args_size_); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rtMallocHost failed, ret: 0x%X", rt_ret); - return RT_FAILED; + + // malloc fixed addr memory, eg: rts op + if (total_fixed_addr_size_ != 0) { + GELOGI("Begin to allocate fixed addr."); + rt_ret = rtMalloc(&fixed_addrs_, total_fixed_addr_size_, RT_MEMORY_HBM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); + } } - GELOGI("DavinciModel::MallocKnownArgs success, total args size %u.", total_args_size_); + + GELOGI("DavinciModel::MallocKnownArgs success, total args size %u. total fixed addr size %ld", total_args_size_, + total_fixed_addr_size_); return SUCCESS; } @@ -2597,26 +2748,28 @@ Status DavinciModel::DistributeTask() { task_desc_info_.clear(); bool flag = GetL1FusionEnableOption(); - char *skt_enable_env = getenv("SKT_ENABLE"); - int64_t env_flag = (skt_enable_env != nullptr) ? strtol(skt_enable_env, nullptr, 10) : 0; + char *skt_enable_env = std::getenv("SKT_ENABLE"); + int64_t env_flag = (skt_enable_env != nullptr) ? std::strtol(skt_enable_env, nullptr, kDecimal) : 0; if (env_flag != 0) { flag = true; } + const auto &model_task_def = ge_model_->GetModelTaskDefPtr(); for (size_t task_index = 0; task_index < task_list_.size(); ++task_index) { auto &task = task_list_.at(task_index); GE_CHK_STATUS_RET(task->Distribute(), "Task[%zu] distribute fail", task_index); // for data dump if (reinterpret_cast(task->GetDumpArgs()) != nullptr) { - auto op_index = std::max(model_task_def_->task(task_index).kernel().context().op_index(), - model_task_def_->task(task_index).kernel_ex().op_index()); + auto op_index = std::max(model_task_def->task(task_index).kernel().context().op_index(), + model_task_def->task(task_index).kernel_ex().op_index()); OpDescPtr op = GetOpByIndex(op_index); if (op == nullptr) { GELOGE(PARAM_INVALID, "Op index %u is null, op list size %zu.", op_index, op_list_.size()); return PARAM_INVALID; } - if (PropertiesManager::Instance().IsLayerNeedDump(name_, om_name_, op->GetName())) { + bool call_dump = GetDumpProperties().IsLayerNeedDump(name_, om_name_, op->GetName()) && task->CallSaveDumpInfo(); + if (call_dump) { SaveDumpTask(task->GetTaskID(), task->GetStreamId(), op, task->GetDumpArgs()); } } @@ -2631,8 +2784,13 @@ Status DavinciModel::DistributeTask() { // else task index is found in op_name_map_ TaskDescInfo task_desc_info; string op_name = op_name_map_[task_index]; + if (!om_name_.empty()) { + task_desc_info.model_name = om_name_; + } else { + task_desc_info.model_name = name_; + } task_desc_info.op_name = op_name; - task_desc_info.block_dim = model_task_def_->task(task_index).kernel().block_dim(); + task_desc_info.block_dim = model_task_def->task(task_index).kernel().block_dim(); task_desc_info.task_id = task->GetTaskID(); task_desc_info.stream_id = task->GetStreamId(); task_desc_info_.emplace_back(task_desc_info); @@ -2653,7 +2811,7 @@ Status DavinciModel::DistributeTask() { } void DavinciModel::SetEndGraphId(uint32_t task_id, uint32_t stream_id) { - auto all_dump_model = PropertiesManager::Instance().GetAllDumpModel(); + auto all_dump_model = GetDumpProperties().GetAllDumpModel(); bool findByOmName = all_dump_model.find(om_name_) != all_dump_model.end(); bool findByModelName = all_dump_model.find(name_) != all_dump_model.end(); if (all_dump_model.find(ge::DUMP_ALL_MODEL) != all_dump_model.end() || findByOmName || findByModelName) { @@ -2664,35 +2822,23 @@ void DavinciModel::SetEndGraphId(uint32_t task_id, uint32_t stream_id) { /// /// @ingroup ge -/// @brief Save Data address info for ZeroCopy. -/// @param [in] const std::vector &outside_addrs +/// @brief Set copy only for No task feed NetOutput address. /// @return None. /// -void DavinciModel::SetInputOutsideAddr(const std::vector &outside_addrs) { - for (auto addr : outside_addrs) { - if (input_outside_addrs_.find(addr) != input_outside_addrs_.end()) { - continue; - } - - (void)input_outside_addrs_.emplace(std::pair>(addr, {})); - GELOGI("SetInputOutsideAddr success."); - } -} - -/// -/// @ingroup ge -/// @brief Save NetOutput address info for ZeroCopy. -/// @param [in] const std::vector &outside_addrs -/// @return None. -/// -void DavinciModel::SetOutputOutsideAddr(const std::vector &outside_addrs) { - for (auto addr : outside_addrs) { - if (output_outside_addrs_.find(addr) != output_outside_addrs_.end()) { - continue; +void DavinciModel::SetCopyOnlyOutput() { + for (const auto &output_outside_addrs : new_output_outside_addrs_) { + ZeroCopyOffset output_outside = output_outside_addrs.second; + for (uint32_t out_count = 0; out_count < output_outside.GetAddrCount(); ++out_count) { + auto &addrs_mapping_list = output_outside.GetOutsideAddrs(); + std::map> virtual_args_addrs = addrs_mapping_list[out_count]; + for (const auto &virtual_args_addr : virtual_args_addrs) { + const auto &args_addrs = virtual_args_addr.second; + if (args_addrs.empty()) { // No task feed Output addr, Need copy directly. + GELOGI("[ZCPY] just copy %p to netoutput.", virtual_args_addr.first); + copy_only_addrs_.insert(virtual_args_addr.first); + } + } } - DisableZeroCopy(addr); // Data to NetOutput directly. - (void)output_outside_addrs_.emplace(std::pair>(addr, {})); - GELOGI("SetOutputOutsideAddr success."); } } @@ -2703,13 +2849,13 @@ void DavinciModel::SetOutputOutsideAddr(const std::vector &outside_addrs /// @return None. /// void DavinciModel::DisableZeroCopy(const void *addr) { - auto it = input_outside_addrs_.find(addr); - if (it == input_outside_addrs_.end()) { + if (find(real_virtual_addrs_.begin(), real_virtual_addrs_.end(), addr) == real_virtual_addrs_.end()) { return; } // Data link to RTS Op directly. std::lock_guard lock(outside_addrs_mutex_); + GELOGI("[ZCPY] disable zero copy of %p.", addr); copy_only_addrs_.insert(addr); } @@ -2718,35 +2864,37 @@ void DavinciModel::DisableZeroCopy(const void *addr) { /// @brief Save outside address used info for ZeroCopy. /// @param [in] const OpDescPtr &op_desc: current op desc /// @param [in] const std::vector &outside_addrs: address of task -/// @param [in] const char *args_offset: arguments address save the address. +/// @param [in] const void *info: task args +/// @param [in] const char *args: task args +/// @param [in] size_t size: size of task args +/// @param [in] size_t offset: offset of task args /// @return None. /// void DavinciModel::SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector &outside_addrs, const void *info, void *args, size_t size, size_t offset) { // Internal call has ensured that op_desc is not nullptr + GELOGI("[ZCPY] SetZeroCopyAddr for %s.", op_desc->GetName().c_str()); size_t nums = outside_addrs.size(); ZeroCopyTask zero_copy_task(op_desc->GetName(), static_cast(args), size); for (size_t i = 0; i < nums; ++i) { std::lock_guard lock(outside_addrs_mutex_); - const uintptr_t addr_val = reinterpret_cast(outside_addrs[i]); - void *args_val = static_cast(args) + offset + i * kAddrLen; - auto it = input_outside_addrs_.find(outside_addrs[i]); - if (it != input_outside_addrs_.end()) { - GE_CHK_STATUS(zero_copy_task.SetTaskArgsOffset(addr_val, offset + i * kAddrLen), "Input args invalid."); - it->second.push_back(args_val); - SetBatchLabelAddr(op_desc, reinterpret_cast(args_val)); - GELOGI("[ZCPY] %s set copy input: %zu, addr: 0x%lx, args: %p, size: %zu, offset: %zu.", - op_desc->GetName().c_str(), i, addr_val, args, size, offset + i * kAddrLen); - continue; + + for (auto &input_outside_addrs : new_input_outside_addrs_) { + ZeroCopyOffset &input_outside = input_outside_addrs.second; + bool ret = input_outside.SetOutsideAddrsValue(zero_copy_task, outside_addrs[i], args, offset + i * kAddrLen); + if (ret) { + void *args_val = static_cast(args) + offset + i * kAddrLen; + SetBatchLabelAddr(op_desc, reinterpret_cast(args_val)); + } } - it = output_outside_addrs_.find(outside_addrs[i]); - if (it != output_outside_addrs_.end()) { - GE_CHK_STATUS(zero_copy_task.SetTaskArgsOffset(addr_val, offset + i * kAddrLen), "Output args invalid."); - it->second.push_back(args_val); - SetBatchLabelAddr(op_desc, reinterpret_cast(args_val)); - GELOGI("[ZCPY] %s set copy output: %zu, args: %p, addr: 0x%lx.", op_desc->GetName().c_str(), i, args, addr_val); - continue; + for (auto &output_outside_addrs : new_output_outside_addrs_) { + ZeroCopyOffset &output_outside = output_outside_addrs.second; + bool ret = output_outside.SetOutsideAddrsValue(zero_copy_task, outside_addrs[i], args, offset + i * kAddrLen); + if (ret) { + void *args_val = static_cast(args) + offset + i * kAddrLen; + SetBatchLabelAddr(op_desc, reinterpret_cast(args_val)); + } } } @@ -2794,7 +2942,7 @@ bool DavinciModel::CheckInputAndModelSize(const int64_t &input_size, const int64 if (input_size > op_size) { GELOGW( - "Input size [%u] is bigger than om size need [%u]," + "Input size [%u] is bigger than om size need [%u], " "MAY cause inference result ERROR, please check model input", input_size, op_size); } @@ -2834,12 +2982,13 @@ bool DavinciModel::CheckInputAndModelSize(const int64_t &input_size, const int64 /// @return SUCCESS handle successfully / PARAM_INVALID for failed /// Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &output_data, bool is_dynamic) { - if (UpdateIoTaskArgs(input_data_info_, true, input_data.blobs, is_dynamic, input_data.batch_label) != SUCCESS) { + if (UpdateIoTaskArgs(new_input_data_info_, true, input_data.blobs, is_dynamic, input_data.batch_label) != SUCCESS) { GELOGE(PARAM_INVALID, "[ZCPY] Update input data to model failed."); return PARAM_INVALID; } - if (UpdateIoTaskArgs(output_data_info_, false, output_data.blobs, is_dynamic, input_data.batch_label) != SUCCESS) { + if (UpdateIoTaskArgs(new_output_data_info_, false, output_data.blobs, is_dynamic, input_data.batch_label) != + SUCCESS) { GELOGE(PARAM_INVALID, "[ZCPY] Update output data to model failed."); return PARAM_INVALID; } @@ -2863,7 +3012,7 @@ Status DavinciModel::CopyModelData(const InputData &input_data, OutputData &outp /// @param [in] batch_label: batch label for multi-batch scenes /// @return SUCCESS handle successfully / others handle failed /// -Status DavinciModel::UpdateIoTaskArgs(const map> &data_info, bool is_input, +Status DavinciModel::UpdateIoTaskArgs(const std::map &data_info, bool is_input, const vector &blobs, bool is_dynamic, const string &batch_label) { string input_or_output = "input"; is_input ? input_or_output = "input" : input_or_output = "output"; @@ -2879,8 +3028,6 @@ Status DavinciModel::UpdateIoTaskArgs(const map> input_or_output.c_str(), data.first, blobs.size()); return FAILED; } - int64_t size = data.second.first; // size of tensor. - void *addr = data.second.second; // addr of tensor. const DataBuffer &buffer = blobs[data.first]; // index of data. if (buffer.data == nullptr) { @@ -2888,30 +3035,38 @@ Status DavinciModel::UpdateIoTaskArgs(const map> return FAILED; } - GELOGI("[ZCPY] Copy Blobs: %u, addr: %p, size: %ld, data: %p, length: %u.", data.first, data.second.second, - data.second.first, buffer.data, buffer.length); - if (!CheckInputAndModelSize(buffer.length, size, is_dynamic)) { + if (!CheckInputAndModelSize(buffer.length, data.second.GetDataSize(), is_dynamic)) { GELOGE(FAILED, "Check input size and model size failed"); return FAILED; } - // For input data, just copy for rts task. - if (copy_only_addrs_.count(addr) > 0) { + void *basic_addr = data.second.GetBasicAddr(); + uint64_t data_size = data.second.GetDataSize(); + if (copy_only_addrs_.count(basic_addr) > 0) { if (is_input) { - GELOGI("[IMAS] Find addr %p need direct copy from user malloc input %p.", addr, buffer.data); - if (rtMemcpy(addr, size, buffer.data, buffer.length, RT_MEMCPY_DEVICE_TO_DEVICE) != RT_ERROR_NONE) { + GELOGI("[IMAS] Find addr %p need direct copy from user malloc input %p", basic_addr, buffer.data); + if (rtMemcpy(basic_addr, data_size, buffer.data, buffer.length, RT_MEMCPY_DEVICE_TO_DEVICE) != RT_ERROR_NONE) { GELOGE(FAILED, "Non-zero copy data node copy failed"); return FAILED; } } - GELOGI("No need to exeucte zero copy task because this addr %p need direct copy.", addr); + GELOGI("No need to exeucte zero copy task because this addr %p need direct copy.", basic_addr); continue; } - for (ZeroCopyTask &task : zero_copy_tasks_) { - uintptr_t addr_val = reinterpret_cast(addr); - if (task.UpdateTaskParam(addr_val, buffer, zero_copy_batch_label_addrs_, batch_label) != SUCCESS) { - return FAILED; + for (size_t count = 0; count < data.second.GetDataCount(); ++count) { + int64_t size = data.second.GetDataInfo().at(count).first; + void *addr = data.second.GetDataInfo().at(count).second; + void *buffer_addr = + reinterpret_cast(reinterpret_cast(buffer.data) + data.second.GetRelativeOffset().at(count)); + GELOGI("[ZCPY] Copy blobs_index %u, virtual_addr: %p, size: %ld, user_data_addr: %p", data.first, addr, size, + buffer_addr); + // For input data, just copy for rts task. + for (ZeroCopyTask &task : zero_copy_tasks_) { + uintptr_t addr_val = reinterpret_cast(addr); + if (task.UpdateTaskParam(addr_val, buffer_addr, zero_copy_batch_label_addrs_, batch_label) != SUCCESS) { + return FAILED; + } } } } @@ -3160,6 +3315,32 @@ Status DavinciModel::InitStreamSwitchN(const OpDescPtr &op_desc) { GELOGI("StreamSwitchNOp node:%s, active_stream_id=%u.", op_desc->GetName().c_str(), active_stream_list[j]); } + (void)AttrUtils::GetInt(op_desc, ATTR_DYNAMIC_TYPE, dynamic_type_); + + batch_info_.clear(); + combined_batch_info_.clear(); + uint32_t batch_num = 0; + if (!AttrUtils::GetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) { + GELOGE(FAILED, "Failed to get attr ATTR_NAME_BATCH_NUM, StreamSwitchN: %s.", op_desc->GetName().c_str()); + return FAILED; + } + + for (uint32_t i = 0; i < batch_num; i++) { + std::vector batch_shape; + const std::string attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); + if (!AttrUtils::GetListInt(op_desc, attr_name, batch_shape)) { + GELOGE(FAILED, "Failed to get attr ATTR_NAME_PRED_VALUE, StreamSwitchN: %s.", op_desc->GetName().c_str()); + batch_info_.clear(); + return FAILED; + } + batch_info_.emplace_back(batch_shape); + batch_shape.clear(); + const string attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); + if (AttrUtils::GetListInt(op_desc, attr_combined_batch, batch_shape)) { + combined_batch_info_.emplace_back(batch_shape); + } + } + return SUCCESS; } @@ -3178,20 +3359,6 @@ bool DavinciModel::IsBroadCastOpData(const ge::NodePtr &var_node) { return false; } -void DavinciModel::InitZeroCopyUtil(bool is_dynamic_batch, bool &input_zero_copy, bool &output_zero_copy) { - if (!is_dynamic_batch) { - zero_copy_batch_label_addrs_.clear(); - } - - for (const auto &addrs : output_outside_addrs_) { - const auto &used_list = addrs.second; - if (used_list.empty()) { - output_zero_copy = false; - break; - } - } -} - /// /// @ingroup ge /// @brief Init model stream for NN model. @@ -3239,14 +3406,14 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa GELOGI("Model Run begin, model id:%u, data index:%u, flag:%d.", model_id_, input_data.index, is_async_mode_); GE_CHK_STATUS_RET(InitModelStream(stream), "Init model stream failed."); - bool input_use_zero_copy = true; - bool output_use_zero_copy = true; - bool is_dynamic_batch = input_data.is_dynamic_batch; - InitZeroCopyUtil(is_dynamic_batch, input_use_zero_copy, output_use_zero_copy); + if (!input_data.is_dynamic_batch) { + zero_copy_batch_label_addrs_.clear(); + } GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_PRE_PROC_START)); - Status ret = CopyModelData(input_data, output_data, is_dynamic_batch); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return INTERNAL_ERROR, "Copy input data to model failed."); + Status ret = CopyModelData(input_data, output_data, input_data.is_dynamic_batch); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Copy input data to model failed. model id: %u", + model_id_); GELOGI("current_data.index=%u", input_data.index); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_PRE_PROC_END)); @@ -3255,15 +3422,15 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa GELOGD("rtModelExecute do"); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_INFER_START)); rtError_t rt_ret = rtModelExecute(rt_model_handle_, rt_model_stream_, 0); - GE_CHK_RT_EXEC(rt_ret, return INTERNAL_ERROR); + GE_CHK_RT_EXEC(rt_ret, return RT_ERROR_TO_GE_STATUS(rt_ret)); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_INFER_END)); GELOGI("rtModelExecute end"); } if (!is_async_mode_) { GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_AFTER_PROC_START)); - ret = CopyOutputData(input_data.index, output_data); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return INTERNAL_ERROR, "Copy Output data to user failed."); + ret = CopyOutputData(input_data.index, output_data, RT_MEMCPY_DEVICE_TO_DEVICE); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Copy Output data to user failed."); GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingOn(), SetProfileTime(MODEL_AFTER_PROC_END)); } @@ -3273,11 +3440,61 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa return SUCCESS; } +// Add active entry stream for special env. +Status DavinciModel::AddHeadStream() { + if (active_stream_list_.empty()) { + GELOGE(INTERNAL_ERROR, "Active stream is empty, stream list size: %zu, stream indication size: %zu.", + stream_list_.size(), active_stream_indication_.size()); + return INTERNAL_ERROR; + } + + if (active_stream_list_.size() == 1) { + GELOGI("Just one active stream, take as head stream."); + rt_head_stream_ = active_stream_list_[0]; + is_pure_head_stream_ = false; + } else { + // Create stream which rt_model_handel running on, this is S0, TS stream. + GELOGI("Multiple active stream: %zu, create head stream.", active_stream_list_.size()); + GE_CHK_RT_RET(rtStreamCreateWithFlags(&rt_head_stream_, priority_, RT_STREAM_PERSISTENT)); + GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, rt_head_stream_, RT_INVALID_FLAG)); // Not active. + is_pure_head_stream_ = true; + + for (auto s : active_stream_list_) { + std::shared_ptr active_entry = MakeShared(rt_head_stream_); + if (active_entry == nullptr) { + GELOGE(MEMALLOC_FAILED, "Make CpuTaskActiveEntry task failed."); + return MEMALLOC_FAILED; + } + + Status status = active_entry->Init(s); + if (status != SUCCESS) { + return status; + } + + cpu_task_list_.emplace_back(active_entry); + } + } + + // Create entry stream active head stream. AICPU stream. + GE_CHK_RT_RET(rtStreamCreateWithFlags(&rt_entry_stream_, priority_, RT_STREAM_AICPU)); + GE_CHK_RT_RET(rtModelBindStream(rt_model_handle_, rt_entry_stream_, RT_HEAD_STREAM)); + return SUCCESS; +} + +Status DavinciModel::InitEntryTask() { + if (deploy_type_ == AICPU_DEPLOY_CROSS_THREAD) { + GE_CHK_STATUS_RET(AddHeadStream(), "Add head stream failed."); + return CpuActiveStream(); + } else { + return LoadWithQueue(); + } +} + uint8_t *DavinciModel::MallocFeatureMapMem(size_t data_size) { uint8_t *mem_base = nullptr; const string purpose("feature map,used for op input and output."); if (std::getenv(kEnvGeuseStaticMemory) != nullptr) { - data_size = static_cast(VarManager::Instance(0)->GetGraphMemoryMaxSize()); + data_size = static_cast(VarManager::Instance(session_id_)->GetGraphMemoryMaxSize()); string memory_key = std::to_string(0) + "_f"; mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, memory_key, data_size, GetDeviceId()); } else { @@ -3341,7 +3558,7 @@ Status DavinciModel::TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id) rtError_t rt_ret = rtCtxGetCurrent(&ctx); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Failed to get current context, error_code is: 0x%X.", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } std::vector variable_node_list; @@ -3362,12 +3579,14 @@ Status DavinciModel::TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id) return SUCCESS; } -void DavinciModel::SetDataDumperArgs() { +void DavinciModel::SetDataDumperArgs(const ComputeGraphPtr &compute_graph) { GELOGI("set data dumper args, name: %s, id: %u.", name_.c_str(), model_id_); data_dumper_.SetModelName(name_); data_dumper_.SetModelId(model_id_); data_dumper_.SetMemory(runtime_param_); data_dumper_.SetOmName(om_name_); + data_dumper_.SetComputeGraph(compute_graph); + data_dumper_.SetRefInfo(saved_task_addrs_); int32_t device_id = 0; rtError_t rt_ret = rtGetDevice(&device_id); @@ -3423,18 +3642,9 @@ void DavinciModel::ReuseHcclFollowStream(int64_t remain_cap, int64_t &index) { } } -Status DavinciModel::CopyVarData(ComputeGraphPtr &compute_graph) { - return TransVarDataUtils::CopyVarData(compute_graph, session_id_, device_id_); -} - -Status DavinciModel::GetComputeGraphInfo(std::vector &compute_graph_desc_info) { +Status DavinciModel::GetComputeGraphInfo(const ComputeGraphPtr &graph, vector &graph_desc_info) { GELOGI("GetComputeGraphInfo start."); - if (compute_graph_ == nullptr) { - GELOGE(FAILED, "compute_graph_ is nullptr"); - return FAILED; - } - - for (auto &node : compute_graph_->GetAllNodes()) { + for (auto &node : graph->GetAllNodes()) { ComputeGraphDescInfo compute_graph_info; auto op_desc = node->GetOpDesc(); if (op_desc == nullptr) { @@ -3445,6 +3655,11 @@ Status DavinciModel::GetComputeGraphInfo(std::vector &comp auto op_mode = static_cast(domi::ImplyType::INVALID); if (AttrUtils::GetInt(op_desc, ATTR_NAME_IMPLY_TYPE, op_mode) && op_mode == static_cast(domi::ImplyType::TVM)) { + if (!om_name_.empty()) { + compute_graph_info.model_name = om_name_; + } else { + compute_graph_info.model_name = name_; + } compute_graph_info.op_name = op_desc->GetName(); compute_graph_info.op_type = op_desc->GetType(); @@ -3462,12 +3677,18 @@ Status DavinciModel::GetComputeGraphInfo(std::vector &comp compute_graph_info.output_data_type.emplace_back(output_desc.GetDataType()); } - compute_graph_desc_info.emplace_back(compute_graph_info); + graph_desc_info.emplace_back(compute_graph_info); } } GELOGI("GetComputeGraphInfo end."); return SUCCESS; } +void DavinciModel::SetTotalFixedAddrsSize(string tensor_name, int64_t fix_addr_size) { + if (tensor_name_to_fixed_addr_size_.find(tensor_name) == tensor_name_to_fixed_addr_size_.end()) { + tensor_name_to_fixed_addr_size_[tensor_name] = total_fixed_addr_size_; + total_fixed_addr_size_ += fix_addr_size; + } +} Status DavinciModel::GetOrigInputInfo(uint32_t index, OriginInputInfo &orig_input_info) { GE_CHK_BOOL_RET_STATUS(index < data_op_list_.size(), PARAM_INVALID, "Index %u is invalid.", index); @@ -3534,7 +3755,8 @@ Status DavinciModel::GetAllAippInputOutputDims(uint32_t index, std::vectorGetInputDescPtr(kDataIndex)), data_input_size); GELOGD( - "GetAllAippInputOutputDims related Data[%d]: tensor_name is %s, dim_num is %u, tensor_size: %zu, format: %s, " + "GetAllAippInputOutputDims related Data[%d]: tensor_name is %s, dim_num is %u, tensor_size: %zu, format: " + "%s, " "data_type: %s, shape: %s .", index, data_op->GetName().c_str(), data_input_desc->GetShape().GetDimNum(), data_input_size, TypeUtils::FormatToSerialString(data_input_desc->GetFormat()).c_str(), @@ -3556,4 +3778,23 @@ Status DavinciModel::GetAllAippInputOutputDims(uint32_t index, std::vectorHasAttr(ATTR_DYNAMIC_SHAPE_FIXED_ADDR) && op_desc->HasAttr(ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX)) { + string tensor_name; + (void)AttrUtils::GetStr(op_desc, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, tensor_name); + int64_t index = -1; + (void)AttrUtils::GetInt(op_desc, ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX, index); + if (index >= 0) { + tensor_name_to_peer_output_index_[tensor_name] = index; + } + } +} } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/davinci_model.h b/src/ge/graph/load/new_model_manager/davinci_model.h index 8123b0b8..cb7e4528 100644 --- a/src/ge/graph/load/new_model_manager/davinci_model.h +++ b/src/ge/graph/load/new_model_manager/davinci_model.h @@ -28,13 +28,15 @@ #include "common/helper/model_helper.h" #include "common/helper/om_file_helper.h" #include "common/opskernel/ge_task_info.h" +#include "common/properties_manager.h" #include "common/types.h" #include "framework/common/util.h" #include "graph/debug/ge_attr_define.h" +#include "graph/load/new_model_manager/aipp_utils.h" #include "graph/load/new_model_manager/data_dumper.h" #include "graph/load/new_model_manager/data_inputer.h" #include "graph/load/new_model_manager/model_utils.h" -#include "graph/load/new_model_manager/aipp_utils.h" +#include "graph/load/new_model_manager/zero_copy_offset.h" #include "graph/load/new_model_manager/zero_copy_task.h" #include "graph/model.h" #include "graph/node.h" @@ -47,6 +49,10 @@ #include "task_info/task_info.h" namespace ge { +// op debug need 2048 bits buffer +const size_t kOpDebugMemorySize = 2048UL; +const size_t kDebugP2pSize = 8UL; + typedef enum tagModelProcStage { MODEL_LOAD_START = 1, MODEL_LOAD_END, @@ -171,13 +177,6 @@ class DavinciModel { // get session id uint64_t SessionId() const { return runtime_param_.session_id; } - vector GetOpDesc() { - vector opDescVector; - GE_IF_BOOL_EXEC(AttrUtils::GetListOpDesc(GetGeModel(), MODEL_ATTR_FUSION_MODEL_DEF, opDescVector), - GELOGI("get opDesc of opDescVector")); - return opDescVector; - } - // get model priority int32_t Priority() const { return priority_; } @@ -248,15 +247,9 @@ class DavinciModel { /// Format GetFormat(); - rtModel_t GetRtModelHandle() { - rtModel_t res = rt_model_handle_; - return res; - } + rtModel_t GetRtModelHandle() const { return rt_model_handle_; } - rtStream_t GetRtModelStream() { - rtModel_t res = rt_model_stream_; - return res; - } + rtStream_t GetRtModelStream() const { return rt_model_stream_; } uint64_t GetRtBaseAddr() const { return runtime_param_.logic_mem_base; } @@ -293,11 +286,20 @@ class DavinciModel { /// @ingroup ge /// @brief Get dynamic batch_info /// @param [out] batch_info + /// @param [out] dynamic_type /// @return execute result /// - Status GetDynamicBatchInfo(std::vector> &batch_info); + Status GetDynamicBatchInfo(std::vector> &batch_info, int32_t &dynamic_type) const; + + /// + /// @ingroup ge + /// @brief Get combined dynamic dims info + /// @param [out] batch_info + /// @return None + /// + void GetCombinedDynamicDims(std::vector> &batch_info) const; - void GetCurShape(std::vector &batch_info); + void GetCurShape(std::vector &batch_info, int32_t &dynamic_type); void GetModelAttr(std::vector &dynamic_output_shape_info); @@ -344,10 +346,9 @@ class DavinciModel { /// /// @ingroup ge /// @brief dump all op input and output information - /// @param [in] op_list model_id - /// @return Status + /// @return void /// - Status DumpOpInputOutput(); + void DumpOpInputOutput(); /// /// @ingroup ge @@ -403,7 +404,9 @@ class DavinciModel { /// uint32_t GetDeviceId() const { return device_id_; } - GeModelPtr GetGeModel() { return ge_model_; } + bool NeedDestroyAicpuKernel() const { return need_destroy_aicpu_kernel_; } + + Status UpdateSessionId(uint64_t session_id); const RuntimeParam &GetRuntimeParam() { return runtime_param_; } @@ -423,7 +426,7 @@ class DavinciModel { void SetZeroCopyAddr(const OpDescPtr &op_desc, const std::vector &outside_addrs, const void *info, void *args, size_t size, size_t offset); - void SetDynamicSize(const std::vector &batch_num); + void SetDynamicSize(const std::vector &batch_num, int32_t dynamic_type); bool GetL1FusionEnableOption() { return is_l1_fusion_enable_; } @@ -463,12 +466,29 @@ class DavinciModel { void *cur_args = static_cast(args_) + offset; return cur_args; } + void SetTotalIOAddrs(vector &io_addrs) { + total_io_addrs_.insert(total_io_addrs_.end(), io_addrs.begin(), io_addrs.end()); + } + void SetTotalFixedAddrsSize(string tensor_name, int64_t fix_addr_size); + int64_t GetFixedAddrsSize(string tensor_name); + void *GetCurrentFixedAddr(int64_t offset) const { + void *cur_addr = static_cast(fixed_addrs_) + offset; + return cur_addr; + } + + uint32_t GetFixedAddrOutputIndex(string tensor_name) { + if (tensor_name_to_peer_output_index_.find(tensor_name) != tensor_name_to_peer_output_index_.end()) { + return tensor_name_to_peer_output_index_[tensor_name]; + } + return UINT32_MAX; + } void SetKnownNode(bool known_node) { known_node_ = known_node; } bool IsKnownNode() { return known_node_; } Status MallocKnownArgs(); Status UpdateKnownNodeArgs(const vector &inputs, const vector &outputs); Status CreateKnownZeroCopyMap(const vector &inputs, const vector &outputs); - Status UpdateKnownZeroCopyAddr(vector &io_addrs, uint32_t args_offset); + Status UpdateKnownZeroCopyAddr(); + void SetKnownNodeAddrNotChanged(bool base_addr_not_changed) { base_addr_not_changed_ = base_addr_not_changed; } Status GetOrigInputInfo(uint32_t index, OriginInputInfo &orig_input_info); Status GetAllAippInputOutputDims(uint32_t index, std::vector &input_dims, @@ -477,6 +497,9 @@ class DavinciModel { // om file name void SetOmName(string om_name) { om_name_ = om_name; } + void SetDumpProperties(const DumpProperties &dump_properties) { data_dumper_.SetDumpProperties(dump_properties); } + const DumpProperties &GetDumpProperties() const { return data_dumper_.GetDumpProperties(); } + private: // memory address of weights uint8_t *weights_mem_base_; @@ -493,8 +516,6 @@ class DavinciModel { struct timeInfo time_info_; int32_t dataInputTid; - void InitZeroCopyUtil(bool is_dynamic_batch, bool &input_zero_copy, bool &output_zero_copy); - /// /// @ingroup ge /// @brief Save Batch label Info. @@ -504,22 +525,6 @@ class DavinciModel { /// void SetBatchLabelAddr(const OpDescPtr &op_desc, uintptr_t addr); - /// - /// @ingroup ge - /// @brief Save Data address info for ZeroCopy. - /// @param [in] const std::vector &outside_addrs - /// @return None. - /// - void SetInputOutsideAddr(const std::vector &outside_addrs); - - /// - /// @ingroup ge - /// @brief Save NetOutput address info for ZeroCopy. - /// @param [in] const std::vector &outside_addrs - /// @return None. - /// - void SetOutputOutsideAddr(const std::vector &outside_addrs); - /// /// @ingroup ge /// @brief Copy Check input size and model op size. @@ -530,6 +535,13 @@ class DavinciModel { /// bool CheckInputAndModelSize(const int64_t &input_size, const int64_t &op_size, bool is_dynamic); + /// + /// @ingroup ge + /// @brief Set copy only for No task feed NetOutput address. + /// @return None. + /// + void SetCopyOnlyOutput(); + /// /// @ingroup ge /// @brief Copy Input/Output to model for direct use. @@ -550,19 +562,15 @@ class DavinciModel { /// @param [in] batch_label: batch label for multi-batch scenes /// @return SUCCESS handle successfully / others handle failed /// - Status UpdateIoTaskArgs(const map> &data_info, bool is_input, + Status UpdateIoTaskArgs(const std::map &data_info, bool is_input, const vector &blobs, bool is_dynamic, const string &batch_label); Status CopyInputData(const InputData &input_data, bool device_data = false); - Status CopyOutputData(uint32_t data_id, OutputData &output_data); - - Status CopyOutputDataToUser(OpDescPtr &op_desc, std::vector &blobs, uint32_t &data_index); + Status CopyOutputData(uint32_t data_id, OutputData &output_data, rtMemcpyKind_t kind); Status SyncVarData(); - Status SyncDataAndDump(); - Status InitModelMem(void *dev_ptr, size_t memsize, void *weight_ptr, size_t weightsize); void CreateInputDimsInfo(const OpDescPtr &op_desc, Format format, InputOutputDescInfo &input); @@ -589,7 +597,12 @@ class DavinciModel { bool IsAicpuKernelConnectSpecifiedLayer(); - Status MarkSpecifiedAicpuKernel(); + /// + /// @ingroup ge + /// @brief Reduce memory usage after task sink. + /// @return: void + /// + void Shrink(); /// /// @ingroup ge @@ -691,8 +704,7 @@ class DavinciModel { /// Status BindInputQueue(); - Status CpuTaskModelZeroCopy(std::vector &mbuf_list, - std::map> &outside_addrs); + Status CpuTaskModelZeroCopy(std::vector &mbuf_list, std::map &outside_addrs); /// /// @ingroup ge @@ -725,10 +737,9 @@ class DavinciModel { /// /// @ingroup ge /// @brief definiteness queue schedule, active original model stream. - /// @param [in] streams: streams will active by S0. /// @return: 0 for success / others for fail /// - Status CpuActiveStream(const std::vector &stream_list); + Status CpuActiveStream(); /// /// @ingroup ge @@ -746,6 +757,9 @@ class DavinciModel { /// Status CpuModelRepeat(); + Status InitEntryTask(); + Status AddHeadStream(); + /// /// @ingroup ge /// @brief set ts device. @@ -753,6 +767,10 @@ class DavinciModel { /// Status SetTSDevice(); + Status OpDebugRegister(); + + void OpDebugUnRegister(); + void CheckHasHcomOp(); Status DoTaskSink(); @@ -760,17 +778,17 @@ class DavinciModel { void CreateOutput(uint32_t index, OpDescPtr &op_desc, InputOutputDescInfo &output, uint32_t &format_result); Status TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id); - Status CopyVarData(ComputeGraphPtr &graph); // get desc info of graph for profiling - Status GetComputeGraphInfo(vector &compute_graph_desc_info); + Status GetComputeGraphInfo(const ComputeGraphPtr &graph, vector &graph_desc_info); - void SetDataDumperArgs(); + void SetDataDumperArgs(const ComputeGraphPtr &compute_graph); Status GenOutputTensorInfo(const OpDescPtr &op_desc, uint32_t data_index, OutputData *output_data, std::vector &outputs); void ParseAIPPInfo(std::string in_out_info, InputOutputDims &dims_info); + void GetFixedAddrAttr(const OpDescPtr &op_desc); bool is_model_has_inited_; uint32_t model_id_; @@ -783,6 +801,9 @@ class DavinciModel { uint32_t version_; GeModelPtr ge_model_; + bool need_destroy_aicpu_kernel_{false}; + vector out_node_name_; + map op_list_; // data op_desc @@ -792,8 +813,12 @@ class DavinciModel { vector variable_op_list_; - std::map> input_data_info_; // Virtual address from Data output. - std::map> output_data_info_; // Virtual address from NetOutput input. + std::map new_input_data_info_; + std::map new_output_data_info_; + std::map new_input_outside_addrs_; + std::map new_output_outside_addrs_; + + std::vector real_virtual_addrs_; // output op: save cce op actual needed memory size vector output_memory_size_list_; @@ -825,9 +850,7 @@ class DavinciModel { std::mutex outside_addrs_mutex_; std::vector zero_copy_tasks_; // Task used Data or NetOutput addr. std::set copy_only_addrs_; // Address need copy to original place. - // {node_addr, {addr_in_task_args}} - std::map> input_outside_addrs_; // Key is virtual address from Data. - std::map> output_outside_addrs_; // Key is virtual address from NetOutput. + // {op_id, batch_label} std::map zero_copy_op_id_batch_label_; // {batch_label, addrs} @@ -843,6 +866,11 @@ class DavinciModel { bool is_async_mode_; // For NN execute, Async mode use rtMemcpyAsync on rt_model_stream_. + bool is_pure_head_stream_{false}; + rtStream_t rt_head_stream_{nullptr}; + rtStream_t rt_entry_stream_{nullptr}; + rtAicpuDeployType_t deploy_type_{AICPU_DEPLOY_RESERVED}; + // ACL queue schedule, save queue ids for Init. std::vector cpu_task_list_; std::vector input_queue_ids_; // input queue ids created by caller. @@ -850,10 +878,6 @@ class DavinciModel { std::vector input_mbuf_list_; // input mbuf created by dequeue task. std::vector output_mbuf_list_; // output mbuf created by dequeue task. - // save input/output tensor descriptor in maps - std::map data_op_input_tensor_desc_map_; - std::map data_op_output_tensor_desc_map_; - uint64_t session_id_; uint32_t device_id_; @@ -864,8 +888,6 @@ class DavinciModel { std::vector active_stream_list_; std::set active_stream_indication_; - std::shared_ptr model_task_def_; - std::set aicpu_streams_; std::set hcom_streams_; RuntimeParam runtime_param_; @@ -877,22 +899,44 @@ class DavinciModel { // for profiling task and graph info std::map op_name_map_; std::vector task_desc_info_; - ComputeGraphPtr compute_graph_; int64_t maxDumpOpNum_; // for data dump DataDumper data_dumper_; uint64_t iterator_count_; bool is_l1_fusion_enable_; + std::map saved_task_addrs_; bool known_node_ = false; uint32_t total_args_size_ = 0; void *args_ = nullptr; void *args_host_ = nullptr; + void *fixed_addrs_ = nullptr; + int64_t total_fixed_addr_size_ = 0; std::map knonw_input_data_info_; std::map knonw_output_data_info_; + vector total_io_addrs_; + vector orig_total_io_addrs_; + bool base_addr_not_changed_ = false; + + vector> batch_info_; + std::vector> combined_batch_info_; + int32_t dynamic_type_ = 0; vector batch_size_; + // key: input tensor name, generally rts op; + // value: the fixed addr of input anchor, same as the peer output anchor addr of the peer op + std::map tensor_name_to_fixed_addr_size_; + + // key: input tensor name, generally rts op; value: the peer output anchor of the peer op + std::map tensor_name_to_peer_output_index_; + // if model is first execute + bool is_first_execute_; + // for op debug + std::mutex debug_reg_mutex_; + bool is_op_debug_reg_ = false; + void *op_debug_addr_ = nullptr; + void *p2p_debug_addr_ = nullptr; bool is_new_model_desc_{false}; }; } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_manager.cc b/src/ge/graph/load/new_model_manager/model_manager.cc index d98ad8de..51b5b028 100644 --- a/src/ge/graph/load/new_model_manager/model_manager.cc +++ b/src/ge/graph/load/new_model_manager/model_manager.cc @@ -22,8 +22,9 @@ #include "common/profiling/profiling_manager.h" #include "common/properties_manager.h" #include "framework/common/debug/ge_log.h" -#include "graph/debug/ge_attr_define.h" #include "framework/common/util.h" +#include "graph/common/ge_call_wrapper.h" +#include "graph/debug/ge_attr_define.h" #include "graph/load/new_model_manager/davinci_model.h" #include "graph/load/new_model_manager/davinci_model_parser.h" #include "model/ge_root_model.h" @@ -33,9 +34,10 @@ thread_local uint32_t device_count = 0; namespace { const int kCmdParSize = 2; const int kDumpCmdPairSize = 2; -const char *const kNeedDestroySpecifiedAicpuKernel = "need_destroy_specified_aicpu_kernel"; } // namespace +DumpProperties ModelManager::dump_properties_; + std::shared_ptr ModelManager::GetInstance() { static const std::shared_ptr instance_ptr = shared_ptr(new (std::nothrow) ModelManager(), ModelManager::FinalizeForPtr); @@ -68,11 +70,11 @@ Status ModelManager::KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, u auto kernel_size = sizeof(uint64_t) * (v_aicpu_kernel.size()); rtError_t rt_ret = rtMalloc(&aicpu_kernel_addr, kernel_size, RT_MEMORY_HBM); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc error, ret: 0x%X", rt_ret); - return RT_FAILED;) + return RT_ERROR_TO_GE_STATUS(rt_ret);) rt_ret = rtMemcpy(aicpu_kernel_addr, kernel_size, v_aicpu_kernel.data(), kernel_size, RT_MEMCPY_HOST_TO_DEVICE); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy to input_output_addr_ error: 0x%X", rt_ret); - GE_CHK_RT(rtFree(aicpu_kernel_addr)); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy to input_output_addr_ error: 0x%X", rt_ret); + GE_CHK_RT(rtFree(aicpu_kernel_addr)); return RT_ERROR_TO_GE_STATUS(rt_ret);) uint64_t kernel_id_addr = static_cast(reinterpret_cast(aicpu_kernel_addr)); param_base.fwkKernelBase.fwk_kernel.kernelID = kernel_id_addr; // In the scene of loading once and running many times, the kernel needs to be destroyed many times, @@ -82,64 +84,64 @@ Status ModelManager::KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, u rtError_t rt_ret = rtMalloc(&(devicebase), sizeof(STR_FWK_OP_KERNEL), RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "malloc device memory failed."); + GELOGE(RT_FAILED, "malloc device memory failed. ret: 0x%X", rt_ret); GE_IF_BOOL_EXEC(aicpu_kernel_addr != nullptr, GE_CHK_RT(rtFree(aicpu_kernel_addr))); - return FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } rt_ret = rtMemcpy(devicebase, sizeof(STR_FWK_OP_KERNEL), ¶m_base, sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "memory copy to device failed."); + GELOGE(RT_FAILED, "memory copy to device failed. ret: 0x%X", rt_ret); GE_IF_BOOL_EXEC(aicpu_kernel_addr != nullptr, GE_CHK_RT(rtFree(aicpu_kernel_addr))); GE_CHK_RT(rtFree(devicebase)); - return FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } rtStream_t stream = nullptr; rt_ret = rtStreamCreate(&stream, 0); if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "create stream failed."); + GELOGE(RT_FAILED, "create stream failed. ret: 0x%X", rt_ret); GE_IF_BOOL_EXEC(aicpu_kernel_addr != nullptr, GE_CHK_RT(rtFree(aicpu_kernel_addr))); GE_CHK_RT(rtFree(devicebase)); - return FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } rt_ret = rtKernelLaunchEx(devicebase, sizeof(STR_FWK_OP_KERNEL), 0, stream); if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "rtKernelLaunchEx failed."); + GELOGE(RT_FAILED, "rtKernelLaunchEx failed. ret: 0x%X", rt_ret); GE_IF_BOOL_EXEC(aicpu_kernel_addr != nullptr, GE_CHK_RT(rtFree(aicpu_kernel_addr))); GE_CHK_RT(rtFree(devicebase)); GE_CHK_RT(rtStreamDestroy(stream)); - return FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } rt_ret = rtStreamSynchronize(stream); if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "rtStreamSynchronize failed."); + GELOGE(RT_FAILED, "rtStreamSynchronize failed. ret: 0x%X", rt_ret); GE_IF_BOOL_EXEC(aicpu_kernel_addr != nullptr, GE_CHK_RT(rtFree(aicpu_kernel_addr))); GE_CHK_RT(rtFree(devicebase)); GE_CHK_RT(rtStreamDestroy(stream)); - return FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } if (aicpu_kernel_addr != nullptr) { rt_ret = rtFree(aicpu_kernel_addr); if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "free memory failed."); + GELOGE(RT_FAILED, "free memory failed. ret: 0x%X", rt_ret); GE_CHK_RT(rtFree(devicebase)); GE_CHK_RT(rtStreamDestroy(stream)); - return FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } } rt_ret = rtFree(devicebase); if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "free memory failed."); + GELOGE(RT_FAILED, "free memory failed. ret: 0x%X", rt_ret); GE_CHK_RT(rtStreamDestroy(stream)); - return FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } rt_ret = rtStreamDestroy(stream); if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "rtStreamDestroy failed."); - return FAILED; + GELOGE(RT_FAILED, "rtStreamDestroy failed. ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); } return SUCCESS; } @@ -166,8 +168,8 @@ ge::Status ModelManager::DestroyAicpuSessionForInfer(uint32_t model_id) { std::lock_guard lock(map_mutex_); auto it = model_map_.find(model_id); if (it == model_map_.end()) { - GELOGE(PARAM_INVALID, "model id %u does not exists.", model_id); - return PARAM_INVALID; + GELOGE(GE_EXEC_MODEL_ID_INVALID, "model id %u does not exists.", model_id); + return GE_EXEC_MODEL_ID_INVALID; } uint64_t session_id = it->second->GetSessionId(); GELOGI("Destroy aicpu session for infer, session id is %u.", session_id); @@ -221,10 +223,11 @@ Status ModelManager::SetDevice(int32_t deviceId) const { return SUCCESS; } -ge::Status ModelManager::SetDynamicSize(uint32_t model_id, const std::vector &batch_num) { +ge::Status ModelManager::SetDynamicSize(uint32_t model_id, const std::vector &batch_num, + int32_t dynamic_type) { std::shared_ptr davinci_model = GetModel(model_id); GE_CHECK_NOTNULL(davinci_model); - davinci_model->SetDynamicSize(batch_num); + davinci_model->SetDynamicSize(batch_num, dynamic_type); return SUCCESS; } @@ -272,6 +275,10 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrSetId(model_id); davinci_model->SetDeviceId(GetContext().DeviceId()); + const DumpProperties &dump_properties = PropertiesManager::Instance().GetDumpProperties(GetContext().SessionId()); + davinci_model->SetDumpProperties(dump_properties); + dump_properties_ = dump_properties; + auto root_graph = ge_root_model->GetRootGraph(); GE_CHECK_NOTNULL(root_graph); string root_model_name = root_graph->GetName(); @@ -296,9 +303,6 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptrSetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond davinci_model->SetProfileTime(MODEL_LOAD_END); - if (davinci_model->SinkModelProfile() != SUCCESS) { - GELOGW("Sink model profile failed."); - } } } while (0); @@ -325,12 +329,18 @@ Status ModelManager::DeleteModel(uint32_t id) { auto it = model_map_.find(id); auto hybrid_model_it = hybrid_model_map_.find(id); if (it != model_map_.end()) { + uint64_t session_id = it->second->GetSessionId(); + std::string model_key = std::to_string(session_id) + "_" + std::to_string(id); + auto iter_aicpu_kernel = model_aicpu_kernel_.find(model_key); + if (iter_aicpu_kernel != model_aicpu_kernel_.end()) { + (void)model_aicpu_kernel_.erase(iter_aicpu_kernel); + } (void)model_map_.erase(it); } else if (hybrid_model_it != hybrid_model_map_.end()) { (void)hybrid_model_map_.erase(hybrid_model_it); } else { - GELOGE(PARAM_INVALID, "model id %u does not exists.", id); - return PARAM_INVALID; + GELOGE(GE_EXEC_MODEL_ID_INVALID, "model id %u does not exists.", id); + return GE_EXEC_MODEL_ID_INVALID; } return SUCCESS; @@ -383,7 +393,7 @@ Status ModelManager::DataInput(const InputData &input_data, OutputData &output_d std::shared_ptr model = GetModel(model_id); - GE_CHK_BOOL_RET_STATUS(model != nullptr, PARAM_INVALID, "Invalid Model ID %u in InputData! ", model_id); + GE_CHK_BOOL_RET_STATUS(model != nullptr, PARAM_INVALID, "Invalid model id %u in InputData! ", model_id); GE_IF_BOOL_EXEC(model->GetDataInputTid() == 0, model->SetDataInputTid(mmGetTid())); @@ -419,7 +429,7 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector(inputs[i].length); + data.length = inputs[i].length; input_data.blobs.push_back(data); } @@ -439,7 +449,7 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vectorGetDataInputer(); GE_CHECK_NOTNULL(inputer); @@ -469,7 +479,7 @@ Status ModelManager::Start(uint32_t model_id) { std::shared_ptr davinci_model = GetModel(model_id); - GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid Model ID %u to start! ", model_id); + GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid model id %u to start! ", model_id); Status status = davinci_model->ModelRunStart(); if (status == SUCCESS) { @@ -496,7 +506,7 @@ Status ModelManager::Stop(uint32_t model_id) { } std::shared_ptr davinci_model = GetModel(model_id); - GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid Model ID %u to stop!", model_id); + GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid model id %u to stop!", model_id); Status status = davinci_model->ModelRunStop(); if (status == SUCCESS) { @@ -611,10 +621,10 @@ Status ModelManager::HandleDumpCommand(const Command &command) { GELOGE(PARAM_INVALID, "parser dump model failed"); return FAILED; } - GELOGI("dump status = %s.", dump_model.c_str()); + GELOGI("dump model = %s.", dump_model.c_str()); if (dump_status == "off" || dump_status == "OFF") { - PropertiesManager::Instance().DeleteDumpPropertyValue(dump_model); + dump_properties_.DeletePropertyValue(dump_model); return SUCCESS; } @@ -631,9 +641,10 @@ Status ModelManager::HandleDumpCommand(const Command &command) { return FAILED; } if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { - dump_path = dump_path + "/" + CurrentTimeInStr() + "/"; + dump_path = dump_path + "/"; } - GELOGI("dump status = %s.", dump_path.c_str()); + dump_path = dump_path + CurrentTimeInStr() + "/"; + GELOGI("dump path = %s.", dump_path.c_str()); ret = ParserPara(command, DUMP_MODE, dump_mode); if (ret != SUCCESS) { @@ -642,20 +653,10 @@ Status ModelManager::HandleDumpCommand(const Command &command) { } GELOGI("dump mode = %s", dump_mode.c_str()); - auto iter_dump_mode = std::find(command.cmd_params.begin(), command.cmd_params.end(), DUMP_MODE); - if (iter_dump_mode != command.cmd_params.end()) { - ++iter_dump_mode; - if (iter_dump_mode == command.cmd_params.end()) { - GELOGE(PARAM_INVALID, "Invalid access."); - return PARAM_INVALID; - } - dump_mode = *iter_dump_mode; - GELOGI("dump mode = %s", dump_mode.c_str()); - } + dump_properties_.AddPropertyValue(dump_model, dump_layers); + dump_properties_.SetDumpPath(dump_path); + dump_properties_.SetDumpMode(dump_mode); - PropertiesManager::Instance().AddDumpPropertyValue(dump_model, dump_layers); - PropertiesManager::Instance().SetDumpOutputPath(dump_path); - PropertiesManager::Instance().SetDumpMode(dump_mode); return SUCCESS; } @@ -667,7 +668,7 @@ Status ModelManager::GetMaxUsedMemory(const uint32_t model_id, uint64_t &max_siz } std::shared_ptr davinci_model = GetModel(model_id); - GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "GetMaxUsedMemory Failed, Invalid Model ID %u !", + GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "GetMaxUsedMemory Failed, Invalid model id %u!", model_id); max_size = davinci_model->TotalMemSize(); @@ -677,8 +678,8 @@ Status ModelManager::GetMaxUsedMemory(const uint32_t model_id, uint64_t &max_siz Status ModelManager::GetInputOutputDescInfo(const uint32_t model_id, vector &input_desc, vector &output_desc) { std::shared_ptr davinci_model = GetModel(model_id); - GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, - "GetInputOutputDescInfo Failed, Invalid Model ID %u !", model_id); + GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "GetInputOutputDescInfo Failed, Invalid model id %u!", + model_id); return davinci_model->GetInputOutputDescInfo(input_desc, output_desc); } @@ -688,8 +689,8 @@ Status ModelManager::GetInputOutputDescInfo(const uint32_t model_id, vector &inputFormats, std::vector &outputFormats, bool new_model_desc) { std::shared_ptr davinci_model = GetModel(model_id); - GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, - "GetInputOutputDescInfo Failed, Invalid Model ID %u !", model_id); + GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, GE_EXEC_MODEL_ID_INVALID, + "GetInputOutputDescInfo Failed, Invalid model id %u!", model_id); davinci_model->SetModelDescVersion(new_model_desc); @@ -703,18 +704,35 @@ Status ModelManager::GetInputOutputDescInfo(const uint32_t model_id, vector> &batch_info) { +Status ModelManager::GetDynamicBatchInfo(const uint32_t model_id, std::vector> &batch_info, + int32_t &dynamic_type) { + std::shared_ptr davinci_model = GetModel(model_id); + GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, GE_EXEC_MODEL_ID_INVALID, + "GetDynamicBatchInfo failed, Invalid model id %u!", model_id); + + return davinci_model->GetDynamicBatchInfo(batch_info, dynamic_type); +} + +/// +/// @ingroup ge +/// @brief Get combined dynamic dims info +/// @param [in] model_id +/// @param [out] batch_info +/// @return execute result +/// +Status ModelManager::GetCombinedDynamicDims(const uint32_t model_id, vector> &batch_info) { std::shared_ptr davinci_model = GetModel(model_id); - GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "GetDynamicBatchInfo Failed, Invalid Model ID %u !", + GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "GetCombinedDynamicDims Failed, Invalid Model ID %u!", model_id); - return davinci_model->GetDynamicBatchInfo(batch_info); + davinci_model->GetCombinedDynamicDims(batch_info); + return SUCCESS; } -Status ModelManager::GetCurShape(const uint32_t model_id, std::vector &batch_info) { +Status ModelManager::GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type) { std::shared_ptr davinci_model = GetModel(model_id); GE_CHECK_NOTNULL(davinci_model); - davinci_model->GetCurShape(batch_info); + davinci_model->GetCurShape(batch_info, dynamic_type); return SUCCESS; } @@ -730,8 +748,8 @@ Status ModelManager::GetInputOutputDescInfoForZeroCopy(const uint32_t model_id, std::vector &inputFormats, std::vector &outputFormats) { std::shared_ptr davinci_model = GetModel(model_id); - GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, - "GetInputOutputDescInfo Failed, Invalid Model ID %u !", model_id); + GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "GetInputOutputDescInfo Failed, Invalid model id %u!", + model_id); return davinci_model->GetInputOutputDescInfoForZeroCopy(input_desc, output_desc, inputFormats, outputFormats); } @@ -771,21 +789,10 @@ Status ModelManager::GenSessionId(uint64_t &session_id) { return SUCCESS; } -Status ModelManager::UpdateSessionId(std::shared_ptr &davinci_model, uint64_t session_id) { - GeModelPtr ge_model_current = davinci_model->GetGeModel(); - GE_CHECK_NOTNULL(ge_model_current); - if (!ge::AttrUtils::SetInt(ge_model_current, ge::MODEL_ATTR_SESSION_ID, static_cast(session_id))) { - GELOGW("Set attr[%s] failed in updating session_id.", MODEL_ATTR_SESSION_ID.c_str()); - } - - GELOGD("Update session id: %lu.", session_id); - return SUCCESS; -} - Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model, shared_ptr listener, void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { - GE_CHK_BOOL_RET_STATUS(model.key.empty() || access(model.key.c_str(), F_OK) == 0, PARAM_INVALID, - "input key file path is not valid, %s", strerror(errno)); + GE_CHK_BOOL_RET_STATUS(model.key.empty() || access(model.key.c_str(), F_OK) == 0, GE_EXEC_MODEL_KEY_PATH_INVALID, + "input key file path %s is invalid, %s", model.key.c_str(), strerror(errno)); GenModelId(&model_id); shared_ptr davinci_model = nullptr; @@ -803,11 +810,11 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model try { davinci_model = std::make_shared(model.priority, listener); } catch (std::bad_alloc &) { - GELOGE(FAILED, "Make shared failed"); - return FAILED; + GELOGE(MEMALLOC_FAILED, "Make shared failed"); + return MEMALLOC_FAILED; } catch (...) { - GELOGE(FAILED, "Make shared failed since other exception raise"); - return FAILED; + GELOGE(INTERNAL_ERROR, "Make shared failed since other exception raise"); + return INTERNAL_ERROR; } ret = davinci_model->Assign(ge_model); if (ret != SUCCESS) { @@ -820,10 +827,11 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model rtError_t rt_ret = rtGetDevice(&device_id); if (rt_ret != RT_ERROR_NONE || device_id < 0) { GELOGE(RT_FAILED, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id); - return FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } davinci_model->SetDeviceId(device_id); davinci_model->SetOmName(model.om_name); + davinci_model->SetDumpProperties(dump_properties_); /// In multi-threaded inference, using the same session_id among multiple threads may cause some threads to fail. /// These session_ids come from the same model, so the values of session_id are the same. @@ -831,7 +839,7 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model uint64_t new_session_id; ret = GenSessionId(new_session_id); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, break, "Generate session_id for infer failed."); - ret = UpdateSessionId(davinci_model, new_session_id); + ret = davinci_model->UpdateSessionId(new_session_id); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, break, "Update session_id for infer failed."); ret = davinci_model->Init(dev_ptr, mem_size, weight_ptr, weight_size); @@ -846,9 +854,6 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * 1000 * 1000 * 1000 + timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond davinci_model->SetProfileTime(MODEL_LOAD_END); - if (davinci_model->SinkModelProfile() != SUCCESS) { - GELOGW("Sink model profile failed."); - } } GE_IF_BOOL_EXEC(ret == SUCCESS, device_count++); @@ -870,8 +875,9 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model Status ModelManager::LoadModelWithQ(uint32_t &model_id, const ModelData &model_data, const std::vector &input_queue_ids, const std::vector &output_queue_ids) { - GE_CHK_BOOL_RET_STATUS(model_data.key.empty() || access(model_data.key.c_str(), F_OK) == 0, PARAM_INVALID, - "input key file path is not valid, %s", strerror(errno)); + GE_CHK_BOOL_RET_STATUS(model_data.key.empty() || access(model_data.key.c_str(), F_OK) == 0, + GE_EXEC_MODEL_KEY_PATH_INVALID, "input key file path %s is not valid, %s", + model_data.key.c_str(), strerror(errno)); ModelHelper model_helper; Status ret = model_helper.LoadModel(model_data); @@ -882,8 +888,8 @@ Status ModelManager::LoadModelWithQ(uint32_t &model_id, const ModelData &model_d shared_ptr davinci_model = MakeShared(model_data.priority, nullptr); if (davinci_model == nullptr) { - GELOGE(FAILED, "create model failed."); - return FAILED; + GELOGE(MEMALLOC_FAILED, "create model failed."); + return MEMALLOC_FAILED; } ret = davinci_model->Assign(model_helper.GetGeModel()); @@ -898,7 +904,7 @@ Status ModelManager::LoadModelWithQ(uint32_t &model_id, const ModelData &model_d uint64_t new_session_id; ret = GenSessionId(new_session_id); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Generate session_id for infer failed."); - ret = UpdateSessionId(davinci_model, new_session_id); + ret = davinci_model->UpdateSessionId(new_session_id); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Update session_id for infer failed."); GenModelId(&model_id); @@ -909,6 +915,8 @@ Status ModelManager::LoadModelWithQ(uint32_t &model_id, const ModelData &model_d return ret; } + davinci_model->SetDumpProperties(dump_properties_); + ret = davinci_model->Init(); if (ret != SUCCESS) { GELOGE(ret, "init model failed."); @@ -933,14 +941,10 @@ Status ModelManager::LoadModelWithQ(uint32_t &model_id, const ModelData &model_d Status ModelManager::ExecuteModel(uint32_t model_id, rtStream_t stream, bool async_mode, const InputData &input_data, OutputData &output_data) { std::shared_ptr davinci_model = GetModel(model_id); - GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid Model ID %u to start! ", model_id); - - GeModelPtr ge_model_current = davinci_model->GetGeModel(); - bool need_destroy_aicpu_kernel = false; - bool result = ge::AttrUtils::GetBool(ge_model_current, kNeedDestroySpecifiedAicpuKernel, need_destroy_aicpu_kernel); - if (result && need_destroy_aicpu_kernel) { - GELOGI("Get attr %s successfully, start to destroy specified aicpu kernel.", kNeedDestroySpecifiedAicpuKernel); + GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid model id %u.", model_id); + if (davinci_model->NeedDestroyAicpuKernel()) { + GELOGI("Start to destroy specified aicpu kernel."); // Zero copy is enabled by default, no need to judge. uint64_t session_id_davinci = davinci_model->GetSessionId(); uint32_t model_id_davinci = davinci_model->GetModelId(); @@ -994,29 +998,30 @@ Status ModelManager::GetModelMemAndWeightSize(const ModelData &model, size_t &me auto partition_table = reinterpret_cast(model_data); if (partition_table->num == 1) { - GELOGE(FAILED, "om model is error,please use executable om model"); - return FAILED; + GELOGE(GE_EXEC_MODEL_PARTITION_NUM_INVALID, "om model is error,please use executable om model"); + return GE_EXEC_MODEL_PARTITION_NUM_INVALID; } ModelPartition task_partition; if (om_file_helper.GetModelPartition(ModelPartitionType::TASK_INFO, task_partition) != SUCCESS) { - GELOGE(FAILED, "get task model partition failed."); - return FAILED; + GELOGE(GE_EXEC_LOAD_TASK_PARTITION_FAILED, "get task model partition failed."); + return GE_EXEC_LOAD_TASK_PARTITION_FAILED; } std::shared_ptr model_task_def = MakeShared(); if (model_task_def == nullptr) { - return FAILED; + return MEMALLOC_FAILED; } if (task_partition.size != 0) { if (!ReadProtoFromArray(task_partition.data, static_cast(task_partition.size), model_task_def.get())) { - GELOGE(FAILED, "ReadProtoFromArray failed."); - return FAILED; + GELOGE(GE_EXEC_LOAD_TASK_PARTITION_FAILED, "ReadProtoFromArray failed."); + return GE_EXEC_LOAD_TASK_PARTITION_FAILED; } } ModelPartition partition_weight; ret = om_file_helper.GetModelPartition(ModelPartitionType::WEIGHTS_DATA, partition_weight); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return ret, "Get weight partition failed. ret = %u", ret); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, return GE_EXEC_LOAD_WEIGHT_PARTITION_FAILED, + "Get weight partition failed. ret = %u", ret); mem_size = model_task_def->memory_size(); weight_size = partition_weight.size; @@ -1050,4 +1055,19 @@ Status ModelManager::GetAllAippInputOutputDims(uint32_t model_id, uint32_t index return davinci_model->GetAllAippInputOutputDims(index, input_dims, output_dims); } +bool ModelManager::IsDynamicShape(uint32_t model_id) { + auto model = GetHybridModel(model_id); + return model != nullptr; +} + +ge::Status ModelManager::SyncExecuteModel(uint32_t model_id, const vector &inputs, + vector &outputs) { + auto model = GetHybridModel(model_id); + if (model == nullptr) { + GELOGE(FAILED, "Hybrid model not found. model id = %u.", model_id); + return FAILED; + } + + return model->Execute(inputs, outputs); +} } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_manager.h b/src/ge/graph/load/new_model_manager/model_manager.h index 8e2424bf..153d324d 100644 --- a/src/ge/graph/load/new_model_manager/model_manager.h +++ b/src/ge/graph/load/new_model_manager/model_manager.h @@ -31,6 +31,7 @@ #include "common/ge_types.h" #include "common/helper/model_helper.h" #include "common/helper/om_file_helper.h" +#include "common/properties_manager.h" #include "common/types.h" #include "ge/ge_api_types.h" #include "graph/ge_context.h" @@ -141,6 +142,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { ge::Status ExecuteModel(uint32_t model_id, rtStream_t stream, bool async_mode, const InputData &input_data, OutputData &output_data); + ge::Status SyncExecuteModel(uint32_t model_id, const std::vector &inputs, std::vector &outputs); + /// /// @ingroup domi_ome /// @brief model stop @@ -184,9 +187,19 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { /// @brief Get dynamic batch_info /// @param [in] model_id /// @param [out] batch_info + /// @param [out] dynamic_type /// @return execute result /// - ge::Status GetDynamicBatchInfo(const uint32_t model_id, std::vector> &batch_info); + ge::Status GetDynamicBatchInfo(const uint32_t model_id, std::vector> &batch_info, + int32_t &dynamic_type); + /// + /// @ingroup ge + /// @brief Get combined dynamic dims info + /// @param [in] model_id + /// @param [out] batch_info + /// @return execute result + /// + ge::Status GetCombinedDynamicDims(const uint32_t model_id, std::vector> &batch_info); /// /// @ingroup ge @@ -212,13 +225,13 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { std::vector &inputFormats, std::vector &outputFormats); - ge::Status GetCurShape(const uint32_t model_id, std::vector &batch_info); + ge::Status GetCurShape(const uint32_t model_id, std::vector &batch_info, int32_t &dynamic_type); ge::Status GetModelAttr(uint32_t model_id, std::vector &dynamic_output_shape_info); ge::Status SetDevice(int32_t deviceId) const; - ge::Status SetDynamicSize(uint32_t model_id, const std::vector &batch_num); + ge::Status SetDynamicSize(uint32_t model_id, const std::vector &batch_num, int32_t dynamic_type); /// /// @ingroup domi_ome @@ -249,6 +262,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { ge::Status GetAllAippInputOutputDims(uint32_t model_id, uint32_t index, std::vector &input_dims, std::vector &output_dims); + bool IsDynamicShape(uint32_t model_id); + private: /// /// @ingroup domi_ome @@ -276,7 +291,6 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { ge::Status DeleteModel(uint32_t id); void GenModelId(uint32_t *id); - ge::Status UpdateSessionId(std::shared_ptr &davinci_model, uint64_t session_id); std::map> model_map_; std::map> hybrid_model_map_; @@ -287,6 +301,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { std::mutex session_id_create_mutex_; uint64_t session_id_bias_; std::set sess_ids_; + + static DumpProperties dump_properties_; }; } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_utils.cc b/src/ge/graph/load/new_model_manager/model_utils.cc index a807f2a3..8a92e1e6 100644 --- a/src/ge/graph/load/new_model_manager/model_utils.cc +++ b/src/ge/graph/load/new_model_manager/model_utils.cc @@ -31,9 +31,9 @@ namespace ge { /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get input size. -/// @return vector +/// @return vector /// vector ModelUtils::GetInputSize(ConstOpDescPtr op_desc) { vector v_input_size; @@ -43,22 +43,29 @@ vector ModelUtils::GetInputSize(ConstOpDescPtr op_desc) { const vector v_is_input_const = op_desc->GetIsInputConst(); for (size_t i = 0; i < inputs_size; ++i) { + const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); + if (tensor_desc == nullptr) { + GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); + continue; + } + + int64_t tensor_size = 0; if ((i < v_is_input_const.size()) && v_is_input_const[i] && (op_type != NETOUTPUT)) { // TBE: add weights size to input - GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); - int64_t tensor_size = 0; - GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + GE_CHK_STATUS(TensorUtils::GetSize(*tensor_desc, tensor_size)); if (tensor_size) { v_input_size.push_back(tensor_size); } + GELOGI("[IMAS]GetInputSize op: %s, index: %lu, size:%ld", op_desc->GetName().c_str(), i, tensor_size); continue; } - int64_t tensor_size = 0; GE_IF_BOOL_EXEC( - TensorUtils::GetSize(op_desc->GetInputDesc(i), tensor_size) != GRAPH_SUCCESS, + TensorUtils::GetSize(*tensor_desc, tensor_size) != GRAPH_SUCCESS, GELOGI("Get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); - continue;); + continue); + + GELOGI("[IMAS]GetInputSize op: %s, index: %lu, size:%ld", op_desc->GetName().c_str(), i, tensor_size); v_input_size.push_back(tensor_size); } @@ -67,9 +74,9 @@ vector ModelUtils::GetInputSize(ConstOpDescPtr op_desc) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get output size. -/// @return vector +/// @return vector /// vector ModelUtils::GetOutputSize(ConstOpDescPtr op_desc) { vector v_output_size; @@ -82,11 +89,17 @@ vector ModelUtils::GetOutputSize(ConstOpDescPtr op_desc) { return v_output_size;); for (size_t i = 0; i < outputs_size; ++i) { + const GeTensorDescPtr tensor_desc = op_desc->MutableOutputDesc(i); + if (tensor_desc == nullptr) { + GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); + continue; + } + int64_t tensor_size = 0; GE_IF_BOOL_EXEC( - TensorUtils::GetSize(op_desc->GetOutputDesc(i), tensor_size) != GRAPH_SUCCESS, + TensorUtils::GetSize(*tensor_desc, tensor_size) != GRAPH_SUCCESS, GELOGI("Get size from TensorDesc failed, op : %s, output index : %zu", op_desc->GetName().c_str(), i); - continue;); + continue); v_output_size.push_back(tensor_size); } @@ -95,9 +108,9 @@ vector ModelUtils::GetOutputSize(ConstOpDescPtr op_desc) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get workspace size. -/// @return vector +/// @return vector /// vector ModelUtils::GetWorkspaceSize(ConstOpDescPtr op_desc) { vector v_workspace_size; @@ -118,9 +131,9 @@ vector ModelUtils::GetWorkspaceSize(ConstOpDescPtr op_desc) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get weight size. -/// @return vector +/// @return vector /// vector ModelUtils::GetWeightSize(ConstOpDescPtr op_desc) { vector v_weight_size; @@ -142,8 +155,14 @@ vector ModelUtils::GetWeightSize(ConstOpDescPtr op_desc) { const vector v_is_input_const = op_desc->GetIsInputConst(); for (size_t i = 0; i < inputs_size; ++i) { if ((i < v_is_input_const.size()) && v_is_input_const[i]) { + const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); + if (tensor_desc == nullptr) { + GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); + continue; + } + int64_t tensor_size = 0; - (void)TensorUtils::GetSize(op_desc->GetInputDesc(i), tensor_size); + (void)TensorUtils::GetSize(*tensor_desc, tensor_size); v_weight_size.push_back(tensor_size); } } @@ -152,7 +171,7 @@ vector ModelUtils::GetWeightSize(ConstOpDescPtr op_desc) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get weights. /// @return vector /// @@ -176,9 +195,14 @@ vector ModelUtils::GetWeights(ConstOpDescPtr op_desc) { const vector v_is_input_const = op_desc->GetIsInputConst(); for (size_t i = 0; i < inputs_size; ++i) { if ((i < v_is_input_const.size()) && v_is_input_const[i]) { + const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); + if (tensor_desc == nullptr) { + GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); + continue; + } + ConstGeTensorPtr weight = nullptr; - GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); - if (AttrUtils::GetTensor(tensor_desc, ATTR_NAME_WEIGHTS, weight)) { + if (AttrUtils::GetTensor(*tensor_desc, ATTR_NAME_WEIGHTS, weight)) { v_weights.push_back(weight); } } @@ -188,7 +212,7 @@ vector ModelUtils::GetWeights(ConstOpDescPtr op_desc) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get AiCpuOp Input descriptor. /// @return vector<::tagCcAICPUTensor> /// @@ -205,20 +229,25 @@ vector<::tagCcAICPUTensor> ModelUtils::GetInputDescs(ConstOpDescPtr op_desc) { continue; } + const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); + if (tensor_desc == nullptr) { + GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); + continue; + } + uint32_t dim_cnt = 0; - const auto &descriptor = op_desc->GetInputDesc(i); - GE_CHK_BOOL_EXEC_WARN(TensorUtils::GetRealDimCnt(descriptor, dim_cnt) == GRAPH_SUCCESS, continue, + GE_CHK_BOOL_EXEC_WARN(TensorUtils::GetRealDimCnt(*tensor_desc, dim_cnt) == GRAPH_SUCCESS, continue, "Get dim_cnt failed"); opTensor_t tmp; - uint32_t tmp_fmt = descriptor.GetFormat(); + uint32_t tmp_fmt = tensor_desc->GetFormat(); tmp.format = tagOpTensorFormat(tmp_fmt); tmp.dim_cnt = static_cast(dim_cnt); - uint32_t tmp_type = descriptor.GetDataType(); + uint32_t tmp_type = tensor_desc->GetDataType(); tmp.data_type = tagOpDataType(tmp_type); for (int32_t j = 0; j < 4; j++) { // 4 dims - tmp.dim[j] = (j < tmp.dim_cnt ? descriptor.GetShape().GetDim(j) : 1); + tmp.dim[j] = (j < tmp.dim_cnt ? tensor_desc->GetShape().GetDim(j) : 1); } v_input_descs.push_back(tmp); @@ -228,7 +257,7 @@ vector<::tagCcAICPUTensor> ModelUtils::GetInputDescs(ConstOpDescPtr op_desc) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get AiCpuOp Output descriptor. /// @return vector<::tagCcAICPUTensor> /// @@ -240,20 +269,25 @@ vector<::tagCcAICPUTensor> ModelUtils::GetOutputDescs(ConstOpDescPtr op_desc) { // init op output opTensor_t struct const size_t output_num = op_desc->GetOutputsSize(); for (size_t i = 0; i < output_num; ++i) { + const GeTensorDescPtr tensor_desc = op_desc->MutableOutputDesc(i); + if (tensor_desc == nullptr) { + GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); + continue; + } + uint32_t dim_cnt = 0; - const auto &descriptor = op_desc->GetOutputDesc(i); - GE_CHK_BOOL_EXEC_WARN(TensorUtils::GetRealDimCnt(descriptor, dim_cnt) == GRAPH_SUCCESS, continue, + GE_CHK_BOOL_EXEC_WARN(TensorUtils::GetRealDimCnt(*tensor_desc, dim_cnt) == GRAPH_SUCCESS, continue, "Get dim_cnt failed"); opTensor_t tmp; - uint32_t tmp_fmt = descriptor.GetFormat(); + uint32_t tmp_fmt = tensor_desc->GetFormat(); tmp.format = tagOpTensorFormat(tmp_fmt); tmp.dim_cnt = static_cast(dim_cnt); - uint32_t tmp_type = descriptor.GetDataType(); + uint32_t tmp_type = tensor_desc->GetDataType(); tmp.data_type = tagOpDataType(tmp_type); for (int32_t j = 0; j < 4; j++) { // 4 dims - tmp.dim[j] = (j < tmp.dim_cnt ? descriptor.GetShape().GetDim(j) : 1); + tmp.dim[j] = (j < tmp.dim_cnt ? tensor_desc->GetShape().GetDim(j) : 1); } v_output_descs.push_back(tmp); @@ -263,44 +297,14 @@ vector<::tagCcAICPUTensor> ModelUtils::GetOutputDescs(ConstOpDescPtr op_desc) { } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get input data address. /// @return vector /// -vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, - bool need_convert) { +vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc) { vector v_input_data_addr; // init as:buf_base + op_def_->input(i)); GE_CHECK_NOTNULL_EXEC(op_desc, return v_input_data_addr); uint64_t session_id = model_param.session_id; - uint8_t *mem_base = model_param.mem_base; - uint8_t *var_base = model_param.var_base; - uint8_t *weight_base = model_param.weight_base; - const uint64_t logic_mem_base = 0; - uint64_t logic_weight_base = 0; - uint64_t logic_var_base = model_param.logic_var_base; - uint64_t mem_size = model_param.mem_size; - uint64_t weight_size = model_param.weight_size; - uint64_t var_size = model_param.var_size; - - if (need_convert) { - Status status = ConvertVirtualAddressToPhysical(mem_base, mem_size, mem_base); - if (status != SUCCESS) { - GELOGE(RT_FAILED, "Convert virtual address to physical for mem_base failed."); - return v_input_data_addr; - } - - status = ConvertVirtualAddressToPhysical(weight_base, weight_size, weight_base); - if (status != SUCCESS) { - GELOGE(RT_FAILED, "Convert virtual address to physical for weight_base failed."); - return v_input_data_addr; - } - - status = ConvertVirtualAddressToPhysical(var_base, var_size, var_base); - if (status != SUCCESS) { - GELOGE(RT_FAILED, "Convert virtual address to physical for var_base failed."); - return v_input_data_addr; - } - } const size_t inputs_size = op_desc->GetInputsSize(); const vector v_input_offset = op_desc->GetInputOffset(); @@ -319,13 +323,18 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co for (size_t i = 0; i < inputs_size; ++i) { if ((i < v_is_input_const.size()) && v_is_input_const[i] && (op_type != NETOUTPUT)) { // TBE: add weights address to input - GeTensorDesc tensor_desc = op_desc->GetInputDesc(i); + const GeTensorDescPtr tensor_desc = op_desc->MutableInputDesc(i); + if (tensor_desc == nullptr) { + GELOGW("Op: %s, Index: %zu, Tensor Desc is null", op_desc->GetName().c_str(), i); + continue; + } + int64_t tensor_size = 0; - GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + GE_CHK_STATUS(TensorUtils::GetSize(*tensor_desc, tensor_size)); if (tensor_size) { int64_t data_offset = 0; - GE_CHK_STATUS(TensorUtils::GetDataOffset(tensor_desc, data_offset)); - uint8_t *weight_addr = static_cast(weight_base + data_offset - logic_weight_base); + GE_CHK_STATUS(TensorUtils::GetDataOffset(*tensor_desc, data_offset)); + uint8_t *weight_addr = model_param.weight_base + data_offset; v_input_data_addr.push_back(weight_addr); GELOGI("[IMAS]GetInputDataAddrs graph_%u type[C] name[%s] input[%zu] memaddr[%p]", model_param.graph_id, op_desc->GetName().c_str(), i, weight_addr); @@ -340,17 +349,13 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co int64_t input_offset = v_input_offset[non_const_index]; non_const_index++; - GE_IF_BOOL_EXEC(var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(input_offset), - uint8_t *variable_addr = var_base + input_offset - logic_var_base; + GE_IF_BOOL_EXEC(model_param.var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(input_offset), + uint8_t *variable_addr = model_param.var_base + input_offset - model_param.logic_var_base; v_input_data_addr.push_back(variable_addr); GELOGI("[IMAS]GetInputDataAddrs graph_%u type[V] name[%s] input[%lu] memaddr[%p]", model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr); - continue;); + continue); - bool input_tensor = false; - GE_IF_BOOL_EXEC(TensorUtils::GetInputTensor(op_desc->GetOutputDesc(i), input_tensor) != GRAPH_SUCCESS, - GELOGW("get size from TensorDesc failed, op: %s, input index: %zu", op_desc->GetName().c_str(), i); - continue;); // feature maps uint8_t *mem_addr = nullptr; // fusion @@ -358,7 +363,7 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co mem_addr = reinterpret_cast(reinterpret_cast(input_offset)); v_input_data_addr.push_back(mem_addr); } else { - mem_addr = static_cast(mem_base + input_offset - logic_mem_base); + mem_addr = model_param.mem_base + input_offset; v_input_data_addr.push_back(mem_addr); } GELOGI("[IMAS]GetInputDataAddrs graph_%u type[F] name[%s] input[%zu] memaddr[%p]", model_param.graph_id, @@ -369,41 +374,20 @@ vector ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get output data address. /// @return vector /// -vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, - bool need_convert) { +vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc) { vector v_output_data_addr; // init as:buf_base + op_def_->output(i) GE_CHECK_NOTNULL_EXEC(op_desc, return v_output_data_addr); uint64_t session_id = model_param.session_id; - uint8_t *mem_base = model_param.mem_base; - uint8_t *var_base = model_param.var_base; - const uint64_t logic_mem_base = 0; - uint64_t logic_var_base = model_param.logic_var_base; - uint64_t mem_size = model_param.mem_size; - uint64_t var_size = model_param.var_size; - - if (need_convert) { - Status status = ConvertVirtualAddressToPhysical(mem_base, mem_size, mem_base); - if (status != SUCCESS) { - GELOGE(RT_FAILED, "Convert virtual address to physical for mem_base failed."); - return v_output_data_addr; - } - - status = ConvertVirtualAddressToPhysical(var_base, var_size, var_base); - if (status != SUCCESS) { - GELOGE(RT_FAILED, "Convert virtual address to physical for var_base failed."); - return v_output_data_addr; - } - } const size_t outputs_size = op_desc->GetOutputsSize(); const vector v_output_offset = op_desc->GetOutputOffset(); GE_IF_BOOL_EXEC(v_output_offset.size() != outputs_size, GELOGW("Output param invalid: output_offset=%zu, outputs=%zu.", v_output_offset.size(), outputs_size); - return v_output_data_addr;); + return v_output_data_addr); vector v_memory_type; bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, v_memory_type); if (has_mem_type_attr && (v_memory_type.size() != outputs_size)) { @@ -413,12 +397,12 @@ vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C return v_output_data_addr; } for (size_t i = 0; i < outputs_size; ++i) { - GE_IF_BOOL_EXEC(var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(v_output_offset[i]), - uint8_t *variable_addr = static_cast(var_base + v_output_offset[i] - logic_var_base); + GE_IF_BOOL_EXEC(model_param.var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(v_output_offset[i]), + uint8_t *variable_addr = model_param.var_base + v_output_offset[i] - model_param.logic_var_base; v_output_data_addr.push_back(variable_addr); GELOGI("[IMAS]GetOutputDataAddrs graph_%u type[V] name[%s] output[%zu] memaddr[%p]", model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr); - continue;); + continue); // feature maps uint8_t *mem_addr = nullptr; // fusion @@ -426,7 +410,7 @@ vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C mem_addr = reinterpret_cast(reinterpret_cast(v_output_offset[i])); v_output_data_addr.push_back(mem_addr); } else { - mem_addr = static_cast(mem_base + v_output_offset[i] - logic_mem_base); + mem_addr = static_cast(model_param.mem_base + v_output_offset[i]); v_output_data_addr.push_back(mem_addr); } GELOGI("[IMAS]GetOutputDataAddrs graph_%u type[F] name[%s] output[%zu] memaddr[%p]", model_param.graph_id, @@ -436,24 +420,13 @@ vector ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C } /// -/// @ingroup domi_ome +/// @ingroup ge /// @brief Get workspace data address. /// @return vector /// -vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, - bool need_convert) { +vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc) { vector v_workspace_data_addr; GE_CHECK_NOTNULL_EXEC(op_desc, return v_workspace_data_addr); - uint8_t *mem_base = model_param.mem_base; - uint64_t mem_size = model_param.mem_size; - - if (need_convert) { - Status status = ConvertVirtualAddressToPhysical(mem_base, mem_size, mem_base); - if (status != SUCCESS) { - GELOGE(RT_FAILED, "Convert virtual address to physical for mem_base failed."); - return v_workspace_data_addr; - } - } const vector v_workspace_offset = op_desc->GetWorkspace(); const vector v_workspace_bytes = op_desc->GetWorkspaceBytes(); @@ -466,13 +439,13 @@ vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, TVM_ATTR_NAME_WORKSPACE_TYPE, v_memory_type); for (size_t i = 0; i < v_workspace_bytes.size(); ++i) { if (has_mem_type_attr && v_memory_type[i] == RT_MEMORY_L1) { - v_workspace_data_addr.push_back(reinterpret_cast(v_workspace_offset[i])); + v_workspace_data_addr.push_back(reinterpret_cast(reinterpret_cast(v_workspace_offset[i]))); GELOGI("Fusion: op: %s, GetWorkspaceDataAddrs mem_addr[workspace index %zu]:%p", op_desc->GetName().c_str(), i, reinterpret_cast(reinterpret_cast(v_workspace_offset[i]))); } else { int64_t workspace_offset = v_workspace_offset[i]; int64_t workspace_bytes = v_workspace_bytes[i]; - uint8_t *mem_addr = workspace_bytes == 0 ? nullptr : mem_base + workspace_offset; + uint8_t *mem_addr = workspace_bytes == 0 ? nullptr : model_param.mem_base + workspace_offset; v_workspace_data_addr.push_back(mem_addr); GELOGI("[IMAS]GetWorkspaceDataAddrs graph_%u type[F] name[%s] workspace[%zu] offset[%ld] bytes[%ld] memaddr[%p]", model_param.graph_id, op_desc->GetName().c_str(), i, workspace_offset, workspace_bytes, mem_addr); @@ -482,21 +455,32 @@ vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param return v_workspace_data_addr; } -Status ModelUtils::ConvertVirtualAddressToPhysical(uint8_t *virtual_address, uint64_t size, - uint8_t *&physical_address) { - // Indicates whether use physical address. - const char *use_physical_address = std::getenv("GE_USE_PHYSICAL_ADDRESS"); - if (use_physical_address == nullptr || virtual_address == 0 || size == 0) { - return SUCCESS; - } - - rtError_t ret = rtKernelConfigTransArg(virtual_address, size, 0, reinterpret_cast(&physical_address)); - if (ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rtKernelConfigTransArg failed, ret: 0x%X", ret); - return RT_FAILED; +/// +/// @ingroup ge +/// @brief Get runtime memory address. +/// @return Status +/// +Status ModelUtils::GetRtAddress(const RuntimeParam ¶m, uintptr_t logic_addr, uint8_t *&mem_addr) { + uint8_t *runtime_base_addr = nullptr; + if ((param.logic_mem_base <= logic_addr) && (logic_addr < param.logic_mem_base + param.mem_size)) { + runtime_base_addr = param.mem_base - param.logic_mem_base; + GELOGI("The logic addr:0x%lx is data address, base:0x%lx, size:%lu", logic_addr, param.logic_mem_base, + param.mem_size); + } else if ((param.logic_weight_base <= logic_addr) && (logic_addr < param.logic_weight_base + param.weight_size)) { + runtime_base_addr = param.weight_base - param.logic_weight_base; + GELOGI("The logic addr:0x%lx is weight address, base:0x%lx, size:%lu", logic_addr, param.logic_weight_base, + param.weight_size); + } else if ((param.logic_var_base <= logic_addr) && (logic_addr < param.logic_var_base + param.var_size)) { + runtime_base_addr = param.var_base - param.logic_var_base; + GELOGI("The logic addr:0x%lx is variable address, base:0x%lx, size:%lu", logic_addr, param.logic_var_base, + param.var_size); + } else if (logic_addr != 0) { + mem_addr = nullptr; + GELOGE(PARAM_INVALID, "The logic addr:0x%lx is abnormal", logic_addr); + return PARAM_INVALID; } - GELOGD("virtual_address=%p, physical_address=%p", virtual_address, physical_address); + mem_addr = runtime_base_addr + logic_addr; return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/model_utils.h b/src/ge/graph/load/new_model_manager/model_utils.h index d6afd5c8..8474a987 100644 --- a/src/ge/graph/load/new_model_manager/model_utils.h +++ b/src/ge/graph/load/new_model_manager/model_utils.h @@ -34,78 +34,79 @@ class ModelUtils { ~ModelUtils() = default; /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get input size. /// @return vector /// static vector GetInputSize(ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get output size. /// @return vector /// static vector GetOutputSize(ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get workspace size. /// @return vector /// static vector GetWorkspaceSize(ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get weight size. /// @return vector /// static vector GetWeightSize(ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get weights. /// @return vector /// static vector GetWeights(ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get AiCpuOp Input descriptor. /// @return vector<::tagCcAICPUTensor> /// static vector<::tagCcAICPUTensor> GetInputDescs(ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get AiCpuOp Output descriptor. /// @return vector<::tagCcAICPUTensor> /// static vector<::tagCcAICPUTensor> GetOutputDescs(ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get input data address. /// @return vector /// - static vector GetInputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, - bool need_convert = true); + static vector GetInputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get output data address. /// @return vector /// - static vector GetOutputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, - bool need_convert = true); + static vector GetOutputDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc); /// - /// @ingroup domi_ome + /// @ingroup ge /// @brief Get workspace data address. /// @return vector /// - static vector GetWorkspaceDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc, - bool need_convert = true); + static vector GetWorkspaceDataAddrs(const RuntimeParam &model_param, ConstOpDescPtr op_desc); - static ge::Status ConvertVirtualAddressToPhysical(uint8_t *virtual_address, uint64_t size, - uint8_t *&physical_address); + /// + /// @ingroup ge + /// @brief Get memory runtime base. + /// @return Status + /// + static Status GetRtAddress(const RuntimeParam &model_param, uintptr_t logic_addr, uint8_t *&mem_addr); }; } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc index 077ae827..39f0591d 100644 --- a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.cc @@ -34,7 +34,7 @@ Status EndGraphTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); if (ret != SUCCESS) { GELOGE(ret, "SetStream fail, stream_id:%u", task_def.stream_id()); - return FAILED; + return ret; } model_ = davinci_model->GetRtModelHandle(); @@ -45,7 +45,7 @@ Status EndGraphTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin Status EndGraphTaskInfo::Distribute() { GELOGI("EndGraphTaskInfo Distribute Start."); GE_CHECK_NOTNULL(davinci_model_); - auto all_dump_model = PropertiesManager::Instance().GetAllDumpModel(); + auto all_dump_model = davinci_model_->GetDumpProperties().GetAllDumpModel(); if (all_dump_model.find(ge::DUMP_ALL_MODEL) != all_dump_model.end() || all_dump_model.find(davinci_model_->Name()) != all_dump_model.end() || all_dump_model.find(davinci_model_->OmName()) != all_dump_model.end()) { @@ -53,14 +53,14 @@ Status EndGraphTaskInfo::Distribute() { rtError_t rt_ret = rtEndGraphEx(model_, stream_, kDumpFlag); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rtEndGraphEx failed, ret: 0x%x", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } } else { GELOGI("Start to call rtEndGraph"); rtError_t rt_ret = rtEndGraph(model_, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rtEndGraph failed, ret: 0x%x", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } } @@ -69,7 +69,7 @@ Status EndGraphTaskInfo::Distribute() { rtError_t rt_ret = rtModelGetTaskId(davinci_model_->GetRtModelHandle(), &task_id, &stream_id); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } task_id_ = task_id; stream_id_ = stream_id; @@ -80,5 +80,4 @@ Status EndGraphTaskInfo::Distribute() { } REGISTER_TASK_INFO(RT_MODEL_TASK_MODEL_END_GRAPH, EndGraphTaskInfo); - } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h index 49bef082..82e228e6 100644 --- a/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/end_graph_task_info.h @@ -22,7 +22,7 @@ namespace ge { class EndGraphTaskInfo : public TaskInfo { public: - EndGraphTaskInfo() : model_(0) {} + EndGraphTaskInfo() {} ~EndGraphTaskInfo() override { model_ = nullptr; } @@ -35,10 +35,10 @@ class EndGraphTaskInfo : public TaskInfo { uint32_t GetStreamId() override { return stream_id_; } private: - rtModel_t model_; - DavinciModel *davinci_model_; - uint32_t task_id_; - uint32_t stream_id_; + rtModel_t model_{nullptr}; + DavinciModel *davinci_model_{nullptr}; + uint32_t task_id_{0}; + uint32_t stream_id_{0}; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_END_GRAPH_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/event_record_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/event_record_task_info.cc index edfd8d17..f742118c 100644 --- a/src/ge/graph/load/new_model_manager/task_info/event_record_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/event_record_task_info.cc @@ -49,7 +49,7 @@ Status EventRecordTaskInfo::Distribute() { rtError_t rt_ret = rtEventRecord(event_, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } return SUCCESS; diff --git a/src/ge/graph/load/new_model_manager/task_info/event_wait_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/event_wait_task_info.cc index a8db158d..e8f96b35 100644 --- a/src/ge/graph/load/new_model_manager/task_info/event_wait_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/event_wait_task_info.cc @@ -51,13 +51,13 @@ Status EventWaitTaskInfo::Distribute() { rtError_t rt_ret = rtStreamWaitEvent(stream_, event_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } rt_ret = rtEventReset(event_, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } return SUCCESS; diff --git a/src/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.cc index f3fa7959..9b1ea04a 100644 --- a/src/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/fusion_start_task_info.cc @@ -40,7 +40,7 @@ Status FusionStartTaskInfo::Distribute() { rtError_t rt_ret = rtKernelFusionStart(stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GELOGI("FusionStartTaskInfo Distribute Success."); diff --git a/src/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc index 128fb325..7acbb5b3 100644 --- a/src/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/fusion_stop_task_info.cc @@ -40,7 +40,7 @@ Status FusionStopTaskInfo::Distribute() { rtError_t rt_ret = rtKernelFusionEnd(stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GELOGI("FusionStopTaskInfo Distribute Success."); diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc index 0ee9727a..cb8cfed6 100644 --- a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.cc @@ -42,6 +42,7 @@ HcclTaskInfo::~HcclTaskInfo() { davinci_model_ = nullptr; ops_kernel_store_ = nullptr; max_node_of_hccl_stream_ = 0; + args_ = nullptr; } Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { GELOGI("HcclTaskInfo Init Start."); @@ -60,52 +61,59 @@ Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_m GELOGI("HcclTaskInfo Init, op_index is: %u", op_index); // Get HCCL op - OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); - GE_CHECK_NOTNULL(op_desc); + op_desc_ = davinci_model->GetOpByIndex(op_index); + GE_CHECK_NOTNULL(op_desc_); // Create the kernel hccl infos - CreateKernelHcclInfo(op_desc); + CreateKernelHcclInfo(op_desc_); // Initialize the hccl_type of all kernel hccl info HcomOmeUtil::GetHcclType(task_def, kernel_hccl_infos_); // Only in Horovod scenario should get the inputName and GeShape - ret = HcomOmeUtil::GetHorovodInputs(op_desc, kernel_hccl_infos_); + ret = HcomOmeUtil::GetHorovodInputs(op_desc_, kernel_hccl_infos_); if (ret != SUCCESS) { - GELOGE(FAILED, "davinci_model: GetHorovodInputs fail! domi error: %u", ret); - return FAILED; + GELOGE(ret, "davinci_model: GetHorovodInputs fail! domi error: %u", ret); + return ret; } - Status dmrt = HcomOmeUtil::GetHcclDataType(op_desc, kernel_hccl_infos_); + Status dmrt = HcomOmeUtil::GetHcclDataType(op_desc_, kernel_hccl_infos_); if (dmrt != SUCCESS) { - GELOGE(FAILED, "davinci_model: GetHcomDataType fail! domi error: %u", dmrt); - return FAILED; + GELOGE(dmrt, "davinci_model: GetHcomDataType fail! domi error: %u", dmrt); + return dmrt; } - dmrt = HcomOmeUtil::GetHcclCount(op_desc, kernel_hccl_infos_); + dmrt = HcomOmeUtil::GetHcclCount(op_desc_, kernel_hccl_infos_); if (dmrt != SUCCESS) { - GELOGE(FAILED, "davinci_model: GetHcomCount fail! domi error: %u", dmrt); - return FAILED; + GELOGE(dmrt, "davinci_model: GetHcomCount fail! domi error: %u", dmrt); + return dmrt; } // Only HCOMBROADCAST and HVDCALLBACKBROADCAST need to get the rootId - dmrt = HcomOmeUtil::GetAllRootId(op_desc, kernel_hccl_infos_); + dmrt = HcomOmeUtil::GetAllRootId(op_desc_, kernel_hccl_infos_); if (dmrt != SUCCESS) { - GELOGE(FAILED, "davinci_model: Get rootId fail! domi error: %u", dmrt); - return FAILED; + GELOGE(dmrt, "davinci_model: Get rootId fail! domi error: %u", dmrt); + return dmrt; } - ret = SetAddrs(op_desc, kernel_hccl_infos_); + + // GE's new process: hccl declares the number of streams required, creates a stream by GE, and sends it to hccl + ret = SetFollowStream(op_desc_, davinci_model); if (ret != SUCCESS) { - GELOGE(ret, "Setaddrs Fail."); + GELOGE(ret, "SetStream Fail."); return ret; } - // GE's new process: hccl declares the need for Workspace size, and GE allocates Workspace - ret = SetWorkspace(op_desc, kernel_hccl_infos_); + + if (davinci_model_->IsKnownNode()) { + args_ = davinci_model_->GetCurrentArgsAddr(args_offset_); + GELOGI("Known node %s args addr %p, offset %u.", op_desc_->GetName().c_str(), args_, args_offset_); + } + + ret = SetAddrs(op_desc_, kernel_hccl_infos_); if (ret != SUCCESS) { - GELOGE(ret, "SetWorkspace Fail."); + GELOGE(ret, "Setaddrs Fail."); return ret; } - // GE's new process: hccl declares the number of streams required, creates a stream by GE, and sends it to hccl - ret = SetFollowStream(op_desc, davinci_model); + // GE's new process: hccl declares the need for Workspace size, and GE allocates Workspace + ret = SetWorkspace(op_desc_, kernel_hccl_infos_); if (ret != SUCCESS) { - GELOGE(ret, "SetStream Fail."); + GELOGE(ret, "SetWorkspace Fail."); return ret; } @@ -130,8 +138,8 @@ Status HcclTaskInfo::SetFollowStream(const ge::ConstOpDescPtr &op_desc, DavinciM uint32_t max_task_count; ret = rtGetMaxStreamAndTask(RT_NORMAL_STREAM, &max_stream_count, &max_task_count); if (ret != RT_ERROR_NONE) { - GELOGE(FAILED, "Get max stream and task count by rts failed."); - return FAILED; + GELOGE(RT_FAILED, "Get max stream and task count by rts failed."); + return RT_ERROR_TO_GE_STATUS(ret); } max_node_of_hccl_stream_ = max_task_count / kMaxTaskOfStream; } @@ -145,8 +153,8 @@ Status HcclTaskInfo::SetFollowStream(const ge::ConstOpDescPtr &op_desc, DavinciM ReuseStream(created_stream_num, davinci_model); ret = CreateStream(hccl_stream_num - created_stream_num, davinci_model); if (ret != SUCCESS) { - GELOGE(FAILED, "Create hccl stream failed."); - return FAILED; + GELOGE(RT_FAILED, "Create hccl stream failed."); + return RT_ERROR_TO_GE_STATUS(ret); } } GELOGI("Initialize hccl slave stream success, hcclStreamNum =%ld", hccl_stream_num); @@ -171,14 +179,14 @@ Status HcclTaskInfo::CreateStream(int64_t stream_num, DavinciModel *davinci_mode rtStreamCreateWithFlags(&stream, davinci_model->Priority(), RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } // Create slave stream, inactive by default, activated by hccl rt_ret = rtModelBindStream(davinci_model->GetRtModelHandle(), stream, RT_MODEL_WAIT_ACTIVE_STREAM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); (void)rtStreamDestroy(stream); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GELOGD("hccl_stream addr is=%p", stream); int64_t remain_cap = max_node_of_hccl_stream_ - 1; @@ -209,40 +217,82 @@ Status HcclTaskInfo::Distribute() { GELOGI("HcclTaskInfo Distribute Success."); return SUCCESS; } + +Status HcclTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + GE_CHECK_NOTNULL(davinci_model); + auto hccl_def = task_def.kernel_hccl(); + uint32_t op_index = hccl_def.op_index(); + GELOGI("HcclTaskInfo Init, op_index is: %u", op_index); + // Get HCCL op + auto op_desc = davinci_model->GetOpByIndex(op_index); + GE_CHECK_NOTNULL(op_desc); + GELOGI("Calc opType[%s] args size. Node name is [%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); + // Only need the number of addr to allocate args memory + auto input_size = op_desc->GetInputsSize(); + auto output_size = op_desc->GetOutputsSize(); + auto workspace_size = op_desc->GetWorkspaceBytes().size(); + uint32_t args_size = sizeof(void *) * (input_size + output_size + workspace_size); + args_offset_ = davinci_model->GetTotalArgsSize(); + davinci_model->SetTotalArgsSize(args_size); + GELOGI("Calculate hccl task args , args_size %u, args_offset %u", args_size, args_offset_); + return SUCCESS; +} + +Status HcclTaskInfo::UpdateArgs() { + GELOGI("HcclTaskInfo::UpdateArgs in."); + const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); + input_data_addrs_ = ModelUtils::GetInputDataAddrs(rts_param, op_desc_); + output_data_addrs_ = ModelUtils::GetOutputDataAddrs(rts_param, op_desc_); + workspace_data_addrs_ = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc_); + + vector io_addrs; + io_addrs.insert(io_addrs.end(), input_data_addrs_.begin(), input_data_addrs_.end()); + io_addrs.insert(io_addrs.end(), output_data_addrs_.begin(), output_data_addrs_.end()); + io_addrs.insert(io_addrs.end(), workspace_data_addrs_.begin(), workspace_data_addrs_.end()); + + davinci_model_->SetTotalIOAddrs(io_addrs); + + GELOGI("HcclTaskInfo::UpdateArgs success."); + return SUCCESS; +} + Status HcclTaskInfo::SetAddrs(const std::shared_ptr &op_desc, std::vector &kernel_hccl_infos) { GE_CHECK_NOTNULL(op_desc); - if (HcomOmeUtil::CheckKernelHcclInfo(op_desc, kernel_hccl_infos) != SUCCESS) { - GELOGE(PARAM_INVALID, "HcomOmeUtil:: the number of GETaskKernelHcclInfo is invalid."); - return PARAM_INVALID; - } + GE_CHK_STATUS_RET(HcomOmeUtil::CheckKernelHcclInfo(op_desc, kernel_hccl_infos), + "HcomOmeUtil:: the number of GETaskKernelHcclInfo is invalid."); GELOGI("Set hccl task input output address, node[%s}, type[%s] kernel_hccl_infos.size[%zu].", op_desc->GetName().c_str(), op_desc->GetType().c_str(), kernel_hccl_infos.size()); if (op_desc->GetType() == HVDWAIT) { return SUCCESS; } - domi::Status dmrt; + hcclRedOp_t op_type = HCCL_REP_OP_SUM; GE_CHECK_NOTNULL(davinci_model_); GELOGI("Calc opType[%s] input address before. Node name[%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); - auto input_data_addr_list = ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); - - auto output_data_addr_list = ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + if (!davinci_model_->IsKnownNode()) { + input_data_addrs_ = ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + output_data_addrs_ = ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + } + void *input_data_addr = nullptr; + void *output_data_addr = nullptr; // initialize every kernel_hccl_info inputDataAddr for (size_t i = 0; i < kernel_hccl_infos.size(); i++) { std::string hccl_type = kernel_hccl_infos[i].hccl_type; - void *input_data_addr = input_data_addr_list.empty() ? nullptr : input_data_addr_list[i]; + if (davinci_model_->IsKnownNode()) { + input_data_addr = reinterpret_cast(reinterpret_cast(args_) + i); + output_data_addr = reinterpret_cast(reinterpret_cast(args_) + op_desc->GetInputsSize() + i); + GELOGI("Hccl task info known input addr %p, output addr %p.", input_data_addr, output_data_addr); + } else { + input_data_addr = input_data_addrs_.empty() ? nullptr : input_data_addrs_[i]; + output_data_addr = output_data_addrs_.empty() ? nullptr : output_data_addrs_[i]; + } kernel_hccl_infos[i].inputDataAddr = input_data_addr; - - void *output_data_addr = output_data_addr_list.empty() ? nullptr : output_data_addr_list[i]; if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER) { kernel_hccl_infos[i].outputDataAddr = output_data_addr; } else if (hccl_type == HCOMALLREDUCE || hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE) { - dmrt = HcomOmeUtil::GetHcclOperationType(op_desc, op_type); - if (dmrt != SUCCESS) { - GELOGE(FAILED, "davinci_model: GetHcomOperationType fail! domi error: %u", dmrt); - return FAILED; - } + GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), + "davinci_model: GetHcomOperationType fail!"); kernel_hccl_infos[i].outputDataAddr = output_data_addr; kernel_hccl_infos[i].opType = op_type; } @@ -310,6 +360,7 @@ void HcclTaskInfo::CreateKernelHcclInfo(const ge::ConstOpDescPtr &op_desc) { Status HcclTaskInfo::SetWorkspace(const std::shared_ptr &op_desc, std::vector &kernel_hccl_infos) { GE_CHECK_NOTNULL(op_desc); + GE_CHECK_NOTNULL(davinci_model_); GELOGI("SetWorkspace Node[%s] opType[%s] set workspace.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); uint64_t workspace_mem_size = 0; void *workspace_addr = nullptr; @@ -319,11 +370,12 @@ Status HcclTaskInfo::SetWorkspace(const std::shared_ptr &op_desc, GELOGI("hccl need workSpaceMemSize=%lu", workspace_mem_size_tmp); if (workspace_mem_size_tmp != 0) { workspace_mem_size = workspace_mem_size_tmp; - vector workspace_data_addrs = - ModelUtils::GetWorkspaceDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); - if (!workspace_data_addrs.empty()) { - GELOGI("Get workSpaceAddr"); - workspace_addr = workspace_data_addrs[0]; + if (davinci_model_->IsKnownNode()) { + workspace_addr = reinterpret_cast(reinterpret_cast(args_) + op_desc->GetInputsSize() + + op_desc->GetOutputsSize()); + } else { + workspace_data_addrs_ = ModelUtils::GetWorkspaceDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); + workspace_addr = workspace_data_addrs_.empty() ? nullptr : workspace_data_addrs_[0]; } } } diff --git a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h index bb0a88de..cc3109f4 100644 --- a/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/hccl_task_info.h @@ -34,7 +34,10 @@ class HcclTaskInfo : public TaskInfo { hccl_stream_list_(), ops_kernel_store_(nullptr), private_def_(nullptr), - private_def_len_(0) {} + private_def_len_(0), + op_desc_(nullptr), + args_(nullptr), + args_offset_(0) {} ~HcclTaskInfo() override; @@ -44,6 +47,10 @@ class HcclTaskInfo : public TaskInfo { uint32_t GetTaskID() override { return id_; } + Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + + Status UpdateArgs() override; + private: ge::Status SetAddrs(const std::string &hccl_type, const std::shared_ptr &op); @@ -72,6 +79,12 @@ class HcclTaskInfo : public TaskInfo { static std::mutex hccl_follow_stream_mutex_; static uint32_t max_node_of_hccl_stream_; vector kernel_hccl_infos_; + vector input_data_addrs_; + vector output_data_addrs_; + vector workspace_data_addrs_; + OpDescPtr op_desc_; + void *args_; + uint32_t args_offset_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_HCCL_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc index 79971529..4f72ec36 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.cc @@ -72,13 +72,16 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin auto rt_ret = rtMalloc(&ext_info_addr_, ext_info.size(), RT_MEMORY_HBM); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc ext_info error: 0x%X, size=%zu", rt_ret, ext_info.size()); - return FAILED;) + return RT_ERROR_TO_GE_STATUS(rt_ret);) rt_ret = rtMemcpy(ext_info_addr_, ext_info.size(), ext_info.c_str(), ext_info.size(), RT_MEMCPY_HOST_TO_DEVICE); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy ext_info error: 0x%X, size=%zu", rt_ret, ext_info.size()); - return FAILED;) + return RT_ERROR_TO_GE_STATUS(rt_ret);) } + GELOGI("Node[%s] type[%s] kernel_ext_info size=%zu, ext_info_addr_=%p", op_desc_->GetName().c_str(), + op_desc_->GetType().c_str(), ext_info.size(), ext_info_addr_); + // 2.1 get loop cond variable for tensor array write uint64_t step_id_addr = 0; OpDescPtr step_id_node = davinci_model_->GetVariableOp(NODE_NAME_GLOBAL_STEP); @@ -97,6 +100,11 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuKernel(session_id, davinci_model->Id(), kernel_id) != SUCCESS, GELOGE(FAILED, "CreateAicpuKernel error."); return FAILED;) + // 2.3 Create session + GE_CHECK_NOTNULL(ModelManager::GetInstance()); + GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuSession(session_id) != SUCCESS, + GELOGE(FAILED, "CreateAicpuSession error. session id: %lu", session_id); + return FAILED;) kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); if (davinci_model_->IsKnownNode()) { @@ -105,7 +113,8 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin static_cast(reinterpret_cast(input_output_addr)); void *workspace_base_addr = nullptr; rtError_t rt_ret = rtMalloc(&workspace_base_addr, kernel_ex_def.task_info_size(), RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc error, ret: Ox%X", rt_ret); return FAILED;); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc error, ret: Ox%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret);); rt_ret = rtMemcpy(workspace_base_addr, kernel_ex_def.task_info_size(), kernel_ex_def.task_info().data(), kernel_ex_def.task_info_size(), RT_MEMCPY_HOST_TO_DEVICE); fwk_op_kernel.fwkKernelBase.fwk_kernel.workspaceBaseAddr = @@ -115,20 +124,23 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = reinterpret_cast(ext_info_addr_); rt_ret = rtMalloc(&kernel_buf_, kernel_buf_size_, RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc error: 0x%X", rt_ret); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc error: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret);) rt_ret = rtMemcpy(kernel_buf_, kernel_buf_size_, static_cast(&fwk_op_kernel), kernel_buf_size_, RT_MEMCPY_HOST_TO_DEVICE); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy error, ret: Ox%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret);) GELOGI("KernelExTaskInfo knonw node Init Success."); return SUCCESS; } // 3. Set workspaceaddr, inputOutputDataAddr - if (CopyTaskInfo(kernel_ex_def, rts_param, op_desc) != SUCCESS) { - GELOGE(FAILED, "copy task info to workspace failed."); - return FAILED; + Status ge_ret = CopyTaskInfo(kernel_ex_def, rts_param, op_desc); + if (ge_ret != SUCCESS) { + GELOGE(ge_ret, "copy task info to workspace failed."); + return ge_ret; } const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); @@ -147,14 +159,15 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin auto addrs_size = sizeof(uint64_t) * (io_addrs.size()); if (addrs_size > 0) { rtError_t rt_ret = rtMalloc(&input_output_addr_, addrs_size, RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc error, ret: 0x%X", rt_ret); return RT_FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc error, ret: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret);) rt_ret = rtMemcpy(input_output_addr_, addrs_size, io_addrs.data(), addrs_size, RT_MEMCPY_HOST_TO_DEVICE); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy to input_output_addr_ error: 0x%X", rt_ret); - return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy to input_output_addr_ error: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret);) - if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), - op_desc->GetName())) { + if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), + op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; dump_args_ = input_output_addr_; } @@ -167,25 +180,17 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoLen = ext_info.size(); fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = reinterpret_cast(ext_info_addr_); - // 4. Create session - GE_CHECK_NOTNULL(ModelManager::GetInstance()); - GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuSession(session_id) != SUCCESS, - GELOGE(FAILED, "CreateAicpuSession error. session id: %lu", session_id); - return FAILED;) - // 5. Return result + // 4. Return result rtError_t rt_ret = rtMalloc(&kernel_buf_, sizeof(STR_FWK_OP_KERNEL), RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc error: 0x%X", rt_ret); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc error: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret);) rt_ret = rtMemcpy(kernel_buf_, sizeof(STR_FWK_OP_KERNEL), static_cast(&fwk_op_kernel), sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_DEVICE); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy error, ret: Ox%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret);) - vector virtual_io_addrs; // use virtual address for zero copy key. - const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); - const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); - davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, io_addrs.data(), input_output_addr_, addrs_size, 0); + davinci_model_->SetZeroCopyAddr(op_desc, io_addrs, io_addrs.data(), input_output_addr_, addrs_size, 0); GELOGI("KernelExTaskInfo Init Success. session id: %lu", session_id); return SUCCESS; @@ -207,22 +212,56 @@ Status KernelExTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciMod uint32_t mem_size = sizeof(uint64_t) * mem_length; davinci_model->SetTotalArgsSize(mem_size); GELOGI("kernel task name %s, args_size %u, args_offset %u", op_desc->GetName().c_str(), mem_size, args_offset_); + + // alloc fixed addr + string peer_input_name; + if (AttrUtils::GetStr(op_desc, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name) && !peer_input_name.empty()) { + uint32_t output_index = davinci_model->GetFixedAddrOutputIndex(peer_input_name); + if (output_index > outputs_size) { + GELOGE(FAILED, "The output size[%zu] and output index[%u] are inconsistent.", outputs_size, output_index); + return FAILED; + } + fixed_addr_offset_ = davinci_model->GetFixedAddrsSize(peer_input_name); + auto tensor_desc = op_desc->GetOutputDesc(output_index); + int64_t tensor_size = 0; + GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + davinci_model->SetTotalFixedAddrsSize(peer_input_name, tensor_size); + GELOGI("Calculate stream switch task args , tensor size is %ld, fixed addr offset %ld", tensor_size, + fixed_addr_offset_); + } return SUCCESS; } Status KernelExTaskInfo::UpdateArgs() { GELOGI("KernelExTaskInfo::UpdateArgs in."); const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); - vector io_addrs; vector input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc_); vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc_); - - io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); - io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); - - GE_CHK_STATUS_RET(davinci_model_->UpdateKnownZeroCopyAddr(io_addrs, args_offset_), - "update known node %s zero copy addr failed.", op_desc_->GetName().c_str()); - + vector io_addrs; + if (!op_desc_->HasAttr(ATTR_DYNAMIC_SHAPE_FIXED_ADDR)) { + io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); + } else { + string peer_input_name; + if (AttrUtils::GetStr(op_desc_, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name)) { + uint32_t output_index = davinci_model_->GetFixedAddrOutputIndex(peer_input_name); + if (output_index > output_data_addrs.size()) { + GELOGE(FAILED, "The output data addr size[%zu] and output index[%u] are inconsistent.", + output_data_addrs.size(), output_index); + return FAILED; + } + io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + for (size_t i = 0; i < output_data_addrs.size(); ++i) { + if (i == output_index) { + void *fixed_addr = davinci_model_->GetCurrentFixedAddr(fixed_addr_offset_); + io_addrs.emplace_back(fixed_addr); + continue; + } + io_addrs.emplace_back(output_data_addrs[i]); + } + } + } + davinci_model_->SetTotalIOAddrs(io_addrs); GELOGI("KernelExTaskInfo::UpdateArgs success."); return SUCCESS; } @@ -231,7 +270,7 @@ Status KernelExTaskInfo::CopyTaskInfo(const domi::KernelExDef &kernel_def, const const OpDescPtr &op_desc) { // Userspace copy need virtual address. const vector workspace_data_sizes = ModelUtils::GetWorkspaceSize(op_desc); - const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc, false); + const vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); if (workspace_data_addrs.empty() || workspace_data_sizes.empty()) { GELOGE(FAILED, "Node:%s invalid workspace, addrs is %zu, size is %zu.", op_desc->GetName().c_str(), workspace_data_addrs.size(), workspace_data_sizes.size()); @@ -252,8 +291,8 @@ Status KernelExTaskInfo::CopyTaskInfo(const domi::KernelExDef &kernel_def, const rtError_t rt_ret = rtMemcpy(workspace_data_addrs[0], kernel_def.task_info_size(), kernel_def.task_info().data(), kernel_def.task_info_size(), RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { - GELOGE(FAILED, "rtMemcpy error: 0x%X", rt_ret); - return FAILED; + GELOGE(RT_FAILED, "rtMemcpy error: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret); } return SUCCESS; @@ -264,7 +303,7 @@ Status KernelExTaskInfo::Distribute() { rtError_t rt_ret = rtKernelLaunchEx(kernel_buf_, kernel_buf_size_, dump_flag_, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } if (davinci_model_ == nullptr) { @@ -277,7 +316,7 @@ Status KernelExTaskInfo::Distribute() { rt_ret = rtModelGetTaskId(davinci_model_->GetRtModelHandle(), &task_id, &stream_id); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } task_id_ = task_id; stream_id_ = stream_id; @@ -292,7 +331,7 @@ Status KernelExTaskInfo::Release() { rtError_t rt_ret = rtFree(kernel_buf_); if (rt_ret != RT_ERROR_NONE) { GELOGW("rtFree error, ret: 0x%X", rt_ret); - ret = FAILED; + ret = RT_ERROR_TO_GE_STATUS(rt_ret); } else { kernel_buf_ = nullptr; } @@ -301,7 +340,7 @@ Status KernelExTaskInfo::Release() { rtError_t rt_ret = rtFree(input_output_addr_); if (rt_ret != RT_ERROR_NONE) { GELOGW("rtFree error, ret: 0x%X", rt_ret); - ret = FAILED; + ret = RT_ERROR_TO_GE_STATUS(rt_ret); } else { input_output_addr_ = nullptr; } @@ -310,7 +349,7 @@ Status KernelExTaskInfo::Release() { rtError_t rt_ret = rtFree(ext_info_addr_); if (rt_ret != RT_ERROR_NONE) { GELOGW("rtFree ext_info_addr[%p] error, ret: 0x%X", ext_info_addr_, rt_ret); - ret = FAILED; + ret = RT_ERROR_TO_GE_STATUS(rt_ret); } else { ext_info_addr_ = nullptr; } diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h index ff8f3119..b26a95ac 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_ex_task_info.h @@ -54,6 +54,7 @@ class KernelExTaskInfo : public TaskInfo { auto ret = reinterpret_cast(dump_args_); return ret; } + bool CallSaveDumpInfo() override { return true; }; private: Status CopyTaskInfo(const domi::KernelExDef &kernel_def, const RuntimeParam &rts_param, const OpDescPtr &op_desc); @@ -69,6 +70,7 @@ class KernelExTaskInfo : public TaskInfo { void *dump_args_; OpDescPtr op_desc_ = nullptr; uint32_t args_offset_ = 0; + int64_t fixed_addr_offset_ = 0; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_KERNEL_EX_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc index 7ef65555..da6d05ca 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.cc @@ -47,16 +47,16 @@ const uint32_t kAddrLen = sizeof(void *); namespace ge { KernelTaskInfo::SuperKernelTaskInfo KernelTaskInfo::skt_info_ = { - 0, 0, 0, 0, nullptr, nullptr, {}, {}, RT_KERNEL_DEFAULT, kInvalidGroupKey, 0, nullptr}; + 0, 0, 0, 0, nullptr, nullptr, {}, {}, {}, {}, {}, RT_KERNEL_DEFAULT, kInvalidGroupKey, 0, nullptr}; Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { if (davinci_model == nullptr) { - GELOGE(PARAM_INVALID, "davinci_model is null!"); + GELOGE(PARAM_INVALID, "davinci model is null!"); return PARAM_INVALID; } davinci_model_ = davinci_model; is_l1_fusion_enable_ = davinci_model_->GetL1FusionEnableOption(); - GELOGD("KernelTaskInfo Init Start, ge.enableL1Fusion in davinci model is %d.", is_l1_fusion_enable_); + GELOGD("KernelTaskInfo init start, ge.enableL1Fusion in davinci model is %d.", is_l1_fusion_enable_); Status ret = SetStream(task_def.stream_id(), davinci_model_->GetStreamList()); if (ret != SUCCESS) { @@ -73,7 +73,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci // get opdesc op_desc_ = davinci_model_->GetOpByIndex(context.op_index()); if (op_desc_ == nullptr) { - GELOGE(INTERNAL_ERROR, "Get op_desc failed, index is out of range!"); + GELOGE(INTERNAL_ERROR, "Get op desc failed, index is out of range!"); return INTERNAL_ERROR; } (void)AttrUtils::GetBool(*op_desc_, ATTR_N_BATCH_SPILT, is_n_batch_spilt_); @@ -99,13 +99,13 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci rt_ret = rtGetFunctionByName(const_cast(kernel_def.stub_func().c_str()), &stub_func_); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. stub_func: %s", kernel_def.stub_func().c_str()); - return RT_FAILED;); + return RT_ERROR_TO_GE_STATUS(rt_ret);); } else if (kernel_type_ != cce::ccKernelType::AI_CPU) { rtError_t rt_ret; rt_ret = rtGetFunctionByName(bin_file_key, &stub_func_); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. bin_file_key: %s", bin_file_key); - return RT_FAILED;); + return RT_ERROR_TO_GE_STATUS(rt_ret);); } if (context.origin_op_index_size() > CC_FUSION_OP_MAX) { @@ -138,14 +138,21 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci ret = InitCceTask(kernel_def); } - GELOGD("KernelTaskInfo Init finish, result=%u.", ret); + GELOGD("KernelTaskInfo init finish, result=%u.", ret); return ret; } Status KernelTaskInfo::SaveSKTDumpInfo() { GE_CHECK_NOTNULL(davinci_model_); - davinci_model_->SaveDumpTask(skt_info_.last_task_id, skt_info_.last_stream_id, skt_info_.last_op, - skt_info_.last_dump_args); + if (skt_dump_flag_ == RT_KERNEL_DEFAULT) { + GELOGD("no need save skt dump info"); + return SUCCESS; + } + // all op in super kernel share one taskid and streamid + for (size_t i = 0; i < skt_info_.op_desc_list.size(); i++) { + davinci_model_->SaveDumpTask(skt_info_.last_task_id, skt_info_.last_stream_id, skt_info_.op_desc_list[i], + skt_info_.dump_args_list[i]); + } return SUCCESS; } @@ -187,6 +194,9 @@ Status KernelTaskInfo::SKTFinalize() { GELOGI("SuperKernel Distribute [skt_id:%u]", skt_id_); skt_info_.kernel_list.clear(); skt_info_.arg_list.clear(); + skt_info_.dump_flag_list.clear(); + skt_info_.op_desc_list.clear(); + skt_info_.dump_args_list.clear(); skt_info_.last_stream = nullptr; skt_info_.last_block_dim = 0; skt_info_.last_sm_desc = sm_desc_; @@ -197,6 +207,15 @@ Status KernelTaskInfo::SKTFinalize() { return SUCCESS; } +uint32_t KernelTaskInfo::GetDumpFlag() { + for (auto flag : skt_info_.dump_flag_list) { + if (flag == RT_KERNEL_DUMPFLAG) { + return RT_KERNEL_DUMPFLAG; + } + } + return RT_KERNEL_DEFAULT; +} + Status KernelTaskInfo::SuperKernelLaunch() { if (skt_info_.kernel_list.empty()) { GELOGI("SuperKernelLaunch: Skt_kernel_list has no task, just return"); @@ -206,38 +225,46 @@ Status KernelTaskInfo::SuperKernelLaunch() { auto &skt_kernel_list = skt_info_.kernel_list; auto &skt_arg_list = skt_info_.arg_list; GELOGI("SuperKernelLaunch: Skt_kernel_list size[%d] skt_arg_list[%d]", skt_kernel_list.size(), skt_arg_list.size()); - if (skt_kernel_list.size() == kSKTSingleSize) { + if (skt_kernel_list.size() == kSKTSingleSize && skt_arg_list.size() == kSKTSingleSize) { rt_ret = rtKernelLaunchWithFlag(skt_info_.kernel_list[0], static_cast(skt_info_.last_block_dim), skt_info_.arg_list[0], skt_info_.last_args_size, static_cast(skt_info_.last_sm_desc), skt_info_.last_stream, skt_info_.last_dump_flag); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "SuperKernelLaunch: Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } + call_save_dump_ = true; GE_CHK_STATUS_RET(SKTFinalize(), "Skt finalize failed"); return SUCCESS; } // Create super kernel factory skt::SuperKernelFactory *factory = &skt::SuperKernelFactory::GetInstance(); // Init super kernel factory - if (factory->Init() != SUCCESS) { - GELOGE(RT_FAILED, "SuperKernelLaunch: SuperKernelFactory init failed"); - return RT_FAILED; + Status ge_ret = factory->Init(); + if (ge_ret != SUCCESS) { + GELOGE(ge_ret, "SuperKernelLaunch: SuperKernelFactory init failed"); + return ge_ret; } // Call the fuse API - skt::SuperKernel *superKernel = nullptr; - if (factory->FuseKernels(skt_kernel_list, skt_arg_list, skt_info_.last_block_dim, superKernel) != SUCCESS) { - GELOGE(RT_FAILED, "SuperKernelLaunch: fuse call failed"); - return RT_FAILED; + std::unique_ptr superKernel = nullptr; + ge_ret = factory->FuseKernels(skt_kernel_list, skt_arg_list, skt_info_.last_block_dim, superKernel); + if (ge_ret != SUCCESS) { + GELOGE(ge_ret, "SuperKernelLaunch: fuse call failed"); + return ge_ret; } // Launch a super kernel - if (superKernel->Launch(skt_info_.last_stream, RT_KERNEL_DUMPFLAG) != SUCCESS) { - GELOGE(RT_FAILED, "SuperKernelLaunch: launch failed"); - return RT_FAILED; + skt_dump_flag_ = GetDumpFlag(); + ge_ret = superKernel->Launch(skt_info_.last_stream, skt_dump_flag_); + if (ge_ret != SUCCESS) { + GELOGE(ge_ret, "SuperKernelLaunch: launch failed"); + return ge_ret; } GELOGI("SuperKernelLaunch: success[skt_kernel_list size[%zu] skt_arg_list[%zu]]", skt_kernel_list.size(), skt_arg_list.size()); + // record skt addr for release + superkernel_dev_nav_table_ = superKernel->GetNavTablePtr(); + superkernel_device_args_addr_ = superKernel->GetDeviceArgsPtr(); GE_CHK_STATUS_RET(SKTFinalize(), "Skt finalize failed"); return SUCCESS; } @@ -250,8 +277,11 @@ Status KernelTaskInfo::SaveSuperKernelInfo() { skt_info_.last_args_size = args_size_; skt_info_.last_sm_desc = sm_desc_; skt_info_.last_dump_flag = dump_flag_; + skt_info_.dump_flag_list.push_back(dump_flag_); + skt_info_.op_desc_list.push_back(op_desc_); + skt_info_.dump_args_list.push_back(reinterpret_cast(skt_dump_args_)); skt_info_.last_group_key = group_key_; - skt_info_.last_dump_args = reinterpret_cast(dump_args_); + skt_info_.last_dump_args = reinterpret_cast(skt_dump_args_); skt_info_.last_op = op_desc_; // last node in a stream, just launch if (IsMarkedLastNode()) { @@ -318,23 +348,24 @@ Status KernelTaskInfo::SuperKernelDistribute() { // 1.launch before ret = SuperKernelLaunch(); if (ret != SUCCESS) { - GELOGE(FAILED, "Call SuperKernelLaunch failed!"); - return FAILED; + GELOGE(ret, "Call SuperKernelLaunch failed!"); + return ret; } // 2.launch current rtError_t rt_ret = rtKernelLaunchWithFlag(stub_func_, block_dim_, args_, args_size_, static_cast(sm_desc_), stream_, dump_flag_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return FAILED; + return rt_ret; } + call_save_dump_ = true; UpdateTaskId(); GELOGI("Current Common Task Distribute [taskid:%u]", task_id_); } else { ret = SaveSuperKernelInfo(); if (ret != SUCCESS) { - GELOGE(FAILED, "Call SuperKernelLaunch failed!"); - return FAILED; + GELOGE(ret, "Call SuperKernelLaunch failed!"); + return ret; } GELOGI("Save Current task [block_dim:%u, size:%zu].", block_dim_, skt_info_.kernel_list.size()); } @@ -356,6 +387,7 @@ Status KernelTaskInfo::Distribute() { rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast(so_name_.c_str()), reinterpret_cast(kernel_name_.c_str()), 1, args_, args_size_, nullptr, stream_, dump_flag_); + call_save_dump_ = true; } else { /* default: not skt launch */ GELOGI( @@ -369,11 +401,12 @@ Status KernelTaskInfo::Distribute() { // call rtKernelLaunch for current task rt_ret = rtKernelLaunchWithFlag(stub_func_, block_dim_, args_, args_size_, static_cast(sm_desc_), stream_, dump_flag_); + call_save_dump_ = true; } } if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } // set for task_id_ UpdateTaskId(); @@ -392,13 +425,33 @@ Status KernelTaskInfo::UpdateArgs() { vector workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc_); vector io_addrs; - io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); - io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); - io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); - - GE_CHK_STATUS_RET(davinci_model_->UpdateKnownZeroCopyAddr(io_addrs, args_offset_), - "update known node %s zero copy addr failed.", op_desc_->GetName().c_str()); + if (!op_desc_->HasAttr(ATTR_DYNAMIC_SHAPE_FIXED_ADDR)) { + io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); + io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); + } else { + string peer_input_name; + if (AttrUtils::GetStr(op_desc_, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name)) { + uint32_t output_index = davinci_model_->GetFixedAddrOutputIndex(peer_input_name); + if (output_index > output_data_addrs.size()) { + GELOGE(FAILED, "The output data addr size[%zu] and output index[%u] are inconsistent.", + output_data_addrs.size(), output_index); + return FAILED; + } + io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + for (size_t i = 0; i < output_data_addrs.size(); ++i) { + if (i == output_index) { + void *fixed_addr = davinci_model_->GetCurrentFixedAddr(fixed_addr_offset_); + io_addrs.emplace_back(fixed_addr); + continue; + } + io_addrs.emplace_back(output_data_addrs[i]); + } + io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); + } + } + davinci_model_->SetTotalIOAddrs(io_addrs); GELOGI("KernelTaskInfo::UpdateArgs success."); return SUCCESS; } @@ -407,24 +460,31 @@ Status KernelTaskInfo::Release() { if (davinci_model_ != nullptr && davinci_model_->IsKnownNode()) { return SUCCESS; } - FreeRtMem(&args_); - FreeRtMem(&flowtable_); - FreeRtMem(&custom_info_.input_descs); - FreeRtMem(&custom_info_.input_addrs); - FreeRtMem(&custom_info_.output_descs); - FreeRtMem(&custom_info_.output_addrs); - FreeRtMem(&custom_info_.attr_handle); - FreeRtMem(&aicpu_ext_info_addr_); + rtContext_t ctx = nullptr; + rtError_t ret = rtCtxGetCurrent(&ctx); + + if (ret == RT_ERROR_NONE) { + FreeRtMem(&args_); + FreeRtMem(&superkernel_device_args_addr_); + FreeRtMem(&superkernel_dev_nav_table_); + FreeRtMem(&flowtable_); + FreeRtMem(&custom_info_.input_descs); + FreeRtMem(&custom_info_.input_addrs); + FreeRtMem(&custom_info_.output_descs); + FreeRtMem(&custom_info_.output_addrs); + FreeRtMem(&custom_info_.attr_handle); + FreeRtMem(&aicpu_ext_info_addr_); + } if (ctx_.argsOffset != nullptr) { delete[] ctx_.argsOffset; ctx_.argsOffset = nullptr; } - rtError_t ret = (sm_desc_ != nullptr) ? rtMemFreeManaged(sm_desc_) : RT_ERROR_NONE; + ret = (sm_desc_ != nullptr) ? rtMemFreeManaged(sm_desc_) : RT_ERROR_NONE; if (ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", static_cast(ret)); - return FAILED; + return RT_ERROR_TO_GE_STATUS(ret); } sm_desc_ = nullptr; @@ -454,13 +514,13 @@ Status KernelTaskInfo::UpdateL2Data(const domi::KernelDef &kernel_def) { rtError_t rt_ret = rtMemAllocManaged(&sm_desc_, sm_desc.size(), RT_MEMORY_SPM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } rt_ret = rtMemcpy(sm_desc_, sm_desc.size(), sm_desc.data(), sm_desc.size(), RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } return SUCCESS; @@ -472,6 +532,29 @@ Status KernelTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel args_offset_ = davinci_model->GetTotalArgsSize(); davinci_model->SetTotalArgsSize(args_size); GELOGI("kernel task name , args_size %u, args_offset %u", args_size, args_offset_); + + // get opcontext stored in model + const domi::KernelContext &context = kernel_def.context(); + // get opdesc + op_desc_ = davinci_model->GetOpByIndex(context.op_index()); + GE_CHECK_NOTNULL(op_desc_); + // alloc fixed addr + string peer_input_name; + if (AttrUtils::GetStr(op_desc_, ATTR_DYNAMIC_SHAPE_FIXED_ADDR, peer_input_name) && !peer_input_name.empty()) { + uint32_t output_index = davinci_model->GetFixedAddrOutputIndex(peer_input_name); + if (output_index > op_desc_->GetOutputsSize()) { + GELOGE(FAILED, "The output size[%zu] and output index[%u] are inconsistent.", op_desc_->GetOutputsSize(), + output_index); + return FAILED; + } + fixed_addr_offset_ = davinci_model->GetFixedAddrsSize(peer_input_name); + auto tensor_desc = op_desc_->GetOutputDesc(output_index); + int64_t tensor_size = 0; + GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + davinci_model->SetTotalFixedAddrsSize(peer_input_name, tensor_size); + GELOGI("Calculate stream switch task args , tensor size is %ld, fixed addr offset %ld", tensor_size, + fixed_addr_offset_); + } return SUCCESS; } @@ -514,14 +597,14 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } // copy orign args rt_ret = rtMemcpy(args_, args_size_, kernel_def.args().data(), args_size_, RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } vector args_info(args_size_); errno_t sec_ret = memcpy_s(args_info.data(), args_size_, kernel_def.args().data(), args_size_); @@ -540,7 +623,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne kAddrLen * tensor_device_addrs.size(), RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } sec_ret = memcpy_s(args_info.data() + offset, args_size_ - offset, tensor_device_addrs.data(), kAddrLen * tensor_device_addrs.size()); @@ -548,23 +631,22 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne GELOGE(FAILED, "memcpy failed, ret: %d", sec_ret); return FAILED; } - - if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), - op_desc->GetName())) { + skt_dump_args_ = static_cast(args_) + offset; + if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), + op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; dump_args_ = static_cast(args_) + offset; } + Status ge_ret = UpdateL2Data(kernel_def); // update origin l2 data - if (UpdateL2Data(kernel_def) != SUCCESS) { - return RT_FAILED; + if (ge_ret != SUCCESS) { + return ge_ret; } vector virtual_io_addrs; // use virtual address for zero copy key. - const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); - const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); + virtual_io_addrs.insert(virtual_io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); + virtual_io_addrs.insert(virtual_io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, args_info.data(), args_, args_size_, offset); GELOGD("Do InitTVMTask end"); @@ -602,7 +684,6 @@ Status KernelTaskInfo::InitAICPUCustomTask(uint32_t op_index, const domi::Kernel const std::vector output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); Status ret = StoreInputOutputTensor(input_data_addrs, output_data_addrs, ModelUtils::GetInputDescs(op_desc), ModelUtils::GetOutputDescs(op_desc)); - if (ret != SUCCESS) { GELOGE(ret, "StoreInputOutputTensor Failed"); return ret; @@ -624,13 +705,13 @@ Status KernelTaskInfo::InitAICPUCustomTask(uint32_t op_index, const domi::Kernel rtError_t rt_ret = rtMalloc(&custom_info_.attr_handle, op_attr_size, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } rt_ret = rtMemcpy(custom_info_.attr_handle, op_attr_size, buffer.GetData(), op_attr_size, RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } // args @@ -657,21 +738,19 @@ Status KernelTaskInfo::InitAICPUCustomTask(uint32_t op_index, const domi::Kernel rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } rt_ret = rtMemcpy(args_, kernel_def.args_size(), kernel_def.args().data(), kernel_def.args_size(), RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } - const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); - const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); - davinci_model_->SetZeroCopyAddr(op_desc, virtual_in_addrs, input_data_addrs.data(), custom_info_.input_addrs, - virtual_in_addrs.size() * kAddrLen, 0); - davinci_model_->SetZeroCopyAddr(op_desc, virtual_out_addrs, output_data_addrs.data(), custom_info_.output_addrs, + davinci_model_->SetZeroCopyAddr(op_desc, input_data_addrs, input_data_addrs.data(), custom_info_.input_addrs, + input_data_addrs.size() * kAddrLen, 0); + davinci_model_->SetZeroCopyAddr(op_desc, output_data_addrs, output_data_addrs.data(), custom_info_.output_addrs, output_data_addrs.size() * kAddrLen, 0); return SUCCESS; } @@ -712,7 +791,8 @@ Status KernelTaskInfo::InitCceTask(const domi::KernelDef &kernel_def) { ctx_.genVariableBaseSize = davinci_model_->TotalVarMemSize(); ctx_.l2ctrlSize = sm_contrl_size; - if (UpdateCceArgs(sm_desc, flowtable, kernel_def) != SUCCESS) { + ret = UpdateCceArgs(sm_desc, flowtable, kernel_def); + if (ret != SUCCESS) { GELOGE(ret, "update cce args fail"); return ret; } @@ -728,7 +808,7 @@ Status KernelTaskInfo::InitCceTask(const domi::KernelDef &kernel_def) { rtError_t rt_ret = rtMalloc(&args_, kernel_def.args_size(), RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "cce task physical memory.", kernel_def.args_size()) @@ -736,7 +816,7 @@ Status KernelTaskInfo::InitCceTask(const domi::KernelDef &kernel_def) { rtMemcpy(args_, kernel_def.args_size(), kernel_def.args().data(), kernel_def.args_size(), RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } // L2 @@ -744,13 +824,13 @@ Status KernelTaskInfo::InitCceTask(const domi::KernelDef &kernel_def) { rt_ret = rtMemAllocManaged(&sm_desc_, sm_desc.size(), RT_MEMORY_SPM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } rt_ret = rtMemcpy(sm_desc_, sm_desc.size(), sm_desc.data(), sm_desc.size(), RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } } return SUCCESS; @@ -801,6 +881,9 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k GELOGE(init_ret, "Init aicpu task ext info failed, ext_info size=%zu", ext_info.size()); return init_ret; } + GELOGI("Node[%s] type[%s] kernel_ext_info size=%zu, aicpu_ext_info_addr_=%p", op_desc_->GetName().c_str(), + op_desc_->GetType().c_str(), ext_info.size(), aicpu_ext_info_addr_); + aicpu_param_head->extInfoAddr = reinterpret_cast(aicpu_ext_info_addr_); aicpu_param_head->extInfoLength = reinterpret_cast(ext_info.size()); @@ -808,7 +891,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k rtError_t rt_ret = rtMalloc(static_cast(&args_), args_size_, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api(rtMalloc) failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "cce task physical memory.", args_size_) @@ -816,22 +899,16 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k rt_ret = rtMemcpy(args_, args_size_, args_addr.get(), args_size_, RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api(rtMemcpy) failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } - if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), - op_desc->GetName())) { + if (davinci_model_->GetDumpProperties().IsLayerNeedDump(davinci_model_->Name(), davinci_model_->OmName(), + op_desc->GetName())) { dump_flag_ = RT_KERNEL_DUMPFLAG; dump_args_ = static_cast(args_) + sizeof(aicpu::AicpuParamHead); } - vector virtual_io_addrs; // use virtual address for zero copy key. - const vector virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); - const vector virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); - virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); - davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, args_addr.get(), args_, args_size_, - sizeof(aicpu::AicpuParamHead)); + davinci_model_->SetZeroCopyAddr(op_desc, io_addrs, args_addr.get(), args_, args_size_, sizeof(aicpu::AicpuParamHead)); return SUCCESS; } @@ -843,12 +920,12 @@ Status KernelTaskInfo::InitAicpuTaskExtInfo(const std::string &ext_info) { auto rt_ret = rtMalloc(&aicpu_ext_info_addr_, ext_info.size(), RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "rtMalloc ext_info error: 0x%X, size=%zu", rt_ret, ext_info.size()); - return FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } rt_ret = rtMemcpy(aicpu_ext_info_addr_, ext_info.size(), ext_info.c_str(), ext_info.size(), RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "rtMemcpy ext_info error: 0x%X, size=%zu", rt_ret, ext_info.size()); - return FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } return SUCCESS; @@ -865,7 +942,7 @@ Status KernelTaskInfo::StoreInputOutputTensor(const std::vector &input_d rtError_t rt_ret = rtMalloc(&custom_info_.input_descs, sizeof(opTensor_t) * input_size, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } for (std::size_t i = 0; i < input_size; ++i) { @@ -873,7 +950,7 @@ Status KernelTaskInfo::StoreInputOutputTensor(const std::vector &input_d const_cast(&input_descs[i]), sizeof(opTensor_t), RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } } @@ -881,7 +958,7 @@ Status KernelTaskInfo::StoreInputOutputTensor(const std::vector &input_d rt_ret = rtMalloc(&custom_info_.input_addrs, sizeof(opTensor_t) * input_size, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } if (!input_data_addrs.empty()) { @@ -889,7 +966,7 @@ Status KernelTaskInfo::StoreInputOutputTensor(const std::vector &input_d RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } } @@ -897,14 +974,14 @@ Status KernelTaskInfo::StoreInputOutputTensor(const std::vector &input_d rt_ret = rtMalloc(&custom_info_.output_descs, sizeof(opTensor_t) * output_size, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } for (std::size_t i = 0; i < output_size; ++i) { rt_ret = rtMemcpy(static_cast(custom_info_.output_descs) + i, sizeof(opTensor_t), const_cast(&input_descs[i]), sizeof(opTensor_t), RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } } @@ -912,7 +989,7 @@ Status KernelTaskInfo::StoreInputOutputTensor(const std::vector &input_d rt_ret = rtMalloc(&custom_info_.output_addrs, sizeof(opTensor_t) * output_size, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } if (!output_data_addrs.empty()) { @@ -920,7 +997,7 @@ Status KernelTaskInfo::StoreInputOutputTensor(const std::vector &input_d RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } } @@ -982,8 +1059,8 @@ Status KernelTaskInfo::UpdateCceArgs(std::string &sm_desc, std::string &flowtabl Status status = CceUpdateKernelArgs(context, data_base_addr, weight_base_addr, var_base_addr, sm_desc, flowtable, kernel_def); if (status != SUCCESS) { - GELOGE(FAILED, "Call cce api failed"); - return FAILED; + GELOGE(status, "Call cce api failed"); + return status; } return SUCCESS; } @@ -1049,14 +1126,14 @@ Status KernelTaskInfo::SetFlowtable(std::string &flowtable, const domi::KernelDe rtError_t rt_ret = rtMalloc(&flowtable_, flowtable.size(), RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "flowtable refresh of cce scence.", flowtable.size()) rt_ret = rtMemcpy(flowtable_, flowtable.size(), flowtable.data(), flowtable.size(), RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } // modify flowtable addr in args diff --git a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h index 41ed5728..cc8edc07 100644 --- a/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/kernel_task_info.h @@ -61,6 +61,8 @@ class KernelTaskInfo : public TaskInfo { sm_desc_ = nullptr; flowtable_ = nullptr; args_ = nullptr; + superkernel_device_args_addr_ = nullptr; + superkernel_dev_nav_table_ = nullptr; } Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; @@ -88,6 +90,8 @@ class KernelTaskInfo : public TaskInfo { uint32_t GetSktTaskID() override { return skt_id_; } + bool CallSaveDumpInfo() override { return call_save_dump_; }; + cce::ccOpContext ctx_; FusionOpInfo fusion_op_info_; @@ -130,6 +134,7 @@ class KernelTaskInfo : public TaskInfo { void UpdateSKTTaskId(); Status SKTFinalize(); Status SuperKernelLaunch(); + uint32_t GetDumpFlag(); Status SaveSuperKernelInfo(); bool IsMarkedLastNode(); bool IsMarkedFirstNode(); @@ -153,17 +158,23 @@ class KernelTaskInfo : public TaskInfo { OpDescPtr op_desc_; DavinciModel *davinci_model_; uint32_t args_offset_ = 0; + int64_t fixed_addr_offset_ = 0; + bool call_save_dump_ = false; // aicpu ext_info device mem void *aicpu_ext_info_addr_ = nullptr; // For super kernel + void *skt_dump_args_ = nullptr; uint32_t skt_id_; std::string stub_func_name_; bool is_l1_fusion_enable_; bool is_n_batch_spilt_; int64_t group_key_; bool has_group_key_; + uint32_t skt_dump_flag_ = RT_KERNEL_DEFAULT; + void *superkernel_device_args_addr_ = nullptr; + void *superkernel_dev_nav_table_ = nullptr; struct AICPUCustomInfo { void *input_descs = nullptr; @@ -183,6 +194,9 @@ class KernelTaskInfo : public TaskInfo { void *last_sm_desc; std::vector kernel_list; std::vector arg_list; + std::vector dump_flag_list; + std::vector op_desc_list; + std::vector dump_args_list; uint32_t last_dump_flag; int64_t last_group_key; uintptr_t last_dump_args; diff --git a/src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc index c157b1df..75f6c121 100644 --- a/src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/label_goto_ex_task_info.cc @@ -59,7 +59,7 @@ Status LabelGotoExTaskInfo::Distribute() { rtError_t rt_ret = rtLabelGotoEx(label_, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GELOGI("LabelGotoExTaskInfo Distribute Success."); diff --git a/src/ge/graph/load/new_model_manager/task_info/label_set_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/label_set_task_info.cc index e8888eef..de6a1d65 100644 --- a/src/ge/graph/load/new_model_manager/task_info/label_set_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/label_set_task_info.cc @@ -59,7 +59,7 @@ Status LabelSetTaskInfo::Distribute() { rtError_t rt_ret = rtLabelSet(label_, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GELOGI("LabelSetTaskInfo Distribute Success."); diff --git a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc index 818307eb..efefd3e2 100644 --- a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.cc @@ -16,8 +16,8 @@ #include "graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h" -#include "graph/load/new_model_manager/davinci_model.h" #include "graph/debug/ge_attr_define.h" +#include "graph/load/new_model_manager/davinci_model.h" namespace ge { constexpr uint8_t kLabelSwitchIndexNum = 1; @@ -59,7 +59,13 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo op_desc->GetName().c_str(), input_data_addr.size(), kLabelSwitchIndexNum); return INTERNAL_ERROR; } - index_value_ = input_data_addr[0]; + + if (davinci_model->IsKnownNode()) { + index_value_ = davinci_model->GetCurrentFixedAddr(fixed_addr_offset_); + } else { + index_value_ = input_data_addr[0]; + } + davinci_model->DisableZeroCopy(index_value_); std::vector label_idx_list; @@ -92,13 +98,13 @@ Status LabelSwitchByIndexTaskInfo::Init(const domi::TaskDef &task_def, DavinciMo rtError_t rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } rt_ret = rtLabelListCpy(label_list_.data(), label_list_.size(), args_, args_size_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GELOGI("LabelSwitchByIndexTaskInfo Init success, branch max: %u.", branch_max_); @@ -124,5 +130,28 @@ Status LabelSwitchByIndexTaskInfo::Distribute() { return SUCCESS; } +Status LabelSwitchByIndexTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + GE_CHECK_NOTNULL(davinci_model); + auto label_switch = task_def.label_switch_by_index(); + uint32_t op_index = label_switch.op_index(); + GELOGI("Begin to calculate args, op_index is: %u", op_index); + auto op_desc = davinci_model->GetOpByIndex(op_index); + GE_CHECK_NOTNULL(op_desc); + GELOGI("Calc opType[%s] args size. Node name is [%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); + if (op_desc->GetInputsSize() != kLabelSwitchIndexNum) { + GELOGE(FAILED, "Label switch op only have one data input. Now input size is %zu", op_desc->GetInputsSize()); + return FAILED; + } + string input_tensor_name = op_desc->GetInputNameByIndex(0); + fixed_addr_offset_ = davinci_model->GetFixedAddrsSize(input_tensor_name); + auto tensor_desc = op_desc->GetInputDesc(0); + int64_t tensor_size = 0; + GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + davinci_model->SetTotalFixedAddrsSize(input_tensor_name, tensor_size); + GELOGI("Calculate stream switchn task args , tensor_size %ld, fixed_addr_offset %ld", tensor_size, + fixed_addr_offset_); + return SUCCESS; +} + REGISTER_TASK_INFO(RT_MODEL_TASK_STREAM_LABEL_SWITCH_BY_INDEX, LabelSwitchByIndexTaskInfo); } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h index 1a644736..4cb39c95 100644 --- a/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/label_switch_by_index_task_info.h @@ -22,7 +22,8 @@ namespace ge { class LabelSwitchByIndexTaskInfo : public TaskInfo { public: - LabelSwitchByIndexTaskInfo() : index_value_(nullptr), branch_max_(0), args_(nullptr), args_size_(0) {} + LabelSwitchByIndexTaskInfo() + : index_value_(nullptr), branch_max_(0), args_(nullptr), args_size_(0), fixed_addr_offset_(0) {} ~LabelSwitchByIndexTaskInfo() override; @@ -30,13 +31,15 @@ class LabelSwitchByIndexTaskInfo : public TaskInfo { Status Distribute() override; + Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + private: void *index_value_; // switch index input. uint32_t branch_max_; // max branch count. void *args_; // label info memory. uint32_t args_size_; // label info length. - std::vector label_list_; + int64_t fixed_addr_offset_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_LABEL_SWITCH_BY_INDEX_TASK_INFO_H_ \ No newline at end of file diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc index e9d99189..8cac9f82 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc @@ -19,11 +19,15 @@ #include "framework/common/debug/ge_log.h" #include "graph/load/new_model_manager/davinci_model.h" +namespace { +const uint32_t kAlignBytes = 64; +} + namespace ge { Status MemcpyAddrAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { - GELOGI("MemcpyAddrAsyncTaskInfo Init Start."); + GELOGI("MemcpyAddrAsyncTaskInfo Init Start"); if (davinci_model == nullptr) { - GELOGE(PARAM_INVALID, "davinci_model is null!"); + GELOGE(PARAM_INVALID, "davinci_model is null"); return PARAM_INVALID; } @@ -32,120 +36,67 @@ Status MemcpyAddrAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel return ret; } - auto memcpy_async_def = task_def.memcpy_async(); - uint32_t op_index = memcpy_async_def.op_index(); - OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); + const auto &memcpy_async = task_def.memcpy_async(); + OpDescPtr op_desc = davinci_model->GetOpByIndex(memcpy_async.op_index()); if (op_desc == nullptr) { - GELOGE(INTERNAL_ERROR, "Init MemcpyAddrAsyncTaskInfo error, index is out of range!"); + GELOGE(INTERNAL_ERROR, "Task op index:%u out of range", memcpy_async.op_index()); return INTERNAL_ERROR; } - uint64_t logic_dst = memcpy_async_def.dst(); - uint64_t logic_src = memcpy_async_def.src(); - - dst_max_ = memcpy_async_def.dst_max(); - - uint64_t update_base_addr = 0; - ret = GetUpdateBaseAddr(davinci_model, logic_src, update_base_addr); + ret = ModelUtils::GetRtAddress(davinci_model->GetRuntimeParam(), memcpy_async.src(), src_); if (ret != SUCCESS) { return ret; } - src_ = reinterpret_cast(update_base_addr + logic_src); - if (src_ == nullptr) { - GELOGE(PARAM_INVALID, "src_ is null!"); - return PARAM_INVALID; - } - uint64_t mem_base = reinterpret_cast(davinci_model->MemBase()); - uint64_t logic_mem_base = davinci_model->GetRtBaseAddr(); - dst_ = reinterpret_cast(reinterpret_cast(mem_base + (logic_dst - logic_mem_base))); - if (dst_ == nullptr) { - GELOGE(PARAM_INVALID, "dst_ is null!"); - return PARAM_INVALID; + ret = ModelUtils::GetRtAddress(davinci_model->GetRuntimeParam(), memcpy_async.dst(), dst_); + if (ret != SUCCESS) { + return ret; } vector io_addrs; io_addrs.emplace_back(src_); io_addrs.emplace_back(dst_); - count_ = memcpy_async_def.count(); - kind_ = memcpy_async_def.kind(); - // malloc args memory size_t args_size = sizeof(void *) * io_addrs.size(); - rtError_t rt_ret = rtMalloc(&args_, args_size, RT_MEMORY_HBM); + rtError_t rt_ret = rtMalloc(&args_, args_size + kAlignBytes, RT_MEMORY_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } + args_align_ = reinterpret_cast((reinterpret_cast(args_) / kAlignBytes + 1) * kAlignBytes); // copy orign src/dst - GELOGI("src_args:%p, destMax:%zu, src_:%p, dst_args:%p, dst_:%p, count=%zu", args_, args_size, src_, - static_cast(args_) + args_size, dst_, io_addrs.size()); - rt_ret = rtMemcpy(args_, args_size, io_addrs.data(), args_size, RT_MEMCPY_HOST_TO_DEVICE); + GELOGI("src_args:%p, destMax:%zu, src_:%p, dst_args:%p, dst_:%p, count=%zu", args_align_, args_size, src_, + static_cast(args_align_) + args_size, dst_, io_addrs.size()); + rt_ret = rtMemcpy(args_align_, args_size, io_addrs.data(), args_size, RT_MEMCPY_HOST_TO_DEVICE); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api for src failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } - // Just dest addr need zero copy. - davinci_model->SetZeroCopyAddr(op_desc, {dst_}, io_addrs.data(), args_, args_size, sizeof(void *)); - - GELOGI("InitMemcpyAddrAsyncTaskInfo, logic_src:%p, logic_dst:%p, src:%p, dst:%p, src_args:%p, dst_args:%p", - reinterpret_cast(reinterpret_cast(logic_src)), - reinterpret_cast(reinterpret_cast(logic_dst)), src_, dst_, args_, - reinterpret_cast(reinterpret_cast(args_) + args_size)); + count_ = memcpy_async.count(); + kind_ = memcpy_async.kind(); + dst_max_ = memcpy_async.dst_max(); + GELOGI("InitMemcpyAddrAsyncTaskInfo, logic[0x%lx, 0x%lx], src:%p, dst:%p, max:%lu, count:%lu, args:%p, size:%zu", + memcpy_async.src(), memcpy_async.dst(), src_, dst_, dst_max_, count_, args_align_, args_size); + davinci_model->SetZeroCopyAddr(op_desc, io_addrs, io_addrs.data(), args_align_, args_size, 0); return SUCCESS; } Status MemcpyAddrAsyncTaskInfo::Distribute() { - GELOGI("MemcpyAddrAsyncTaskInfo Distribute Start."); - GELOGI("Distribute MemcpyAddrAsync, dst_max:%lu, count:%lu, kind:%u.", dst_max_, count_, kind_); + GELOGI("MemcpyAddrAsyncTaskInfo Distribute Start, dst_max:%lu, count:%lu, kind:%u", dst_max_, count_, kind_); - rtError_t rt_ret = rtMemcpyAsync(reinterpret_cast(reinterpret_cast(args_) + sizeof(void *)), - dst_max_, args_, count_, static_cast(kind_), stream_); + rtError_t rt_ret = rtMemcpyAsync(reinterpret_cast(reinterpret_cast(args_align_) + sizeof(void *)), + dst_max_, args_align_, count_, static_cast(kind_), stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } return SUCCESS; } -Status MemcpyAddrAsyncTaskInfo::GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, - uint64_t &base_addr) { - GE_CHECK_NOTNULL(davinci_model); - uint64_t data_base_addr = - static_cast(reinterpret_cast(davinci_model->MemBase())) - davinci_model->GetRtBaseAddr(); - uint64_t weight_base_addr = static_cast(reinterpret_cast(davinci_model->WeightsMemBase())) - - davinci_model->GetRtWeightAddr(); - uint64_t var_base_addr = - static_cast(reinterpret_cast(davinci_model->VarMemBase())) - davinci_model->GetRtVarAddr(); - - uint64_t data_base_addr_start = davinci_model->GetRtBaseAddr(); - uint64_t data_base_addr_end = davinci_model->GetRtBaseAddr() + davinci_model->TotalMemSize(); - uint64_t wight_base_addr_start = davinci_model->GetRtWeightAddr(); - uint64_t wight_base_addr_end = davinci_model->GetRtWeightAddr() + davinci_model->TotalWeightsMemSize(); - uint64_t varible_base_addr_start = davinci_model->GetRtVarAddr(); - uint64_t varible_base_addr_end = davinci_model->GetRtVarAddr() + davinci_model->TotalVarMemSize(); - - if ((data_base_addr_start <= update_addr) && (update_addr <= data_base_addr_end)) { - base_addr = data_base_addr; - GELOGI("The update_addr is data address."); - } else if ((wight_base_addr_start <= update_addr) && (update_addr <= wight_base_addr_end)) { - base_addr = weight_base_addr; - GELOGI("The update_addr is weight address."); - } else if ((varible_base_addr_start <= update_addr) && (update_addr <= varible_base_addr_end)) { - base_addr = var_base_addr; - GELOGI("The update_addr is variable address."); - } else if (update_addr != 0) { - base_addr = 0; - GELOGE(PARAM_INVALID, "The update_addr is abnormal."); - return PARAM_INVALID; - } - return SUCCESS; -} - REGISTER_TASK_INFO(RT_MODEL_TASK_MEMCPY_ADDR_ASYNC, MemcpyAddrAsyncTaskInfo); } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h index 9252e43a..90aad9b7 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h @@ -16,12 +16,14 @@ #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ + #include "graph/load/new_model_manager/task_info/task_info.h" namespace ge { class MemcpyAddrAsyncTaskInfo : public TaskInfo { public: - MemcpyAddrAsyncTaskInfo() : dst_(nullptr), dst_max_(0), src_(nullptr), args_(nullptr), count_(0), kind_(0) {} + MemcpyAddrAsyncTaskInfo() + : dst_(nullptr), dst_max_(0), src_(nullptr), args_(nullptr), args_align_(nullptr), count_(0), kind_(0) {} ~MemcpyAddrAsyncTaskInfo() override { src_ = nullptr; @@ -32,9 +34,8 @@ class MemcpyAddrAsyncTaskInfo : public TaskInfo { if (ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); } + args_ = nullptr; } - - args_ = nullptr; } Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; @@ -42,12 +43,11 @@ class MemcpyAddrAsyncTaskInfo : public TaskInfo { Status Distribute() override; private: - Status GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, uint64_t &base_addr); - - void *dst_; + uint8_t *dst_; uint64_t dst_max_; - void *src_; + uint8_t *src_; void *args_; + void *args_align_; uint64_t count_; uint32_t kind_; }; diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc index 82eabe69..1cc18a85 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.cc @@ -21,9 +21,9 @@ namespace ge { Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { - GELOGI("MemcpyAsyncTaskInfo Init Start."); + GELOGI("MemcpyAsyncTaskInfo Init Start"); if (davinci_model == nullptr) { - GELOGE(PARAM_INVALID, "davinci_model is null!"); + GELOGE(PARAM_INVALID, "davinci_model is null"); return PARAM_INVALID; } @@ -32,76 +32,79 @@ Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da return ret; } - auto memcpy_async_def = task_def.memcpy_async(); - uint64_t logic_dst = memcpy_async_def.dst(); - uint64_t logic_src = memcpy_async_def.src(); - - dst_max_ = memcpy_async_def.dst_max(); - - uint64_t update_base_addr = 0; - ret = GetUpdateBaseAddr(davinci_model, logic_src, update_base_addr); + memcpy_async = task_def.memcpy_async(); + count_ = memcpy_async.count(); + kind_ = memcpy_async.kind(); + dst_max_ = memcpy_async.dst_max(); + if (davinci_model->IsKnownNode()) { + src_ = reinterpret_cast(davinci_model_->GetCurrentArgsAddr(args_offset_)); + dst_ = reinterpret_cast(reinterpret_cast(src_) + sizeof(void *)); + // for zero copy + kind_ = RT_MEMCPY_ADDR_DEVICE_TO_DEVICE; + GELOGI("MemcpyAsyncTaskInfo src_ %p, dst_ %p, args_offset %u.", src_, dst_, args_offset_); + return SUCCESS; + } + ret = ModelUtils::GetRtAddress(davinci_model->GetRuntimeParam(), memcpy_async.src(), src_); if (ret != SUCCESS) { return ret; } - src_ = reinterpret_cast(update_base_addr + logic_src); - davinci_model->DisableZeroCopy(src_); - uint64_t mem_base = reinterpret_cast(davinci_model->MemBase()); - uint64_t logic_mem_base = davinci_model->GetRtBaseAddr(); - dst_ = reinterpret_cast(mem_base + (logic_dst - logic_mem_base)); + ret = ModelUtils::GetRtAddress(davinci_model->GetRuntimeParam(), memcpy_async.dst(), dst_); + if (ret != SUCCESS) { + return ret; + } - count_ = memcpy_async_def.count(); - kind_ = memcpy_async_def.kind(); - GELOGI("MemcpyAsyncTaskInfo Init Success, logic_src:%p, logic_dst:%p, src:%p, dst:%p", - reinterpret_cast(reinterpret_cast(logic_src)), - reinterpret_cast(reinterpret_cast(logic_dst)), src_, dst_); + GELOGI("MemcpyAsyncTaskInfo Init Success, logic[0x%lx, 0x%lx], src:%p, dst:%p, max:%lu, count:%lu", + memcpy_async.src(), memcpy_async.dst(), src_, dst_, dst_max_, count_); + davinci_model->DisableZeroCopy(src_); + davinci_model->DisableZeroCopy(dst_); return SUCCESS; } Status MemcpyAsyncTaskInfo::Distribute() { - GELOGI("MemcpyAsyncTaskInfo Distribute Start. dst_max:%lu, count:%lu, kind:%u.", dst_max_, count_, kind_); + GELOGI("MemcpyAsyncTaskInfo Distribute Start. dst_max:%lu, count:%lu, kind:%u", dst_max_, count_, kind_); rtError_t rt_ret = rtMemcpyAsync(dst_, dst_max_, src_, count_, static_cast(kind_), stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } - GELOGI("MemcpyAsyncTaskInfo Distribute Success."); + GELOGI("MemcpyAsyncTaskInfo Distribute Success"); return SUCCESS; } -Status MemcpyAsyncTaskInfo::GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, uint64_t &base_addr) { - GE_CHECK_NOTNULL(davinci_model); - uint64_t data_base_addr = - reinterpret_cast(reinterpret_cast(davinci_model->MemBase())) - davinci_model->GetRtBaseAddr(); - uint64_t weight_base_addr = reinterpret_cast(reinterpret_cast(davinci_model->WeightsMemBase())) - - davinci_model->GetRtWeightAddr(); - uint64_t var_base_addr = reinterpret_cast(reinterpret_cast(davinci_model->VarMemBase())) - - davinci_model->GetRtVarAddr(); - - uint64_t data_base_addr_start = davinci_model->GetRtBaseAddr(); - uint64_t data_base_addr_end = davinci_model->GetRtBaseAddr() + davinci_model->TotalMemSize(); - uint64_t wight_base_addr_start = davinci_model->GetRtWeightAddr(); - uint64_t wight_base_addr_end = davinci_model->GetRtWeightAddr() + davinci_model->TotalWeightsMemSize(); - uint64_t varible_base_addr_start = davinci_model->GetRtVarAddr(); - uint64_t varible_base_addr_end = davinci_model->GetRtVarAddr() + davinci_model->TotalVarMemSize(); - - if ((data_base_addr_start <= update_addr) && (update_addr <= data_base_addr_end)) { - base_addr = data_base_addr; - GELOGI("The update_addr is data address."); - } else if ((wight_base_addr_start <= update_addr) && (update_addr <= wight_base_addr_end)) { - base_addr = weight_base_addr; - GELOGI("The update_addr is weight address."); - } else if ((varible_base_addr_start <= update_addr) && (update_addr <= varible_base_addr_end)) { - base_addr = var_base_addr; - GELOGI("The update_addr is variable address."); - } else if (update_addr != 0) { - base_addr = 0; - GELOGE(PARAM_INVALID, "The update_addr is abnormal."); - return PARAM_INVALID; +Status MemcpyAsyncTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + // the num of src and dst size is 2 + uint32_t args_size = sizeof(void *) * 2; + args_offset_ = davinci_model->GetTotalArgsSize(); + davinci_model->SetTotalArgsSize(args_size); + davinci_model_ = davinci_model; + GELOGI("MemcpyAsyncTaskInfo kernel args_size %u, args_offset %u", args_size, args_offset_); + return SUCCESS; +} + +Status MemcpyAsyncTaskInfo::UpdateArgs() { + GELOGI("MemcpyAsyncTaskInfo::UpdateArgs in."); + GE_CHECK_NOTNULL(davinci_model_); + Status ret = ModelUtils::GetRtAddress(davinci_model_->GetRuntimeParam(), memcpy_async.src(), src_); + if (ret != SUCCESS) { + return ret; + } + + ret = ModelUtils::GetRtAddress(davinci_model_->GetRuntimeParam(), memcpy_async.dst(), dst_); + if (ret != SUCCESS) { + return ret; } + + vector io_addrs; + io_addrs.emplace_back(reinterpret_cast(src_)); + io_addrs.emplace_back(reinterpret_cast(dst_)); + + davinci_model_->SetTotalIOAddrs(io_addrs); + + GELOGI("MemcpyAsyncTaskInfo::UpdateArgs success."); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h index 02872f34..c3daa862 100644 --- a/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/memcpy_async_task_info.h @@ -16,6 +16,7 @@ #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ASYNC_TASK_INFO_H_ #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ASYNC_TASK_INFO_H_ + #include "graph/load/new_model_manager/task_info/task_info.h" namespace ge { @@ -32,14 +33,19 @@ class MemcpyAsyncTaskInfo : public TaskInfo { Status Distribute() override; - private: - Status GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, uint64_t &base_addr); + Status UpdateArgs() override; - void *dst_; + Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + + private: + uint8_t *dst_; uint64_t dst_max_; - void *src_; + uint8_t *src_; uint64_t count_; uint32_t kind_; + DavinciModel *davinci_model_ = nullptr; + uint32_t args_offset_ = 0; + domi::MemcpyAsyncDef memcpy_async; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ASYNC_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc index 1232ddb2..fd5f4f4c 100644 --- a/src/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/profiler_trace_task_info.cc @@ -47,7 +47,7 @@ Status ProfilerTraceTaskInfo::Distribute() { rtError_t rt_ret = rtProfilerTrace(log_id_, notify_, flat_, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GELOGI("ProfilerTraceTaskInfo Distribute Success."); diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc index c30cad09..f48f64e3 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/stream_active_task_info.cc @@ -74,7 +74,7 @@ Status StreamActiveTaskInfo::Distribute() { rtError_t rt_ret = rtStreamActive(active_stream_, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GELOGI("StreamActiveTaskInfo Distribute Success. activeStreamID:%p.", active_stream_); diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc index a1d2f143..45db2be5 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.cc @@ -42,16 +42,11 @@ Status StreamSwitchTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *d auto stream_switch_def = task_def.stream_switch(); uint32_t op_index = stream_switch_def.op_index(); - // get StreamSwitch op OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); GE_CHECK_NOTNULL(op_desc); auto input_data_addr = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); - if (!input_data_addr.empty() && input_data_addr.size() >= STREAM_SWITCH_INPUT_NUM) { - input_ptr_ = input_data_addr[0]; - value_ptr_ = input_data_addr[1]; - } - + SetInputAndValuePtr(davinci_model, input_data_addr); uint32_t cond = 0; if (!AttrUtils::GetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, cond)) { GELOGE(INTERNAL_ERROR, "StreamSwitchOp get attr STREAM_SWITCH_COND fail."); @@ -109,12 +104,48 @@ Status StreamSwitchTaskInfo::Distribute() { rtError_t rt_ret = rtStreamSwitchEx(input_ptr_, cond_, value_ptr_, true_stream_, stream_, data_type_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GELOGI("StreamSwitchTaskInfo Distribute Success. cond:%d, stream:%p, datatype:%d.", cond_, true_stream_, data_type_); return SUCCESS; } +Status StreamSwitchTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + GE_CHECK_NOTNULL(davinci_model); + auto stream_switch_def = task_def.stream_switch(); + uint32_t op_index = stream_switch_def.op_index(); + GELOGI("Begin to calculate args, op_index is: %u", op_index); + auto op_desc = davinci_model->GetOpByIndex(op_index); + GE_CHECK_NOTNULL(op_desc); + GELOGI("Calc opType[%s] args size. Node name is [%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); + if (op_desc->GetInputsSize() != STREAM_SWITCH_INPUT_NUM) { + GELOGE(FAILED, "Stream switch op only have one data input. Now input size is %zu", op_desc->GetInputsSize()); + return FAILED; + } + for (uint32_t i = 0; i < STREAM_SWITCH_INPUT_NUM; ++i) { + string input_tensor_name = op_desc->GetInputNameByIndex(i); + int64_t fixed_addr_offset = davinci_model->GetFixedAddrsSize(input_tensor_name); + fixed_addr_offset_.emplace_back(fixed_addr_offset); + auto tensor_desc = op_desc->GetInputDesc(i); + int64_t tensor_size = 0; + GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + davinci_model->SetTotalFixedAddrsSize(input_tensor_name, tensor_size); + GELOGI("Calculate stream switch task args , tensor size is %ld, fixed addr[%u] offset %ld", tensor_size, i, + fixed_addr_offset); + } + return SUCCESS; +} +void StreamSwitchTaskInfo::SetInputAndValuePtr(DavinciModel *davinci_model, const vector &input_data_addrs) { + if (davinci_model->IsKnownNode() && fixed_addr_offset_.size() == STREAM_SWITCH_INPUT_NUM) { + input_ptr_ = davinci_model->GetCurrentFixedAddr(fixed_addr_offset_[0]); + value_ptr_ = davinci_model->GetCurrentFixedAddr(fixed_addr_offset_[1]); + } else { + if (!input_data_addrs.empty() && input_data_addrs.size() >= STREAM_SWITCH_INPUT_NUM) { + input_ptr_ = input_data_addrs[0]; + value_ptr_ = input_data_addrs[1]; + } + } +} REGISTER_TASK_INFO(RT_MODEL_TASK_STREAM_SWITCH, StreamSwitchTaskInfo); } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h index 07509c7c..e6e8339a 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switch_task_info.h @@ -39,13 +39,18 @@ class StreamSwitchTaskInfo : public TaskInfo { Status Distribute() override; + Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + private: + void SetInputAndValuePtr(DavinciModel *davinci_model, const vector &input_data_addrs); void *input_ptr_; rtCondition_t cond_; void *value_ptr_; rtStream_t true_stream_; uint32_t true_stream_id_; rtSwitchDataType_t data_type_; + static const uint32_t kInputNum = 2; + vector fixed_addr_offset_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_STREAM_SWITCH_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc index 29b107bd..d134dfdd 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.cc @@ -22,20 +22,15 @@ #include "graph/load/new_model_manager/model_utils.h" namespace { -const uint32_t kDynamicBtachParamNum = 1; -const uint32_t kDynamicResolutionParamNum = 2; -} // namespace +const uint8_t kStreamSwitchnInputNum = 1; +} namespace ge { Status StreamSwitchNTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { GELOGI("StreamSwitchNTaskInfo Init Start."); - if (davinci_model == nullptr) { - GELOGE(PARAM_INVALID, "davinci_model is null!"); - return PARAM_INVALID; - } + GE_CHECK_NOTNULL(davinci_model); - Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); - if (ret != SUCCESS) { + if (SetStream(task_def.stream_id(), davinci_model->GetStreamList()) != SUCCESS) { return FAILED; } @@ -48,10 +43,6 @@ Status StreamSwitchNTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel * // set size_ input_size_ = stream_switchn_def.size(); - if (input_size_ != kDynamicBtachParamNum && input_size_ != kDynamicResolutionParamNum) { - GELOGE(FAILED, "The size of dynamic batch or imagesize input is 1 or 2, now it is %u.", input_size_); - return FAILED; - } // set value_ptr_ auto value = stream_switchn_def.target_value(); @@ -75,14 +66,16 @@ Status StreamSwitchNTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel * GELOGE(FAILED, "Get true stream ptr of switchN op failed."); return FAILED; } - - // set input_ptr_ - auto input_data_addr = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); - if (input_data_addr.empty()) { - GELOGE(FAILED, "Input data addr is nullptr."); - return FAILED; + if (davinci_model->IsKnownNode()) { + input_ptr_ = davinci_model->GetCurrentFixedAddr(args_offset_); + } else { + auto input_data_addr = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); + if (input_data_addr.empty()) { + GELOGE(FAILED, "Input data addr is nullptr."); + return FAILED; + } + input_ptr_ = input_data_addr[0]; } - input_ptr_ = input_data_addr[0]; davinci_model->DisableZeroCopy(input_ptr_); GELOGI("StreamSwitchNTaskInfo Init Success, inputSize:%u, elementSize:%d, trueStreamID:%ld.", input_size_, element_size_, op_desc->GetStreamId()); @@ -96,7 +89,7 @@ Status StreamSwitchNTaskInfo::Distribute() { rtStreamSwitchN(input_ptr_, input_size_, value_ptr_, true_stream_ptr_, element_size_, stream_, data_type_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } GELOGI("StreamSwitchNTaskInfo Distribute Success. inputSize:%u, elementSize:%d, datatype:%d.", input_size_, @@ -140,5 +133,26 @@ Status StreamSwitchNTaskInfo::GetTrueStreamPtr(const OpDescPtr &op_desc, Davinci return SUCCESS; } +Status StreamSwitchNTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { + GE_CHECK_NOTNULL(davinci_model); + auto stream_switchn_def = task_def.stream_switch_n(); + uint32_t op_index = stream_switchn_def.op_index(); + GELOGI("Begin to calculate args, op_index is: %u", op_index); + auto op_desc = davinci_model->GetOpByIndex(op_index); + GE_CHECK_NOTNULL(op_desc); + GELOGI("Calc opType[%s] args size. Node name is [%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); + if (op_desc->GetInputsSize() != kStreamSwitchnInputNum) { + GELOGE(FAILED, "Stream switchn op only have one data input. Now input size is %zu", op_desc->GetInputsSize()); + return FAILED; + } + string input_tensor_name = op_desc->GetInputNameByIndex(0); + args_offset_ = davinci_model->GetFixedAddrsSize(input_tensor_name); + auto tensor_desc = op_desc->GetInputDesc(0); + int64_t tensor_size = 0; + GE_CHK_STATUS(TensorUtils::GetSize(tensor_desc, tensor_size)); + davinci_model->SetTotalFixedAddrsSize(input_tensor_name, tensor_size); + GELOGI("Calculate stream switchn task args , tensor_size %ld, args_offset %ld", tensor_size, args_offset_); + return SUCCESS; +} REGISTER_TASK_INFO(RT_MODEL_TASK_STREAM_SWITCH_N, StreamSwitchNTaskInfo); } // namespace ge \ No newline at end of file diff --git a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h index d1002da7..1a96243a 100644 --- a/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/stream_switchn_task_info.h @@ -29,7 +29,8 @@ class StreamSwitchNTaskInfo : public TaskInfo { value_ptr_(nullptr), true_stream_ptr_(nullptr), element_size_(0), - data_type_(RT_SWITCH_INT64) {} + data_type_(RT_SWITCH_INT64), + args_offset_(0) {} ~StreamSwitchNTaskInfo() override {} @@ -37,6 +38,8 @@ class StreamSwitchNTaskInfo : public TaskInfo { Status Distribute() override; + Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; + private: Status GetTrueStreamPtr(const OpDescPtr &op_desc, DavinciModel *davinci_model); void *input_ptr_; @@ -47,6 +50,7 @@ class StreamSwitchNTaskInfo : public TaskInfo { rtSwitchDataType_t data_type_; vector true_stream_list_; vector value_list_; + int64_t args_offset_; }; } // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_STREAM_SWITCHN_TASK_INFO_H_ diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc index b8fc77ac..100a4fea 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc @@ -26,13 +26,15 @@ Status SuperKernel::Launch(rtStream_t stream, uint32_t dump_flag) { reinterpret_cast(reinterpret_cast(this->GetNavTableSize()))}; rtError_t rt_ret = rtMalloc((void **)&(device_args_addr_), sizeof(args), RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failied. error: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret);) rt_ret = rtMemcpy((void *)device_args_addr_, sizeof(args), (void *)args, sizeof(args), RT_MEMCPY_HOST_TO_DEVICE); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy failied. error: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret);) rt_ret = rtKernelLaunchWithFlag((void *const)func_stub_, block_dim_, device_args_addr_, sizeof(args), NULL, stream, dump_flag); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelLaunchWithFlag failied. error: 0x%X", rt_ret); - return FAILED;) + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtKernelLaunchWithFlag failied. error: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret);) return SUCCESS; } } // namespace skt diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h index 1c31acd1..b7e76af0 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel.h @@ -34,22 +34,13 @@ class SuperKernel { public: SuperKernel(const void *stub, void *ptr, uint64_t sz, uint32_t dim) : func_stub_(stub), dev_nav_table_(ptr), nav_table_size_(sz), block_dim_(dim) {} - ~SuperKernel() { - // free memory when all releasing - if (device_args_addr_ != nullptr) { - GE_CHK_RT(rtFree(device_args_addr_)); - GELOGI("SKT: super_kernel args addr free."); - } - if (dev_nav_table_ != nullptr) { - GE_CHK_RT(rtFree(dev_nav_table_)); - GELOGI("SKT: super_kernel args addr free."); - } - } + ~SuperKernel() = default; Status Launch(rtStream_t stream, uint32_t dump_flag); const void *GetFuncStub() const { return func_stub_; } - const void *GetNavTablePtr() const { return dev_nav_table_; } uint64_t GetNavTableSize() const { return nav_table_size_; } uint32_t GetBlockDim() const { return block_dim_; } + void *GetNavTablePtr() const { return dev_nav_table_; } + void *GetDeviceArgsPtr() const { return device_args_addr_; } }; } // namespace skt } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc index d2ad474a..ca42b4e2 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc @@ -33,30 +33,19 @@ Status SuperKernelFactory::Init() { } rtError_t rt_ret; rt_ret = rtGetFunctionByName(this->sk_stub_name_.c_str(), &this->func_stub_); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtGetFunctionByName " "failed. stub_func: %s, please export LD_LIBRARY_PATH for " "libcce_aicore.so", this->sk_stub_name_.c_str()); - return FAILED;) + return RT_ERROR_TO_GE_STATUS(rt_ret);) rt_ret = rtGetAddrByFun(this->func_stub_, &this->func_ptr_); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); - return FAILED;) - if (this->use_physical_address_ != nullptr) { - void *skt_func = nullptr; - rt_ret = rtKernelConfigTransArg(this->func_ptr_, sizeof(uint64_t), 0, &skt_func); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); - return FAILED;) - GELOGD( - "SKT: fuseKernels super_kernel_template subFunc %p, device func " - "address %p, device physic PC %p", - this->func_stub_, this->func_ptr_, skt_func); - } else { - GELOGD( - "SKT: fuseKernels super_kernel_template subFunc %p, device func " - "address %p", - this->func_stub_, this->func_ptr_); - } + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtGetAddrByFun failed. error: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret);) + GELOGD( + "SKT: fuseKernels super_kernel_template subFunc %p, device func " + "address %p", + this->func_stub_, this->func_ptr_); } is_init_ = true; @@ -71,7 +60,8 @@ Status SuperKernelFactory::Uninitialize() { } Status SuperKernelFactory::FuseKernels(const std::vector &stub_func_list, - const std::vector &args_addr_list, uint32_t block_dim, SuperKernel *&h) { + const std::vector &args_addr_list, uint32_t block_dim, + std::unique_ptr &h) { // Iterate through the ops to be fused // Each subkernel to be fused contains 2 fields: fn address offset, args // address. @@ -101,70 +91,29 @@ Status SuperKernelFactory::FuseKernels(const std::vector &stub_func_list rtError_t rt_ret; void *hbm_nav_table_addr = nullptr; - if (this->use_physical_address_ != nullptr) { - for (unsigned i = 0; i < stub_func_list.size(); i++) { - void *sub_device_func = nullptr; - rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); - return FAILED;) - void *sub_device_func_pys = nullptr; - void *args_addr_pys = nullptr; - rt_ret = rtKernelConfigTransArg(sub_device_func, sizeof(uint64_t), 0, &sub_device_func_pys); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); - return FAILED;) - rt_ret = rtKernelConfigTransArg(args_addr_list[i], sizeof(uint64_t), 0, &args_addr_pys); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); - return FAILED;) - GELOGD( - "SKT: fuseKernels subFunc %p, device func address %p, device " - "physic func address %p", - stub_func_list[i], sub_device_func, sub_device_func_pys); - // store two uint64_t address - // address divided by 4 because of 32bits encoding, call offset will *4 when calculating - nav_table[i * 2] = reinterpret_cast(reinterpret_cast(sub_device_func_pys)) / 4; - GELOGD("SKT: CALL offset %lu", nav_table[i * 2]); - nav_table[i * 2 + 1] = reinterpret_cast(reinterpret_cast(args_addr_pys)); - - GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]); - } - - void *hbm_nav_table_addr_pys = nullptr; - rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) - rt_ret = - rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); - GE_CHK_RT(rtFree(hbm_nav_table_addr)); return FAILED;) - rt_ret = rtKernelConfigTransArg(hbm_nav_table_addr, sizeof(uint64_t), 0, &hbm_nav_table_addr_pys); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); - GE_CHK_RT(rtFree(hbm_nav_table_addr)); return FAILED;) - - GELOGD("SKT: hbm_nav_table_addr %p, hbm_nav_table_addr_pys %p", hbm_nav_table_addr, hbm_nav_table_addr_pys); - // Create the necessary metadata for the super kernel - h = new SuperKernel(this->func_stub_, hbm_nav_table_addr_pys, nav_table_size, block_dim); - } else { - for (unsigned i = 0; i < stub_func_list.size(); i++) { - void *sub_device_func = nullptr; - rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); - return FAILED;) - GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); - // store two uint64_t address - // address divided by 4 because of 32bits encoding, call offset will *4 when calculating - nav_table[i * 2] = reinterpret_cast(reinterpret_cast(sub_device_func)) / 4; - GELOGD("SKT: CALL offet %lu", nav_table[i * 2]); - nav_table[i * 2 + 1] = reinterpret_cast(reinterpret_cast(args_addr_list[i])); - GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]); - } - rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) - rt_ret = - rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); - GE_CHK_RT(rtFree(hbm_nav_table_addr)); return FAILED;) - // Create the necessary metadata for the super kernel - h = new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim); + for (unsigned i = 0; i < stub_func_list.size(); i++) { + void *sub_device_func = nullptr; + rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtGetAddrByFun failed. error: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret);) + GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); + // store two uint64_t address + // address divided by 4 because of 32bits encoding, call offset will *4 when calculating + nav_table[i * 2] = reinterpret_cast(reinterpret_cast(sub_device_func)) / 4; + GELOGD("SKT: CALL offet %lu", nav_table[i * 2]); + nav_table[i * 2 + 1] = reinterpret_cast(reinterpret_cast(args_addr_list[i])); + GELOGD("SKT: fuseKernels args base address %lu", nav_table[i * 2 + 1]); } + rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMalloc failed. error: 0x%X", rt_ret); + return RT_ERROR_TO_GE_STATUS(rt_ret);) + rt_ret = + rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); + GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "rtMemcpy failed. error: 0x%X", rt_ret); + GE_CHK_RT(rtFree(hbm_nav_table_addr)); return RT_ERROR_TO_GE_STATUS(rt_ret);) + // Create the necessary metadata for the super kernel + h = + std::unique_ptr(new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim)); return SUCCESS; } } // namespace skt diff --git a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h index d8b7ff26..7db44eec 100644 --- a/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h +++ b/src/ge/graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.h @@ -29,7 +29,6 @@ class SuperKernelFactory { void *func_ptr_ = nullptr; void *handle_ = nullptr; std::string sk_stub_name_ = "_Z21super_kernel_templatePmm"; - const char *use_physical_address_ = getenv("GE_USE_PHYSICAL_ADDRESS"); bool is_init_ = false; SuperKernelFactory(){}; ~SuperKernelFactory() { @@ -48,7 +47,7 @@ class SuperKernelFactory { Status Init(); Status Uninitialize(); Status FuseKernels(const std::vector &stub_func_list, const std::vector &args_addr_list, - uint32_t block_dim, SuperKernel *&h); + uint32_t block_dim, std::unique_ptr &h); }; } // namespace skt } // namespace ge diff --git a/src/ge/graph/load/new_model_manager/task_info/task_info.h b/src/ge/graph/load/new_model_manager/task_info/task_info.h index 5d2c89eb..f69511e6 100644 --- a/src/ge/graph/load/new_model_manager/task_info/task_info.h +++ b/src/ge/graph/load/new_model_manager/task_info/task_info.h @@ -72,6 +72,8 @@ class TaskInfo { virtual uint32_t GetTaskID() { return 0xFFFFFFFF; } + virtual bool CallSaveDumpInfo() { return false; } + virtual uint32_t GetStreamId() { return 0xFFFFFFFF; } virtual uintptr_t GetDumpArgs() { return 0; } diff --git a/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h b/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h index b6954016..5b220960 100644 --- a/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h +++ b/src/ge/graph/load/new_model_manager/task_info/task_info_factory.h @@ -86,5 +86,5 @@ class TaskInfoFactory { return ptr; \ } \ TaskInfoFactory::Registerar g_##type##_Task_Info_Creator(type, Creator_##type##_Task_Info); -}; // namespace ge +} // namespace ge #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_TASK_INFO_FACTORY_H_ diff --git a/src/ge/graph/load/new_model_manager/zero_copy_offset.cc b/src/ge/graph/load/new_model_manager/zero_copy_offset.cc new file mode 100644 index 00000000..18b958ef --- /dev/null +++ b/src/ge/graph/load/new_model_manager/zero_copy_offset.cc @@ -0,0 +1,220 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/load/new_model_manager/zero_copy_offset.h" + +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "graph/load/new_model_manager/model_utils.h" +#include "graph/load/new_model_manager/zero_copy_task.h" + +namespace ge { +namespace { +const uint32_t kDataIndex = 0; +} // namespace + +ZeroCopyOffset::ZeroCopyOffset() {} + +ZeroCopyOffset::~ZeroCopyOffset() {} + +Status ZeroCopyOffset::InitInputDataInfo(const vector &output_size_list, + const vector &virtual_addr_list, const OpDescPtr &op_desc, + bool &fusion_flag) { + GELOGI("[ZCPY] Start to InitInputDataInfo of %s, total_data_size is %ld, virtual_addr is %p", + op_desc->GetName().c_str(), output_size_list[kDataIndex], virtual_addr_list[kDataIndex]); + if (output_size_list.empty() || virtual_addr_list.empty() || (output_size_list.size() != virtual_addr_list.size())) { + GELOGE(PARAM_INVALID, "Data[%s] init failed: Output size is %zu, Output addr is %zu", op_desc->GetName().c_str(), + output_size_list.size(), virtual_addr_list.size()); + return PARAM_INVALID; + } + + basic_addr_ = virtual_addr_list[kDataIndex]; + (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_); + (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_); + GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), return PARAM_INVALID, + "basic_offset_size should be equal to relative_offset_size"); + GELOGI("[ZCPY] zero_copy_basic_offset size is %zu", zero_copy_basic_offset_.size()); + + int64_t virtual_addr_offset = op_desc->GetOutputOffset().at(kDataIndex); + GELOGI("virtual_addr_offset is %ld.", virtual_addr_offset); + IsL2Fusion(zero_copy_basic_offset_, virtual_addr_offset, fusion_flag); + + uint32_t out_count = 0; + data_size_ = output_size_list[kDataIndex]; + if (!fusion_flag) { + GELOGI("[ZCPY] %s not set l2_fusion.", op_desc->GetName().c_str()); + out_count++; + data_info_.emplace_back(output_size_list[kDataIndex], virtual_addr_list[kDataIndex]); + relative_offset_.emplace_back(0); + GELOGI("[ZCPY] %s size is %ld, virtual_addr is %p.", op_desc->GetName().c_str(), output_size_list[kDataIndex], + virtual_addr_list[kDataIndex]); + } else { + GELOGI("[ZCPY] set l2_fusion for %s.", op_desc->GetName().c_str()); + for (size_t index = 0; index < zero_copy_basic_offset_.size(); ++index) { + if (zero_copy_basic_offset_.at(index) == virtual_addr_offset) { + out_count++; + uint64_t out_offset = + reinterpret_cast(virtual_addr_list[kDataIndex]) + zero_copy_relative_offset_.at(index); + int64_t real_data_size = ModelUtils::GetOutputSize(op_desc).at(kDataIndex); + data_info_.emplace_back(real_data_size, reinterpret_cast(reinterpret_cast(out_offset))); + relative_offset_.emplace_back(zero_copy_relative_offset_.at(index)); + GELOGI("[ZCPY] virtual_addr: %p has been l2-fusion to %lu, need copy data_size is %ld.", basic_addr_, + out_offset, real_data_size); + } + } + } + data_count_ = out_count; + return SUCCESS; +} + +Status ZeroCopyOffset::InitOutputDataInfo(const vector &input_size_list, + const vector &virtual_addr_list, const OpDescPtr &op_desc, + const size_t &idx, bool &fusion_flag) { + GELOGI("[ZCPY] Start to InitOutputDataInfo of %s.", op_desc->GetName().c_str()); + int64_t size = input_size_list[idx]; + auto tensor_desc = op_desc->GetInputDescPtr(idx); + GE_CHECK_NOTNULL(tensor_desc); + if (TensorUtils::GetTensorSizeInBytes(*tensor_desc, size) != GRAPH_SUCCESS) { + GELOGE(FAILED, "GetTensorSizeInBytes failed!"); + return FAILED; + } + + GELOGI("Tensor data size: GetSize=%ld, GetTensorSizeInBytes=%ld", input_size_list[idx], size); + + basic_addr_ = virtual_addr_list[idx]; + (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_); + (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_); + GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), return PARAM_INVALID, + "basic_offset_size should be equal to relative_offset_size"); + int64_t virtual_addr_offset = op_desc->GetInputOffset().at(idx); + GELOGI("virtual_addr_offset is %ld.", virtual_addr_offset); + IsL2Fusion(zero_copy_basic_offset_, virtual_addr_offset, fusion_flag); + + uint32_t in_count = 0; + data_size_ = size; + if (!fusion_flag) { + GELOGI("[ZCPY] %s not set l2-fusion.", op_desc->GetName().c_str()); + in_count++; + data_info_.emplace_back(size, virtual_addr_list[idx]); + // op_desc not set l2fusion when fusion_flag is false + relative_offset_.emplace_back(0); + GELOGI("[ZCPY] %s size is %ld, virtual_addr is %p.", op_desc->GetName().c_str(), size, virtual_addr_list[idx]); + } else { + GELOGI("[ZCPY] set l2-fusion for %s.", op_desc->GetName().c_str()); + for (size_t index = 0; index < zero_copy_basic_offset_.size(); ++index) { + if (zero_copy_basic_offset_.at(index) == virtual_addr_offset) { + in_count++; + uint64_t in_offset = reinterpret_cast(virtual_addr_list[idx]) + zero_copy_relative_offset_.at(index); + int64_t real_data_size = ModelUtils::GetInputSize(op_desc).at(idx); + data_info_.emplace_back(real_data_size, reinterpret_cast(reinterpret_cast(in_offset))); + relative_offset_.emplace_back(zero_copy_relative_offset_.at(index)); + GELOGI("[ZCPY] virtual_addr: %p has been l2-fusion from %lu, need copy data_size is %ld.", basic_addr_, + in_offset, real_data_size); + } + } + } + data_count_ = in_count; + return SUCCESS; +} + +void ZeroCopyOffset::IsL2Fusion(const vector &fusion_basic_addrs, const int64_t &tensor_offset, + bool &fusion_flag) { + for (size_t fusion_count = 0; fusion_count < fusion_basic_addrs.size(); ++fusion_count) { + if (fusion_basic_addrs.at(fusion_count) == tensor_offset) { + fusion_flag = true; + break; + } + } +} + +void ZeroCopyOffset::SetInputOutsideAddrs(const vector &output_offset_list, void *addr, const size_t &index, + bool fusion_flag, std::vector &real_virtual_addrs) { + GELOGI("[ZCPY] Start to SetInputOutsideAddrs for virtual_addr %p.", addr); + uint32_t out_count = 0; + if (!fusion_flag) { + GELOGI("[ZCPY] not set l2-fusion for virtual_adr %p.", addr); + out_count++; + std::map> addr_mapping; + addr_mapping[addr] = {}; + outside_addrs_.emplace_back(addr_mapping); + real_virtual_addrs.emplace_back(addr); + } else { + GELOGI("[ZCPY] set l2-fusion for virtual_addr %p.", addr); + int64_t output_offset = output_offset_list.at(index); + for (size_t i = 0; i < zero_copy_basic_offset_.size(); ++i) { + if (zero_copy_basic_offset_.at(i) == output_offset) { + out_count++; + void *virtual_addr = + reinterpret_cast(reinterpret_cast(addr) + zero_copy_relative_offset_.at(i)); + std::map> addr_mapping; + addr_mapping[virtual_addr] = {}; + outside_addrs_.emplace_back(addr_mapping); + real_virtual_addrs.emplace_back(virtual_addr); + GELOGI("[ZCPY] virtual_addr %p has been fusion to virtual_addr %p.", addr, virtual_addr); + } + } + } + addr_count_ = out_count; +} + +void ZeroCopyOffset::SetOutputOutsideAddrs(const int64_t &input_offset, const bool &fusion_flag, void *addr, + std::vector &tensor_addrs) { + GELOGI("[ZCPY] Start to SetOutputOutsideAddrs for virtual_addr %p.", addr); + uint32_t out_count = 0; + if (!fusion_flag) { + GELOGI("[ZCPY] not set l2-fusion for virtual_addr %p.", addr); + out_count++; + std::map> addr_mapping; + addr_mapping[addr] = {}; + outside_addrs_.emplace_back(addr_mapping); + tensor_addrs.emplace_back(addr); + } else { + GELOGI("[ZCPY] set l2-fusion for virtual_addr %p.", addr); + for (size_t i = 0; i < zero_copy_basic_offset_.size(); ++i) { + if (zero_copy_basic_offset_.at(i) == input_offset) { + out_count++; + void *virtual_addr = + reinterpret_cast(reinterpret_cast(addr) + zero_copy_relative_offset_.at(i)); + std::map> addr_mapping; + addr_mapping[virtual_addr] = {}; + outside_addrs_.emplace_back(addr_mapping); + tensor_addrs.emplace_back(virtual_addr); + GELOGI("[ZCPY] virtual_addr %p has been fusion to virtual_addr %p.", addr, virtual_addr); + } + } + } + addr_count_ = out_count; +} + +bool ZeroCopyOffset::SetOutsideAddrsValue(ZeroCopyTask &zero_copy_task, void *outside_addr, void *args, size_t offset) { + const auto addr_val = reinterpret_cast(outside_addr); + bool set_batch_label_flag = false; + for (uint32_t out_count = 0; out_count < GetAddrCount(); ++out_count) { + auto &addrs_mapping_list = GetOutsideAddrs(); + auto args_addrs = addrs_mapping_list[out_count].find(outside_addr); + if (args_addrs != addrs_mapping_list[out_count].end()) { + GE_CHK_STATUS(zero_copy_task.SetTaskArgsOffset(addr_val, offset), "Input args invalid."); + void *args_val = static_cast(args) + offset; + args_addrs->second.push_back(args_val); + GELOGI("[ZCPY] set copy input: virtual_addr: 0x%lx, task_addr: %p, args: %p, offset: %zu.", addr_val, args_val, + args, offset); + set_batch_label_flag = true; + } + } + return set_batch_label_flag; +} + +} // namespace ge diff --git a/src/ge/graph/load/new_model_manager/zero_copy_offset.h b/src/ge/graph/load/new_model_manager/zero_copy_offset.h new file mode 100644 index 00000000..eb2cdb4d --- /dev/null +++ b/src/ge/graph/load/new_model_manager/zero_copy_offset.h @@ -0,0 +1,84 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_OFFSET_H_ +#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_OFFSET_H_ + +#include +#include +#include +#include + +#include "external/ge/ge_api_error_codes.h" +#include "framework/common/ge_types.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/load/new_model_manager/zero_copy_task.h" +#include "graph/utils/attr_utils.h" +#include "graph/utils/tensor_utils.h" +#include "runtime/mem.h" +#include "task_info/task_info.h" + +using std::map; +using std::set; +using std::string; +using std::vector; + +namespace ge { +class ZeroCopyOffset { + public: + ZeroCopyOffset(); + ~ZeroCopyOffset(); + + Status InitInputDataInfo(const vector &output_size_list, const vector &virtual_addr_list, + const OpDescPtr &op_desc, bool &fusion_flag); + void SetInputOutsideAddrs(const vector &output_offset_list, void *addr, const size_t &index, + bool fusion_flag, std::vector &real_virtual_addrs); + + void IsL2Fusion(const vector &fusion_basic_addrs, const int64_t &tensor_addr, bool &fusion_flag); + Status InitOutputDataInfo(const vector &input_size_list, const vector &virtual_addr_list, + const OpDescPtr &op_desc, const size_t &idx, bool &fusion_flag); + void SetOutputOutsideAddrs(const int64_t &input_offset, const bool &fusion_flag, void *addr, + std::vector &tensor_addrs); + bool SetOutsideAddrsValue(ZeroCopyTask &zero_copy_task, void *outside_addr, void *args, size_t offset); + + // basic_addr of l2-fusion + void *GetBasicAddr() const { return basic_addr_; } + // total num of out_of_data/in_of_phonyconcat + uint32_t GetDataCount() const { return data_count_; } + uint32_t GetAddrCount() const { return addr_count_; } + // value of *data_info_ from davinci_model + std::vector> GetDataInfo() const { return data_info_; } + // relative_offset from zero_copy_relative_offset_ + std::vector GetRelativeOffset() const { return relative_offset_; } + // data_size of Data/Netoutput + int64_t GetDataSize() const { return data_size_; } + // value of *outside_addrs_ from davinci_model + std::vector>> &GetOutsideAddrs() { return outside_addrs_; } + + private: + void *basic_addr_ = nullptr; + uint32_t data_count_ = 0; + std::vector> data_info_; + vector relative_offset_; + int64_t data_size_ = 0; + uint32_t addr_count_ = 0; + std::vector>> outside_addrs_; + + std::vector zero_copy_basic_offset_; + std::vector zero_copy_relative_offset_; +}; +} // namespace ge +#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_ZERO_COPY_OFFSET_H_ \ No newline at end of file diff --git a/src/ge/graph/load/new_model_manager/zero_copy_task.cc b/src/ge/graph/load/new_model_manager/zero_copy_task.cc index 42734a87..5b595d76 100644 --- a/src/ge/graph/load/new_model_manager/zero_copy_task.cc +++ b/src/ge/graph/load/new_model_manager/zero_copy_task.cc @@ -16,9 +16,9 @@ #include "graph/load/new_model_manager/zero_copy_task.h" -#include "graph/load/new_model_manager/model_utils.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" +#include "graph/load/new_model_manager/model_utils.h" namespace ge { const char *const kDefaultBatchLable = "Batch_default"; @@ -48,8 +48,8 @@ Status ZeroCopyTask::SetTaskArgsOffset(uintptr_t addr, size_t offset) { it->second.push_back(offset); } - GELOGI("[ZCPY] %s set task, addr: 0x%lx, args: %p, size: %zu, offset: %zu", name_.c_str(), addr, args_addr_, - args_size_, offset); + GELOGI("[ZCPY] %s set task, virtual_addr: 0x%lx, args_addr: %p, size: %zu, offset: %zu", name_.c_str(), addr, + args_addr_, args_size_, offset); return SUCCESS; } @@ -65,7 +65,8 @@ void ZeroCopyTask::SetOriginalArgs(const void *info, size_t size) { const uint8_t *data = static_cast(info); args_info_.assign(data, data + size); - GELOGI("[ZCPY] %s set info, args: %p, args size: %zu, info size: %zu", name_.c_str(), args_addr_, args_size_, size); + GELOGI("[ZCPY] %s set info from virtual_addr: %p, args_addr: %p, args size: %zu, info size: %zu", name_.c_str(), info, + args_addr_, args_size_, size); } /** @@ -110,13 +111,13 @@ bool ZeroCopyTask::CheckDynamicBatch(const map> &batch_ad * @ingroup ge * @brief Set user data addr to Task param. * @param [in] addr: virtual address value from Op. - * @param [in] data: data buffer from user. + * @param [in] buffer_addr: real_data_buffer_addr from user. * @param [in] batch_addrs: dynamic batch addr info. * @param [in] batch_label: batch label. * @return: void */ -Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, const DataBuffer &data, - const map> &batch_addrs, const string &batch_label) { +Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, void *buffer_addr, const map> &batch_addrs, + const string &batch_label) { for (auto pair : task_addr_offset_) { if (pair.first != addr) { continue; @@ -128,15 +129,9 @@ Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, const DataBuffer &data, continue; } - auto dst_addr = static_cast(data.data); - auto dst_size = static_cast(data.length); - if (ModelUtils::ConvertVirtualAddressToPhysical(dst_addr, dst_size, dst_addr) != SUCCESS) { - GELOGE(FAILED, "[ZCPY] Convert virtual address to physical for dst_addr failed."); - return FAILED; - } - - GELOGI("[ZCPY] %s update task, args: %p, size: %zu, offset: %zu, addr: 0x%lx, length: %u", name_.c_str(), - args_addr_, args_size_, offset, addr, data.length); + auto dst_addr = static_cast(buffer_addr); + GELOGI("[ZCPY] %s update task, args_addr: %p, size: %zu, offset: %zu, virtual_addr: 0x%lx", name_.c_str(), + args_addr_, args_size_, offset, addr); *(uintptr_t *)(args_info + offset) = reinterpret_cast(dst_addr); is_updated_ = true; } @@ -168,11 +163,11 @@ Status ZeroCopyTask::DistributeParam(rtStream_t stream) { } if (rt_err != RT_ERROR_NONE) { - GELOGE(FAILED, "[ZCPY] %s distribute task param failed, error=0x%x", name_.c_str(), rt_err); - return FAILED; + GELOGE(RT_FAILED, "[ZCPY] %s distribute task param failed, error=0x%x", name_.c_str(), rt_err); + return RT_ERROR_TO_GE_STATUS(rt_err); } - GELOGI("[ZCPY] %s refresh task args success, args: %p, size: %zu, args_info_: %p, length: %zu", name_.c_str(), + GELOGI("[ZCPY] %s refresh task args success, args_addr: %p, size: %zu, args_info_: %p, length: %zu", name_.c_str(), args_addr_, args_size_, args_info_.data(), args_info_.size()); return SUCCESS; } diff --git a/src/ge/graph/load/new_model_manager/zero_copy_task.h b/src/ge/graph/load/new_model_manager/zero_copy_task.h index 9d3f5b03..d2a91ce7 100644 --- a/src/ge/graph/load/new_model_manager/zero_copy_task.h +++ b/src/ge/graph/load/new_model_manager/zero_copy_task.h @@ -66,12 +66,12 @@ class ZeroCopyTask { * @ingroup ge * @brief Set user data addr to Task param. * @param [in] addr: virtual address value from Op. - * @param [in] data: data buffer from user. + * @param [in] buffer_addr: data buffer_addr from user. * @param [in] batch_addrs: dynamic batch addr info. * @param [in] batch_label: batch label. * @return: 0 SUCCESS / others FAILED */ - ge::Status UpdateTaskParam(uintptr_t addr, const DataBuffer &data, const map> &batch_addrs, + ge::Status UpdateTaskParam(uintptr_t addr, void *buffer_addr, const map> &batch_addrs, const string &batch_label); /** diff --git a/src/ge/graph/load/output/output.cc b/src/ge/graph/load/output/output.cc deleted file mode 100644 index d922ce7c..00000000 --- a/src/ge/graph/load/output/output.cc +++ /dev/null @@ -1,175 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/load/output/output.h" - -#include - -#include "common/properties_manager.h" -#include "graph/load/new_model_manager/davinci_model.h" -#include "graph/manager/graph_var_manager.h" -#include "graph/utils/op_desc_utils.h" -#include "graph/utils/tensor_utils.h" - -namespace ge { -Output::Output(const OpDescPtr &op_desc, DavinciModel *model) - : base_(nullptr), - var_base_(nullptr), - logic_base_(0), - logic_var_base_(0), - model_(model), - op_desc_(op_desc), - input_num_(0) {} - -Output::~Output() { - var_base_ = nullptr; - base_ = nullptr; - model_ = nullptr; -} - -/// -/// @ingroup domi -/// @brief Initialize input/output params -/// @return Status -/// -Status Output::Init() { - if (op_desc_ == nullptr || model_ == nullptr) { - GELOGE(INTERNAL_ERROR, "The op_desc_ or model_ is nullptr."); - return INTERNAL_ERROR; - } - - base_ = model_->MemBase(); - var_base_ = model_->VarMemBase(); - logic_base_ = model_->GetRtBaseAddr(); - logic_var_base_ = model_->GetRtVarAddr(); - - input_num_ = op_desc_->GetInputsSize(); - v_input_size_.clear(); - v_input_data_addr_.clear(); - - auto input_vector = op_desc_->GetInputOffset(); - if (input_num_ != input_vector.size()) { - GELOGE(INTERNAL_ERROR, "input desc size: %zu != input offset size: %zu.", input_num_, input_vector.size()); - return INTERNAL_ERROR; - } - - for (size_t i = 0; i < input_num_; i++) { - int64_t tensor_size = 0; - auto input_desc = op_desc_->GetInputDescPtr(i); - GE_CHECK_NOTNULL(input_desc); - Status ret = TensorUtils::GetSize(*input_desc, tensor_size); - if (ret != GRAPH_SUCCESS) { - GELOGE(ret, "Get size from TensorDesc failed, op : %s, input index : %zu", op_desc_->GetName().c_str(), i); - return ret; - } - v_input_size_.push_back(tensor_size); - - if (VarManager::Instance(model_->SessionId())->IsVarAddr(input_vector[i])) { - v_input_data_addr_.push_back(static_cast(var_base_ + input_vector[i] - logic_var_base_)); - } else { - v_input_data_addr_.push_back(static_cast(base_ + input_vector[i])); - } - } - - GELOGI("Init output:%lu, %lu, %lu", input_num_, v_input_size_.size(), v_input_data_addr_.size()); - - return SUCCESS; -} - -/// -/// @ingroup domi -/// @brief Copy Op Output to user space. -/// @brief when model running, Add one DataOp as input node, Add one Output Op as output node. -/// @return Status -/// -Status Output::CopyResult(OutputData &rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share) { - uint32_t data_count = 0; - if (input_num_ > rslt.blobs.size() - data_begin) { - GELOGE(FAILED, "Tensor num %zu, data_buf num: %zu.", input_num_, rslt.blobs.size() - data_begin); - return FAILED; - } else if (input_num_ < rslt.blobs.size() - data_begin) { - GELOGW("Tensor num %zu, data_buf num: %zu.", input_num_, rslt.blobs.size() - data_begin); - } - - for (size_t i = 0; i < input_num_; i++) { - DataBuffer data_buf = rslt.blobs[data_begin + data_count]; - Status ret = SetDataBuf(data_buf, data_count, i, support_mem_share); - if (ret != SUCCESS) { - GELOGE(ret, "Copy data to host error. index: %zu", i); - return ret; - } - data_index = data_begin + data_count; - } - - return SUCCESS; -} - -Status Output::SetDataBuf(DataBuffer &data_buf, uint32_t &data_count, size_t i, bool support_mem_share) { - if (data_buf.length == 0) { - ++data_count; - GELOGD("Length of data_buffer is zero, No need to copy. output op : %s, output tensor index : %zu!", - op_desc_->GetName().c_str(), i); - return SUCCESS; - } - - auto tensor_desc = op_desc_->GetInputDescPtr(static_cast(i)); - if (tensor_desc == nullptr) { - GELOGE(FAILED, "tensor_desc is null"); - return FAILED; - } - - if (data_buf.isDataSupportMemShare && support_mem_share) { - GELOGI("No need to copy input data, user's output data buffer can be shared."); - } else { - // Copy result to Databuf - int64_t size = v_input_size_[i]; - GELOGI("Tensor data size before: %ld", size); - - graphStatus graph_status = TensorUtils::GetTensorSizeInBytes(*tensor_desc, size); - if (graph_status != ge::GRAPH_SUCCESS) { - GELOGE(graph_status, "GetTensorSizeInBytes failed!"); - return FAILED; - } - - if (data_buf.length < size) { - GELOGE(FAILED, "Tensor data size: %ld data_buf length: %ld", size, data_buf.length); - return FAILED; - } else if (data_buf.length > size) { - GELOGW("Tensor data size: %ld data_buf length: %ld", size, data_buf.length); - } - - rtError_t rt_ret = rtMemcpy(data_buf.data, size, v_input_data_addr_[i], size, RT_MEMCPY_DEVICE_TO_HOST); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(rt_ret, "rtmemcpy error"); - return FAILED; - } - GELOGI("Tensor data size: %ld data_buf length: %ld", size, data_buf.length); - } - - ++data_count; - GELOGD("Successfully copy the output tensor memory to buffer, output op : %s, output tensor index : %zu!", - op_desc_->GetName().c_str(), i); - - return SUCCESS; -} - -void Output::GetOutputData(vector &v_data_addr, vector &v_data_size) { - for (size_t i = 0; i < input_num_; ++i) { - v_data_addr.push_back(v_input_data_addr_[i]); - v_data_size.push_back(v_input_size_[i]); - } -} -} // namespace ge diff --git a/src/ge/graph/load/output/output.h b/src/ge/graph/load/output/output.h deleted file mode 100644 index d93b8de9..00000000 --- a/src/ge/graph/load/output/output.h +++ /dev/null @@ -1,94 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#ifndef GE_GRAPH_LOAD_OUTPUT_OUTPUT_H_ -#define GE_GRAPH_LOAD_OUTPUT_OUTPUT_H_ - -#include -#include - -#include "common/debug/log.h" -#include "common/op/attr_value_util.h" -#include "common/op/ge_op_utils.h" -#include "common/types.h" -#include "common/util.h" -#include "common/ge_types.h" -#include "graph/load/new_model_manager/davinci_model.h" -#include "graph/op_desc.h" -#include "graph/debug/ge_attr_define.h" - -namespace ge { -using std::string; -using std::vector; - -// The base class for all op -class Output { - public: - Output(const OpDescPtr &op_desc, DavinciModel *model); - virtual ~Output(); - - /// - /// @ingroup domi - /// @brief Initialize input/output params - /// @return Status - /// - virtual Status Init(); - - /// - /// @ingroup domi - /// @brief Copy Op Output to user space. - /// @brief when model running, Add one DataOp as input node, Add one Output Op as output node. - /// @return Status - /// - virtual Status CopyResult(OutputData &rslt, uint32_t data_begin, uint32_t &data_index, bool support_mem_share); - - /// - /// @ingroup domi - /// @brief Trans Output data to fp16 - /// @return Status - /// - Status SetDataBuf(DataBuffer &data_buf, uint32_t &data_count, size_t i, bool support_mem_share); - - /// - /// @ingroup domi - /// @brief Get Output data and size. - /// @return void - /// - void GetOutputData(vector &v_data_addr, vector &v_data_size); - - // Copy assignment operator and copy constructor are deleted - Output &operator=(const Output &output) = delete; - Output(const Output &output) = delete; - - protected: - // Model's base address - uint8_t *base_; - uint8_t *var_base_; - uint64_t logic_base_; - uint64_t logic_var_base_; - // The DavinciModel which ops belong to - DavinciModel *model_; - - ConstOpDescPtr op_desc_; - - // Input descriptions - size_t input_num_; - vector v_input_data_addr_; // init as:buf_base + op_def_->input(i)); - vector v_input_size_; -}; -} // namespace ge - -#endif // GE_GRAPH_LOAD_OUTPUT_OUTPUT_H_ diff --git a/src/ge/graph/manager/block_memory.h b/src/ge/graph/manager/block_memory.h new file mode 100644 index 00000000..e2bf74b2 --- /dev/null +++ b/src/ge/graph/manager/block_memory.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_MANAGER_BLOCK_MEMORY_H_ +#define GE_GRAPH_MANAGER_BLOCK_MEMORY_H_ +namespace ge { +struct Block; +typedef bool (*Comparison)(const Block *, const Block *); +using BlockBin = std::set; + +struct Block { + uint32_t device_id; // npu device id + size_t size; // block size in bytes + BlockBin *bin; // owning block bin + uint8_t *ptr; // memory address + bool allocated; // in-use flag + Block *prev; // prev block if split from a larger allocation + Block *next; // next block if split from a larger allocation + + Block(uint32_t device, size_t size, BlockBin *bin, uint8_t *ptr) + : device_id(device), size(size), bin(bin), ptr(ptr), allocated(false), prev(nullptr), next(nullptr) {} + + // constructor for search key + Block(uint32_t device, size_t size, uint8_t *ptr) + : device_id(device), size(size), bin(nullptr), ptr(ptr), allocated(false), prev(nullptr), next(nullptr) {} + + bool IsSplit() const { return (prev != nullptr) || (next != nullptr); } +}; +} // namespace ge +#endif // GE_GRAPH_MANAGER_BLOCK_MEMORY_H_ diff --git a/src/ge/graph/manager/graph_caching_allocator.cc b/src/ge/graph/manager/graph_caching_allocator.cc index 5df6769b..4ba39ca8 100644 --- a/src/ge/graph/manager/graph_caching_allocator.cc +++ b/src/ge/graph/manager/graph_caching_allocator.cc @@ -34,9 +34,6 @@ const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, 26 * kGByteSize}; static bool BlockComparator(const Block *left, const Block *right) { - if (left->device_id != right->device_id) { - return left->device_id < right->device_id; - } if (left->size != right->size) { return left->size < right->size; } @@ -137,11 +134,6 @@ uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device } if (ptr == nullptr) { GELOGE(FAILED, "Malloc failed device id = %u, size= %zu", device_id, size); - } else { - std::lock_guard lock(mutex_); - block->allocated = true; - allocated_blocks_[block->ptr] = block; - GELOGI("Malloc device id = %u, size= %zu", device_id, size); } return ptr; } @@ -225,9 +217,16 @@ Block *CachingAllocator::FindFreeBlock(size_t size, uint8_t *org_ptr, uint32_t d if (block != nullptr) { GELOGI("Find block size = %zu", block->size); if (ShouldSplit(block, size)) { - return SplitBlock(block, size, *bin, device_id); + block = SplitBlock(block, size, *bin, device_id); + } + + if (block->ptr != nullptr) { + block->allocated = true; + allocated_blocks_[block->ptr] = block; + GELOGI("Malloc device id = %u, size= %zu", device_id, size); } } + return block; } return nullptr; @@ -267,20 +266,20 @@ Status CachingAllocator::TryExtendCache(size_t size, uint32_t device_id) { return ge::FAILED; } } - if (AddToBlockBin(memory_addr, memory_size) != ge::SUCCESS) { + if (AddToBlockBin(memory_addr, memory_size, device_id) != ge::SUCCESS) { (void)memory_allocator_->FreeMemory(memory_addr); return ge::FAILED; } return ge::SUCCESS; } -Status CachingAllocator::AddToBlockBin(uint8_t *ptr, size_t size) { +Status CachingAllocator::AddToBlockBin(uint8_t *ptr, size_t size, uint32_t device_id) { BlockBin *bin = GetBlockBin(size); if (bin == nullptr) { GELOGE(ge::FAILED, "Get block bin failed size = %zu", size); return ge::FAILED; } - Block *block = new (std::nothrow) Block(0, size, bin, nullptr); + Block *block = new (std::nothrow) Block(device_id, size, bin, nullptr); if (block == nullptr) { GELOGE(ge::FAILED, "Alloc block failed size = %zu", size); return ge::FAILED; @@ -339,5 +338,4 @@ void CachingAllocator::FreeBlockBins() { } } } - } // namespace ge diff --git a/src/ge/graph/manager/graph_caching_allocator.h b/src/ge/graph/manager/graph_caching_allocator.h index 75864ce7..94a5066a 100644 --- a/src/ge/graph/manager/graph_caching_allocator.h +++ b/src/ge/graph/manager/graph_caching_allocator.h @@ -32,7 +32,6 @@ #include "runtime/mem.h" namespace ge { - constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold constexpr size_t kKByteSize = 1024; @@ -69,6 +68,10 @@ class CachingAllocator { public: explicit CachingAllocator(rtMemType_t memory_type); + CachingAllocator(const CachingAllocator &) = delete; + + CachingAllocator &operator=(const CachingAllocator &) = delete; + virtual ~CachingAllocator() = default; /// @@ -137,9 +140,10 @@ class CachingAllocator { /// @brief add memory to right bin based on size /// @param [in] memory ptr /// @param [in] memory size + /// @param [in] device_id device id /// @return Status result of function /// - Status AddToBlockBin(uint8_t *ptr, size_t size); + Status AddToBlockBin(uint8_t *ptr, size_t size, uint32_t device_id); /// /// @ingroup ge_graph @@ -206,7 +210,5 @@ class CachingAllocator { // block bins by different block size BlockBin *free_block_bins_[kNumBins]; }; - -}; // namespace ge - +} // namespace ge #endif // GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_ diff --git a/src/ge/graph/manager/graph_manager.cc b/src/ge/graph/manager/graph_manager.cc index dd4855b6..bdf2143c 100644 --- a/src/ge/graph/manager/graph_manager.cc +++ b/src/ge/graph/manager/graph_manager.cc @@ -43,7 +43,9 @@ #include "graph/manager/util/rt_context_util.h" #include "graph/partition/dynamic_shape_partition.h" #include "graph/passes/addn_pass.h" +#include "graph/passes/bitcast_pass.h" #include "graph/passes/atomic_addr_clean_pass.h" +#include "graph/passes/attach_stream_label_pass.h" #include "graph/passes/cast_remove_pass.h" #include "graph/passes/common_subexpression_elimination_pass.h" #include "graph/passes/compile_nodes_pass.h" @@ -57,15 +59,18 @@ #include "graph/passes/flow_ctrl_pass.h" #include "graph/passes/hccl_group_pass.h" #include "graph/passes/hccl_memcpy_pass.h" -#include "graph/passes/identify_reference_pass.h" #include "graph/passes/identity_pass.h" +#include "graph/passes/input_output_connection_identify_pass.h" #include "graph/passes/iterator_op_pass.h" #include "graph/passes/link_gen_mask_nodes_pass.h" +#include "graph/passes/mark_graph_unknown_status_pass.h" #include "graph/passes/merge_pass.h" +#include "graph/passes/merge_to_stream_merge_pass.h" #include "graph/passes/multi_batch_pass.h" #include "graph/passes/next_iteration_pass.h" #include "graph/passes/permute_pass.h" #include "graph/passes/prune_pass.h" +#include "graph/passes/ref_identity_delete_op_pass.h" #include "graph/passes/replace_with_empty_const_pass.h" #include "graph/passes/reshape_recovery_pass.h" #include "graph/passes/reshape_remove_pass.h" @@ -74,7 +79,7 @@ #include "graph/passes/switch_data_edges_bypass.h" #include "graph/passes/switch_dead_branch_elimination.h" #include "graph/passes/switch_logic_remove_pass.h" -#include "graph/passes/switch_op_pass.h" +#include "graph/passes/switch_to_stream_switch_pass.h" #include "graph/passes/transop_breadth_fusion_pass.h" #include "graph/passes/transop_depth_fusion_pass.h" #include "graph/passes/transop_nearby_allreduce_fusion_pass.h" @@ -85,6 +90,7 @@ #include "graph/passes/variable_prepare_op_pass.h" #include "graph/passes/variable_ref_delete_op_pass.h" #include "graph/passes/variable_ref_useless_control_out_delete_pass.h" +#include "graph/passes/end_of_sequence_add_control_pass.h" #include "graph/utils/tensor_adapter.h" #include "inc/pass_manager.h" #include "init/gelib.h" @@ -347,12 +353,13 @@ Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_gr return SUCCESS; } -#define GM_RUN_AND_DUMP(name, func, ...) \ +#define GM_RUN_AND_DUMP_PERF(name, func, ...) \ do { \ - GE_RUN(GraphManager, func, __VA_ARGS__); \ + GE_RUN_PERF(GraphManager, func, __VA_ARGS__); \ GE_DUMP(compute_graph, "PreRunAfter" name); \ GELOGI("Run %s on graph %s(%u) success.", name, compute_graph->GetName().c_str(), graph_node->GetGraphId()); \ } while (0) + Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector &inputs, GeRootModelPtr &ge_root_model, uint64_t session_id) { GE_CHECK_NOTNULL(graph_node); @@ -365,30 +372,30 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vectorGetName().c_str()); GE_DUMP(compute_graph, "PreRunBegin"); - GM_RUN_AND_DUMP("OptimizeGraphPrepare", graph_optimize_.OptimizeOriginalGraphForQuantize, compute_graph); - GM_RUN_AND_DUMP("HandleSummaryOp", graph_optimize_.HandleSummaryOp, compute_graph); - GM_RUN_AND_DUMP("Prepare", graph_preparer_.PrepareDynShape, graph_node->GetGraph(), inputs, compute_graph, - session_id); - GM_RUN_AND_DUMP("OptimizeOriginalGraph", graph_optimize_.OptimizeOriginalGraph, compute_graph); + GM_RUN_AND_DUMP_PERF("OptimizeGraphPrepare", graph_optimize_.OptimizeOriginalGraphForQuantize, compute_graph); + GM_RUN_AND_DUMP_PERF("HandleSummaryOp", graph_optimize_.HandleSummaryOp, compute_graph); + GM_RUN_AND_DUMP_PERF("Prepare", graph_preparer_.PrepareDynShape, graph_node->GetGraph(), inputs, compute_graph, + session_id); + GM_RUN_AND_DUMP_PERF("OptimizeOriginalGraph", graph_optimize_.OptimizeOriginalGraph, compute_graph); - GM_RUN_AND_DUMP("PrepareRunningFormatRefiner", graph_preparer_.PrepareRunningFormatRefiner); - GM_RUN_AND_DUMP("RefineRunningFormat", graph_optimize_.OptimizeOriginalGraphJudgeInsert, compute_graph); + GM_RUN_AND_DUMP_PERF("PrepareRunningFormatRefiner", graph_preparer_.PrepareRunningFormatRefiner); + GM_RUN_AND_DUMP_PERF("RefineRunningFormat", graph_optimize_.OptimizeOriginalGraphJudgeInsert, compute_graph); GE_RUN(GraphManager, graph_preparer_.RecordAIPPInfo, compute_graph); if (IsTailingOptimization()) { - GM_RUN_AND_DUMP("OptimizeSwitchOp", graph_preparer_.SwitchOpOptimize, compute_graph); + GM_RUN_AND_DUMP_PERF("OptimizeSwitchOp", graph_preparer_.SwitchOpOptimize, compute_graph); } - GM_RUN_AND_DUMP("Optimize1", OptimizeStage1, compute_graph); - GM_RUN_AND_DUMP("InferShape2", compute_graph->InferShapeInNeed); + GM_RUN_AND_DUMP_PERF("Optimize1", OptimizeStage1, compute_graph); + GM_RUN_AND_DUMP_PERF("InferShape2", compute_graph->InferShapeInNeed); const char *unknown_shape_skip = std::getenv("EXPERIMENTAL_DYNAMIC_PARTITION"); if (unknown_shape_skip != nullptr) { PassManager graph_pass; GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::CtrlEdgeTransferPass", new (std::nothrow) CtrlEdgeTransferPass)) GE_CHK_STATUS_RET(graph_pass.Run(compute_graph)); } - - GM_RUN_AND_DUMP("OptimizeSubgraph", OptimizeSubgraph, graph_node, compute_graph, session_id); - GM_RUN_AND_DUMP("Optimize2", OptimizeStage2, compute_graph); - GM_RUN_AND_DUMP("Build", Build, graph_node, compute_graph, ge_root_model, session_id); + GE_CHK_STATUS_RET(graph_optimize_.IdentifyReference(compute_graph), "Identify reference failed."); + GM_RUN_AND_DUMP_PERF("OptimizeSubgraph", OptimizeSubgraph, graph_node, compute_graph, session_id); + GM_RUN_AND_DUMP_PERF("Optimize2", OptimizeStage2, compute_graph); + GM_RUN_AND_DUMP_PERF("Build", Build, graph_node, compute_graph, ge_root_model, session_id); // when set incre build, save om model and var manager GeModelPtr ge_model = nullptr; @@ -397,7 +404,7 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vectorSetRunFlag(false); @@ -634,7 +641,7 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vector &inputs, + GeRootModelPtr &ge_root_model, uint64_t session_id) { + // find graph + GraphNodePtr graph_node = nullptr; + Status ret = GetGraphNode(graph_id, graph_node); + if (ret != SUCCESS) { + GELOGE(ret, "[BuildGraph] graph not exist, graph_id = %u.", graph_id); + return ret; + } + + if (graph_node == nullptr) { + GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "[BuildGraph] graph node is NULL, graphId = %u.", graph_id); + return GE_GRAPH_GRAPH_NODE_NULL; + } + auto compute_graph = GraphUtils::GetComputeGraph(*graph_node->GetGraph()); + GE_CHECK_NOTNULL(compute_graph); + + GM_RUN_AND_DUMP_PERF("Prepare", graph_preparer_.PrepareDynShape, graph_node->GetGraph(), inputs, compute_graph, + session_id); + + for (auto &node : compute_graph->GetAllNodes()) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) { + vector node_vec = {node}; + + auto instance_ptr = ge::GELib::GetInstance(); + if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GE is not initialized"); + return GE_CLI_GE_NOT_INITIALIZED; + } + + OpsKernelInfoStorePtr kernel_info = + instance_ptr->OpsKernelManagerObj().GetOpsKernelInfoStore(op_desc->GetOpKernelLibName()); + if (kernel_info == nullptr) { + GELOGE(FAILED, "Get op kernel info store failed"); + return FAILED; + } + + ret = kernel_info->CompileOp(node_vec); + if (ret != SUCCESS) { + GELOGE(ret, "Compile op failed, op = %s, graph_id = %u.", op_desc->GetName().c_str(), graph_id); + return ret; + } + } + } + + GM_RUN_AND_DUMP_PERF("Build", Build, graph_node, compute_graph, ge_root_model, session_id); + + return SUCCESS; +} + Status GraphManager::BuildGraph(const GraphId &graph_id, const std::vector &inputs, GeRootModelPtr &ge_root_model, uint64_t session_id, bool async) { GELOGI("[BuildGraph] start to build graph, graph_id=%u.", graph_id); @@ -1613,7 +1672,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { SwitchDeadBranchElimination switch_dead_branch_elimination; SwitchLogicRemovePass switch_logic_remove_pass; MergePass merge_pass; - IdentifyReferencePass identify_reference_pass; CastRemovePass cast_remove_pass; TransposeTransDataPass transpose_transdata_pass; TransOpSymmetryEliminationPass symmetry_elimination_pass; @@ -1622,7 +1680,6 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); names_to_passes.emplace_back("SwitchLogicRemovePass", &switch_logic_remove_pass); names_to_passes.emplace_back("MergePass", &merge_pass); - names_to_passes.emplace_back("IdentifyReferencePass", &identify_reference_pass); names_to_passes.emplace_back("CastRemovePass", &cast_remove_pass); names_to_passes.emplace_back("TransposeTransDataPass", &transpose_transdata_pass); names_to_passes.emplace_back("TransOpSymmetryEliminationPass", &symmetry_elimination_pass); @@ -1638,14 +1695,32 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { GELOGE(ret, "Run passes when OptimizeStage1_2 failed, ret:%u.", ret); return ret; } + // Calculate Op/Fe constantfolding cost + uint64_t op_constant_folding_cost = 0; + for (auto &it : constant_folding_pass.GetOpConstantFoldingPerfStatistic()) { + op_constant_folding_cost += it.second.second; + GELOGI("The time cost of %s constant folding is [%lu] micro second, calls is %lu.", it.first.c_str(), + it.second.second, it.second.first); + } + GEEVENT("[GEPERFTRACE] The time cost of extern constant folding is [%lu] micro second.", op_constant_folding_cost); + for (auto &it : constant_folding_pass.GetGeConstantFoldingPerfStatistic()) { + op_constant_folding_cost += it.second.second; + GELOGI("The time cost of %s constant folding is [%lu] micro second, calls is %lu.", it.first.c_str(), + it.second.second, it.second.first); + } GraphUtils::DumpGEGraphToOnnx(*compute_graph, "OptimizeStage1_2"); PassManager graph_pass; - // the prune pass should between SwtichPass and SwitchOpPass + // the prune pass should between SwitchPass and SwitchToStreamSwitchPass GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::PrunePass", new (std::nothrow) PrunePass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::NextIterationPass", new (std::nothrow) NextIterationPass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::ControlTriggerPass", new (std::nothrow) ControlTriggerPass)) - GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::SwitchOpPass", new (std::nothrow) SwitchOpPass)) + GE_CHK_STATUS_RET( + graph_pass.AddPass("OptimizeStage1_3::MergeToStreamMergePass", new (std::nothrow) MergeToStreamMergePass)) + GE_CHK_STATUS_RET( + graph_pass.AddPass("OptimizeStage1_3::SwitchToStreamSwitchPass", new (std::nothrow) SwitchToStreamSwitchPass)) + GE_CHK_STATUS_RET( + graph_pass.AddPass("OptimizeStage1_3::AttachStreamLabelPass", new (std::nothrow) AttachStreamLabelPass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::IteratorOpPass", new (std::nothrow) IteratorOpPass)) GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeStage1_3::VariableRefUselessControlOutDeletePass", new (std::nothrow) VariableRefUselessControlOutDeletePass)) @@ -1657,10 +1732,9 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { GELOGE(ret, "Run passes when OptimizeStage1_3 failed, ret:%u.", ret); return ret; } - NamesToPass identity_remove_pass; GE_TIMESTAMP_START(identity_remove_pass); - IdentityPass identity_force_pass(true); // after SwitchOpPass + IdentityPass identity_force_pass(false); // after SwitchToStreamSwitchPass identity_remove_pass.emplace_back("IdentityPass", &identity_force_pass); ret = GEPass(compute_graph).Run(identity_remove_pass); GE_TIMESTAMP_END(identity_remove_pass, "GraphPrepare::IdentityRemovePass"); @@ -1692,9 +1766,11 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { ConstantFoldingPass constant_folding_pass; ReshapeRemovePass reshape_remove_pass; CondRemovePass condition_remove_pass; + BitcastPass bitcast_pass; names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); names_to_passes.emplace_back("CondRemovePass", &condition_remove_pass); + names_to_passes.emplace_back("BitcastPass", &bitcast_pass); GE_TIMESTAMP_START(names_to_passes); ret = GEPass(compute_graph).Run(names_to_passes); GE_TIMESTAMP_END(names_to_passes, "OptimizeStage2::MergedGraphNameToPasses"); @@ -1720,6 +1796,8 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::ControlAttrOptimize::MultiBatchPass", new (std::nothrow) MultiBatchPass)) + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::AfterMergePasses::RefIdentityDeleteOpPass", + new (std::nothrow) RefIdentityDeleteOpPass)) // the value of the attr is the original variable name the ref-variable ref from. // The attr will be used when allocating memory, // the node marked attr will be output to a variable instead of new-allocated memory. @@ -1729,19 +1807,31 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { new (std::nothrow) VariableRefDeleteOpPass)) GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::ControlAttrOptimize::CompileNodesPass", new (std::nothrow) CompileNodesPass)) + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass( + "OptimizeStage2::AfterMergePasses::MarkGraphUnknownStatusPass", new (std::nothrow) MarkGraphUnknownStatusPass)) + GE_CHK_STATUS_RET( + pass_for_control_attr_optimize.AddPass("OptimizeStage2::AfterMergePasses::InputOutputConnectionIdentifyPass", + new (std::nothrow) InputOutputConnectionIdentifyPass)) // When the input node to be cleared is after a `Data` node, the atomic-clean-node should not be inserted. // So The ComputeGraph should not delete nodes after `AtomicAddrCleanPass` // to prevent unexpected deletion of nodes after a `Data` node GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::AfterMergePasses::AtomicAddrCleanPass", new (std::nothrow) AtomicAddrCleanPass)) + GE_CHK_STATUS_RET( + pass_for_control_attr_optimize.AddPass("OptimizeStage2::AfterMergePasses::" + "EndOfSequenceAddControlPass", + new (std::nothrow) EndOfSequenceAddControlPass)) const char *unknown_shape_skip = std::getenv("EXPERIMENTAL_DYNAMIC_PARTITION"); if (unknown_shape_skip == nullptr) { // SubgraphPass solves memory_assign_conflicts by insert MemcpyAsync node, which depends on multi attrs and // graph-structure. So try not to add new pass after SubgraphPass. GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::ControlAttrOptimize::SubgraphPass", - new (std::nothrow) SubgraphPass)); + new (std::nothrow) SubgraphPass)) } + // AttachStreamLabelPass modifies attr without changing structure of compute_graph + GE_CHK_STATUS_RET(pass_for_control_attr_optimize.AddPass("OptimizeStage2::ControlAttrOptimize::AttachStreamLabelPass", + new (std::nothrow) AttachStreamLabelPass)) GE_TIMESTAMP_START(pass_for_control_attr_optimize); ret = pass_for_control_attr_optimize.Run(compute_graph); @@ -1751,6 +1841,14 @@ Status GraphManager::OptimizeStage2(ge::ComputeGraphPtr &compute_graph) { return ret; } + // After while sub graph handle, mark all node rw type + auto result = graph_optimize_.HandleMemoryRWConflict(compute_graph); + if (result != SUCCESS) { + GELOGW( + "Mark node rw type failed. It will take some effect on memory_assign_conflicts handling." + "Please pay attention to it."); + } + ChangeConstTypeWhenTraining(compute_graph); ret = compute_graph->TopologicalSorting(); @@ -1777,8 +1875,6 @@ Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_gra GEPass ge_passes_for_shape(compute_graph); NamesToPass names_to_passes_for_shape; - IdentifyReferencePass identify_reference_pass; - names_to_passes_for_shape.emplace_back("IdentifyReferencePass", &identify_reference_pass); CastRemovePass cast_remove_pass; names_to_passes_for_shape.emplace_back("CastRemovePass", &cast_remove_pass); TransposeTransDataPass transpose_transdata_pass; @@ -1821,6 +1917,8 @@ Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_gra after_merge_fusion_passes.AddPass("VariableRefDeleteOpPass", new (std::nothrow) VariableRefDeleteOpPass)); GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass("SameTransdataBreadthFusionPass", new (std::nothrow) SameTransdataBreadthFusionPass)); + GE_CHK_STATUS_RET( + after_merge_fusion_passes.AddPass("MarkGraphUnknownStatusPass", new (std::nothrow) MarkGraphUnknownStatusPass)); GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass("AtomicAddrCleanPass", new (std::nothrow) AtomicAddrCleanPass)); GE_CHK_STATUS_RET(after_merge_fusion_passes.AddPass( "LinkGenMaskNodesPass", new (std::nothrow) LinkGenMaskNodesPass(options_.stream_max_parallel_num))); @@ -1866,7 +1964,10 @@ Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_gra GE_CHK_STATUS_RET(ret, "Remove isolated Constant failed, ret:%d.", ret); PassManager pass_for_optimize; - GE_CHK_STATUS_RET(pass_for_optimize.AddPass("SubgraphPass", new (std::nothrow) SubgraphPass)); + const char *unknown_shape_skip = std::getenv("EXPERIMENTAL_DYNAMIC_PARTITION"); + if (unknown_shape_skip == nullptr) { + GE_CHK_STATUS_RET(pass_for_optimize.AddPass("SubgraphPass", new (std::nothrow) SubgraphPass)); + } GE_CHK_STATUS_RET(pass_for_optimize.AddPass("MultiBatchPass", new (std::nothrow) MultiBatchPass)); GE_CHK_STATUS_RET(pass_for_optimize.AddPass("CompileNodesPass", new (std::nothrow) CompileNodesPass)); GE_TIMESTAMP_START(pass_for_optimize); @@ -1906,7 +2007,7 @@ Status GraphManager::LoadGraphAsync(const GeRootModelPtr &ge_root_model, const G GE_CHECK_NOTNULL(graph_node->graph_run_async_listener_); Status ret = GraphLoader::LoadModelOnline(model_id_info.model_id, ge_root_model, graph_node->graph_run_async_listener_); - GE_TIMESTAMP_END(LoadGraph, "GraphManager::LoadGraphAsync"); + GE_TIMESTAMP_EVENT_END(LoadGraph, "GraphManager::LoadGraphAsync"); if (ret != SUCCESS) { GELOGE(ret, "[LoadGraphAsync] LoadGraphAsync Failed"); graph_node->SetRunFlag(false); @@ -2309,21 +2410,21 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra GELOGE(FAILED, "failed get dynamic shape partitioned flag on partitioned graph."); return FAILED; } - GE_TIMESTAMP_END(GraphPartitionDynamicShape, "OptimizeSubgraph::GraphPartitionDynamicShape"); + GE_TIMESTAMP_EVENT_END(GraphPartitionDynamicShape, "OptimizeSubgraph::GraphPartitionDynamicShape"); GE_TIMESTAMP_START(GraphPartition); ret = graph_partitioner_.Partition(compute_graph, GraphPartitioner::kPartitioning); if (ret != SUCCESS) { GELOGE(ret, "Graph partition Failed"); return ret; } - GE_TIMESTAMP_END(GraphPartition, "OptimizeSubgraph::Partition1"); + GE_TIMESTAMP_EVENT_END(GraphPartition, "OptimizeSubgraph::Partition1"); GE_TIMESTAMP_START(SetSubgraph); ret = SetSubgraph(session_id, compute_graph); if (ret != SUCCESS) { GELOGE(ret, "Graph set subgraph Failed"); return ret; } - GE_TIMESTAMP_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph"); + GE_TIMESTAMP_EVENT_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph"); ComputeGraphPtr merged_compute_graph = nullptr; std::vector merged_sub_graph_list; @@ -2342,7 +2443,7 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra sub_graph->SetSessionID(session_id); sub_graph->SetGraphID(graph_node->GetGraphId()); } - GE_TIMESTAMP_END(MergeSubgraph, "OptimizeSubgraph::MergeSubGraph"); + GE_TIMESTAMP_EVENT_END(MergeSubgraph, "OptimizeSubgraph::MergeSubGraph"); GE_DUMP(merged_compute_graph, "mergedComputeGraph"); compute_graph = merged_compute_graph; if (!AttrUtils::SetBool(*compute_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, dynamic_shape_partitioned)) { @@ -2368,8 +2469,7 @@ Status GraphManager::Build(const GraphNodePtr &graph_node, ComputeGraphPtr &comp } bool is_always_dump = false; - PropertiesManager &properties_manager = PropertiesManager::Instance(); - if (!properties_manager.GetDumpOutputPath().empty()) { + if (!PropertiesManager::Instance().GetDumpProperties(session_id).GetDumpPath().empty()) { is_always_dump = true; } diff --git a/src/ge/graph/manager/graph_manager.h b/src/ge/graph/manager/graph_manager.h index 8ab28316..681efac8 100644 --- a/src/ge/graph/manager/graph_manager.h +++ b/src/ge/graph/manager/graph_manager.h @@ -102,6 +102,9 @@ class GraphManager { ge::Status BuildGraph(const GraphId &graph_id, const std::vector &inputs, GeRootModelPtr &models, uint64_t session_id = 0, bool async = false); + Status BuildGraphForUnregisteredOp(const GraphId &graph_id, const std::vector &inputs, + GeRootModelPtr &ge_root_model, uint64_t session_id); + /// /// @ingroup ge_graph /// @brief Save extra attribute to Model @@ -327,6 +330,6 @@ class GraphManager { std::mutex run_mutex_; }; -}; // namespace ge +} // namespace ge #endif // GE_GRAPH_MANAGER_GRAPH_MANAGER_H_ diff --git a/src/ge/graph/manager/graph_mem_allocator.h b/src/ge/graph/manager/graph_mem_allocator.h index 7bf82897..e4eeded3 100644 --- a/src/ge/graph/manager/graph_mem_allocator.h +++ b/src/ge/graph/manager/graph_mem_allocator.h @@ -190,6 +190,6 @@ class MemManager { std::map caching_allocator_map_; std::recursive_mutex allocator_mutex_; }; -}; // namespace ge +} // namespace ge #endif // GE_GRAPH_MANAGER_GRAPH_MEM_ALLOCATOR_H_ diff --git a/src/ge/graph/manager/graph_var_manager.cc b/src/ge/graph/manager/graph_var_manager.cc index 2982eb89..7ca0224b 100644 --- a/src/ge/graph/manager/graph_var_manager.cc +++ b/src/ge/graph/manager/graph_var_manager.cc @@ -91,7 +91,7 @@ ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTen std::string var_key = VarKey(var_name, tensor_desc); GELOGD("VarResource::SaveVarAddr, var_key = %s", var_key.c_str()); if (var_addr_mgr_map_.count(var_key) == 0) { - uint64_t logic_address = VarManager::Instance(0)->GetVarMemLogicBase() + + uint64_t logic_address = VarManager::Instance(session_id_)->GetVarMemLogicBase() + reinterpret_cast(reinterpret_cast(address)); GELOGI("SaveVarAddr node_name %s, tensor_desc format %s, type %s.", var_name.c_str(), TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(), @@ -274,7 +274,7 @@ MemResource::MemResource() : total_size_(0), var_mem_size_(0) {} Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uint64_t session_id, size_t &mem_offset) { size = (size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize; uint64_t real_size = size; - total_size_ = VarManager::Instance(0)->GetVarMemMaxSize(); + total_size_ = VarManager::Instance(session_id)->GetVarMemMaxSize(); if (total_size_ < var_mem_size_) { GELOGE(PARAM_INVALID, "total_size_: %lu is smaller than var_mem_size_: %lu", total_size_, var_mem_size_); return PARAM_INVALID; @@ -684,7 +684,8 @@ uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_ty if (mem_base == nullptr) { return nullptr; } - uint8_t *mem_addr = logic_addr + reinterpret_cast(mem_base) - VarManager::Instance(0)->GetVarMemLogicBase(); + uint8_t *mem_addr = + logic_addr + reinterpret_cast(mem_base) - VarManager::Instance(session_id_)->GetVarMemLogicBase(); return mem_addr; } diff --git a/src/ge/graph/manager/graph_var_manager.h b/src/ge/graph/manager/graph_var_manager.h index be839eee..2142d906 100644 --- a/src/ge/graph/manager/graph_var_manager.h +++ b/src/ge/graph/manager/graph_var_manager.h @@ -309,5 +309,5 @@ class VarManagerPool { std::mutex var_manager_mutex_; map var_manager_map_; }; -}; // namespace ge +} // namespace ge #endif // GE_GRAPH_MANAGER_GRAPH_VAR_MANAGER_H_ diff --git a/src/ge/graph/manager/host_mem_manager.cc b/src/ge/graph/manager/host_mem_manager.cc new file mode 100644 index 00000000..1d35f7af --- /dev/null +++ b/src/ge/graph/manager/host_mem_manager.cc @@ -0,0 +1,86 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/manager/host_mem_manager.h" + +#include + +#include "graph/utils/tensor_utils.h" + +namespace ge { +Status HostMemoryAllocator::Allocate(std::size_t memory_size, uint8_t *memory_addr) { + GELOGI("HostMemoryAllocator::MallocMemory size= %zu.", memory_size); + return SUCCESS; +} + +Status HostMemoryAllocator::DeAllocate(uint8_t *memory_addr) { + if (rtFreeHost(memory_addr) != RT_ERROR_NONE) { + GELOGE(GE_GRAPH_FREE_FAILED, "MemoryAllocator::Free memory failed."); + return GE_GRAPH_FREE_FAILED; + } + memory_addr = nullptr; + return ge::SUCCESS; +} + +HostMemManager &HostMemManager::Instance() { + static HostMemManager mem_manager; + return mem_manager; +} + +Status HostMemManager::Initialize() { + std::lock_guard lock(mutex_); + allocator_ = std::unique_ptr(new (std::nothrow) HostMemoryAllocator()); + if (allocator_ == nullptr) { + GELOGE(GE_GRAPH_MALLOC_FAILED, "Host mem allocator init failed!"); + return GE_GRAPH_MALLOC_FAILED; + } + return SUCCESS; +} + +void HostMemManager::Finalize() noexcept { + std::lock_guard lock(mutex_); + + for (const auto &it : var_memory_base_map_) { + if (allocator_->DeAllocate(it.second.address) != SUCCESS) { + GELOGW("Host %s mem deAllocator failed!", it.first.c_str()); + } + } + var_memory_base_map_.clear(); +} + +Status HostMemManager::MallocMemoryForHostVar(const string &op_name, uint64_t tensor_size, uint8_t *&var_addr) { + std::lock_guard lock(mutex_); + if (var_memory_base_map_.find(op_name) != var_memory_base_map_.end()) { + GELOGI("Host mem for variable %s has been malloced", op_name.c_str()); + return SUCCESS; + } + GE_CHECK_NOTNULL(allocator_); + GE_CHK_STATUS(allocator_->Allocate(tensor_size, var_addr)); + HostMemInfo info(var_addr, tensor_size); + var_memory_base_map_[op_name] = info; + return SUCCESS; +} + +Status HostMemManager::QueryVarMemInfo(const string &op_name, uint64_t &base_addr, uint64_t &data_size) { + if (var_memory_base_map_.find(op_name) == var_memory_base_map_.end()) { + GELOGE(INTERNAL_ERROR, "Find host base base_addr failed,node name:%s!", op_name.c_str()); + return INTERNAL_ERROR; + } + base_addr = reinterpret_cast(reinterpret_cast(var_memory_base_map_[op_name].address)); + data_size = var_memory_base_map_[op_name].data_size; + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/manager/host_mem_manager.h b/src/ge/graph/manager/host_mem_manager.h new file mode 100644 index 00000000..3a5a0602 --- /dev/null +++ b/src/ge/graph/manager/host_mem_manager.h @@ -0,0 +1,73 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_MANAGER_HOST_VAR_MANAGER_H_ +#define GE_GRAPH_MANAGER_HOST_VAR_MANAGER_H_ + +#include +#include +#include +#include +#include +#include +#include +#include + +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/ge_types.h" +#include "framework/common/l2_cache_optimize.h" +#include "graph/ge_tensor.h" +#include "graph/op_desc.h" +#include "graph/tensor.h" +#include "runtime/mem.h" + +namespace ge { +class HostMemoryAllocator { + public: + ~HostMemoryAllocator() = default; + + Status Allocate(std::size_t size, uint8_t *memory_addr); + Status DeAllocate(uint8_t *memory_addr); +}; + +struct HostMemInfo { + uint8_t *address; + uint64_t data_size; + HostMemInfo() : address(nullptr), data_size(0) {} + HostMemInfo(uint8_t *addr, uint64_t size) : address(addr), data_size(size) {} +}; + +class HostMemManager { + public: + HostMemManager() = default; + ~HostMemManager() { Finalize(); } + HostMemManager(const HostMemManager &) = delete; + HostMemManager &operator=(const HostMemManager &) = delete; + + static HostMemManager &Instance(); + Status Initialize(); + void Finalize() noexcept; + Status MallocMemoryForHostVar(const string &op_name, uint64_t tensor_size, uint8_t *&var_addr); + Status QueryVarMemInfo(const string &op_name, uint64_t &base_addr, uint64_t &data_size); + + private: + std::unordered_map var_memory_base_map_; + std::unique_ptr allocator_; + mutable std::recursive_mutex mutex_; +}; +} // namespace ge + +#endif // GE_GRAPH_MANAGER_HOST_VAR_MANAGER_H_ diff --git a/src/ge/graph/manager/memory_api.cc b/src/ge/graph/manager/memory_api.cc new file mode 100644 index 00000000..0a98e983 --- /dev/null +++ b/src/ge/graph/manager/memory_api.cc @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "framework/memory/memory_api.h" + +#include + +#include "graph/manager/graph_mem_allocator.h" +#include "graph/manager/host_mem_manager.h" +#include "graph/manager/rdma_pool_allocator.h" +#include "hccl/base.h" +#include "hccl/hcom.h" + +namespace ge { +Status InitRdmaPool(size_t size, rtMemType_t mem_type) { + GELOGD("InitRdmaPool in"); + return MemManager::Instance().RdmaPoolInstance(mem_type).InitMemory(size); +} + +Status RdmaRemoteRegister(const std::vector &var_info, rtMemType_t mem_type) { + GELOGD("Start to register rdma memory with host var size %zu", var_info.size()); + uint64_t device_base = 0; + uint64_t device_size = 0; + GE_CHK_STATUS_RET(MemManager::Instance().RdmaPoolInstance(mem_type).GetBaseAddr(device_base, device_size)); + return SUCCESS; +} + +Status GetVarBaseAddrAndSize(const string &var_name, uint64_t &base_addr, uint64_t &var_size) { + GELOGD("GetVarBaseAddrAndSize in"); + return HostMemManager::Instance().QueryVarMemInfo(var_name, base_addr, var_size); +} +} // namespace ge \ No newline at end of file diff --git a/src/ge/graph/manager/model_manager/event_manager.h b/src/ge/graph/manager/model_manager/event_manager.h index bdf0633a..a20afead 100644 --- a/src/ge/graph/manager/model_manager/event_manager.h +++ b/src/ge/graph/manager/model_manager/event_manager.h @@ -92,6 +92,6 @@ class EventManager { std::vector event_list_; bool inited_; uint32_t current_idx_; -}; // EventManager -}; // namespace ge +}; // EventManager +} // namespace ge #endif // GE_GRAPH_MANAGER_MODEL_MANAGER_EVENT_MANAGER_H_ diff --git a/src/ge/graph/manager/rdma_pool_allocator.cc b/src/ge/graph/manager/rdma_pool_allocator.cc new file mode 100644 index 00000000..1daeafb8 --- /dev/null +++ b/src/ge/graph/manager/rdma_pool_allocator.cc @@ -0,0 +1,179 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/manager/rdma_pool_allocator.h" +#include "framework/common/debug/ge_log.h" +#include "graph/manager/graph_mem_allocator.h" + +namespace { +const size_t kAlignedSize = 512; +const float kSplitThreshold = 0.5; + +inline size_t GetAlignedBlockSize(size_t size) { + if (size == 0) { + return kAlignedSize; + } + return kAlignedSize * ((size + kAlignedSize - 1) / kAlignedSize); +} + +inline bool ShouldSplit(const ge::Block *block, size_t size) { + return static_cast(size) <= (static_cast(block->size) * kSplitThreshold); +} + +inline bool CanMerge(ge::Block *block) { return block != nullptr && !block->allocated; } +} // namespace + +namespace ge { +RdmaPoolAllocator::RdmaPoolAllocator(rtMemType_t memory_type) + : memory_type_(memory_type), block_bin_(BlockBin([](const Block *left, const Block *right) { + if (left->size != right->size) { + return left->size < right->size; + } + return reinterpret_cast(left->ptr) < reinterpret_cast(right->ptr); + })) {} + +Status RdmaPoolAllocator::Initialize() { + memory_allocator_ = MemManager::Instance(memory_type_); + if (memory_allocator_ == nullptr) { + return ge::FAILED; + } + return ge::SUCCESS; +} +void RdmaPoolAllocator::Finalize() { + for (auto it = allocated_blocks_.begin(); it != allocated_blocks_.end();) { + auto block = it->second; + allocated_blocks_.erase(it); + delete block; + } + for (auto it = block_bin_.begin(); it != block_bin_.end();) { + auto block = *it; + block_bin_.erase(it); + delete block; + } + + if (rdma_base_addr_ != nullptr) { + if (memory_allocator_->FreeMemory(rdma_base_addr_) != SUCCESS) { + GELOGW("Free rdma pool memory failed"); + } + } +} + +Status RdmaPoolAllocator::InitMemory(size_t mem_size, uint32_t device_id) { + if (rdma_base_addr_ != nullptr) { + GELOGE(GE_MULTI_INIT, "Rdma pool has been malloced"); + return GE_MULTI_INIT; + } + const std::string purpose = "Memory for rdma pool."; + std::lock_guard lock(mutex_); + rdma_base_addr_ = memory_allocator_->MallocMemory(purpose, mem_size, device_id); + if (rdma_base_addr_ == nullptr) { + GELOGE(GE_GRAPH_MALLOC_FAILED, "Rdma pool memory malloc failed"); + return GE_GRAPH_MALLOC_FAILED; + } + rdma_mem_size_ = mem_size; + // Init with a base block. + auto *base_block = new (std::nothrow) Block(device_id, mem_size, rdma_base_addr_); + if (base_block == nullptr) { + GELOGE(GE_GRAPH_MALLOC_FAILED, "Block malloc failed"); + return GE_GRAPH_MALLOC_FAILED; + } + block_bin_.insert(base_block); + return SUCCESS; +} + +uint8_t *RdmaPoolAllocator::Malloc(size_t size, uint32_t device_id) { + auto aligned_size = GetAlignedBlockSize(size); + Block key(device_id, aligned_size, nullptr); + std::lock_guard lock(mutex_); + auto it = block_bin_.lower_bound(&key); + if (it != block_bin_.end()) { + Block *block = *it; + block_bin_.erase(it); + block->allocated = true; + if (block->ptr == nullptr) { + GELOGE(INTERNAL_ERROR, "Rdmapool memory address is nullptr."); + return nullptr; + } + allocated_blocks_.emplace(block->ptr, block); + GELOGI("Find block size = %zu", block->size); + + if (ShouldSplit(block, aligned_size)) { + auto *new_block = + new (std::nothrow) Block(device_id, block->size - aligned_size, nullptr, block->ptr + aligned_size); + if (new_block == nullptr) { + GELOGW("Block split failed"); + return block->ptr; + } + new_block->next = block->next; + if (block->next != nullptr) { + block->next->prev = new_block; + } + new_block->prev = block; + block->next = new_block; + block->size = aligned_size; + block_bin_.insert(new_block); + } + return block->ptr; + } + return nullptr; +} + +Status RdmaPoolAllocator::Free(uint8_t *memory_addr, uint32_t device_id) { + GELOGI("Free device id = %u", device_id); + if (memory_addr == nullptr) { + GELOGE(GE_GRAPH_FREE_FAILED, "Invalid memory pointer"); + return GE_GRAPH_FREE_FAILED; + } + + std::lock_guard lock(mutex_); + auto it = allocated_blocks_.find(memory_addr); + if (it == allocated_blocks_.end()) { + GELOGE(PARAM_INVALID, "Invalid memory pointer"); + return PARAM_INVALID; + } + Block *block = it->second; + block->allocated = false; + allocated_blocks_.erase(it); + block_bin_.insert(block); + // Each time merge with its pre and next. + MergeBlockNearby(block, block->next); + MergeBlockNearby(block->prev, block); + return SUCCESS; +} + +void RdmaPoolAllocator::MergeBlockNearby(Block *pre_block, Block *block) { + if (!(CanMerge(pre_block) && CanMerge(block))) { + return; + } + pre_block->size += block->size; + pre_block->next = block->next; + if (block->next != nullptr) { + block->next->prev = pre_block; + } + block_bin_.erase(block); + delete block; +} + +Status RdmaPoolAllocator::GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size) { + if (rdma_base_addr_ == nullptr) { + GELOGE(INTERNAL_ERROR, "Rdma base addr is nullptr."); + return INTERNAL_ERROR; + } + base_addr = reinterpret_cast(reinterpret_cast(rdma_base_addr_)); + mem_size = rdma_mem_size_; + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/manager/rdma_pool_allocator.h b/src/ge/graph/manager/rdma_pool_allocator.h new file mode 100644 index 00000000..59d33916 --- /dev/null +++ b/src/ge/graph/manager/rdma_pool_allocator.h @@ -0,0 +1,71 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_MANAGER_RDMA_POOL_ALLOCATOR_H_ +#define GE_GRAPH_MANAGER_RDMA_POOL_ALLOCATOR_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "framework/common/ge_inner_error_codes.h" +#include "graph/manager/block_memory.h" +#include "graph/node.h" +#include "runtime/mem.h" + +namespace ge { +class MemoryAllocator; + +class RdmaPoolAllocator { + public: + explicit RdmaPoolAllocator(rtMemType_t memory_type); + + RdmaPoolAllocator(const RdmaPoolAllocator &) = delete; + + RdmaPoolAllocator &operator=(const RdmaPoolAllocator &) = delete; + + ~RdmaPoolAllocator() { Finalize(); } + + Status Initialize(); + void Finalize(); + + Status InitMemory(size_t mem_size, uint32_t device_id = 0); + + uint8_t *Malloc(size_t size, uint32_t device_id = 0); + + Status Free(uint8_t *memory_addr, uint32_t device_id = 0); + + Status GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size); + + private: + void MergeBlockNearby(Block *pre_block, Block *block); + + rtMemType_t memory_type_; + size_t rdma_mem_size_ = 0; // Total rdma memory size to be allocated. + uint8_t *rdma_base_addr_ = nullptr; + MemoryAllocator *memory_allocator_ = nullptr; + BlockBin block_bin_; // Save all rdma blocks. + std::unordered_map allocated_blocks_; + // lock around all operations + mutable std::recursive_mutex mutex_; +}; +} // namespace ge + +#endif // GE_GRAPH_MANAGER_RDMA_POOL_ALLOCATOR_H_ diff --git a/src/ge/graph/manager/trans_var_data_utils.cc b/src/ge/graph/manager/trans_var_data_utils.cc index e8444c53..60a0d0db 100644 --- a/src/ge/graph/manager/trans_var_data_utils.cc +++ b/src/ge/graph/manager/trans_var_data_utils.cc @@ -397,10 +397,11 @@ Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeT uint8_t *src_addr = nullptr; GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr)); - uint8_t *mem_addr = src_addr - - static_cast(reinterpret_cast(VarManager::Instance(0)->GetVarMemLogicBase())) + - static_cast( - reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); + uint8_t *mem_addr = + src_addr - + static_cast(reinterpret_cast(VarManager::Instance(session_id)->GetVarMemLogicBase())) + + static_cast( + reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); GE_CHK_RT_RET(rtMallocHost(reinterpret_cast(host_addr), src_tensor_size)); GE_CHK_RT_RET(rtMemcpy(*host_addr, src_tensor_size, mem_addr, src_tensor_size, RT_MEMCPY_DEVICE_TO_HOST)); @@ -413,10 +414,11 @@ Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8 const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { uint8_t *dst_addr = nullptr; GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, dst_tensor_desc, &dst_addr)); - uint8_t *mem_addr = dst_addr - - static_cast(reinterpret_cast(VarManager::Instance(0)->GetVarMemLogicBase())) + - static_cast( - reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); + uint8_t *mem_addr = + dst_addr - + static_cast(reinterpret_cast(VarManager::Instance(session_id)->GetVarMemLogicBase())) + + static_cast( + reinterpret_cast(VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM))); GE_CHK_RT_RET(rtMemcpy(mem_addr, addr_size, host_addr, addr_size, RT_MEMCPY_HOST_TO_DEVICE)); GELOGI("SyncTensorToDevice var_name %s, addr_size %u", var_name.c_str(), addr_size); @@ -442,7 +444,7 @@ Status TransVarDataUtils::TransAllVarData(const vector &variable_nodes, rtError_t rt_ret = rtCtxSetCurrent(ctx); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Failed to set context, error_code is: 0x%X.", rt_ret); - return RT_FAILED; + return RT_ERROR_TO_GE_STATUS(rt_ret); } uint32_t allocated_graph_id = 0; Status ret = VarManager::Instance(session_id)->GetAllocatedGraphId(node->GetName(), allocated_graph_id); diff --git a/src/ge/graph/manager/util/hcom_util.cc b/src/ge/graph/manager/util/hcom_util.cc index 4f6fe591..5f31c982 100644 --- a/src/ge/graph/manager/util/hcom_util.cc +++ b/src/ge/graph/manager/util/hcom_util.cc @@ -24,7 +24,6 @@ #include "graph/utils/type_utils.h" namespace ge { - Status HcomOmeUtil::GetHcclDataType(const ge::ConstOpDescPtr &op_desc, std::vector &kernel_hccl_infos) { GE_CHECK_NOTNULL(op_desc); @@ -101,6 +100,12 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(i)); GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetInputDescPtr(i), input_size), "get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); + // dynamic shape hccl op get size from output tensor desc + if (op_desc->HasAttr(ATTR_NAME_IS_UNKNOWN_SHAPE)) { + GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(i)); + GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetOutputDescPtr(i), input_size), + "get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); + } GE_IF_BOOL_EXEC( op_desc->GetType() == HCOMREDUCESCATTER, int32_t rank_size = 0; @@ -114,6 +119,8 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType total_size = total_size + block_size; continue;); int64_t shape_size = op_desc->GetInputDescPtr(i)->GetShape().GetShapeSize(); + GELOGD("hcom util node %s inputsize %ld, shapesize %ld, datasize %d.", op_desc->GetName().c_str(), input_size, + shape_size, size); GE_CHK_STATUS_RET(ge::CheckInt64Int32MulOverflow(shape_size, size), "Product of shape size and size beyond INT64_MAX"); GE_IF_BOOL_EXEC(is_allgather, block_size = shape_size * size;); diff --git a/src/ge/graph/manager/util/hcom_util.h b/src/ge/graph/manager/util/hcom_util.h index 40aac3e5..e31e3ef0 100644 --- a/src/ge/graph/manager/util/hcom_util.h +++ b/src/ge/graph/manager/util/hcom_util.h @@ -144,8 +144,6 @@ class HcomOmeUtil { /// static Status GetHorovodInputs(const ge::ConstOpDescPtr &op_desc, std::vector &kernel_hccl_infos); - - private: /// /// @ingroup domi_ome /// @brief GetHcomCount @@ -154,6 +152,8 @@ class HcomOmeUtil { /// static Status GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType_t data_type, bool is_allgather, int &count); + + private: /// /// @ingroup domi_ome /// @brief GetHorovodCount diff --git a/src/ge/graph/manager/util/rt_context_util.cc b/src/ge/graph/manager/util/rt_context_util.cc index 05120f6a..e6344539 100644 --- a/src/ge/graph/manager/util/rt_context_util.cc +++ b/src/ge/graph/manager/util/rt_context_util.cc @@ -19,13 +19,30 @@ #include "framework/common/debug/ge_log.h" namespace ge { -void RtContextUtil::AddrtContext(rtContext_t context) { rtContexts_.emplace_back(context); } +void RtContextUtil::AddRtContext(uint64_t session_id, rtContext_t context) { + std::lock_guard lock(ctx_mutex_); + rt_contexts_[session_id].emplace_back(context); +} + +void RtContextUtil::DestroyRtContexts(uint64_t session_id) { + std::lock_guard lock(ctx_mutex_); + auto &contexts = rt_contexts_[session_id]; + DestroyRtContexts(session_id, contexts); +} + +void RtContextUtil::DestroyAllRtContexts() { + std::lock_guard lock(ctx_mutex_); + for (auto &ctx_pair : rt_contexts_) { + DestroyRtContexts(ctx_pair.first, ctx_pair.second); + } + rt_contexts_.clear(); +} -void RtContextUtil::DestroyrtContexts() { - GELOGI("The size of runtime context handle is %zu.", rtContexts_.size()); - for (auto &rtContext : rtContexts_) { +void RtContextUtil::DestroyRtContexts(uint64_t session_id, std::vector &contexts) { + GELOGI("Runtime context handle number of session %lu is %zu.", session_id, contexts.size()); + for (auto &rtContext : contexts) { (void)rtCtxDestroy(rtContext); } - rtContexts_.clear(); + contexts.clear(); } } // namespace ge diff --git a/src/ge/graph/manager/util/rt_context_util.h b/src/ge/graph/manager/util/rt_context_util.h index 93db9882..58cc0803 100644 --- a/src/ge/graph/manager/util/rt_context_util.h +++ b/src/ge/graph/manager/util/rt_context_util.h @@ -18,6 +18,8 @@ #define GE_GRAPH_MANAGER_UTIL_RT_CONTEXT_UTIL_H_ #include +#include +#include #include "runtime/context.h" @@ -29,13 +31,14 @@ class RtContextUtil { return instance; } - void AddrtContext(rtContext_t context); + void AddRtContext(uint64_t session_id, rtContext_t context); const rtContext_t GetNormalModeContext() const { return before_prerun_ctx_; } void SetNormalModeContext(rtContext_t context) { before_prerun_ctx_ = context; } - void DestroyrtContexts(); + void DestroyRtContexts(uint64_t session_id); + void DestroyAllRtContexts(); RtContextUtil &operator=(const RtContextUtil &) = delete; RtContextUtil(const RtContextUtil &RtContextUtil) = delete; @@ -44,8 +47,12 @@ class RtContextUtil { RtContextUtil() = default; ~RtContextUtil() {} - std::vector rtContexts_; + void DestroyRtContexts(uint64_t session_id, std::vector &contexts); + + std::map> rt_contexts_; rtContext_t before_prerun_ctx_ = nullptr; + + std::mutex ctx_mutex_; }; } // namespace ge diff --git a/src/ge/graph/optimize/graph_optimize.cc b/src/ge/graph/optimize/graph_optimize.cc index b42c2e01..09acae33 100644 --- a/src/ge/graph/optimize/graph_optimize.cc +++ b/src/ge/graph/optimize/graph_optimize.cc @@ -299,4 +299,36 @@ void GraphOptimize::TranFrameOp(ComputeGraphPtr &compute_graph) { } } } + +Status GraphOptimize::IdentifyReference(ComputeGraphPtr &compute_graph) { + for (auto &node : compute_graph->GetAllNodes()) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto input_name_index = op_desc->GetAllInputName(); + bool is_ref = false; + for (const auto &name_index : input_name_index) { + const int out_index = op_desc->GetOutputIndexByName(name_index.first); + if (out_index != -1) { + auto input_desc = op_desc->GetInputDesc(name_index.second); + input_desc.SetRefPortByIndex({name_index.second}); + op_desc->UpdateInputDesc(name_index.second, input_desc); + GELOGI("SetRefPort: set op[%s] input desc[%u-%s] ref.", op_desc->GetName().c_str(), name_index.second, + name_index.first.c_str()); + auto output_desc = op_desc->GetOutputDesc(static_cast(out_index)); + output_desc.SetRefPortByIndex({name_index.second}); + op_desc->UpdateOutputDesc(static_cast(out_index), output_desc); + GELOGI("SetRefPort: set op[%s] output desc[%u-%s] ref.", op_desc->GetName().c_str(), out_index, + name_index.first.c_str()); + is_ref = true; + } + } + if (is_ref) { + AttrUtils::SetBool(op_desc, ATTR_NAME_REFERENCE, is_ref); + GELOGI("param [node] %s is reference node, set attribute %s to be true.", node->GetName().c_str(), + ATTR_NAME_REFERENCE.c_str()); + } + } + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/optimize/graph_optimize.h b/src/ge/graph/optimize/graph_optimize.h index 72709932..f3eb2009 100644 --- a/src/ge/graph/optimize/graph_optimize.h +++ b/src/ge/graph/optimize/graph_optimize.h @@ -60,13 +60,20 @@ class GraphOptimize { const std::map> &GetSummaryOutputIndexes() const { return summary_output_indexes_; - } + } // lint !e1073 void ClearSummaryOutputIndexes() { summary_output_indexes_.clear(); } // handle summary node before preRun graph Status HandleSummaryOp(ComputeGraphPtr &compute_graph); + // Identify reference node before optimize subgraph + Status IdentifyReference(ComputeGraphPtr &compute_graph); + + Status HandleMemoryRWConflict(ComputeGraphPtr &compute_graph); + + Status CheckRWConflict(ComputeGraphPtr &compute_graph, bool &has_conflict); + void TranFrameOp(ComputeGraphPtr &compute_graph); private: @@ -85,5 +92,5 @@ class GraphOptimize { std::map> summary_output_indexes_ = {}; std::string func_bin_path_; }; -}; // namespace ge +} // namespace ge #endif // GE_GRAPH_OPTIMIZE_GRAPH_OPTIMIZE_H_ diff --git a/src/ge/graph/optimize/mem_rw_conflict_optimize.cc b/src/ge/graph/optimize/mem_rw_conflict_optimize.cc new file mode 100644 index 00000000..f75565ba --- /dev/null +++ b/src/ge/graph/optimize/mem_rw_conflict_optimize.cc @@ -0,0 +1,712 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include +#include + +#include "common/ge/ge_util.h" +#include "graph/common/omg_util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/optimize/graph_optimize.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" + +namespace { +using namespace ge; +const int kIdentityAnchorIndex = 0; +// rw type of input. +enum class InputRWType { + kReadOnly, // Normal op input only read + kWriteable, // Op like Assign/ApplyMomentum + kScopeWriteable, // Op like hcom_allreduce, it will modify input ,but not expect take effect on pre ouput + kInvalidRWType +}; +// rw type of output +enum class OutputRWType { + kReadOnly, // 1.const output 2.not ref output but has several peer output + kSoftRead, // not ref output but only has one output node + kWriteable, // ref output. Like Assign/ApplyMomentum + kInvalidRWType +}; +// input and output rw_type of one node. key is anchor_idx, value is rw_type +struct NodeInputOutputRWType { + map input_rw_type_map; + map output_rw_type_map; +}; +// input and output rw_type of node in current graph +map node_rwtype_map_; + +/// +/// @brief Convert input rw_type enum to string. For log print. +/// @param rw_type +/// @return rw_type_name +/// +static std::string InputRWTypeToSerialString(InputRWType rw_type) { + const static char *names[4] = {"ReadOnly", "Writeable", "ScopeWriteable", "InvalidRWType"}; + return names[static_cast(rw_type)]; +} + +/// +/// @brief Convert output rw_type enum to string. For log print. +/// @param rw_type +/// @return rw_type_name +/// +static std::string OutputRWTypeToSerialString(OutputRWType rw_type) { + const static char *names[4] = {"ReadOnly", "SoftRead", "Writeable", "InvalidRWType"}; + return names[static_cast(rw_type)]; +} + +OutputRWType GetSingleNodeOutputRWTypeByIndex(const Node &node, uint32_t index) { + auto op_desc = node.GetOpDesc(); + if (op_desc == nullptr) { + return OutputRWType::kInvalidRWType; + } + if (op_desc->GetType() == VARIABLE) { + return OutputRWType::kWriteable; + } + // check if it is ref output + auto input_names = op_desc->GetAllInputName(); + for (auto &input_name_2_idx : input_names) { + if (op_desc->GetOutputNameByIndex(index) == input_name_2_idx.first) { + return OutputRWType::kWriteable; + } + } + // check if it is ref switch + std::string type; + if ((node.GetType() == FRAMEWORK_OP_TYPE) && AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type) && + (type == REFSWITCH)) { + return OutputRWType::kWriteable; + } + + if (op_desc->GetType() == CONSTANT || op_desc->GetType() == CONSTANTOP) { + return OutputRWType::kReadOnly; + } + auto out_data_anchor = node.GetOutDataAnchor(index); + if (out_data_anchor == nullptr) { + return OutputRWType::kInvalidRWType; + } + if (out_data_anchor->GetPeerInDataNodesSize() > 1) { + return OutputRWType::kReadOnly; + } else { + return OutputRWType::kSoftRead; + } +} + +/// +/// @brief Get input rw_type of one node with sub graph. It will return rw_type after solve conflict scene. +/// @param rw_type_set +/// @return +/// +InputRWType GetInputRwTypeInConflict(std::set rw_type_set) { + // for input rw type calc + int total_rw_type = 0; + for (auto rw : rw_type_set) { + total_rw_type += rw; + } + switch (total_rw_type) { + case 0: + return InputRWType::kReadOnly; + case 2: + return InputRWType::kScopeWriteable; + case 3: + return InputRWType::kWriteable; + case 5: + return InputRWType::kInvalidRWType; + default: + return InputRWType::kInvalidRWType; + } +} + +NodePtr CreateIdentityAfterSrcNode(const Node &src_node, int out_anchor_idx) { + if (src_node.GetOpDesc() == nullptr) { + return nullptr; + } + static std::atomic identity_num(0); + auto next_num = identity_num.fetch_add(1); + // 1. create new identity op desc + string identity_name = src_node.GetName() + "_" + IDENTITY + std::to_string(next_num); + auto identity_opdesc = MakeShared(identity_name, IDENTITY); + if (identity_opdesc == nullptr) { + GELOGE(OUT_OF_MEMORY, "Failed to insert identity node, name %s", identity_name.c_str()); + return nullptr; + } + auto data_desc = src_node.GetOpDesc()->GetOutputDesc(out_anchor_idx); + // 2. add input_desc & output_desc for new identity + Status ret = identity_opdesc->AddInputDesc(data_desc); + if (ret != SUCCESS) { + GELOGE(ret, "Add Input desc failed for new identity %s.", identity_name.c_str()); + return nullptr; + } + ret = identity_opdesc->AddOutputDesc(data_desc); + if (ret != SUCCESS) { + GELOGE(ret, "Add Output desc failed for new Identity %s.", identity_name.c_str()); + return nullptr; + } + GELOGI("Insert new Identity node %s.", identity_name.c_str()); + auto graph = src_node.GetOwnerComputeGraph(); + if (graph == nullptr) { + GELOGE(GRAPH_PARAM_INVALID, "Node %s owner compute graph is null.", src_node.GetName().c_str()); + return nullptr; + } + return graph->AddNode(identity_opdesc); +} + +OutputRWType GetOutputRWTypeByIndex(const Node &node, uint32_t index) { + auto op_desc = node.GetOpDesc(); + if (op_desc == nullptr) { + return OutputRWType::kInvalidRWType; + } + if (op_desc->GetType() == WHILE) { + return OutputRWType::kSoftRead; + } + vector subgraph_names = op_desc->GetSubgraphInstanceNames(); + if (subgraph_names.empty()) { + // single node without sub graph + return GetSingleNodeOutputRWTypeByIndex(node, index); + } else { + // node with sub graph + auto output_node_vec = NodeUtils::GetSubgraphOutputNodes(node); + auto output_rw_type = OutputRWType::kInvalidRWType; + if (output_node_vec.size() == 1) { + // find rw type from map. + auto iter = node_rwtype_map_.find(output_node_vec.at(0)->GetName()); + if (iter == node_rwtype_map_.end()) { + GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.", + output_node_vec.at(0)->GetName().c_str()); + return OutputRWType::kInvalidRWType; + } + auto index_2_output_rw_type = iter->second.output_rw_type_map.find(index); + if (index_2_output_rw_type == iter->second.output_rw_type_map.end()) { + GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.", + output_node_vec.at(0)->GetName().c_str()); + return OutputRWType::kInvalidRWType; + } + output_rw_type = index_2_output_rw_type->second; + } else { + output_rw_type = OutputRWType::kSoftRead; + } + // check peer input + auto out_data_anchor = node.GetOutDataAnchor(index); + if (out_data_anchor == nullptr) { + return OutputRWType::kInvalidRWType; + } + if (out_data_anchor->GetPeerInDataNodesSize() > 1) { + return OutputRWType::kReadOnly; + } else { + return output_rw_type; + } + } +} + +InputRWType GetSingleNodeInputRWTypeByIndex(const Node &node, uint32_t index) { + auto op_desc = node.GetOpDesc(); + if (op_desc == nullptr) { + return InputRWType::kInvalidRWType; + } + if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMALLGATHER || + op_desc->GetType() == HCOMREDUCESCATTER) { + return InputRWType::kScopeWriteable; + } + // check if it is ref input + auto output_names = op_desc->GetAllOutputName(); + for (auto &output_name_2_idx : output_names) { + if (op_desc->GetInputNameByIndex(index) == output_name_2_idx.first) { + return InputRWType::kWriteable; + } + } + // check if it is ref switch todo + std::string type; + if ((node.GetType() == FRAMEWORK_OP_TYPE) && (AttrUtils::GetStr(op_desc, ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) && + (type == REFSWITCH) && (index == 0)) { + return InputRWType::kWriteable; + } + + return InputRWType::kReadOnly; +} + +InputRWType GetInputRWTypeByIndex(const Node &node, uint32_t index) { + auto op_desc = node.GetOpDesc(); + if (op_desc == nullptr) { + return InputRWType::kInvalidRWType; + } + if (op_desc->GetType() == WHILE) { + return InputRWType::kScopeWriteable; + } + vector subgraph_names = op_desc->GetSubgraphInstanceNames(); + if (subgraph_names.empty()) { + // single node without sub graph + return GetSingleNodeInputRWTypeByIndex(node, index); + } else { + // node with sub graph + std::set node_rw_type_set; + auto data_node_vec = NodeUtils::GetSubgraphDataNodesByIndex(node, index); + // get all input data node in subgraph + std::set anchor_rw_type_set; + for (const auto &data_node : data_node_vec) { + // Data only has 1 out data anchor. Here just take first out data anchor. And index 0 is valid. + auto out_data_anchor = data_node->GetOutDataAnchor(0); + if (out_data_anchor == nullptr) { + continue; + } + auto data_op_desc = data_node->GetOpDesc(); + if (data_op_desc == nullptr) { + continue; + } + // find rw type from map. + auto iter = node_rwtype_map_.find(data_op_desc->GetName()); + if (iter == node_rwtype_map_.end()) { + GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.", + data_op_desc->GetName().c_str()); + return InputRWType::kInvalidRWType; + } + auto input_rw_type = iter->second.input_rw_type_map.find(out_data_anchor->GetIdx()); + if (input_rw_type == iter->second.input_rw_type_map.end()) { + GELOGW("Can not find rw type of node %s from map.It could take some effect on following preprocess.", + data_op_desc->GetName().c_str()); + return InputRWType::kInvalidRWType; + } + anchor_rw_type_set.emplace(static_cast(input_rw_type->second)); + } + return GetInputRwTypeInConflict(anchor_rw_type_set); + } +} + +/// +/// @brief Reverse traversal all subgraph and mark rw_type for Data/Netoutput. +/// @param sub_graph_vecgs +/// +Status MarkRWTypeForSubgraph(vector> sub_graph_vec) { + for (auto iter = sub_graph_vec.rbegin(); iter != sub_graph_vec.rend(); ++iter) { + auto parent_node = (*iter)->GetParentNode(); + if (parent_node == nullptr) { + GELOGD("Current sub graph has no parent node. Ignore it."); + continue; + } + if (parent_node->GetType() == WHILE) { + continue; + } + for (const auto &node : (*iter)->GetDirectNode()) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + if (node->GetType() == DATA) { + // calc all input_rw_type of peer output , as input_rw_type of DATA. Index 0 is valid. + auto out_data_anchor = node->GetOutDataAnchor(0); + GE_CHECK_NOTNULL(out_data_anchor); + std::set anchor_rw_type_set; + for (const auto peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_CHECK_NOTNULL(peer_in_anchor); + auto peer_in_node = peer_in_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_in_node); + auto input_rw_type = GetInputRWTypeByIndex(*peer_in_node, peer_in_anchor->GetIdx()); + GELOGD("Input rw type of Node %s %dth input anchor is %s", peer_in_node->GetName().c_str(), + peer_in_anchor->GetIdx(), InputRWTypeToSerialString(input_rw_type).c_str()); + anchor_rw_type_set.emplace(static_cast(input_rw_type)); + } + auto anchor_rw_type = GetInputRwTypeInConflict(anchor_rw_type_set); + GELOGD("Input rw type of Node %s is %s", node->GetName().c_str(), + InputRWTypeToSerialString(anchor_rw_type).c_str()); + map input_rw_type_map{std::make_pair(0, anchor_rw_type)}; + NodeInputOutputRWType data_rw_type{input_rw_type_map}; + node_rwtype_map_.emplace(std::make_pair(node->GetName(), data_rw_type)); + } + + if (node->GetType() == NETOUTPUT) { + // calc all output_rw_type of peer input , as output_rw_type of DATA + map output_rw_type_map; + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + GE_CHECK_NOTNULL(in_data_anchor); + auto pre_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(pre_out_anchor); + auto pre_node = pre_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(pre_node); + + auto pre_output_rw_type = GetOutputRWTypeByIndex(*pre_node, pre_out_anchor->GetIdx()); + GELOGD("Output rw type of Node %s %dth output anchor is %s", pre_node->GetName().c_str(), + pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str()); + if (pre_output_rw_type == OutputRWType::kWriteable) { + // insert identity + auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_anchor->GetIdx()); + GE_CHECK_NOTNULL(identity_node); + auto ret = GraphUtils::InsertNodeBetweenDataAnchors(pre_out_anchor, in_data_anchor, identity_node); + if (ret != SUCCESS) { + GELOGE(ret, "Fail to insert identity"); + return ret; + } + GELOGI("InsertNode %s between %s and %s successfully.", identity_node->GetName().c_str(), + pre_node->GetName().c_str(), node->GetName().c_str()); + } + output_rw_type_map.emplace(std::make_pair(in_data_anchor->GetIdx(), OutputRWType::kSoftRead)); + } + NodeInputOutputRWType output_rw_type{{}, output_rw_type_map}; + node_rwtype_map_.emplace(std::make_pair(node->GetName(), output_rw_type)); + } + } + } + return SUCCESS; +} + +/// +/// @brief Check identity is near subgraph. +/// Eg. As output of Data node in subgraph +/// or as input of Netoutput of subgraph +/// or as input of one node with subgraph +/// or as output of one node with subgraph +/// @param node +/// @return is_near_subgraph +/// +bool CheckIdentityIsNearSubgraph(const NodePtr &node) { + for (const auto &in_node : node->GetInDataNodes()) { + auto in_node_opdesc = in_node->GetOpDesc(); + if (in_node_opdesc == nullptr) { + continue; + } + // near entrance of subgraph + if (in_node->GetType() == DATA && NodeUtils::IsSubgraphInput(in_node)) { + return true; + } + // near subgraph + if (!in_node_opdesc->GetSubgraphInstanceNames().empty()) { + return true; + } + } + + for (const auto &out_node : node->GetOutDataNodes()) { + auto out_node_opdesc = out_node->GetOpDesc(); + if (out_node_opdesc == nullptr) { + continue; + } + // near output of subgraph + if (out_node->GetType() == NETOUTPUT && NodeUtils::IsSubgraphOutput(out_node)) { + return true; + } + // near subgraph + if (!out_node_opdesc->GetSubgraphInstanceNames().empty()) { + return true; + } + } + return false; +} +enum ConflictResult { DO_NOTHING, WRONG_GRAPH, INSERT_IDENTITY }; +vector> output_2_input_rwtype = {{DO_NOTHING, WRONG_GRAPH, INSERT_IDENTITY}, + {DO_NOTHING, WRONG_GRAPH, DO_NOTHING}, + {DO_NOTHING, DO_NOTHING, INSERT_IDENTITY}}; +ConflictResult GetConflictResultBetweenNode(const OutputRWType output_rw_type, const InputRWType input_rw_type) { + if (output_rw_type == OutputRWType::kInvalidRWType || input_rw_type == InputRWType::kInvalidRWType) { + return WRONG_GRAPH; + } + auto n = static_cast(output_rw_type); + auto m = static_cast(input_rw_type); + // no need to check index or container, because container and index is all defined. + return output_2_input_rwtype[n][m]; +} + +/// +/// @brief Keep identity_node which near subgraph or has multi output +/// @param node +/// @return +/// +Status RemoveNoUseIdentity(const NodePtr &node) { + if (node->GetInDataNodes().empty()) { + return SUCCESS; + } + if (node->GetOutDataNodesSize() > 1) { + return SUCCESS; + } + if (node->GetOutDataNodesSize() == 1 && node->GetOutDataNodes().at(0)->GetType() == STREAMMERGE) { + return SUCCESS; + } + if (CheckIdentityIsNearSubgraph(node)) { + return SUCCESS; + } + auto out_data_anchor = node->GetOutDataAnchor(kIdentityAnchorIndex); + GE_CHECK_NOTNULL(out_data_anchor); + GE_CHECK_NOTNULL(node->GetInDataAnchor(kIdentityAnchorIndex)); + auto pre_out_anchor = node->GetInDataAnchor(kIdentityAnchorIndex)->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(pre_out_anchor); + auto pre_node = pre_out_anchor->GetOwnerNode(); + auto pre_output_rw_type = GetOutputRWTypeByIndex(*pre_node, pre_out_anchor->GetIdx()); + + ConflictResult conflict_result = WRONG_GRAPH; + if (!out_data_anchor->GetPeerInDataAnchors().empty()) { + auto peer_in_data_anchor = out_data_anchor->GetPeerInDataAnchors().at(0); + GE_CHECK_NOTNULL(peer_in_data_anchor); + auto peer_node = peer_in_data_anchor->GetOwnerNode(); + auto peer_input_rw_type = GetInputRWTypeByIndex(*peer_node, peer_in_data_anchor->GetIdx()); + + GELOGD("Pre Node %s %dth output rw type is %s, peer node %s %dth input rw type is %s.", pre_node->GetName().c_str(), + pre_out_anchor->GetIdx(), OutputRWTypeToSerialString(pre_output_rw_type).c_str(), + peer_node->GetName().c_str(), peer_in_data_anchor->GetIdx(), + InputRWTypeToSerialString(peer_input_rw_type).c_str()); + conflict_result = GetConflictResultBetweenNode(pre_output_rw_type, peer_input_rw_type); + } else { + // identity node has no out data node, it can be removed + conflict_result = DO_NOTHING; + } + + switch (conflict_result) { + case DO_NOTHING: { + GELOGI("No need insert Identity. Node %s need to remove.", node->GetName().c_str()); + auto ret = GraphUtils::IsolateNode(node, {0}); + if (ret != SUCCESS) { + GELOGE(ret, "Fail to isolate node %s.", node->GetName().c_str()); + return ret; + } + ret = GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node); + if (ret != SUCCESS) { + GELOGE(ret, "Fail to isolate node %s.", node->GetName().c_str()); + return ret; + } + GELOGI("Pre node is %s and %dth output rw type is %s. Isolate and remove Identity node %s.", + pre_node->GetName().c_str(), pre_out_anchor->GetIdx(), + OutputRWTypeToSerialString(pre_output_rw_type).c_str(), node->GetName().c_str()); + return SUCCESS; + } + default: + return SUCCESS; + } + return SUCCESS; +} + +Status SplitIdentity(const NodePtr &node) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (op_desc->GetType() != IDENTITY) { + return SUCCESS; + } + auto out_data_anchor = node->GetOutDataAnchor(kIdentityAnchorIndex); + GE_CHECK_NOTNULL(out_data_anchor); + if (out_data_anchor->GetPeerInDataNodesSize() <= 1) { + return SUCCESS; + } + // get pre node and next node of identity + GE_CHECK_NOTNULL(node->GetInDataAnchor(kIdentityAnchorIndex)); + auto pre_out_data_anchor = node->GetInDataAnchor(kIdentityAnchorIndex)->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(pre_out_data_anchor); + auto pre_node = pre_out_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(pre_node); + for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + // 1.check peer in node RW type. + GE_CHECK_NOTNULL(peer_in_data_anchor); + auto peer_in_data_node = peer_in_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_in_data_node); + auto input_rw_type = GetInputRWTypeByIndex(*peer_in_data_node, peer_in_data_anchor->GetIdx()); + auto ret = out_data_anchor->Unlink(peer_in_data_anchor); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to unlink from %s %dth out to %s.", node->GetName().c_str(), out_data_anchor->GetIdx(), + peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); + return ret; + } + if (input_rw_type == InputRWType::kScopeWriteable || input_rw_type == InputRWType::kWriteable) { + auto identity_node = CreateIdentityAfterSrcNode(*pre_node, pre_out_data_anchor->GetIdx()); + GE_CHECK_NOTNULL(identity_node); + ret = GraphUtils::AddEdge(pre_out_data_anchor, identity_node->GetInDataAnchor(kIdentityAnchorIndex)); + if (ret != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s", + pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), + peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); + return INTERNAL_ERROR; + } + ret = GraphUtils::AddEdge(identity_node->GetOutDataAnchor(kIdentityAnchorIndex), peer_in_data_anchor); + if (ret != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to insert Identity between node %s and %s", + pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), + peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); + return INTERNAL_ERROR; + } + // 2. copy in-control-edge from dst to Identity + GraphUtils::CopyInCtrlEdges(peer_in_data_node, identity_node); + GELOGI("Node %s intput rw type is %s. Insert Identity between %s and %s.", peer_in_data_node->GetName().c_str(), + InputRWTypeToSerialString(input_rw_type).c_str(), pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), + peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); + } else { + // link identity pre node to next node directly + // todo control edge + if (GraphUtils::AddEdge(pre_out_data_anchor, peer_in_data_anchor) != SUCCESS) { + GELOGW("Fail to link data edge from node %s to %s.", pre_out_data_anchor->GetOwnerNode()->GetName().c_str(), + peer_in_data_anchor->GetOwnerNode()->GetName().c_str()); + return FAILED; + } + GELOGI("Node %s intput rw type is %s, link data edge from Identity input node %s to out node %s directly.", + peer_in_data_node->GetName().c_str(), InputRWTypeToSerialString(input_rw_type).c_str(), + pre_node->GetName().c_str(), peer_in_data_node->GetName().c_str()); + } + } + // 2.isolate Identity node with no data output + if (node->GetOutDataNodesSize() == 0) { + auto ret = GraphUtils::IsolateNode(node, {}); + if (ret != SUCCESS) { + GELOGE(FAILED, "IsolateAndDelete identity node %s.", node->GetName().c_str()); + return FAILED; + } + ret = GraphUtils::RemoveNodeWithoutRelink(node->GetOwnerComputeGraph(), node); + if (ret != SUCCESS) { + GELOGE(FAILED, "IsolateAndDelete identity node %s.", node->GetName().c_str()); + return FAILED; + } + GELOGI("IsolateAndDelete identity node %s.", node->GetName().c_str()); + } + return SUCCESS; +} + +Status InsertIdentityAsNeeded(const NodePtr &node) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (node->GetOutDataNodesSize() == 0 || node->GetInDataNodes().empty()) { + return SUCCESS; + } + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + GE_CHECK_NOTNULL(out_data_anchor); + auto output_rw_type = GetOutputRWTypeByIndex(*node, out_data_anchor->GetIdx()); + for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_CHECK_NOTNULL(peer_in_data_anchor); + auto peer_in_node = peer_in_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_in_node); + auto input_rw_type = GetInputRWTypeByIndex(*peer_in_node, peer_in_data_anchor->GetIdx()); + GELOGD("Node %s output rw type is %s, Node %s input rw type is %s", node->GetName().c_str(), + OutputRWTypeToSerialString(output_rw_type).c_str(), peer_in_node->GetName().c_str(), + InputRWTypeToSerialString(input_rw_type).c_str()); + auto conflict_result = GetConflictResultBetweenNode(output_rw_type, input_rw_type); + switch (conflict_result) { + case DO_NOTHING: + case WRONG_GRAPH: + GELOGD("No need insert Identity."); + continue; + case INSERT_IDENTITY: + auto identity_node = CreateIdentityAfterSrcNode(*node, out_data_anchor->GetIdx()); + if (identity_node == nullptr) { + GELOGE(FAILED, "Create identity node failed."); + return FAILED; + } + auto ret = GraphUtils::InsertNodeBetweenDataAnchors(out_data_anchor, peer_in_data_anchor, identity_node); + if (ret != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to insert reshape between node %s and %s", node->GetName().c_str(), + peer_in_node->GetName().c_str()); + return INTERNAL_ERROR; + } + GELOGI("Insert Identity between %s and %s to handle memory conflict.", node->GetName().c_str(), + peer_in_node->GetName().c_str()); + continue; + } + } + } + return SUCCESS; +} +} // namespace + +namespace ge { +Status GraphOptimize::CheckRWConflict(ComputeGraphPtr &compute_graph, bool &has_conflict) { + node_rwtype_map_.clear(); + auto sub_graph_vec = compute_graph->GetAllSubgraphs(); + if (sub_graph_vec.empty()) { + GELOGD("No sub graph here. Ignore memory conflict handle."); + return SUCCESS; + } + // 1.loop all subgraph, mark rw type from inside to outside + Status ret = MarkRWTypeForSubgraph(sub_graph_vec); + if (ret != SUCCESS) { + GELOGE(ret, "Fail to mark rw type for subgraph."); + return ret; + } + has_conflict = false; + for (const auto &node : compute_graph->GetAllNodes()) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (node->GetOutDataNodesSize() == 0) { + return SUCCESS; + } + if (node->GetType() == WHILE) { + return SUCCESS; + } + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + GE_CHECK_NOTNULL(out_data_anchor); + auto output_rw_type = GetOutputRWTypeByIndex(*node, out_data_anchor->GetIdx()); + for (const auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_CHECK_NOTNULL(peer_in_data_anchor); + auto peer_in_node = peer_in_data_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_in_node); + if (peer_in_node->GetType() == WHILE) { + return SUCCESS; + } + auto input_rw_type = GetInputRWTypeByIndex(*peer_in_node, peer_in_data_anchor->GetIdx()); + auto conflict_result = GetConflictResultBetweenNode(output_rw_type, input_rw_type); + switch (conflict_result) { + case DO_NOTHING: + GELOGD("No rw conflict."); + continue; + case WRONG_GRAPH: + has_conflict = true; + GELOGI("Node %s output rw type is %s, next node %s input_rw_type is %s.It is wrong graph.", + node->GetName().c_str(), OutputRWTypeToSerialString(output_rw_type).c_str(), + peer_in_node->GetName().c_str(), InputRWTypeToSerialString(input_rw_type).c_str()); + return SUCCESS; + case INSERT_IDENTITY: + GELOGD("There is rw conflict. It will handle later."); + continue; + } + } + } + } + return SUCCESS; +} +Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) { + node_rwtype_map_.clear(); + auto sub_graph_vec = compute_graph->GetAllSubgraphs(); + if (sub_graph_vec.empty()) { + GELOGD("No sub graph here. Ignore memory conflict handle."); + return SUCCESS; + } + GE_DUMP(compute_graph, "BeforeHandleMemConflict"); + // 1.loop all subgraph, mark rw type from inside to outside + Status ret = MarkRWTypeForSubgraph(sub_graph_vec); + if (ret != SUCCESS) { + GELOGE(ret, "Fail to mark rw type for subgraph."); + return ret; + } + // 2.loop all node, including node in subgraph and handle memory rw conflict + for (auto &node : compute_graph->GetAllNodes()) { + // ignore data / netoutput of subgraph + if (node->GetType() == DATA && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { + continue; + } + if (node->GetType() == NETOUTPUT && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { + continue; + } + if (node->GetType() == IDENTITY || node->GetType() == READVARIABLEOP) { + // split identity + ret = SplitIdentity(node); + if (ret != SUCCESS) { + GELOGE(ret, "Fail to split identity node %s.", node->GetName().c_str()); + return ret; + } + // remove no use identity + ret = RemoveNoUseIdentity(node); + if (ret != SUCCESS) { + GELOGE(ret, "Fail to remove useless identity node %s.", node->GetName().c_str()); + return ret; + } + } + // insert Identity + ret = InsertIdentityAsNeeded(node); + if (ret != SUCCESS) { + GELOGE(ret, "Fail to insert Identity node."); + return ret; + } + } + GE_DUMP(compute_graph, "AfterHandleMemConflict"); + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/optimize/summary_optimize.cc b/src/ge/graph/optimize/summary_optimize.cc index 8b38d602..a8325da3 100644 --- a/src/ge/graph/optimize/summary_optimize.cc +++ b/src/ge/graph/optimize/summary_optimize.cc @@ -80,7 +80,8 @@ Status GraphOptimize::HandleSummaryOp(ComputeGraphPtr &compute_graph) { del_nodes.emplace_back(node_ptr); } } - summary_output_indexes_.insert({compute_graph->GetGraphID(), summary_output_indexes}); + GE_IF_BOOL_EXEC(!summary_output_indexes.empty(), + summary_output_indexes_.insert({compute_graph->GetGraphID(), summary_output_indexes})); // add output nodes for summary std::vector> out_nodes_info; diff --git a/src/ge/graph/partition/dynamic_shape_partition.cc b/src/ge/graph/partition/dynamic_shape_partition.cc index 6a396eef..903159b9 100644 --- a/src/ge/graph/partition/dynamic_shape_partition.cc +++ b/src/ge/graph/partition/dynamic_shape_partition.cc @@ -62,15 +62,16 @@ Status DynamicShapePartitioner::Partition() { } GELOGD("Start dynamic shape partition graph %s.", root_graph_->GetName().c_str()); - REQUIRE_SUCCESS(MarkUnknownShapeNodes(), "Failed mark unknown shape nodes."); + REQUIRE_SUCCESS(MarkUnknownShapeNodes(), "Failed mark unknown shape nodes, root grah name:%s.", + root_graph_->GetName().c_str()); if (unknown_shape_nodes_.empty()) { GELOGD("Skip dynamic shape partition of graph %s as all nodes are known shape.", root_graph_->GetName().c_str()); REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, false), - "Failed set dynamic shape partitioned flag on root graph."); + "Failed set dynamic shape partitioned flag on root graph %s.", root_graph_->GetName().c_str()); return SUCCESS; } REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, true), - "Failed set dynamic shape partitioned flag on root graph."); + "Failed set dynamic shape partitioned flag on root graph %s.", root_graph_->GetName().c_str()); DumpGraph("_Before_DSP"); auto status = PartitionImpl(); @@ -102,26 +103,32 @@ void DynamicShapePartitioner::PruneUniqueClusters() { if (unique_clusters_.count(cluster) != 0) { continue; } - unique_clusters_.insert(cluster); + if (unique_clusters_.insert(cluster).second) { + sorted_unique_clusters_.emplace_back(cluster); + } } + auto comp_func = [](std::shared_ptr clu_a, std::shared_ptr clu_b) -> bool { + return clu_a->Id() < clu_b->Id(); + }; + std::sort(sorted_unique_clusters_.begin(), sorted_unique_clusters_.end(), comp_func); } Status DynamicShapePartitioner::BuildPartitionFrame() { - for (auto cluster : unique_clusters_) { + for (const auto &cluster : sorted_unique_clusters_) { REQUIRE_SUCCESS(cluster->BuildFrame(), "Failed build frame of cluster[%lu].", cluster->Id()); } return SUCCESS; } Status DynamicShapePartitioner::CombinePartitionFrame() { - for (auto cluster : unique_clusters_) { + for (const auto &cluster : sorted_unique_clusters_) { REQUIRE_SUCCESS(cluster->CombinePartitionFrame(), "Failed combine frame of cluster[%lu].", cluster->Id()); } return SUCCESS; } Status DynamicShapePartitioner::BuildPartitionSubgraph() { - for (auto cluster : unique_clusters_) { + for (const auto &cluster : sorted_unique_clusters_) { REQUIRE_SUCCESS(cluster->BuildPartitionSubgraph(), "Failed build subgraph of cluster[%lu].", cluster->Id()); } return SUCCESS; @@ -134,10 +141,10 @@ std::string DynamicShapePartitioner::DebugString() const { size_t netoutput = 0; std::stringstream ss; ss << "All unknown shape nodes:" << std::endl; - for (auto node : unknown_shape_nodes_) { + for (const auto &node : unknown_shape_nodes_) { ss << " [" << node->GetName() << "](" << node->GetType() << ")" << std::endl; } - for (auto cluster : unique_clusters_) { + for (const auto &cluster : unique_clusters_) { if (cluster->IsUnknownShape()) { unknown++; } else if (cluster->IsKnownShape()) { @@ -150,7 +157,7 @@ std::string DynamicShapePartitioner::DebugString() const { } ss << "All clusters:" << unique_clusters_.size() << ", data:" << data << ", known:" << known << ", unknown:" << unknown << ", netoutput:" << netoutput << std::endl; - for (auto cluster : unique_clusters_) { + for (const auto &cluster : unique_clusters_) { ss << " " << cluster->DebugString() << std::endl; } return ss.str(); @@ -158,25 +165,25 @@ std::string DynamicShapePartitioner::DebugString() const { void DynamicShapePartitioner::DumpGraph(const std::string &suffix) { GraphUtils::DumpGEGraphToOnnx(*root_graph_, root_graph_->GetName() + suffix); - for (auto sub_graph : root_graph_->GetAllSubgraphs()) { + for (const auto &sub_graph : root_graph_->GetAllSubgraphs()) { GraphUtils::DumpGEGraphToOnnx(*sub_graph, sub_graph->GetName() + suffix); } } void DynamicShapePartitioner::ClearResource() { - for (auto cluster : unique_clusters_) { + for (const auto &cluster : unique_clusters_) { cluster->Clear(); } node_2_cluster_.clear(); ordered_cluster_.clear(); unique_clusters_.clear(); + sorted_unique_clusters_.clear(); unknown_shape_nodes_.clear(); root_graph_.reset(); } Status DynamicShapePartitioner::MarkUnknownShapeNodes() { - auto graph = root_graph_; - for (auto &node : graph->GetDirectNode()) { + for (auto &node : root_graph_->GetDirectNode()) { REQUIRE_SUCCESS(CollectSpreadUnknownShapeNodes(node), "Failed collect spread unknown shape nodes %s.", node->GetName().c_str()); } @@ -186,7 +193,7 @@ Status DynamicShapePartitioner::MarkUnknownShapeNodes() { Status DynamicShapePartitioner::InitClusters() { auto graph = root_graph_; size_t rank = 0; - for (const auto node : graph->GetDirectNode()) { + for (const auto &node : graph->GetDirectNode()) { Cluster::Type type = Cluster::DATA; if (node->GetType() == DATA) { type = Cluster::DATA; @@ -208,7 +215,7 @@ Status DynamicShapePartitioner::InitClusters() { cluster->AddInput(node_2_cluster_[parent]); } } - for (const auto node : graph->GetDirectNode()) { + for (const auto &node : graph->GetDirectNode()) { GELOGD("Make cluster for node %s : %s.", node->GetName().c_str(), node_2_cluster_[node]->DebugString().c_str()); } return SUCCESS; @@ -220,8 +227,8 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { std::queue ready_clusters; std::unordered_map cluster_pending_count; std::unordered_set seen_clusters; - for (auto iter = node_2_cluster_.begin(); iter != node_2_cluster_.end(); iter++) { - auto cluster = iter->second; + for (auto &node : root_graph_->GetDirectNode()) { + auto &cluster = node_2_cluster_[node]; if (seen_clusters.count(cluster) != 0) { continue; } @@ -242,7 +249,7 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { if (cluster->IsKnownShape()) { ordered_cluster_.push_back(cluster); } - for (auto out_cluster : cluster->Outputs()) { + for (const auto &out_cluster : cluster->Outputs()) { if (cluster_pending_count[out_cluster] > 0 && --cluster_pending_count[out_cluster] == 0) { ready_clusters.push(out_cluster); } @@ -273,16 +280,16 @@ static std::string ToString(const std::vector &clusters) { Status DynamicShapePartitioner::MergeClusters() { // Merge unknown shape clusters - for (auto cluster : ordered_cluster_) { - for (auto in_cluster : cluster->Inputs()) { + for (const auto &cluster : ordered_cluster_) { + for (const auto &in_cluster : cluster->Inputs()) { if (!in_cluster->IsUnknownShape()) { continue; } auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), ToString(merged_clusters).c_str()); - for (auto merged_cluster : merged_clusters) { - for (auto node : merged_cluster->Nodes()) { + for (const auto &merged_cluster : merged_clusters) { + for (const auto &node : merged_cluster->Nodes()) { node_2_cluster_[node] = cluster; } } @@ -291,7 +298,7 @@ Status DynamicShapePartitioner::MergeClusters() { REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters."); // Merge known shape clusters - for (auto cluster : ordered_cluster_) { + for (const auto &cluster : ordered_cluster_) { if (cluster->IsRefVariable() && cluster->Inputs().size() == 1) { auto in_cluster = *(cluster->Inputs().begin()); in_cluster->Merge(cluster); @@ -299,13 +306,13 @@ Status DynamicShapePartitioner::MergeClusters() { continue; } - for (auto in_cluster : cluster->Inputs()) { + for (const auto &in_cluster : cluster->Inputs()) { if (!in_cluster->IsKnownShape()) { continue; } if (cluster->TryMerge(in_cluster)) { GELOGD("Success merge known shape cluster from %lu to %lu.", in_cluster->Id(), cluster->Id()); - for (auto node : in_cluster->Nodes()) { + for (const auto &node : in_cluster->Nodes()) { node_2_cluster_[node] = cluster; } } @@ -333,7 +340,7 @@ Status DynamicShapePartitioner::CollectSpreadUnknownShapeNodes(NodePtr node) { if (IsUnknownShapeTensor(out_tensor)) { GELOGD("Collect node %s as unknown as output %lu is unknown.", node->GetName().c_str(), anchor_index); is_unknown = true; - auto anchor = node->GetOutDataAnchor(anchor_index); + auto anchor = node->GetOutDataAnchor(static_cast(anchor_index)); for (const auto peer_anchor : anchor->GetPeerInDataAnchors()) { if (peer_anchor != nullptr) { GELOGD("Collect node %s as has unknown input from %s:%lu.", peer_anchor->GetOwnerNode()->GetName().c_str(), @@ -349,7 +356,7 @@ Status DynamicShapePartitioner::CollectSpreadUnknownShapeNodes(NodePtr node) { if (IsUnknownShapeTensor(in_tensor)) { GELOGD("Collect node %s as unknown as input %lu is unknown.", node->GetName().c_str(), anchor_index); is_unknown = true; - auto anchor = node->GetInDataAnchor(anchor_index); + auto anchor = node->GetInDataAnchor(static_cast(anchor_index)); const auto peer_anchor = anchor->GetPeerOutAnchor(); if (peer_anchor != nullptr) { GELOGD("Collect node %s as has unknown output to %s:%lu.", peer_anchor->GetOwnerNode()->GetName().c_str(), @@ -453,15 +460,15 @@ std::string Cluster::DebugString() const { } ss << "[" << id_ << "](size:" << nodes_.size() << ")"; ss << "(" << min_ << "," << max_ << ")("; - for (auto cluster : in_clusters_) { + for (const auto &cluster : in_clusters_) { ss << cluster->id_ << ","; } ss << ")->("; - for (auto cluster : out_clusters_) { + for (const auto &cluster : out_clusters_) { ss << cluster->id_ << ","; } ss << ")|"; - for (auto node : nodes_) { + for (const auto &node : nodes_) { ss << (node->GetName() + "|"); } return ss.str(); @@ -507,12 +514,12 @@ void Cluster::Merge(ClusterPtr other) { in_clusters_.erase(other); out_clusters_.erase(other); auto in_clusters = other->in_clusters_; - for (auto cluster : in_clusters) { + for (const auto &cluster : in_clusters) { cluster->RemoveOutput(other); cluster->AddOutput(shared_from_this()); } auto out_clusters = other->out_clusters_; - for (auto cluster : out_clusters) { + for (const auto &cluster : out_clusters) { cluster->RemoveInput(other); cluster->AddInput(shared_from_this()); } @@ -529,7 +536,7 @@ bool Cluster::TryMerge(ClusterPtr other) { while (!forward_reached.empty()) { auto current_cluster = forward_reached.front(); forward_reached.pop(); - for (auto cluster : current_cluster->out_clusters_) { + for (const auto &cluster : current_cluster->out_clusters_) { if (cluster->max_ == max_ && current_cluster != other) { return false; } else if (cluster->min_ < max_) { @@ -557,7 +564,7 @@ std::vector Cluster::MergeAllPathFrom(ClusterPtr other) { while (!forward_reached_queue.empty()) { auto current_cluster = forward_reached_queue.front(); forward_reached_queue.pop(); - for (auto cluster : current_cluster->out_clusters_) { + for (const auto &cluster : current_cluster->out_clusters_) { if (cluster->min_ < max_ && cluster->max_ != max_ && forward_reached_clusters.count(cluster) == 0) { forward_reached_clusters.insert(cluster); forward_reached_queue.push(cluster); @@ -567,7 +574,7 @@ std::vector Cluster::MergeAllPathFrom(ClusterPtr other) { while (!backward_reached_queue.empty()) { auto current_cluster = backward_reached_queue.front(); backward_reached_queue.pop(); - for (auto cluster : current_cluster->in_clusters_) { + for (const auto &cluster : current_cluster->in_clusters_) { if (cluster->max_ > other->min_ && cluster->max_ != other->max_ && backward_reached_clusters.count(cluster) == 0) { backward_reached_clusters.insert(cluster); @@ -578,7 +585,7 @@ std::vector Cluster::MergeAllPathFrom(ClusterPtr other) { } } } - for (auto cluster : path_clusters) { + for (const auto &cluster : path_clusters) { Merge(cluster); } return path_clusters; @@ -598,11 +605,11 @@ void Cluster::AddFrameOutput(OutDataAnchorPtr anchor) { }; InDataAnchorPtr Cluster::GetFrameInDataAnchor(InDataAnchorPtr anchor) { - return partition_node_->GetInDataAnchor(inputs_index_[anchor]); + return partition_node_->GetInDataAnchor(static_cast(inputs_index_[anchor])); }; OutDataAnchorPtr Cluster::GetFrameOutDataAnchor(OutDataAnchorPtr anchor) { - return partition_node_->GetOutDataAnchor(outputs_index_[anchor]); + return partition_node_->GetOutDataAnchor(static_cast(outputs_index_[anchor])); }; InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_->GetInControlAnchor(); }; @@ -616,22 +623,25 @@ Status Cluster::BuildFrame() { auto node = nodes_.front(); auto in_control_anchor = node->GetInControlAnchor(); if (in_control_anchor != nullptr) { - for (auto peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { + for (const auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { auto src_cluster = partitioner_->node_2_cluster_[peer_out_control_anchor->GetOwnerNode()]; if (src_cluster->id_ != id_) { - auto src_cluster = partitioner_->node_2_cluster_[peer_out_control_anchor->GetOwnerNode()]; - GraphUtils::RemoveEdge(peer_out_control_anchor, in_control_anchor); + REQUIRE_GRAPH_SUCCESS( + GraphUtils::RemoveEdge(peer_out_control_anchor, in_control_anchor), + "Failed remove edge from node %s index %d to node %s index %d.", + peer_out_control_anchor->GetOwnerNode()->GetName().c_str(), AnchorUtils::GetIdx(peer_out_control_anchor), + in_control_anchor->GetOwnerNode()->GetName().c_str(), AnchorUtils::GetIdx(in_control_anchor)); control_inputs_.insert(src_cluster); src_cluster->control_outputs_.insert(peer_out_control_anchor); } } } if (IsData()) { - for (auto anchor : node->GetAllOutDataAnchors()) { + for (const auto &anchor : node->GetAllOutDataAnchors()) { AddFrameOutput(anchor); } } else { - for (auto anchor : node->GetAllInDataAnchors()) { + for (const auto &anchor : node->GetAllInDataAnchors()) { AddFrameInput(anchor); } } @@ -660,7 +670,7 @@ Status Cluster::BuildPartitionFrame() { "Failed set shape flag."); REQUIRE_GRAPH_SUCCESS(GraphUtils::RemoveJustNode(graph, node), "Failed remove root graph node."); REQUIRE_GRAPH_SUCCESS(node->SetOwnerComputeGraph(subgraph_), "Failed set owner graph."); - for (auto anchor : node->GetAllInDataAnchors()) { + for (const auto &anchor : node->GetAllInDataAnchors()) { auto peer_out_anchor = anchor->GetPeerOutAnchor(); if (peer_out_anchor == nullptr) { continue; // Skip overhang input. @@ -674,7 +684,7 @@ Status Cluster::BuildPartitionFrame() { } auto in_control_anchor = node->GetInControlAnchor(); if (in_control_anchor != nullptr) { - for (auto peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { + for (const auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { if (peer_out_control_anchor == nullptr) { continue; } @@ -689,9 +699,9 @@ Status Cluster::BuildPartitionFrame() { } } } - for (auto anchor : node->GetAllOutDataAnchors()) { + for (const auto &anchor : node->GetAllOutDataAnchors()) { auto peer_in_anchors = anchor->GetPeerInDataAnchors(); - for (auto peer_in_anchor : peer_in_anchors) { + for (const auto &peer_in_anchor : peer_in_anchors) { auto src_cluster = partitioner_->node_2_cluster_[peer_in_anchor->GetOwnerNode()]; if (src_cluster->id_ != id_) { AddFrameOutput(anchor); @@ -717,7 +727,7 @@ Status Cluster::BuildPartitionFrame() { } Status Cluster::CombinePartitionFrame() { - for (auto anchor : inputs_) { + for (const auto &anchor : inputs_) { auto peer_out_anchor = anchor->GetPeerOutAnchor(); auto src_cluster = partitioner_->node_2_cluster_[peer_out_anchor->GetOwnerNode()]; auto src_anchor = src_cluster->GetFrameOutDataAnchor(peer_out_anchor); @@ -729,7 +739,7 @@ Status Cluster::CombinePartitionFrame() { src_anchor->GetOwnerNode()->GetName().c_str(), src_anchor->GetIdx(), dst_anchor->GetOwnerNode()->GetName().c_str(), dst_anchor->GetIdx()); } - for (auto src_cluster : control_inputs_) { + for (const auto &src_cluster : control_inputs_) { auto src_anchor = src_cluster->GetFrameOutControlAnchor(); auto dst_anchor = GetFrameInControlAnchor(); REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(src_anchor, dst_anchor), "Failed add edge from %s:%d to %s:%d.", @@ -753,6 +763,9 @@ Status Cluster::BuildPartitionSubgraph() { REQUIRE_GRAPH_SUCCESS(data_op->AddOutputDesc(input_desc), "Failed add output desc."); REQUIRE(AttrUtils::SetInt(data_op, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index), "Failed set parent_node_index on subgraph data node."); + bool is_unknown_shape = IsUnknownShape(); + REQUIRE(AttrUtils::SetBool(data_op, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape), + "Failed set _is_unknown_shape flag on data op %s.", data_op->GetName().c_str()); auto data_node = subgraph_->AddNode(data_op); REQUIRE_NOT_NULL(data_node, "Failed add data node to subgraph."); REQUIRE_GRAPH_SUCCESS(data_node->SetOwnerComputeGraph(subgraph_), "Failed set owner graph of data node."); @@ -766,6 +779,9 @@ Status Cluster::BuildPartitionSubgraph() { } auto net_output_op = MakeShared(subgraph_->GetName() + "_" + NODE_NAME_NET_OUTPUT, ge::NETOUTPUT); REQUIRE_NOT_NULL(net_output_op, "Failed new memory for netoutput op."); + bool is_unknown_shape = IsUnknownShape(); + REQUIRE(AttrUtils::SetBool(net_output_op, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape), + "Failed set _is_unknown_shape flag on net_output_op %s.", net_output_op->GetName().c_str()); for (size_t i = 0; i < outputs_.size(); ++i) { GeTensorDesc input_desc; REQUIRE_GRAPH_SUCCESS(net_output_op->AddInputDesc(input_desc), "Failed add input desc."); @@ -774,8 +790,8 @@ Status Cluster::BuildPartitionSubgraph() { REQUIRE_NOT_NULL(net_output_node, "Failed add netoutput node to subgraph."); REQUIRE_GRAPH_SUCCESS(net_output_node->SetOwnerComputeGraph(subgraph_), "Failed set owner graph of netoutput node."); parent_node_index = 0; - for (auto anchor : outputs_) { - auto output_desc = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(anchor->GetIdx()); + for (const auto &anchor : outputs_) { + auto output_desc = anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(static_cast(anchor->GetIdx())); REQUIRE(AttrUtils::SetInt(output_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index), "Failed set parent_node_index on subgraph netoutput's input."); REQUIRE_GRAPH_SUCCESS(net_output_op->UpdateInputDesc(parent_node_index, output_desc), @@ -786,7 +802,7 @@ Status Cluster::BuildPartitionSubgraph() { anchor->GetIdx()); parent_node_index++; } - for (auto anchor : control_outputs_) { + for (const auto &anchor : control_outputs_) { REQUIRE_GRAPH_SUCCESS(GraphUtils::AddEdge(anchor, net_output_node->GetInControlAnchor()), "Faile add control edge from %s:%d to netoutput node.", anchor->GetOwnerNode()->GetName().c_str(), anchor->GetIdx()); @@ -809,4 +825,4 @@ void Cluster::Clear() { } size_t Cluster::unique_id_ = 0; -} // namespace ge \ No newline at end of file +} // namespace ge diff --git a/src/ge/graph/partition/dynamic_shape_partition.h b/src/ge/graph/partition/dynamic_shape_partition.h index 4cbd20b7..ba349b1c 100644 --- a/src/ge/graph/partition/dynamic_shape_partition.h +++ b/src/ge/graph/partition/dynamic_shape_partition.h @@ -150,6 +150,8 @@ class DynamicShapePartitioner { std::vector> ordered_cluster_; // Unique clusters left after merged clusters std::unordered_set> unique_clusters_; + // Unique clusters left after merged clusters sorted by rank + std::vector> sorted_unique_clusters_; // Nodes of root_graph_ that satisfy the unknowshape rules std::unordered_set unknown_shape_nodes_; }; diff --git a/src/ge/graph/partition/engine_place.cc b/src/ge/graph/partition/engine_place.cc index 74da0326..2d1a7f13 100644 --- a/src/ge/graph/partition/engine_place.cc +++ b/src/ge/graph/partition/engine_place.cc @@ -38,6 +38,7 @@ Status EnginePlacer::Run() { return FAILED; } // Assign engine for each node in the graph + instance_ptr->DNNEngineManagerObj().InitPerformanceStaistic(); for (const auto &node_ptr : compute_graph_->GetDirectNode()) { GE_CHECK_NOTNULL(node_ptr); GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); @@ -60,12 +61,15 @@ Status EnginePlacer::Run() { return FAILED; } } + for (auto &it : instance_ptr->DNNEngineManagerObj().GetCheckSupportCost()) { + GEEVENT("The time cost of %s::CheckSupported is [%lu] micro second.", it.first.c_str(), it.second); + } GELOGI("Engine placer ends."); return SUCCESS; } Status EnginePlacer::AssignEngineAndLog(ge::ConstNodePtr node_ptr, const std::string &engine_name) { - if (node_ptr == nullptr || node_ptr->GetOpDesc() == nullptr) { + if ((node_ptr == nullptr) || (node_ptr->GetOpDesc() == nullptr)) { GELOGE(FAILED, "node_ptr is null."); return FAILED; } diff --git a/src/ge/graph/partition/graph_partition.cc b/src/ge/graph/partition/graph_partition.cc index 50cd7e81..15f298c0 100644 --- a/src/ge/graph/partition/graph_partition.cc +++ b/src/ge/graph/partition/graph_partition.cc @@ -25,6 +25,7 @@ #include "framework/common/types.h" #include "graph/debug/ge_attr_define.h" #include "graph/manager/graph_manager_utils.h" +#include "graph/common/ge_call_wrapper.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/type_utils.h" @@ -231,33 +232,33 @@ Status ge::GraphPartitioner::MergeSubGraph(ge::ComputeGraphPtr &output_merged_co ComputeGraphPtr new_sub_graph = MakeShared(original_compute_graph->GetName()); GE_CHECK_NOTNULL(new_sub_graph); output_merged_compute_graph = new_sub_graph; - GE_TIMESTAMP_START(MergeGraphRemoveNode); + GE_TIMESTAMP_START(MergeSubGraphRemoveNode); if (RemoveNodeAndEdgeBetweenEndPld(output_merged_compute_graph, sub_graph_list) != ge::SUCCESS) { GELOGE(GE_GRAPH_PARAM_NULLPTR, "[GraphPartitioner]: merging sub-graphs failed"); return FAILED; } - GE_TIMESTAMP_END(MergeGraphRemoveNode, "GraphPartitioner::MergeGraphRemoveNodeAndEdge"); - GE_TIMESTAMP_START(MergeGraphTopologicalSorting); + GE_TIMESTAMP_END(MergeSubGraphRemoveNode, "GraphPartitioner::MergeGraphRemoveNodeAndEdge"); + GE_TIMESTAMP_START(MergeSubGraphTopologicalSorting); Status ret = output_merged_compute_graph->TopologicalSorting(); if (ret != SUCCESS) { GELOGE(GE_GRAPH_TOPO_SORT_FAILED, "[GraphPartitioner]: output_merged_compute_graph->TopologicalSorting failed"); return FAILED; } - GE_TIMESTAMP_END(MergeGraphTopologicalSorting, "GraphPartitioner::MergeGraphTopologicalSorting"); + GE_TIMESTAMP_END(MergeSubGraphTopologicalSorting, "GraphPartitioner::MergeGraphTopologicalSorting"); // flush all nodes' engine of merged graph - GE_TIMESTAMP_START(MergeGraphEnginePlacerRun); + GE_TIMESTAMP_START(MergeSubGraphEnginePlacerRun); graph_info_.engine_placer_.SetComputeGraph(output_merged_compute_graph); if (graph_info_.engine_placer_.Run() != SUCCESS) { GELOGE(GE_GRAPH_INIT_FAILED, "[GraphPartitioner]: engine_placer run failed"); return FAILED; } - GE_TIMESTAMP_END(MergeGraphEnginePlacerRun, "GraphPartitioner::MergeGraphEnginePlacerRun"); + GE_TIMESTAMP_END(MergeSubGraphEnginePlacerRun, "GraphPartitioner::MergeGraphEnginePlacerRun"); GELOGI("Graph merge ends."); return SUCCESS; } Status ge::GraphPartitioner::UpdatePldOpDesc(const NodePtr &dst_node, int input_index, OpDescPtr &pld_op_desc) { - if (dst_node == nullptr || pld_op_desc == nullptr || dst_node->GetOpDesc() == nullptr) { + if ((dst_node == nullptr) || (pld_op_desc == nullptr) || (dst_node->GetOpDesc() == nullptr)) { GELOGE(FAILED, "parameter ptr is null."); return FAILED; } @@ -275,7 +276,7 @@ Status ge::GraphPartitioner::UpdatePldOpDesc(const NodePtr &dst_node, int input_ } Status ge::GraphPartitioner::UpdateEndOpDesc(const NodePtr &src_node, int output_index, OpDescPtr &end_op_desc) { - if (src_node == nullptr || end_op_desc == nullptr || src_node->GetOpDesc() == nullptr) { + if ((src_node == nullptr) || (end_op_desc == nullptr) || (src_node->GetOpDesc() == nullptr)) { GELOGE(FAILED, "parameter ptr is null."); return FAILED; } @@ -296,9 +297,9 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr const AnchorPtr &peer_in_anchor, const ge::ComputeGraphPtr &pld_graph, const ge::ComputeGraphPtr &end_graph) { - GE_CHECK_NOTNULL(out_anchor); GE_CHECK_NOTNULL(peer_in_anchor); GE_CHECK_NOTNULL(pld_graph); + GE_CHECK_NOTNULL(out_anchor); GE_CHECK_NOTNULL(end_graph); const auto &src_node = out_anchor->GetOwnerNode(); const auto &dst_node = peer_in_anchor->GetOwnerNode(); @@ -313,6 +314,12 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr GELOGW("SetInt peerIndex failed");) GE_IF_BOOL_EXEC(!AttrUtils::SetStr(end_op_desc, "parentOpType", dst_node->GetType()), GELOGW("SetStr parentOpType failed");) + GE_IF_BOOL_EXEC(!end_op_desc->SetExtAttr("parentNode", dst_node), GELOGW("SetEndExtAttr parentNode failed");) + OpDescPtr dst_node_op_desc = dst_node->GetOpDesc(); + GE_CHECK_NOTNULL(dst_node_op_desc); + GE_IF_BOOL_EXEC( + !AttrUtils::SetStr(end_op_desc, ATTR_NAME_END_REAR_NODE_ENGINE_NAME, dst_node_op_desc->GetOpEngineName()), + GELOGW("SetStr rearNodeEngineName failed");) // replace input_desc of end with owner node's desc int output_index = ge::AnchorUtils::GetIdx(out_anchor); bool is_need_update_desc = (output_index >= 0) && (graph_info_.mode_ == kPartitioning); @@ -361,6 +368,12 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr GELOGW("SetStr parentId failed");) GE_IF_BOOL_EXEC(!AttrUtils::SetInt(pld_op_desc, "anchorIndex", AnchorUtils::GetIdx(out_anchor)), GELOGW("SetInt anchorIndex failed");) + GE_IF_BOOL_EXEC(!pld_op_desc->SetExtAttr("parentNode", src_node), GELOGW("SetPldExtAttr parentNode failed");) + OpDescPtr src_node_op_desc = src_node->GetOpDesc(); + GE_CHECK_NOTNULL(src_node_op_desc); + GE_IF_BOOL_EXEC( + !AttrUtils::SetStr(pld_op_desc, ATTR_NAME_PLD_FRONT_NODE_ENGINE_NAME, src_node_op_desc->GetOpEngineName()), + GELOGW("SetStr frontNodeEngineName failed");) // do not care over flow graph_info_.num_of_pld_end_++; // replace output_desc of pld with input node's output desc @@ -395,14 +408,14 @@ graphStatus ge::GraphPartitioner::AddPlaceHolderEndInSrcDstGraph(const AnchorPtr return FAILED; } graph_info_.index_2_end_[graph_info_.num_of_pld_end_] = new_end_node; - graph_info_.end_2_pld_[new_end_node] = new_pld_node; graph_info_.pld_2_end_[new_pld_node] = new_end_node; + graph_info_.end_2_pld_[new_end_node] = new_pld_node; return SUCCESS; } Status ge::GraphPartitioner::LinkInput2EndRemoveOrginalLink(ge::NodePtr input_node, ge::ComputeGraphPtr src_graph, ge::ComputeGraphPtr dst_graph) { - if (input_node == nullptr || src_graph == nullptr || dst_graph == nullptr) { + if ((input_node == nullptr) || (src_graph == nullptr) || (dst_graph == nullptr)) { GELOGE(FAILED, "parameter ptr is null."); return FAILED; } @@ -442,7 +455,7 @@ Status ge::GraphPartitioner::LinkInput2EndRemoveOrginalLink(ge::NodePtr input_no Status ge::GraphPartitioner::PutInputNodesInSubGraph(const ge::ComputeGraphPtr &src_graph, const ge::ComputeGraphPtr &dst_graph) { - if (src_graph == nullptr || dst_graph == nullptr) { + if ((src_graph == nullptr) || (dst_graph == nullptr)) { GELOGE(FAILED, "parameter ptr is null."); return FAILED; } @@ -849,34 +862,34 @@ Status ge::GraphPartitioner::PartitionSubGraph(ge::ComputeGraphPtr compute_graph GELOGE(GE_GRAPH_TOPO_SORT_FAILED, "[GraphPartitioner]: subGraphPtr->TopologicalSorting failed"); return FAILED; } - GE_TIMESTAMP_START(GraphPartitionInitialize); + GE_TIMESTAMP_START(PartitionSubGraphInitialize); if (Initialize(compute_graph) != SUCCESS) { GELOGE(GE_GRAPH_INIT_FAILED, "[GraphPartitioner]: initialize failed"); return FAILED; } - GE_TIMESTAMP_END(GraphPartitionInitialize, "GraphPartitioner::PartitionInitialize"); - GE_TIMESTAMP_START(GraphPartitionMarkClusters); + GE_TIMESTAMP_END(PartitionSubGraphInitialize, "GraphPartitioner::PartitionInitialize"); + GE_TIMESTAMP_START(PartitionSubGraphMarkClusters); MarkClusters(); - GE_TIMESTAMP_END(GraphPartitionMarkClusters, "GraphPartitioner::PartitionMarkClusters"); - GE_TIMESTAMP_START(GraphPartitionSplitSubGraphs); + GE_TIMESTAMP_END(PartitionSubGraphMarkClusters, "GraphPartitioner::PartitionMarkClusters"); + GE_TIMESTAMP_START(PartitionSubGraphSplitSubGraphs); if (SplitSubGraphs(compute_graph) != SUCCESS) { GELOGE(FAILED, "[GraphPartitioner]: SplitSubGraphs failed"); return FAILED; } - GE_TIMESTAMP_END(GraphPartitionSplitSubGraphs, "GraphPartitioner::PartitionSplitSubGraphs"); - GE_TIMESTAMP_START(GraphPartitionSortSubGraphs); + GE_TIMESTAMP_END(PartitionSubGraphSplitSubGraphs, "GraphPartitioner::PartitionSplitSubGraphs"); + GE_TIMESTAMP_START(PartitionSubGraphSortSubGraphs); if (SortSubGraphs(compute_graph) != ge::SUCCESS) { GELOGE(GE_GRAPH_TOPO_SORT_FAILED, "Graph Partition SortSubGraphs failed."); return ge::FAILED; } - GE_TIMESTAMP_END(GraphPartitionSortSubGraphs, "GraphPartitioner::PartitionSortSubGraphs"); - GE_TIMESTAMP_START(GraphPartitionAddPartitionsToGraphNode); + GE_TIMESTAMP_END(PartitionSubGraphSortSubGraphs, "GraphPartitioner::PartitionSortSubGraphs"); + GE_TIMESTAMP_START(PartitionSubGraphAddPartitionsToGraphNode); vector output_subgraphs; if (AddPartitionsToGraphNode(output_subgraphs, compute_graph) != ge::SUCCESS) { GELOGE(GE_GRAPH_EMPTY_PARTITION, "Graph Partition AddPartitionsToGraphNode failed."); return ge::FAILED; } - GE_TIMESTAMP_END(GraphPartitionAddPartitionsToGraphNode, "GraphPartitioner::PartitionAddPartitionsToGraphNode"); + GE_TIMESTAMP_END(PartitionSubGraphAddPartitionsToGraphNode, "GraphPartitioner::PartitionAddPartitionsToGraphNode"); GELOGI("Graph Partition ends. Adding partitions to SubGraphInfo, got %zu sub graphs", output_subgraphs.size()); graph_info_.mode_ = kMerging; // do not care over flow @@ -923,7 +936,7 @@ Status ge::GraphPartitioner::AddPlaceHolderEnd(const AnchorPtr &out_anchor, cons Status ge::GraphPartitioner::SortSubGraphs(const ge::ComputeGraphPtr &compute_graph) { uint32_t rank = kRankOne; // rank 0 for data graph ComputeGraphPtr new_input_nodes_sub_graph = MakeShared("inputNodeGraph"); - if (new_input_nodes_sub_graph == nullptr || compute_graph == nullptr) { + if ((new_input_nodes_sub_graph == nullptr) || (compute_graph == nullptr)) { GELOGE(FAILED, "[GraphPartitioner]: new_input_nodes_sub_graph or compute_graph is null."); return FAILED; } @@ -965,7 +978,7 @@ Status ge::GraphPartitioner::SortSubGraphs(const ge::ComputeGraphPtr &compute_gr } AnchorPtr ge::GraphPartitioner::GetEndInAnchor(const AnchorPtr &src_anchor, const NodePtr &end_node) { - if (src_anchor == nullptr || end_node == nullptr) { + if ((src_anchor == nullptr) || (end_node == nullptr)) { GELOGE(FAILED, "parameter ptr is null."); return nullptr; } @@ -979,7 +992,7 @@ AnchorPtr ge::GraphPartitioner::GetEndInAnchor(const AnchorPtr &src_anchor, cons } AnchorPtr ge::GraphPartitioner::GetPldOutAnchor(const NodePtr &pld_node, const AnchorPtr &dst_anchor) { - if (pld_node == nullptr || dst_anchor == nullptr) { + if ((pld_node == nullptr) || (dst_anchor == nullptr)) { GELOGE(FAILED, "parameter ptr is null."); return nullptr; } @@ -992,16 +1005,16 @@ AnchorPtr ge::GraphPartitioner::GetPldOutAnchor(const NodePtr &pld_node, const A return pld_out_anchor; } -void ge::GraphPartitioner::AddEndPldInformationToSubGraphInfo(ge::SubGraphInfoPtr &sub_graph_info) { - if (sub_graph_info == nullptr) { +void ge::GraphPartitioner::AddEndPldInformationToSubGraphInfo(ge::SubGraphInfoPtr &subgraph_info) { + if (subgraph_info == nullptr) { GELOGE(FAILED, "parameter ptr is null."); return; } - auto sub_graph = sub_graph_info->GetSubGraph(); - GE_CHECK_NOTNULL_JUST_RETURN(sub_graph); + auto subgraph = subgraph_info->GetSubGraph(); + GE_CHECK_NOTNULL_JUST_RETURN(subgraph); NodetoNodeMap end_map; NodetoNodeMap pld_map; - for (const auto &node : sub_graph->GetDirectNode()) { + for (const auto &node : subgraph->GetDirectNode()) { if (node->GetType() == kEndType) { end_map[node] = graph_info_.end_2_pld_.at(node); } @@ -1009,8 +1022,8 @@ void ge::GraphPartitioner::AddEndPldInformationToSubGraphInfo(ge::SubGraphInfoPt pld_map[node] = graph_info_.pld_2_end_.at(node); } } - sub_graph_info->SetEnd2PldMap(end_map); - sub_graph_info->SetPld2EndMap(pld_map); + subgraph_info->SetEnd2PldMap(end_map); + subgraph_info->SetPld2EndMap(pld_map); } const Graph2SubGraphInfoList &ge::GraphPartitioner::GetSubGraphMap() { return graph_2_subgraph_list_; } diff --git a/src/ge/graph/passes/atomic_addr_clean_pass.cc b/src/ge/graph/passes/atomic_addr_clean_pass.cc index 7d9b8dec..2c7fb9bb 100644 --- a/src/ge/graph/passes/atomic_addr_clean_pass.cc +++ b/src/ge/graph/passes/atomic_addr_clean_pass.cc @@ -22,68 +22,44 @@ #include #include -#include "framework/common/debug/ge_log.h" #include "common/ge_inner_error_codes.h" #include "common/ge/ge_util.h" +#include "graph/common/ge_call_wrapper.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/node_utils.h" #include "init/gelib.h" -namespace { -bool is_loop_graph = false; -} namespace ge { -namespace { -bool GraphShouldBeSkip(const ge::ComputeGraphPtr &graph) { - // Internal function, guaranteeing graph non-null - if (graph->GetParentGraph() == nullptr) { - return false; - } - return GraphUtils::IsUnknownShapeGraph(graph); -} -} // namespace - Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { - GE_TIMESTAMP_START(AtomicAddrCleanPass); - if (graph == nullptr) { - GELOGE(PARAM_INVALID, "param [graph] must not be null."); - return PARAM_INVALID; - } - if (GraphShouldBeSkip(graph)) { - return SUCCESS; - } + GE_CHECK_NOTNULL(graph); GELOGD("AtomicAddrCleanPass begin."); // 1.Recoginze atomic and loop mark vector atomic_node_vec; for (NodePtr &node : graph->GetDirectNode()) { if (IsAtomicOp(node)) { - bool is_unknown = false; - auto ret_status = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); - if (ret_status != GRAPH_SUCCESS) { - GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(), - node->GetType().c_str()); - continue; - } - if (is_unknown) { - GELOGI("Current node %s, type %s is unknown shape which should be skip.", node->GetName().c_str(), - node->GetType().c_str()); - continue; - } atomic_node_vec.push_back(node); } - if (!is_loop_graph && node->GetType() == LOOPCOND) { + if (!is_loop_graph_ && node->GetType() == LOOPCOND) { // there is loop in this graph GELOGD("There is no loop node. It will insert clean node follow atomic node."); - is_loop_graph = true; + is_loop_graph_ = true; } } if (atomic_node_vec.empty()) { GELOGI("There is no atomic node. Ignore atomicAddrClean pass."); return SUCCESS; } + + bool is_known_graph = graph->GetGraphUnknownFlag(); + if (is_known_graph) { + GELOGD("Graph[%s] is unknown graph. It will call fe interface to compile op.", graph->GetName().c_str()); + GE_CHK_STATUS_RET(CompileUnknownGraphOp(atomic_node_vec)); + return SUCCESS; + } + // 2.Insert clean node and link to atomic node Status ret; - if (is_loop_graph) { + if (is_loop_graph_) { ret = HandleLoopGraph(graph, atomic_node_vec); if (ret != SUCCESS) { return ret; @@ -95,7 +71,6 @@ Status AtomicAddrCleanPass::Run(ComputeGraphPtr graph) { } } GELOGD("AtomicAddrCleanPass end."); - GE_TIMESTAMP_END(AtomicAddrCleanPass, "GraphManager::AtomicAddrCleanPass"); return SUCCESS; } @@ -129,15 +104,28 @@ Status AtomicAddrCleanPass::HandleLoopGraph(ComputeGraphPtr &graph, const vector } Status AtomicAddrCleanPass::HandleNormalGraph(ComputeGraphPtr &graph, const vector &atomic_node_vec) { - GELOGD("Not loop graph. It will insert only 1 clean node."); + GELOGD("Not loop graph and unknown graph. It will insert only 1 clean node."); + + vector common_atomic_nodes; + auto ret = HandleDispersedAtomicNodes(graph, atomic_node_vec, common_atomic_nodes); + if (ret != SUCCESS) { + GELOGE(ret, "Handle dispersed atomic nodes failed, graph name is %s.", graph->GetName().c_str()); + return ret; + } + + if (common_atomic_nodes.empty()) { + GELOGI("common_atomic_nodes is empty"); + return SUCCESS; + } + // not loop graph , insert only one clean node in graph NodePtr clean_addr_node = InsertAtomicAddrCleanNode(graph); if (clean_addr_node == nullptr) { GELOGE(FAILED, "Insert AtomicAddrClean node failed. Ignore atomicAddrClean pass."); return FAILED; } - for (const auto &node : atomic_node_vec) { - auto ret = LinkToAtomicNode(node, clean_addr_node); + for (const auto &node : common_atomic_nodes) { + ret = LinkToAtomicNode(node, clean_addr_node); if (ret != SUCCESS) { GELOGE(ret, "Link control anchor failed from atomic node to atomic_addr_clean node."); return ret; @@ -149,7 +137,7 @@ Status AtomicAddrCleanPass::HandleNormalGraph(ComputeGraphPtr &graph, const vect for (auto &in_anchor : node->GetAllInDataAnchors()) { GE_CHECK_NOTNULL(in_anchor->GetPeerOutAnchor()); NodePtr peer_in_node = in_anchor->GetPeerOutAnchor()->GetOwnerNode(); - Status ret = LinkToAtomicNode(peer_in_node, clean_addr_node); + ret = LinkToAtomicNode(peer_in_node, clean_addr_node); if (ret != SUCCESS) { GELOGE(ret, "Link failed, %s : %s", peer_in_node->GetName().c_str(), clean_addr_node->GetName().c_str()); return ret; @@ -159,6 +147,44 @@ Status AtomicAddrCleanPass::HandleNormalGraph(ComputeGraphPtr &graph, const vect return SUCCESS; } +Status AtomicAddrCleanPass::HandleDispersedAtomicNodes(ComputeGraphPtr &graph, + const std::vector &atomic_node_vec, + std::vector &common_atomic_nodes) { + int index = 0; + for (const auto &node : atomic_node_vec) { + vector node_anchors_connect_netoutput; + // If GetBool fail, attr is_connect_netoutput is an empty vector. + (void)ge::AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_NODE_CONNECT_OUTPUT, node_anchors_connect_netoutput); + if (!node_anchors_connect_netoutput.empty()) { + NodePtr dispersed_clean_addr_node = InsertAtomicAddrCleanNode(graph); + if (dispersed_clean_addr_node == nullptr) { + GELOGE(FAILED, "Insert AtomicAddrClean node failed. Ignore atomicAddrClean pass."); + return FAILED; + } + + auto dispersed_node_op_desc = dispersed_clean_addr_node->GetOpDesc(); + GE_CHECK_NOTNULL(dispersed_node_op_desc); + string node_name = dispersed_node_op_desc->GetName(); + std::ostringstream oss; + oss << node_name << "_" << index; + node_name = oss.str(); + dispersed_node_op_desc->SetName(node_name); + GELOGD("Inserted dispersed atomic clean node name is %s", node_name.c_str()); + ++index; + Status ret = LinkToAtomicNode(node, dispersed_clean_addr_node); + if (ret != SUCCESS) { + GELOGE(ret, "Link control anchor failed from atomic node: %s to atomic_addr_clean node: %s.", + node->GetName().c_str(), dispersed_clean_addr_node->GetName().c_str()); + return ret; + } + } else { + common_atomic_nodes.emplace_back(node); + } + } + + return SUCCESS; +} + NodePtr AtomicAddrCleanPass::InsertAtomicAddrCleanNode(ComputeGraphPtr &graph) { OpDescPtr op_desc = MakeShared(NODE_NAME_ATOMIC_ADDR_CLEAN, ATOMICADDRCLEAN); if (op_desc == nullptr) { @@ -172,12 +198,14 @@ NodePtr AtomicAddrCleanPass::InsertAtomicAddrCleanNode(ComputeGraphPtr &graph) { if (!session_graph_id.empty()) { (void)AttrUtils::SetStr(op_desc, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); } + string node_name = op_desc->GetName(); // Only flush subgraph name - string node_name = (graph->GetParentGraph() != nullptr) - ? (graph->GetName() + "_" + op_desc->GetName() + session_graph_id) - : (op_desc->GetName() + session_graph_id); + if (graph->GetParentGraph() != nullptr) { + node_name = graph->GetName() + "_" + node_name; + } - op_desc->SetName(node_name); + string name = node_name + session_graph_id; + op_desc->SetName(name); GELOGI("Create cleanAddr op:%s.", op_desc->GetName().c_str()); // To avoid same name between graphs, set session graph id to this node NodePtr clean_addr_node = graph->AddNodeFront(op_desc); @@ -203,7 +231,7 @@ Status AtomicAddrCleanPass::LinkToAtomicNode(const NodePtr &atomic_node, NodePtr } GELOGD("Graph add cleanAddrNode op out ctrl edge, dst node: %s.", atomic_node->GetName().c_str()); std::string stream_label; - if (is_loop_graph && AttrUtils::GetStr(atomic_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { + if (is_loop_graph_ && AttrUtils::GetStr(atomic_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { if (!AttrUtils::SetStr(atomic_clean_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label)) { GELOGW("LinkToAtomicNode: SetStr failed"); return INTERNAL_ERROR; @@ -262,11 +290,56 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { return true; } /// -/// @brief Clear Status, uesd for subgraph pass +/// @brief Clear Status, used for subgraph pass /// @return SUCCESS /// Status AtomicAddrCleanPass::ClearStatus() { hcom_node_vec_.clear(); return SUCCESS; } + +Status AtomicAddrCleanPass::CompileUnknownGraphOp(const vector &atomic_node_vec) { + GE_TIMESTAMP_CALLNUM_START(UnknownGraphCompileOp); + std::unordered_map> node_vector_map; + std::shared_ptr instance = ge::GELib::GetInstance(); + if ((instance == nullptr) || !instance->InitFlag()) { + GELOGE(ge::GE_CLI_GE_NOT_INITIALIZED, "CompileSingleOp failed."); + return ge::GE_CLI_GE_NOT_INITIALIZED; + } + + for (auto &atomic_node : atomic_node_vec) { + auto op_desc = atomic_node->GetOpDesc(); + if (op_desc == nullptr) { + GELOGW("op desc is nullptr."); + continue; + } + string kernel_lib_name = op_desc->GetOpKernelLibName(); + if (kernel_lib_name.empty()) { + GELOGE(ge::INTERNAL_ERROR, "Get atomic node:%s(%s) kernel lib failed.", atomic_node->GetName().c_str(), + atomic_node->GetType().c_str()); + return ge::INTERNAL_ERROR; + } + + OpsKernelInfoStorePtr kernel_info = instance->OpsKernelManagerObj().GetOpsKernelInfoStore(kernel_lib_name); + GE_CHECK_NOTNULL(kernel_info); + node_vector_map[kernel_lib_name].emplace_back(atomic_node); + } + + for (auto &it : node_vector_map) { + auto &kernel_lib_name = it.first; + auto &node_vector = it.second; + OpsKernelInfoStorePtr kernel_info = instance->OpsKernelManagerObj().GetOpsKernelInfoStore(kernel_lib_name); + GE_CHECK_NOTNULL(kernel_info); + GE_TIMESTAMP_RESTART(UnknownGraphCompileOp); + auto ret = kernel_info->CompileOp(node_vector); + GELOGI("The atomic node size of compile op of %s is %zu", kernel_lib_name.c_str(), node_vector.size()); + GE_TIMESTAMP_ADD(UnknownGraphCompileOp); + if (ret != ge::SUCCESS) { + GELOGE(ret, "Compile atomic op failed, kernel lib name is %s", kernel_lib_name.c_str()); + return ret; + } + } + GE_TIMESTAMP_CALLNUM_END(UnknownGraphCompileOp, "AtomicAddrCleanPass::CompileUnknownGraphOp"); + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/passes/atomic_addr_clean_pass.h b/src/ge/graph/passes/atomic_addr_clean_pass.h index d2d8f2ce..e22c1792 100644 --- a/src/ge/graph/passes/atomic_addr_clean_pass.h +++ b/src/ge/graph/passes/atomic_addr_clean_pass.h @@ -74,7 +74,18 @@ class AtomicAddrCleanPass : public GraphPass { */ bool IsAtomicOp(const NodePtr &node); + /** + * Handle atomic node in unknown graph + * @param atomic_node_vec: atomic node vector in unknown graph + * @return + */ + Status CompileUnknownGraphOp(const vector &atomic_node_vec); + + Status HandleDispersedAtomicNodes(ComputeGraphPtr &graph, const std::vector &atomic_node_vec, + std::vector &common_atomic_nodes); + vector hcom_node_vec_; + bool is_loop_graph_ = false; }; } // namespace ge diff --git a/src/ge/graph/passes/attach_stream_label_pass.cc b/src/ge/graph/passes/attach_stream_label_pass.cc new file mode 100644 index 00000000..9962821b --- /dev/null +++ b/src/ge/graph/passes/attach_stream_label_pass.cc @@ -0,0 +1,291 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/attach_stream_label_pass.h" +#include "ge/ge_api_types.h" +#include "graph/common/omg_util.h" + +namespace ge { +Status AttachStreamLabelPass::Run(ComputeGraphPtr graph) { + GELOGD("AttachStreamLabelPass Enter."); + + FindNodes(graph); + for (const auto &node : need_label_nodes_) { + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (!op_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { + GE_CHK_STATUS_RET(UpdateCondBranch(node), "Update cond branch failed, start node:%s.", node->GetName().c_str()); + } + } + GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode failed."); + + GELOGD("AttachStreamLabelPass Leave."); + return SUCCESS; +} + +/// +/// @brief Clear Status, used for subgraph pass +/// @return +/// +Status AttachStreamLabelPass::ClearStatus() { + stream_switch_nodes_.clear(); + need_label_nodes_.clear(); + enter_nodes_.clear(); + branch_head_nodes_.clear(); + return SUCCESS; +} + +/// +/// @brief Find StreamSwitch / StreamMerge / Enter node +/// @param [in] graph +/// @return void +/// +void AttachStreamLabelPass::FindNodes(const ComputeGraphPtr &graph) { + for (const NodePtr &node : graph->GetDirectNode()) { + const std::string &type = node->GetType(); + if (type == STREAMSWITCH) { + stream_switch_nodes_.emplace_back(node); + } else if (type == STREAMMERGE) { + if ((node->GetOpDesc() != nullptr) && !node->GetOpDesc()->HasAttr(ATTR_NAME_NEXT_ITERATION)) { + need_label_nodes_.emplace_back(node); + } + } else if ((type == ENTER) || (type == REFENTER)) { + enter_nodes_.emplace_back(node); + } + } + + for (const auto &node : stream_switch_nodes_) { + for (const auto &out_ctrl_node : node->GetOutControlNodes()) { + GELOGD("branch_head_node %s of stream_switch %s.", out_ctrl_node->GetName().c_str(), node->GetName().c_str()); + branch_head_nodes_[out_ctrl_node] = node; + } + need_label_nodes_.emplace_back(node); + } +} + +/// +/// @brief update cond branch +/// @param [in] node +/// @return Status +/// +Status AttachStreamLabelPass::UpdateCondBranch(const NodePtr &node) { + std::string stream_label; + std::unordered_set branch_nodes; + std::unordered_set visited; + std::stack nodes; + nodes.push(node); + + static const std::set end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; + bool merge_flag = false; + bool exit_flag = false; + bool net_output_flag = false; + while (!nodes.empty()) { + NodePtr cur_node = nodes.top(); + nodes.pop(); + if (visited.count(cur_node) > 0) { + continue; + } + if (AttachFlag(cur_node, stream_label, merge_flag, exit_flag, net_output_flag) != SUCCESS) { + GELOGE(FAILED, "Attach flag for node %s failed.", cur_node->GetName().c_str()); + return FAILED; + } + + const std::string &type = cur_node->GetType(); + for (const auto &out_node : cur_node->GetOutAllNodes()) { + const std::string &out_type = out_node->GetType(); + bool stop_flag = (end_type_set.count(out_type) > 0) || + ((branch_head_nodes_.count(out_node) > 0) && (branch_head_nodes_[out_node] != node)) || + (((type == ENTER) || (type == REFENTER)) && (out_type != STREAMACTIVE)); + if (!stop_flag) { + nodes.push(out_node); + GELOGD("Insert branch node %s.", out_node->GetName().c_str()); + branch_nodes.insert(out_node); + } + } + visited.insert(cur_node); + } + + if (node->GetType() == STREAMSWITCH) { + GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed."); + } + + bool attach_flag = (merge_flag || exit_flag) && net_output_flag; + if (attach_flag) { + GELOGI("No need to keep on attaching label."); + return SUCCESS; + } + + for (const NodePtr &tmp_node : branch_nodes) { + GELOGD("Attach label %s to node: %s.", stream_label.c_str(), tmp_node->GetName().c_str()); + GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "Set stream label failed."); + } + + return SUCCESS; +} + +/// +/// @brief attach flag +/// @param [in] node +/// @param [out] stream_label +/// @param [out] merge_flag +/// @param [out] exit_flag +/// @param [out] net_output_flag +/// @return Status +/// +Status AttachStreamLabelPass::AttachFlag(const NodePtr &node, std::string &stream_label, bool &merge_flag, + bool &exit_flag, bool &net_output_flag) { + const std::string &type = node->GetType(); + if (type == STREAMSWITCH) { + if (node->GetInDataNodes().empty()) { + GELOGE(INTERNAL_ERROR, "node %s has no input_data_node.", node->GetName().c_str()); + return INTERNAL_ERROR; + } + stream_label = node->GetInDataNodes().at(0)->GetName(); + GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); + bool value = false; + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, + "StreamSwitch get attr TRUE_BRANCH_STREAM failed."); + stream_label += (value ? "_t" : "_f"); + } else if (type == STREAMMERGE) { + stream_label = node->GetName(); + GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); + merge_flag = true; + } else if ((type == EXIT) || (type == REFEXIT)) { + GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "Set stream label failed."); + exit_flag = true; + } else if (type == NETOUTPUT) { + net_output_flag = true; + } + + return SUCCESS; +} + +/// +/// @brief Update stream_label start with enter nodes +/// @return Status +/// +Status AttachStreamLabelPass::UpdateEnterNode() { + std::unordered_map> enter_active_map; + for (const auto &enter_node : enter_nodes_) { + for (const auto &out_ctrl_node : enter_node->GetOutControlNodes()) { + if (out_ctrl_node->GetType() != STREAMACTIVE) { + continue; + } + auto iter = enter_active_map.find(out_ctrl_node); + if (iter == enter_active_map.end()) { + enter_active_map[out_ctrl_node] = {enter_node}; + } else { + iter->second.emplace_back(enter_node); + } + } + } + + for (const auto &pair : enter_active_map) { + if (SetEnterLabel(pair.second, pair.first) != SUCCESS) { + GELOGE(FAILED, "Set stream_label for enter_nodes failed."); + return FAILED; + } + + NodePtr active_node = pair.first; + GE_CHECK_NOTNULL(active_node); + std::vector active_label_list; + if (!AttrUtils::GetListStr(active_node->GetOpDesc(), ATTR_NAME_ACTIVE_LABEL_LIST, active_label_list) || + (active_label_list.size() != 1) || active_label_list[0].empty()) { + GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ACTIVE_LABEL_LIST failed, node: %s.", active_node->GetName().c_str()); + return INTERNAL_ERROR; + } + + std::stack enter_nodes; + for (const auto &enter_node : pair.second) { + enter_nodes.emplace(enter_node); + } + if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) { + GELOGE(FAILED, "Update stream_label for loop_branch failed."); + return FAILED; + } + } + + return SUCCESS; +} + +/// +/// @brief Set stream_label for enter_nodes +/// @param [in] enter_nodes +/// @param [in] active_node +/// @return Status +/// +Status AttachStreamLabelPass::SetEnterLabel(const std::vector &enter_nodes, const NodePtr &active_node) { + std::string stream_label; + GE_CHECK_NOTNULL(active_node); + (void)AttrUtils::GetStr(active_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); + + bool same_flag = true; + for (const auto &enter_node : enter_nodes) { + std::string tmp_label; + (void)AttrUtils::GetStr(enter_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, tmp_label); + if (tmp_label.empty() || (stream_label == tmp_label)) { + continue; + } + same_flag = false; + break; + } + + if (stream_label.empty()) { + if (same_flag) { + stream_label = active_node->GetName(); + } else { + GELOGW("stream_label of enter_active is empty while stream_label of some enter_node is not."); + return SUCCESS; + } + } + + for (const auto &enter_node : enter_nodes) { + GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "Set stream label failed."); + } + GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "Set stream label failed."); + return SUCCESS; +} + +/// +/// @brief Update stream_label for loop_branch +/// @param [in] enter_nodes +/// @param [in] stream_label +/// @return Status +/// +Status AttachStreamLabelPass::UpdateLoopBranch(const std::stack &enter_nodes, + const std::string &stream_label) { + std::stack nodes(enter_nodes); + NodePtr cur_node = nullptr; + while (!nodes.empty()) { + cur_node = nodes.top(); + nodes.pop(); + for (const NodePtr &out_node : cur_node->GetOutAllNodes()) { + OpDescPtr out_desc = out_node->GetOpDesc(); + GE_CHECK_NOTNULL(out_desc); + std::string out_type = out_desc->GetType(); + if (out_desc->HasAttr(ATTR_NAME_STREAM_LABEL) || (out_type == ENTER) || (out_type == REFENTER)) { + continue; + } + GELOGD("Attach label %s to node: %s.", stream_label.c_str(), out_node->GetName().c_str()); + GE_CHK_STATUS_RET(SetStreamLabel(out_node, stream_label), "Set stream label failed."); + nodes.push(out_node); + } + } + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/attach_stream_label_pass.h b/src/ge/graph/passes/attach_stream_label_pass.h new file mode 100644 index 00000000..fc6abd30 --- /dev/null +++ b/src/ge/graph/passes/attach_stream_label_pass.h @@ -0,0 +1,89 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_ATTACH_STREAM_LABEL_PASS_H_ +#define GE_GRAPH_PASSES_ATTACH_STREAM_LABEL_PASS_H_ + +#include +#include "inc/graph_pass.h" + +namespace ge { +class AttachStreamLabelPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph); + + /// + /// @brief Clear Status, used for subgraph pass + /// @return + /// + Status ClearStatus() override; + + private: + /// + /// @brief Find StreamSwitch / StreamMerge / Enter node + /// @param [in] graph + /// @return void + /// + void FindNodes(const ComputeGraphPtr &graph); + + /// + /// @brief update cond branch + /// @param [in] node + /// @return Status + /// + Status UpdateCondBranch(const NodePtr &node); + + /// + /// @brief attach flag + /// @param [in] node + /// @param [out] stream_label + /// @param [out] merge_flag + /// @param [out] exit_flag + /// @param [out] net_output_flag + /// @return Status + /// + static Status AttachFlag(const NodePtr &node, std::string &stream_label, bool &merge_flag, bool &exit_flag, + bool &net_output_flag); + + /// + /// @brief Update stream_label for loop_branch + /// @param [in] enter_nodes + /// @param [in] stream_label + /// @return Status + /// + static Status UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label); + + /// + /// @brief Update stream_label start with enter nodes + /// @return Status + /// + Status UpdateEnterNode(); + + /// + /// @brief Set stream_label for enter_nodes + /// @param [in] enter_nodes + /// @param [in] active_node + /// @return Status + /// + static Status SetEnterLabel(const std::vector &enter_nodes, const NodePtr &active_node); + + std::vector stream_switch_nodes_; + std::vector need_label_nodes_; + std::vector enter_nodes_; + std::unordered_map branch_head_nodes_; +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_ATTACH_STREAM_LABEL_PASS_H_ diff --git a/src/ge/graph/passes/base_pass.cc b/src/ge/graph/passes/base_pass.cc index 629b08ba..4da51ab0 100644 --- a/src/ge/graph/passes/base_pass.cc +++ b/src/ge/graph/passes/base_pass.cc @@ -66,7 +66,7 @@ void AddNextIterNodes(const Node::Vistor &nodes, std::queue &n } Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unordered_set &nodes_re_pass, - std::unordered_set &nodes_deleted, std::unordered_set &nodes_seen) { + std::unordered_set &nodes_deleted, std::unordered_set &nodes_seen) { if (node == nullptr) { GELOGE(FAILED, "parameter is null."); return FAILED; @@ -106,7 +106,7 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, std::unorder auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); - if (nodes_deleted_by_pass.count(node.get()) > 0) { + if (nodes_deleted_by_pass.count(node) > 0) { GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), name_to_pass.first.c_str()); break; @@ -153,7 +153,7 @@ Status BaseNodePass::IsolateAndDeleteNode(NodePtr &node, const std::vector return FAILED; } - AddNodeDeleted(node.get()); + AddNodeDeleted(node); return SUCCESS; } @@ -182,7 +182,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { GELOGD("Begin to run pass on graph, passes count %zu", names_to_passes.size()); std::queue nodes; std::unordered_set nodes_seen; - std::unordered_set nodes_deleted; + std::unordered_set nodes_deleted; std::unordered_set nodes_re_pass; std::unordered_set nodes_last; GetAllNodesNoInputEdge(graph_, nodes, nodes_seen, nodes_last); @@ -202,7 +202,7 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { (void)nodes_re_pass.erase(node); GE_IF_BOOL_EXEC(node == nullptr, GELOGW("node is null"); continue); - if (nodes_deleted.count(node.get()) > 0) { + if (nodes_deleted.count(node) > 0) { GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); continue; } diff --git a/src/ge/graph/passes/base_pass.h b/src/ge/graph/passes/base_pass.h index dfba581e..6e7b292e 100644 --- a/src/ge/graph/passes/base_pass.h +++ b/src/ge/graph/passes/base_pass.h @@ -53,7 +53,7 @@ class BaseNodePass { std::unordered_set GetNodesNeedRePass() { return nodes_need_re_pass_; } - std::unordered_set GetNodesDeleted() { return nodes_deleted_; } + std::unordered_set GetNodesDeleted() { return nodes_deleted_; } void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } @@ -103,13 +103,13 @@ class BaseNodePass { /// next iterations. /// @param node /// - void AddNodeDeleted(Node *node) { nodes_deleted_.insert(node); } + void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } private: std::unordered_set nodes_need_re_pass_; - std::unordered_set nodes_deleted_; + std::unordered_set nodes_deleted_; std::map options_; }; diff --git a/src/ge/graph/passes/bitcast_pass.cc b/src/ge/graph/passes/bitcast_pass.cc new file mode 100644 index 00000000..e8e1f84f --- /dev/null +++ b/src/ge/graph/passes/bitcast_pass.cc @@ -0,0 +1,148 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/bitcast_pass.h" + +#include +#include +#include "common/ge/ge_util.h" +#include "graph/utils/type_utils.h" +#include "framework/common/debug/log.h" +#include "framework/common/ge_inner_error_codes.h" + +namespace ge { +namespace { +const char *const kAttrNameType = "type"; +} // namespace + +Status BitcastPass::Run(NodePtr &node) { + GELOGD("Bitcast running"); + if (node == nullptr) { + GELOGE(PARAM_INVALID, "Param [node] must not be null."); + return PARAM_INVALID; + } + + if (node->GetType() != BITCAST) { + return SUCCESS; + } + + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + return PARAM_INVALID; + } + ge::DataType dst_data_type; + if (CheckDstDataType(op_desc, dst_data_type) != SUCCESS) { + return PARAM_INVALID; + } + + if (CheckOutputShape(op_desc, dst_data_type) != SUCCESS) { + return PARAM_INVALID; + } + + return IsolateAndDeleteNode(node, {0}); +} + +Status BitcastPass::CheckDstDataType(const OpDescPtr op_desc, ge::DataType &dst_data_type) { + if (!ge::AttrUtils::GetDataType(op_desc, kAttrNameType, dst_data_type)) { + GELOGE(PARAM_INVALID, "Node failed to get attribute type."); + return PARAM_INVALID; + } + if (dst_data_type >= ge::DT_UNDEFINED) { + GELOGE(PARAM_INVALID, "dst_data_type[%s] is not valid.", TypeUtils::DataTypeToSerialString(dst_data_type).c_str()); + return PARAM_INVALID; + } + + if (op_desc->GetOutputDescPtr(0) == nullptr) { + GELOGE(PARAM_INVALID, "Bitcast node outputDesc is null."); + return PARAM_INVALID; + } + if (op_desc->GetOutputDescPtr(0)->GetDataType() != dst_data_type) { + GELOGE(PARAM_INVALID, "dst_data_type[%s] is not equal to output_data_type[%s].", + TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), + TypeUtils::DataTypeToSerialString(op_desc->GetOutputDescPtr(0)->GetDataType()).c_str()); + return PARAM_INVALID; + } + return SUCCESS; +} + +Status BitcastPass::CheckOutputShape(const OpDescPtr op_desc, const ge::DataType dst_data_type) { + const GeTensorDescPtr &input_tensor_desc = op_desc->MutableInputDesc(0); + const GeTensorDescPtr &output_tensor_desc = op_desc->MutableOutputDesc(0); + if (input_tensor_desc == nullptr) { + GELOGE(PARAM_INVALID, "input_tensor_desc must not be null."); + return PARAM_INVALID; + } + + // get origin data_type and shape + ge::DataType ori_data_type = input_tensor_desc->GetDataType(); + if (ori_data_type >= ge::DT_UNDEFINED) { + GELOGE(PARAM_INVALID, "ori_data_type[%s] is not valid.", TypeUtils::DataTypeToSerialString(ori_data_type).c_str()); + return PARAM_INVALID; + } + + if (ori_data_type == dst_data_type) { + GELOGW("Origin data type is equal to dest data type."); + return SUCCESS; + } + + BitcastPass::kVecInt64 dim_vec(input_tensor_desc->GetShape().GetDims()); + if (CalcAndUpdateShape(dim_vec, ori_data_type, dst_data_type) != SUCCESS) { + GELOGE(PARAM_INVALID, "CalcAndUpdateShape failed."); + return PARAM_INVALID; + } + + if (dim_vec != output_tensor_desc->GetShape().GetDims()) { + GELOGE(PARAM_INVALID, "out_put_shape is different from expectations."); + return PARAM_INVALID; + } + + return SUCCESS; +} + +Status BitcastPass::CalcAndUpdateShape(BitcastPass::kVecInt64 &dim_vec, ge::DataType ori_data_type, + ge::DataType dst_data_type) { + if (dim_vec.size() == 0) { + GELOGE(PARAM_INVALID, "Pre node shape size is zero."); + return PARAM_INVALID; + } + int64_t ori_data_size = GetSizeByDataType(ori_data_type); + int64_t dst_data_size = GetSizeByDataType(dst_data_type); + + if (ori_data_size == dst_data_size) { + return SUCCESS; + } else if (ori_data_size > dst_data_size) { + if (ori_data_size % dst_data_size != 0) { + GELOGE(PARAM_INVALID, "ori_data_size is not divisible by dst_data_size."); + return PARAM_INVALID; + } + dim_vec.push_back(ori_data_size / dst_data_size); + return SUCCESS; + } else { + if (dst_data_size % ori_data_size != 0) { + GELOGE(PARAM_INVALID, "dst_data_size is not divisible by ori_data_size."); + return PARAM_INVALID; + } + + if (dim_vec[dim_vec.size() - 1] != (dst_data_size / ori_data_size)) { + GELOGE(PARAM_INVALID, "The last dim is not equal to dst_data_size / ori_data_size."); + return PARAM_INVALID; + } + dim_vec.pop_back(); + } + return SUCCESS; +} + +} // namespace ge diff --git a/src/ge/graph/passes/bitcast_pass.h b/src/ge/graph/passes/bitcast_pass.h new file mode 100644 index 00000000..4a9e2e1b --- /dev/null +++ b/src/ge/graph/passes/bitcast_pass.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_BITCAST_PASS_H_ +#define GE_GRAPH_PASSES_BITCAST_PASS_H_ + +#include "common/ge_inner_error_codes.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/types.h" +#include "graph/graph.h" +#include "graph/op_desc.h" +#include "graph/passes/base_pass.h" +#include "graph/passes/pass_utils.h" + +namespace ge { +class BitcastPass : public BaseNodePass { + public: + Status Run(ge::NodePtr &node) override; + typedef std::vector kVecInt64; + + private: + Status CheckDstDataType(const OpDescPtr op_desc, ge::DataType &dst_data_type); + Status CheckOutputShape(const OpDescPtr op_desc, const ge::DataType dst_data_type); + Status CalcAndUpdateShape(BitcastPass::kVecInt64 &dim_vec, ge::DataType ori_data_type, ge::DataType dst_data_type); +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_BITCAST_PASS_H_ diff --git a/src/ge/graph/passes/cast_remove_pass.cc b/src/ge/graph/passes/cast_remove_pass.cc index d18c4b4e..f7ff941c 100644 --- a/src/ge/graph/passes/cast_remove_pass.cc +++ b/src/ge/graph/passes/cast_remove_pass.cc @@ -69,7 +69,6 @@ bool CastRemovePass::HasSameDataType(OpDescPtr &begin_op_desc, OpDescPtr &end_op auto begin_out_desc = begin_op_desc->MutableOutputDesc(0); DataType begin_out_datatype = begin_out_desc->GetDataType(); - if (begin_out_datatype == end_out_datatype && (begin_out_datatype == DT_FLOAT16 || begin_out_datatype == DT_FLOAT)) { type = begin_out_datatype; return true; diff --git a/src/ge/graph/passes/cast_translate_pass.cc b/src/ge/graph/passes/cast_translate_pass.cc index 2d67b0a8..ee67e93d 100644 --- a/src/ge/graph/passes/cast_translate_pass.cc +++ b/src/ge/graph/passes/cast_translate_pass.cc @@ -264,7 +264,7 @@ Status CastTranslatePass::FuseDstNTranslates(NodePtr &node) { GELOGE(FAILED, "[%s] RemoveNodeWithoutRelink failed.", out_data_node->GetName().c_str()); return FAILED; } - AddNodeDeleted(out_data_node.get()); + AddNodeDeleted(out_data_node); } return SUCCESS; diff --git a/src/ge/graph/passes/common_subexpression_elimination_pass.cc b/src/ge/graph/passes/common_subexpression_elimination_pass.cc index a52535c1..18f2e857 100644 --- a/src/ge/graph/passes/common_subexpression_elimination_pass.cc +++ b/src/ge/graph/passes/common_subexpression_elimination_pass.cc @@ -83,6 +83,7 @@ Status CommonSubexpressionEliminationPass::Run(ComputeGraphPtr graph) { continue; } auto key = GetCseKey(node); + GELOGD("The node %s cse key %s", node->GetName().c_str(), key.c_str()); auto iter = keys_to_node.find(key); if (iter == keys_to_node.end()) { keys_to_node[key] = node; diff --git a/src/ge/graph/passes/compile_nodes_pass.cc b/src/ge/graph/passes/compile_nodes_pass.cc index def7655e..330569a2 100644 --- a/src/ge/graph/passes/compile_nodes_pass.cc +++ b/src/ge/graph/passes/compile_nodes_pass.cc @@ -23,6 +23,7 @@ #include "common/ge_inner_error_codes.h" #include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" +#include "graph/common/ge_call_wrapper.h" #include "graph/op_desc.h" using domi::ImplyType; @@ -78,7 +79,7 @@ graphStatus CompileNodesPass::Run(ComputeGraphPtr graph) { return result; } GELOGI("[CompileNodesPass]: Optimize success."); - GE_TIMESTAMP_END(CompileNodesPass, "GraphManager::CompileNodesPass"); + GE_TIMESTAMP_EVENT_END(CompileNodesPass, "OptimizeStage2::ControlAttrOptimize::CompileNodesPass"); return GRAPH_SUCCESS; } @@ -101,7 +102,6 @@ graphStatus CompileNodesPass::GetSupportedKernel(const NodePtr &node, const std: } } OpsKernelInfoStorePtr kernel_info = instance->OpsKernelManagerObj().GetOpsKernelInfoStore(kernel_lib_name); - if (kernel_info == nullptr) { GELOGE(ge::GE_GRAPH_PARAM_NULLPTR, "Get op %s ops kernel info store failed", node->GetName().c_str()); return ge::GE_GRAPH_PARAM_NULLPTR; diff --git a/src/ge/graph/passes/cond_pass.cc b/src/ge/graph/passes/cond_pass.cc index 651cf98b..2f3f9333 100644 --- a/src/ge/graph/passes/cond_pass.cc +++ b/src/ge/graph/passes/cond_pass.cc @@ -226,7 +226,7 @@ Status CondPass::HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnc return FAILED; } - if (GraphUtils::InsertNodeBefore(out_anchor, {in_anchor}, cast_node) != GRAPH_SUCCESS) { + if (GraphUtils::InsertNodeAfter(out_anchor, {in_anchor}, cast_node) != GRAPH_SUCCESS) { GELOGE(FAILED, "Insert Cast node %s between %s->%s failed.", cast_node->GetName().c_str(), out_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); return FAILED; @@ -271,7 +271,7 @@ Status CondPass::InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr } AddRePassNode(new_node); - if (GraphUtils::InsertNodeBefore(out_anchor, {in_anchor}, new_node) != GRAPH_SUCCESS) { + if (GraphUtils::InsertNodeAfter(out_anchor, {in_anchor}, new_node) != GRAPH_SUCCESS) { GELOGE(FAILED, "Insert %s node %s between %s->%s failed.", type.c_str(), new_node->GetName().c_str(), out_anchor->GetOwnerNode()->GetName().c_str(), in_anchor->GetOwnerNode()->GetName().c_str()); return FAILED; diff --git a/src/ge/graph/passes/cond_remove_pass.cc b/src/ge/graph/passes/cond_remove_pass.cc index 8bc34fbc..1650be92 100644 --- a/src/ge/graph/passes/cond_remove_pass.cc +++ b/src/ge/graph/passes/cond_remove_pass.cc @@ -225,41 +225,40 @@ bool CondRemovePass::CheckIfCondConstInput(const OutDataAnchorPtr &cond_out_anch Status CondRemovePass::ReplaceIfCaseNodeWithPartitioncall(const NodePtr &node, const ComputeGraphPtr &save_branch) { // Add compute graph to new node - const auto &input_anchors = node->GetAllInAnchors(); - const auto &output_anchors = node->GetAllOutAnchors(); + const auto &input_desc_size = node->GetOpDesc()->GetInputsSize(); + const auto &output_desc_size = node->GetOpDesc()->GetOutputsSize(); // Create subgraph opdesc & node auto partitioncall_opdesc = - CreateSubgraphOpDesc(save_branch->GetName(), input_anchors.size() - kConditionIndexNum, output_anchors.size()); + CreateSubgraphOpDesc(save_branch->GetName(), input_desc_size - kConditionIndexNum, output_desc_size); auto partitioncall_node = node->GetOwnerComputeGraph()->AddNode(partitioncall_opdesc); // Link node's peerout anchors to new node's inanchors - for (const auto &input_anchor : input_anchors) { + for (const auto &input_anchor : node->GetAllInAnchors()) { for (const auto &peerout_anchor : input_anchor->GetPeerAnchors()) { if (GraphUtils::AddEdge(peerout_anchor, partitioncall_node->GetInAnchor( input_anchor->GetIdx() - kConditionIndexNum)) != ge::GRAPH_SUCCESS) { GELOGE(FAILED, "Add edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%d, output num:%d", peerout_anchor->GetOwnerNode()->GetName().c_str(), peerout_anchor->GetIdx(), - partitioncall_node->GetName().c_str(), input_anchor->GetIdx(), input_anchors.size(), - output_anchors.size()); + partitioncall_node->GetName().c_str(), input_anchor->GetIdx(), input_desc_size, output_desc_size); return FAILED; } } } // Remove If / Case anchor and peer in anchor // Link new node's out anchors to node's peer inanchors - for (const auto &output_anchor : output_anchors) { + for (const auto &output_anchor : node->GetAllOutAnchors()) { for (const auto &peerin_anchor : output_anchor->GetPeerAnchors()) { if (GraphUtils::RemoveEdge(node->GetOutAnchor(output_anchor->GetIdx()), peerin_anchor) != ge::GRAPH_SUCCESS) { GELOGE(FAILED, "Remove edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%d, output num:%d", node->GetName().c_str(), output_anchor->GetIdx(), peerin_anchor->GetOwnerNode()->GetName().c_str(), - peerin_anchor->GetIdx(), input_anchors.size(), output_anchors.size()); + peerin_anchor->GetIdx(), input_desc_size, output_desc_size); return FAILED; } if (GraphUtils::AddEdge(partitioncall_node->GetOutAnchor(output_anchor->GetIdx()), peerin_anchor) != ge::GRAPH_SUCCESS) { GELOGE(FAILED, "Add edge failed, from node:%s idx:%d to node:%s idx:%d, input num:%d, output num:%d", partitioncall_node->GetName().c_str(), output_anchor->GetIdx(), - peerin_anchor->GetOwnerNode()->GetName().c_str(), peerin_anchor->GetIdx(), input_anchors.size(), - output_anchors.size()); + peerin_anchor->GetOwnerNode()->GetName().c_str(), peerin_anchor->GetIdx(), input_desc_size, + output_desc_size); return FAILED; } } diff --git a/src/ge/graph/passes/constant_folding_pass.cc b/src/ge/graph/passes/constant_folding_pass.cc index 3ac7feb6..80bf7867 100644 --- a/src/ge/graph/passes/constant_folding_pass.cc +++ b/src/ge/graph/passes/constant_folding_pass.cc @@ -29,6 +29,18 @@ #include "inc/kernel.h" namespace ge { +const int64_t kStartCallNum = 1; + +const std::unordered_map> + &ConstantFoldingPass::GetGeConstantFoldingPerfStatistic() const { + return statistic_of_ge_constant_folding_; +} + +const std::unordered_map> + &ConstantFoldingPass::GetOpConstantFoldingPerfStatistic() const { + return statistic_of_op_constant_folding_; +} + Status ConstantFoldingPass::Run(ge::NodePtr &node) { GE_CHECK_NOTNULL(node); GELOGD("Begin to run constant folding on node %s", node->GetName().c_str()); @@ -50,6 +62,8 @@ Status ConstantFoldingPass::Run(ge::NodePtr &node) { auto inputs = OpDescUtils::GetInputData(input_nodes); vector outputs; + // Statistic of ge constant folding kernel + uint64_t start_time = GetCurrentTimestap(); auto ret = RunOpKernel(node, inputs, outputs); if (ret != SUCCESS) { auto op_kernel = folding_pass::GetKernelByType(node); @@ -59,7 +73,18 @@ Status ConstantFoldingPass::Run(ge::NodePtr &node) { return SUCCESS; } + // Statistic of op and fe constant folding kernel + start_time = GetCurrentTimestap(); ret = op_kernel->Compute(node_desc, inputs, outputs); + uint64_t cost_time = GetCurrentTimestap() - start_time; + if (statistic_of_ge_constant_folding_.find(node->GetType()) != statistic_of_ge_constant_folding_.end()) { + uint64_t &cnt = statistic_of_ge_constant_folding_[node->GetType()].first; + uint64_t &cur_cost_time = statistic_of_ge_constant_folding_[node->GetType()].second; + cnt++; + cur_cost_time += cost_time; + } else { + statistic_of_ge_constant_folding_[node->GetType()] = std::pair(kStartCallNum, cost_time); + } if (ret != SUCCESS) { if (ret == NOT_CHANGED) { GELOGD("Node %s type %s, compute terminates and exits the constant folding.", node->GetName().c_str(), @@ -70,6 +95,16 @@ Status ConstantFoldingPass::Run(ge::NodePtr &node) { return ret; } GELOGI("Node %s type %s, constant folding compute success.", node->GetName().c_str(), node->GetType().c_str()); + } else { + if (statistic_of_op_constant_folding_.find(node->GetType()) != statistic_of_op_constant_folding_.end()) { + uint64_t &cnt = statistic_of_op_constant_folding_[node->GetType()].first; + uint64_t &cost_time = statistic_of_op_constant_folding_[node->GetType()].second; + cnt++; + cost_time += GetCurrentTimestap() - start_time; + } else { + statistic_of_op_constant_folding_[node->GetType()] = + std::pair(kStartCallNum, GetCurrentTimestap() - start_time); + } } if (outputs.empty()) { diff --git a/src/ge/graph/passes/constant_folding_pass.h b/src/ge/graph/passes/constant_folding_pass.h index 1dcbcdc3..683b66f1 100644 --- a/src/ge/graph/passes/constant_folding_pass.h +++ b/src/ge/graph/passes/constant_folding_pass.h @@ -26,6 +26,12 @@ namespace ge { class ConstantFoldingPass : public FoldingPass { public: Status Run(ge::NodePtr &node) override; + const std::unordered_map> &GetGeConstantFoldingPerfStatistic() const; + const std::unordered_map> &GetOpConstantFoldingPerfStatistic() const; + + private: + std::unordered_map> statistic_of_op_constant_folding_; + std::unordered_map> statistic_of_ge_constant_folding_; }; } // namespace ge diff --git a/src/ge/graph/passes/control_trigger_pass.cc b/src/ge/graph/passes/control_trigger_pass.cc index 77fcbd69..0c00d553 100644 --- a/src/ge/graph/passes/control_trigger_pass.cc +++ b/src/ge/graph/passes/control_trigger_pass.cc @@ -15,16 +15,9 @@ */ #include "graph/passes/control_trigger_pass.h" - #include - #include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/debug/log.h" -#include "framework/common/ge_inner_error_codes.h" -#include "framework/common/types.h" #include "graph/common/omg_util.h" -#include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" namespace ge { @@ -444,7 +437,7 @@ Status ControlTriggerPass::FindPredInput(const NodePtr &switch_node) { return SUCCESS; } /// -/// @brief Clear Status, uesd for subgraph pass +/// @brief Clear Status, used for subgraph pass /// @return SUCCESS /// Status ControlTriggerPass::ClearStatus() { diff --git a/src/ge/graph/passes/end_of_sequence_add_control_pass.cc b/src/ge/graph/passes/end_of_sequence_add_control_pass.cc new file mode 100644 index 00000000..a3928835 --- /dev/null +++ b/src/ge/graph/passes/end_of_sequence_add_control_pass.cc @@ -0,0 +1,139 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/end_of_sequence_add_control_pass.h" + +#include +#include + +#include "init/gelib.h" +#include "graph/node.h" + +namespace ge { + +Status EndOfSequenceAddControlPass::Run(ComputeGraphPtr graph) { + if (graph == nullptr) { + GELOGE(PARAM_INVALID, "param [graph] must not be null."); + return PARAM_INVALID; + } + if (graph->GetParentGraph() != nullptr) { + return SUCCESS; + } + NodePtr end_of_sequence = GetEndOfSequence(graph); + if (end_of_sequence == nullptr) { + return SUCCESS; + } + GELOGI("EndOfSequenceAddControlPass begin."); + + std::vector target_nodes; + for (NodePtr &node : graph->GetDirectNode()) { + if (node == nullptr) { + GELOGW("node is nullptr."); + continue; + } + string stream_label; + (void)AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, stream_label); + if (!stream_label.empty() || IsDataLikeNode(node)) { + continue; + } + // Save the nodes whose pre-nodes are all data-like node + auto in_data_nodes = node->GetInDataNodes(); + bool flag = false; + for (auto in_node : in_data_nodes) { + if (!IsDataLikeNode(in_node)) { + flag = true; + break; + } + } + if (flag) { + continue; + } + target_nodes.push_back(node); + } + // Insert control edge + Status status = AddControlEdge(end_of_sequence, target_nodes); + if (status != SUCCESS) { + GELOGE(FAILED, "Graph add EndOfSequence op out ctrl edge fail."); + return FAILED; + } + GELOGI("EndOfSequenceAddControlPass end."); + return SUCCESS; +} + +Status EndOfSequenceAddControlPass::AddControlEdge(NodePtr &end_of_sequence, std::vector &target_nodes) { + auto out_ctrl_anchor = end_of_sequence->GetOutControlAnchor(); + for (NodePtr &node : target_nodes) { + auto in_ctrl_anchor = node->GetInControlAnchor(); + if (in_ctrl_anchor == nullptr) { + continue; + } + Status status = GraphUtils::AddEdge(out_ctrl_anchor, in_ctrl_anchor); + if (status != GRAPH_SUCCESS) { + GELOGE(FAILED, "Graph add EndOfSequence op out ctrl edge fail, dst node: %s.", node->GetName().c_str()); + return FAILED; + } + GELOGI("Graph add EndOfSequence op out ctrl edge, dst node: %s.", node->GetName().c_str()); + } + return SUCCESS; +} + +inline NodePtr EndOfSequenceAddControlPass::GetEndOfSequence(const ComputeGraphPtr &graph) const { + // Internal function, guaranteeing graph non-null + for (NodePtr &node : graph->GetDirectNode()) { + if (node->GetType() == ENDOFSEQUENCE) { + return node; + } + } + return nullptr; +} + +bool EndOfSequenceAddControlPass::IsDataLikeNode(const NodePtr &node) { + std::shared_ptr instance_ptr = GELib::GetInstance(); + if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { + GELOGW("GELib not initialized"); + return false; + } + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + return false; + } + string engine_name = op_desc->GetOpEngineName(); + if (engine_name.empty()) { + engine_name = instance_ptr->DNNEngineManagerObj().GetDNNEngineName(node->GetOpDesc()); + } + const map schedulers = instance_ptr->DNNEngineManagerObj().GetSchedulers(); + // Only one scheduler has been supported by now + for (auto schedulers_iter = schedulers.begin(); schedulers_iter != schedulers.end(); ++schedulers_iter) { + const map cal_engines = schedulers_iter->second.cal_engines; + auto cal_engines_iter = cal_engines.find(engine_name); + if (cal_engines_iter == cal_engines.end()) { + GELOGW("No cal_engines found within engine %s, node name %s", engine_name.c_str(), node->GetName().c_str()); + continue; + } + EngineConfPtr engine_conf_ptr = cal_engines_iter->second; + if (engine_conf_ptr == nullptr) { + GELOGW("engine_conf_ptr within engine %s, node name %s is null", engine_name.c_str(), node->GetName().c_str()); + continue; + } + bool skip_assign_stream = engine_conf_ptr->skip_assign_stream; + if (skip_assign_stream) { + return true; + } + return false; + } + return false; +} +} // namespace ge diff --git a/src/ge/graph/passes/end_of_sequence_add_control_pass.h b/src/ge/graph/passes/end_of_sequence_add_control_pass.h new file mode 100644 index 00000000..2540a988 --- /dev/null +++ b/src/ge/graph/passes/end_of_sequence_add_control_pass.h @@ -0,0 +1,56 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_END_OF_SEQUENCE_ADD_CONTROL_EDGE_PASS_H_ +#define GE_GRAPH_PASSES_END_OF_SEQUENCE_ADD_CONTROL_EDGE_PASS_H_ + +#include "graph/graph.h" +#include "inc/graph_pass.h" + +namespace ge { +class EndOfSequenceAddControlPass : public GraphPass { + public: + EndOfSequenceAddControlPass() {} + EndOfSequenceAddControlPass(const EndOfSequenceAddControlPass &eos_pass) = delete; + EndOfSequenceAddControlPass &operator=(const EndOfSequenceAddControlPass &eos_pass) = delete; + + ~EndOfSequenceAddControlPass() override {} + + Status Run(ComputeGraphPtr graph) override; + + private: + /** + * Get EndOfSequence node in graph, nullptr if not exist. + * @param graph + * @return EndOfSequence node + */ + inline NodePtr GetEndOfSequence(const ComputeGraphPtr &graph) const; + /** + * Check whether this node is a data-like node. + * @param node + * @return + */ + bool IsDataLikeNode(const NodePtr &node); + /** + * Check whether this node is a data-like node. + * @param node + * @return + */ + Status AddControlEdge(NodePtr &end_of_sequence, std::vector &target_nodes); +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_END_OF_SEQUENCE_ADD_CONTROL_EDGE_PASS_H_ diff --git a/src/ge/graph/passes/folding_pass.cc b/src/ge/graph/passes/folding_pass.cc index 4e51f1ca..44dbc182 100644 --- a/src/ge/graph/passes/folding_pass.cc +++ b/src/ge/graph/passes/folding_pass.cc @@ -291,7 +291,7 @@ Status FoldingPass::RemoveNodeKeepingCtrlEdges(NodePtr &node) { GELOGE(INTERNAL_ERROR, "Failed to remove node %s from graph", node->GetName().c_str()); return INTERNAL_ERROR; } - AddNodeDeleted(node.get()); + AddNodeDeleted(node); return SUCCESS; } diff --git a/src/ge/graph/passes/hccl_memcpy_pass.cc b/src/ge/graph/passes/hccl_memcpy_pass.cc index 5325f56e..a9b3484b 100644 --- a/src/ge/graph/passes/hccl_memcpy_pass.cc +++ b/src/ge/graph/passes/hccl_memcpy_pass.cc @@ -28,6 +28,7 @@ namespace { const int32_t kAnchorSize = 1; const int kAnchorNum = 0; +const char *const kInputMutable = "_input_mutable"; } // namespace namespace ge { Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { @@ -35,7 +36,16 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { for (const auto &node : graph->GetDirectNode()) { auto op_desc = node->GetOpDesc(); GE_IF_BOOL_EXEC(op_desc == nullptr, continue); - if (!NeedInsertMemcpyOp(op_desc)) { + + bool node_input_mutable = false; + if (!AttrUtils::HasAttr(op_desc, kInputMutable)) { + continue; + } + + GE_IF_BOOL_EXEC(!AttrUtils::GetBool(op_desc, kInputMutable, node_input_mutable), + GELOGE(INTERNAL_ERROR, "node:%s get attr:_input_mutable failed.", node->GetName().c_str()); + return FAILED); + if (!node_input_mutable) { continue; } @@ -53,7 +63,7 @@ Status HcclMemcpyPass::Run(ge::ComputeGraphPtr graph) { NodePtr src_node = src_out_anchor->GetOwnerNode(); std::string src_type = src_node->GetType(); bool check_src_type = (src_type == CONSTANTOP) || (src_type == DATA) || (src_type == CONSTANT); - if (check_src_type && node->GetType() == HCOMALLREDUCE) { + if (check_src_type) { Status ret = ModifyEdgeConnection(graph, src_out_anchor, hccl_in_anchor); if (ret != SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to modify the connection."); @@ -135,16 +145,6 @@ std::string HcclMemcpyPass::CheckDuplicateName(const std::string &node_name) { return tmp_name; } -/// -/// @brief Check hcom op -/// @param [in] ge::ConstOpDescPtr op_desc -/// @return bool -/// -bool HcclMemcpyPass::NeedInsertMemcpyOp(const ge::ConstOpDescPtr &op_desc) const { - return (op_desc->GetType() == HCOMALLGATHER || op_desc->GetType() == HCOMALLREDUCE || - op_desc->GetType() == HCOMREDUCESCATTER); -} - /// /// @brief Modify edge connection /// @param [in] ComputeGraphPtr graph @@ -182,7 +182,7 @@ Status HcclMemcpyPass::ModifyEdgeConnection(const ComputeGraphPtr &graph, const return SUCCESS; } /// -/// @brief Clear Status, uesd for subgraph pass +/// @brief Clear Status, used for subgraph pass /// @return SUCCESS /// Status HcclMemcpyPass::ClearStatus() { diff --git a/src/ge/graph/passes/hccl_memcpy_pass.h b/src/ge/graph/passes/hccl_memcpy_pass.h index 9de96fbf..13863bd6 100644 --- a/src/ge/graph/passes/hccl_memcpy_pass.h +++ b/src/ge/graph/passes/hccl_memcpy_pass.h @@ -34,8 +34,6 @@ class HcclMemcpyPass : public GraphPass { std::string CheckDuplicateName(const std::string &node_name); - bool NeedInsertMemcpyOp(const ge::ConstOpDescPtr &op_desc) const; - Status ModifyEdgeConnection(const ComputeGraphPtr &graph, const OutDataAnchorPtr &src_out_anchor, const InDataAnchorPtr &hccl_in_anchor); diff --git a/src/ge/graph/passes/identify_reference_pass.cc b/src/ge/graph/passes/identify_reference_pass.cc deleted file mode 100644 index 92f7e7b6..00000000 --- a/src/ge/graph/passes/identify_reference_pass.cc +++ /dev/null @@ -1,52 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/identify_reference_pass.h" - -#include -#include "framework/common/debug/ge_log.h" -#include "graph/debug/ge_attr_define.h" - -namespace ge { -Status IdentifyReferencePass::Run(NodePtr &node) { - if (node == nullptr) { - GELOGE(PARAM_INVALID, "param [node] must not be null."); - return PARAM_INVALID; - } - auto op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - GELOGE(PARAM_INVALID, "OpDesc of param [node] must not be null."); - return PARAM_INVALID; - } - - auto input_names = op_desc->GetAllInputNames(); - auto outputs = op_desc->GetAllOutputName(); - for (auto &output : outputs) { - for (auto &input_name : input_names) { - if (input_name == output.first) { - bool is_ref = true; - if (AttrUtils::SetBool(op_desc, ATTR_NAME_REFERENCE, is_ref)) { - GELOGI("param [node] %s is reference node, set attribute %s to be true.", node->GetName().c_str(), - ATTR_NAME_REFERENCE.c_str()); - return SUCCESS; - } - } - } - } - - return SUCCESS; -} -} // namespace ge diff --git a/src/ge/graph/passes/identity_pass.cc b/src/ge/graph/passes/identity_pass.cc index 9b15f77a..1f4725bf 100644 --- a/src/ge/graph/passes/identity_pass.cc +++ b/src/ge/graph/passes/identity_pass.cc @@ -18,26 +18,35 @@ #include #include - #include "framework/common/debug/ge_log.h" -#include "framework/common/ge_inner_error_codes.h" #include "graph/common/omg_util.h" +#include "graph/utils/node_utils.h" namespace ge { namespace { /// -/// A `Identity` node may after a `Switch` node and has control-dependency-out nodes. +/// 1. A `Identity` node may after a `Switch` node and has control-dependency-out nodes. /// Or a `Identity` node may before a `Merge` node and has control-dependency-in nodes. /// The identity nodes are used to represent control dependencies in condition branch, and can not be deleted. -/// +/// 2. Check identity is near subgraph. +/// Eg. As output of Data node in subgraph +/// or as input of Netoutput of subgraph +/// or as input of one node with subgraph +/// or as output of one node with subgraph Status CheckIdentityUsable(const NodePtr &node, bool &usable) { std::string node_type; for (auto &in_node : node->GetInDataNodes()) { - auto ret = GetOriginalType(in_node, node_type); - if (ret != SUCCESS) { - GELOGE(ret, "Failed to get node type from node %s", node->GetName().c_str()); - return ret; + auto in_node_opdesc = in_node->GetOpDesc(); + GE_CHECK_NOTNULL(in_node_opdesc); + // near entrance of subgraph || near subgraph + if ((in_node->GetType() == DATA && NodeUtils::IsSubgraphInput(in_node)) || + !in_node_opdesc->GetSubgraphInstanceNames().empty()) { + usable = true; + return SUCCESS; } + + GE_CHK_STATUS_RET(GetOriginalType(in_node, node_type), "Failed to get node type from node %s", + node->GetName().c_str()); if ((node_type != SWITCH) && (node_type != REFSWITCH)) { GELOGD("skip identity %s connected to switch", node->GetName().c_str()); break; @@ -49,11 +58,15 @@ Status CheckIdentityUsable(const NodePtr &node, bool &usable) { } } for (auto &out_node : node->GetOutDataNodes()) { - auto ret = GetOriginalType(out_node, node_type); - if (ret != SUCCESS) { - GELOGE(ret, "Failed to get node type from node %s", node->GetName().c_str()); - return ret; + auto out_node_opdesc = out_node->GetOpDesc(); + GE_CHECK_NOTNULL(out_node_opdesc); + // near output of subgraph || near subgraph + if (NodeUtils::IsSubgraphOutput(out_node) || !out_node_opdesc->GetSubgraphInstanceNames().empty()) { + usable = true; + return SUCCESS; } + GE_CHK_STATUS_RET(GetOriginalType(out_node, node_type), "Failed to get node type from node %s", + node->GetName().c_str()); if ((node_type != MERGE) && (node_type != REFMERGE)) { GELOGD("skip identity %s connected to merge", node->GetName().c_str()); break; @@ -79,7 +92,7 @@ Status IdentityPass::Run(NodePtr &node) { GELOGE(status_ret, "Identity pass get original type fail."); return status_ret; } - if ((type != IDENTITY) && (type != IDENTITYN)) { + if ((type != IDENTITY) && (type != IDENTITYN) && (type != READVARIABLEOP)) { return SUCCESS; } diff --git a/src/ge/graph/passes/infershape_pass.cc b/src/ge/graph/passes/infershape_pass.cc index 18767cea..8b44d31b 100644 --- a/src/ge/graph/passes/infershape_pass.cc +++ b/src/ge/graph/passes/infershape_pass.cc @@ -15,7 +15,7 @@ */ #include "graph/passes/infershape_pass.h" - +#include "common/util/error_manager/error_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" #include "graph/shape_refiner.h" @@ -24,6 +24,8 @@ namespace ge { Status InferShapePass::Run(NodePtr &node) { auto ret = ShapeRefiner::InferShapeAndType(node, !OptionExists(kOptimizeAfterSubGraph)); if (ret != GRAPH_SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E35003", {"opname", "err_msg"}, + {node->GetName(), "check your model!"}); GELOGE(GE_GRAPH_INFERSHAPE_FAILED, "infershape failed. node: %s", node->GetName().c_str()); return GE_GRAPH_INFERSHAPE_FAILED; } diff --git a/src/ge/graph/passes/input_output_connection_identify_pass.cc b/src/ge/graph/passes/input_output_connection_identify_pass.cc new file mode 100644 index 00000000..45560bf5 --- /dev/null +++ b/src/ge/graph/passes/input_output_connection_identify_pass.cc @@ -0,0 +1,193 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/input_output_connection_identify_pass.h" + +#include +#include +#include +#include +#include + +#include "common/ge/ge_util.h" +#include "common/ge_inner_error_codes.h" +#include "framework/common/debug/ge_log.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" + +using std::map; +using std::string; +using std::vector; + +namespace ge { +namespace { +inline bool IsDataOp(const std::string &node_type) { + return (node_type == DATA_TYPE) || (node_type == AIPP_DATA_TYPE) || (node_type == ANN_DATA_TYPE); +} +} // namespace + +Status InputOutputConnectionIdentifyPass::Run(ComputeGraphPtr graph) { + if (graph == nullptr) { + GELOGE(PARAM_INVALID, "Input param graph is null, skip identification of nodes that connect to input and output."); + return PARAM_INVALID; + } + + if (graph->GetParentGraph() != nullptr) { + GELOGD("Current graph %s is a subgraph, skip identification of nodes that connect to input and output.", + graph->GetName().c_str()); + return SUCCESS; + } + + GELOGD("Start to identify nodes that connect to input and output."); + if (graph->TopologicalSorting() != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Graph topological sort failed."); + return INTERNAL_ERROR; + } + + if (GraphUtils::GetRefMapping(graph, symbol_to_anchors_, anchor_to_symbol_) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get ref-mapping for graph %s failed.", graph->GetName().c_str()); + return INTERNAL_ERROR; + } + + map> connect_input_node_idx_map; + map> connect_output_node_idx_map; + Status status = SUCCESS; + for (const NodePtr &node : graph->GetDirectNode()) { + // Not only node type Data is determined. + if (IsDataOp(node->GetType())) { + GELOGD("Find nodes that connect to root graph input node: %s.", node->GetName().c_str()); + status = ProcessInputNode(node, connect_input_node_idx_map, connect_output_node_idx_map); + if (status != SUCCESS) { + GELOGE(status, "Failed to process nodes that connect to input node: %s.", node->GetName().c_str()); + return status; + } + } + + if (node->GetType() == NETOUTPUT) { + GELOGD("Find nodes that connect to root graph output node: %s.", node->GetName().c_str()); + status = ProcessOutputNode(node, connect_input_node_idx_map, connect_output_node_idx_map); + if (status != SUCCESS) { + GELOGE(status, "Failed to process nodes that connect to output node: %s.", node->GetName().c_str()); + return status; + } + } + } + + status = SetNodeAttrOfConnectingInputOutput(connect_input_node_idx_map, connect_output_node_idx_map); + if (status != SUCCESS) { + GELOGE(status, "Failed to set attr for nodes that connect to input and output."); + return status; + } + + GELOGD("Success to identify nodes that connect to input and output."); + return SUCCESS; +} + +Status InputOutputConnectionIdentifyPass::ProcessInputNode(const NodePtr &node, + map> &connect_input_node_idx, + map> &connect_output_node_idx) { + GE_CHECK_NOTNULL(node); + for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { + // The return ptr of GetAllOutDataAnchors is always valid. + auto anchor_iter = anchor_to_symbol_.find(NodeIndexIO(node, out_data_anchor->GetIdx(), kOut).ToString()); + if (anchor_iter == anchor_to_symbol_.end()) { + GELOGW("Current node: %s out_data_anchor: %d is invalid, can not find related symbol.", node->GetName().c_str(), + out_data_anchor->GetIdx()); + continue; + } + + const string &symbol = anchor_iter->second; + auto status = UpdateNodeIdxMap(symbol, connect_input_node_idx, connect_output_node_idx); + if (status != SUCCESS) { + GELOGE(status, "Failed to update node anchor_index map."); + return status; + } + } + return SUCCESS; +} + +Status InputOutputConnectionIdentifyPass::UpdateNodeIdxMap(const string &symbol_string, + map> &connect_input_node_idx, + map> &connect_output_node_idx) { + auto symbol_iter = symbol_to_anchors_.find(symbol_string); + if (symbol_iter == symbol_to_anchors_.end()) { + GELOGE(PARAM_INVALID, "Input param symbol string: %s is invalid.", symbol_string.c_str()); + return PARAM_INVALID; + } + const auto &node_index_io_list = symbol_iter->second; + for (const auto &node_index_io : node_index_io_list) { + if (node_index_io.io_type_ == kOut) { + // find node that has shared output memory. + connect_output_node_idx[node_index_io.node_].emplace_back(node_index_io.index_); + } else { + // find node that has shared input memory. + connect_input_node_idx[node_index_io.node_].emplace_back(node_index_io.index_); + } + } + return SUCCESS; +} + +Status InputOutputConnectionIdentifyPass::ProcessOutputNode(const NodePtr &node, + map> &connect_input_node_idx, + map> &connect_output_node_idx) { + GE_CHECK_NOTNULL(node); + for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { + // The return ptr of GetAllInDataAnchors is always valid. + auto anchor_iter = anchor_to_symbol_.find(NodeIndexIO(node, in_data_anchor->GetIdx(), kIn).ToString()); + if (anchor_iter == anchor_to_symbol_.end()) { + GELOGW("Current node: %s in_data_anchor: %d is invalid, can not find related symbol.", node->GetName().c_str(), + in_data_anchor->GetIdx()); + continue; + } + + const string &symbol = anchor_iter->second; + auto status = UpdateNodeIdxMap(symbol, connect_input_node_idx, connect_output_node_idx); + if (status != SUCCESS) { + GELOGE(status, "Failed to update node anchor_index map."); + return status; + } + } + return SUCCESS; +} + +Status InputOutputConnectionIdentifyPass::SetNodeAttrOfConnectingInputOutput( + const map> &connect_input_node_idx, + const map> &connect_output_node_idx) { + for (const auto &iter : connect_input_node_idx) { + GE_CHECK_NOTNULL(iter.first); + if (iter.first->GetOpDesc() != nullptr) { + if (!AttrUtils::SetListInt(iter.first->GetOpDesc(), ATTR_NAME_NODE_CONNECT_INPUT, iter.second)) { + GELOGE(INTERNAL_ERROR, "Failed to set attr %s for node %s.", ATTR_NAME_NODE_CONNECT_INPUT.c_str(), + iter.first->GetName().c_str()); + return INTERNAL_ERROR; + } + } + } + + for (const auto &iter : connect_output_node_idx) { + GE_CHECK_NOTNULL(iter.first); + if (iter.first->GetOpDesc() != nullptr) { + if (!AttrUtils::SetListInt(iter.first->GetOpDesc(), ATTR_NAME_NODE_CONNECT_OUTPUT, iter.second)) { + GELOGE(INTERNAL_ERROR, "Failed to set attr %s for node %s.", ATTR_NAME_NODE_CONNECT_OUTPUT.c_str(), + iter.first->GetName().c_str()); + return INTERNAL_ERROR; + } + } + } + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/input_output_connection_identify_pass.h b/src/ge/graph/passes/input_output_connection_identify_pass.h new file mode 100644 index 00000000..0dd32102 --- /dev/null +++ b/src/ge/graph/passes/input_output_connection_identify_pass.h @@ -0,0 +1,75 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_INPUT_OUTPUT_CONNECTION_IDENTIFY_PASS_H_ +#define GE_GRAPH_PASSES_INPUT_OUTPUT_CONNECTION_IDENTIFY_PASS_H_ + +#include +#include +#include "graph/graph.h" +#include "inc/graph_pass.h" + +namespace ge { +class InputOutputConnectionIdentifyPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph) override; + + private: + /// + /// Find all nodes that connect to input node. + /// @param [in] input node + /// @param [out] map of nodes and anchor index that connect to input + /// @param [out] map of nodes and anchor index that connect to output + /// @return Status + /// + Status ProcessInputNode(const NodePtr &node, std::map> &connect_input_node_idx, + std::map> &connect_output_node_idx); + + /// + /// Find all nodes that connect to output node. + /// @param [in] output node + /// @param [out] map of nodes and anchor index that connect to input + /// @param [out] map of nodes and anchor index that connect to output + /// @return Status + /// + Status ProcessOutputNode(const NodePtr &node, std::map> &connect_input_node_idx, + std::map> &connect_output_node_idx); + + /// + /// Update all nodes that have shared memory. + /// @param [in] symbol string + /// @param [out] map of nodes and anchor index that connect to input + /// @param [out] map of nodes and anchor index that connect to output + /// @return Status + /// + Status UpdateNodeIdxMap(const string &symbol_string, std::map> &connect_input_node_idx, + std::map> &connect_output_node_idx); + + /// + /// Set attr for all nodes that connect to input and output. + /// @param [in] map of nodes and anchor index that connect to input + /// @param [in] map of nodes and anchor index that connect to output + /// @return Status + /// + Status SetNodeAttrOfConnectingInputOutput(const std::map> &connect_input_node_idx, + const std::map> &connect_output_node_idx); + + // Members for ref mapping + std::map> symbol_to_anchors_; + std::map anchor_to_symbol_; +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_INPUT_OUTPUT_CONNECTION_IDENTIFY_PASS_H_ \ No newline at end of file diff --git a/src/ge/graph/passes/iterator_op_pass.cc b/src/ge/graph/passes/iterator_op_pass.cc index e1d452b1..1d11004d 100644 --- a/src/ge/graph/passes/iterator_op_pass.cc +++ b/src/ge/graph/passes/iterator_op_pass.cc @@ -73,14 +73,14 @@ Status IteratorOpPass::Run(ge::ComputeGraphPtr graph) { GE_IF_BOOL_EXEC(status != SUCCESS, GELOGW("Fail to Get var_desc of NODE_NAME_FLOWCTRL_LOOP_PER_ITER failed."); continue); Status ret; - ret = SetRtContext(rtContext_t(), RT_CTX_NORMAL_MODE); + ret = SetRtContext(graph->GetSessionID(), rtContext_t(), RT_CTX_NORMAL_MODE); // EOS will not be considered if ret is not SUCCESS. - GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGW("Set rt context RT_CTX_GEN_MODE failed."); continue); + GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGW("Set rt context RT_CTX_NORMAL_MODE failed."); continue); status = GetVariableValue(graph->GetSessionID(), ge_tensor_desc, NODE_NAME_FLOWCTRL_LOOP_PER_ITER, &loop_per_iter); - ret = SetRtContext(rtContext_t(), RT_CTX_GEN_MODE); + ret = SetRtContext(graph->GetSessionID(), rtContext_t(), RT_CTX_GEN_MODE); // The following process will be affected if ret is not SUCCESS. GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Set rt context RT_CTX_GEN_MODE failed."); return ret); @@ -108,7 +108,7 @@ Status IteratorOpPass::GetVariableValue(uint64_t session_id, const ge::GeTensorD // base_addr uint8_t *var_mem_base = VarManager::Instance(session_id)->GetVarMemoryBase(RT_MEMORY_HBM); GE_CHECK_NOTNULL(var_mem_base); - // offset + // offset + logic_base uint8_t *dev_ptr = nullptr; GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, &dev_ptr), "Get variable %s address failed.", var_name.c_str()); @@ -279,11 +279,11 @@ ge::OpDescPtr IteratorOpPass::CreateMemcpyAsyncOp(const ge::NodePtr &pre_node) { return op_desc; } -Status IteratorOpPass::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode) { +Status IteratorOpPass::SetRtContext(uint64_t session_id, rtContext_t rt_context, rtCtxMode_t mode) { GELOGI("set rt_context %d, device id:%u.", static_cast(mode), ge::GetContext().DeviceId()); GE_CHK_RT_RET(rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId())); GE_CHK_RT_RET(rtCtxSetCurrent(rt_context)); - RtContextUtil::GetInstance().AddrtContext(rt_context); + RtContextUtil::GetInstance().AddRtContext(session_id, rt_context); return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/passes/iterator_op_pass.h b/src/ge/graph/passes/iterator_op_pass.h index e403020c..78b951e6 100644 --- a/src/ge/graph/passes/iterator_op_pass.h +++ b/src/ge/graph/passes/iterator_op_pass.h @@ -64,7 +64,7 @@ class IteratorOpPass : public GraphPass { /// ge::OpDescPtr CreateMemcpyAsyncOp(const ge::NodePtr &pre_node); - Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); + Status SetRtContext(uint64_t session_id, rtContext_t rt_context, rtCtxMode_t mode); }; } // namespace ge #endif // GE_GRAPH_PASSES_ITERATOR_OP_PASS_H_ diff --git a/src/ge/graph/passes/link_gen_mask_nodes_pass.cc b/src/ge/graph/passes/link_gen_mask_nodes_pass.cc index ff150a54..63ca68a2 100644 --- a/src/ge/graph/passes/link_gen_mask_nodes_pass.cc +++ b/src/ge/graph/passes/link_gen_mask_nodes_pass.cc @@ -97,9 +97,16 @@ void LinkGenMaskNodesPass::GetAllGenMaskNodes(ComputeGraphPtr graph, vectorGetOpDesc() == nullptr) || (node->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { + continue; + } + auto in_data_nodes = node->GetInDataNodes(); if (in_data_nodes.size() > kGenMaskInputIndex) { NodePtr &gen_mask = in_data_nodes.at(kGenMaskInputIndex); + if ((gen_mask->GetOpDesc() == nullptr) || (gen_mask->GetOpDesc()->HasAttr(ATTR_NAME_STREAM_LABEL))) { + continue; + } if (AreAllInputsConst(gen_mask) && nodes_set.count(gen_mask) == 0) { gen_mask_nodes.emplace_back(gen_mask); nodes_set.emplace(gen_mask); diff --git a/src/ge/graph/passes/mark_graph_unknown_status_pass.cc b/src/ge/graph/passes/mark_graph_unknown_status_pass.cc new file mode 100644 index 00000000..7106e58c --- /dev/null +++ b/src/ge/graph/passes/mark_graph_unknown_status_pass.cc @@ -0,0 +1,35 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/mark_graph_unknown_status_pass.h" +#include "graph/utils/node_utils.h" + +namespace ge { +Status MarkGraphUnknownStatusPass::Run(ComputeGraphPtr graph) { + GE_CHECK_NOTNULL(graph); + bool is_unknown_shape = false; + for (const auto &node : graph->GetDirectNode()) { + GE_CHK_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), + "Get node[%s] shape status failed!", node->GetName().c_str()); + if (is_unknown_shape) { + break; + } + } + graph->SetGraphUnknownFlag(is_unknown_shape); + GELOGD("mark graph [%s] unknown status success! value is %d", graph->GetName().c_str(), is_unknown_shape); + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/src/ge/graph/passes/identify_reference_pass.h b/src/ge/graph/passes/mark_graph_unknown_status_pass.h similarity index 67% rename from src/ge/graph/passes/identify_reference_pass.h rename to src/ge/graph/passes/mark_graph_unknown_status_pass.h index 5f284b4c..662e321c 100644 --- a/src/ge/graph/passes/identify_reference_pass.h +++ b/src/ge/graph/passes/mark_graph_unknown_status_pass.h @@ -14,16 +14,15 @@ * limitations under the License. */ -#ifndef GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_ -#define GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_ - -#include "graph/passes/base_pass.h" +#ifndef GE_GRAPH_PASSES_MARK_GRAPH_UNKNOWN_STATUS_PASS_H_ +#define GE_GRAPH_PASSES_MARK_GRAPH_UNKNOWN_STATUS_PASS_H_ +#include "graph/graph.h" +#include "inc/graph_pass.h" namespace ge { -class IdentifyReferencePass : public BaseNodePass { +class MarkGraphUnknownStatusPass : public GraphPass { public: - Status Run(NodePtr &node) override; + Status Run(ComputeGraphPtr graph); }; } // namespace ge - -#endif // GE_GRAPH_PASSES_IDENTIFY_REFERENCE_PASS_H_ +#endif // GE_GRAPH_PASSES_MARK_GRAPH_UNKNOWN_STATUS_PASS_H_ diff --git a/src/ge/graph/passes/mark_same_addr_pass.cc b/src/ge/graph/passes/mark_same_addr_pass.cc new file mode 100644 index 00000000..0ed151d3 --- /dev/null +++ b/src/ge/graph/passes/mark_same_addr_pass.cc @@ -0,0 +1,66 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/mark_same_addr_pass.h" + +namespace ge { +bool MarkSameAddrPass::IsNextNodeExpected(const ge::NodePtr &cur_node, const vector &next_nodes, + int &out_anchor_idx) { + for (auto out_anchor : cur_node->GetAllOutDataAnchors()) { + if (out_anchor == nullptr) { + continue; + } + for (auto in_anchor : out_anchor->GetPeerInDataAnchors()) { + if (in_anchor == nullptr) { + continue; + } + auto dst_node = in_anchor->GetOwnerNode(); + if (dst_node == nullptr) { + continue; + } + if (std::count(next_nodes.begin(), next_nodes.end(), dst_node->GetType()) > 0) { + out_anchor_idx = out_anchor->GetIdx(); + GELOGD("Current node is %s, next node is %s.", cur_node->GetName().c_str(), dst_node->GetName().c_str()); + return true; + } + } + } + return false; +} + +Status MarkSameAddrPass::Run(ComputeGraphPtr graph) { + GELOGD("MarkSameAddrPass begin."); + GE_CHECK_NOTNULL(graph); + if (graph->GetGraphUnknownFlag()) { + GELOGD("Graph[%s] is unknown shape, do not need to set fixed addr attr.", graph->GetName().c_str()); + return SUCCESS; + } + + int out_anchor_idx = 0; + for (const ge::NodePtr &node : graph->GetDirectNode()) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + vector next_nodes = {STREAMSWITCH, STREAMSWITCHN, LABELSWITCHBYINDEX}; + if (IsNextNodeExpected(node, next_nodes, out_anchor_idx)) { + string tensor_name = op_desc->GetOutputNameByIndex(out_anchor_idx); + (void)ge::AttrUtils::SetStr(node->GetOpDesc(), ATTR_DYNAMIC_SHAPE_FIXED_ADDR, tensor_name); + (void)ge::AttrUtils::SetInt(node->GetOpDesc(), ATTR_DYNAMIC_SHAPE_FIXED_ADDR_INDEX, out_anchor_idx); + } + } + GELOGD("MarkSameAddrPass end."); + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/mark_same_addr_pass.h b/src/ge/graph/passes/mark_same_addr_pass.h new file mode 100644 index 00000000..ebfcf6b2 --- /dev/null +++ b/src/ge/graph/passes/mark_same_addr_pass.h @@ -0,0 +1,32 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/graph.h" +#include "inc/graph_pass.h" + +#ifndef GE_GRAPH_PASSES_MARK_SAME_ADDR_PASS_H_ +#define GE_GRAPH_PASSES_MARK_SAME_ADDR_PASS_H_ + +namespace ge { +class MarkSameAddrPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph); + + private: + bool IsNextNodeExpected(const ge::NodePtr &cur_node, const vector &next_nodes, int &out_anchor_idx); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_MARK_SAME_ADDR_PASS_H_ diff --git a/src/ge/graph/passes/memcpy_addr_async_pass.cc b/src/ge/graph/passes/memcpy_addr_async_pass.cc new file mode 100644 index 00000000..7cbacc23 --- /dev/null +++ b/src/ge/graph/passes/memcpy_addr_async_pass.cc @@ -0,0 +1,245 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/memcpy_addr_async_pass.h" + +#include "common/ge/ge_util.h" +#include "framework/common/debug/log.h" +#include "graph/utils/node_utils.h" + +namespace ge { +Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { + GE_CHECK_NOTNULL(graph); + for (auto &node : graph->GetAllNodes()) { + auto op_desc = node->GetOpDesc(); + GE_IF_BOOL_EXEC(op_desc == nullptr, continue); + + if (op_desc->GetType() == STREAMSWITCHN || op_desc->GetType() == STREAMMERGE) { + Status ret = AddMemcpyAddrAsyncNode(graph, node); + if (ret != SUCCESS) { + GELOGE(ret, "AddMemcpyAddrAsyncNode failed."); + return ret; + } + } + } + return SUCCESS; +} + +Status MemcpyAddrAsyncPass::AddMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, const NodePtr &node) { + GELOGI("Start AddMemcpyAddrAsyncNode for %s.", node->GetName().c_str()); + for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + NodePtr in_node = peer_out_anchor->GetOwnerNode(); + + if (in_node->GetType() == DATA) { + ComputeGraphPtr owner_graph = in_node->GetOwnerComputeGraph(); + GE_CHECK_NOTNULL(owner_graph); + // Data is in parent_graph + if (owner_graph->GetParentGraph() == nullptr) { + GELOGI("Need to insert MemcpyAddrAsync directly when data in parent graph."); + NodePtr memcpy_addr_async_node = CreateMemcpyAddrAsyncNode(graph, peer_out_anchor, node); + GE_IF_BOOL_EXEC(memcpy_addr_async_node == nullptr, GELOGE(INTERNAL_ERROR, "CreateMemcpyAddrAsyncNode failed."); + return INTERNAL_ERROR); + + Status ret = InsertMemcpyAddrAsyncNode(peer_out_anchor, in_data_anchor, memcpy_addr_async_node); + GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "InsertMemcpyAddrAsyncNode failed."); return ret); + } else { + uint32_t parent_index = 0; + if (!AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(INTERNAL_ERROR, "Failed to get parent index of %s", in_node->GetName().c_str()); + return INTERNAL_ERROR; + } + // Data is in sub_graph + GELOGI("Need to find data in parent graph, then insert MemcpyAddrAsync."); + NodePtr parent_node = owner_graph->GetParentNode(); + user_data_for_known_ = in_node; + out_of_user_data_for_known_ = node; + peer_out_anchor_for_known_ = peer_out_anchor; + in_anchor_for_known_ = in_data_anchor; + FindUserData(parent_node, parent_index); + if (find_user_data_) { + GELOGI("Insert memcpy_addr_async for non_dynamic."); + GE_CHECK_NOTNULL(peer_out_anchor_); + NodePtr memcpy_addr_async_node = CreateMemcpyAddrAsyncNode(graph, peer_out_anchor_, out_of_user_data_); + GE_IF_BOOL_EXEC(memcpy_addr_async_node == nullptr, + GELOGE(INTERNAL_ERROR, "CreateMemcpyAddrAsyncNode failed."); + return INTERNAL_ERROR); + + Status ret = InsertMemcpyAddrAsyncNode(peer_out_anchor_, in_anchor_, memcpy_addr_async_node); + GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "InsertMemcpyAddrAsyncNode failed."); return ret); + } + if (find_user_data_for_known_) { + GELOGI("Insert memcpy_addr_async for known graph."); + auto sub_graph = user_data_for_known_->GetOwnerComputeGraph(); + NodePtr memcpy_addr_async_node = + CreateMemcpyAddrAsyncNode(sub_graph, peer_out_anchor_for_known_, out_of_user_data_for_known_); + GE_IF_BOOL_EXEC(memcpy_addr_async_node == nullptr, + GELOGE(INTERNAL_ERROR, "CreateMemcpyAddrAsyncNode for known failed."); + return INTERNAL_ERROR); + + Status ret = + InsertMemcpyAddrAsyncNode(peer_out_anchor_for_known_, in_anchor_for_known_, memcpy_addr_async_node); + GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "InsertMemcpyAddrAsyncNode for known failed."); return ret); + } + } + } + } + return SUCCESS; +} + +void MemcpyAddrAsyncPass::FindUserDataForKnown(const NodePtr &parent_node, uint32_t &parent_index) { + GELOGI("Start FindUserDataForKnown of %s.", parent_node->GetName().c_str()); + if (user_data_for_known_->GetOpDesc() == nullptr) { + GELOGI("Cannot get op_desc of %s.", user_data_for_known_->GetName().c_str()); + return; + } + string src_var_name; + if (ge::AttrUtils::GetStr(user_data_for_known_->GetOpDesc(), REF_VAR_SRC_VAR_NAME, src_var_name)) { + GELOGI("The data in known graph is variable, no need to insert memcpy_addr_async."); + find_user_data_for_known_ = false; + return; + } else { + find_user_data_for_known_ = true; + } +} + +void MemcpyAddrAsyncPass::FindUserDataForNonDynamic(const ge::NodePtr &parent_node, uint32_t &parent_index) { + GELOGI("Start to FindUserDataForNonDynamic of %s.", parent_node->GetName().c_str()); + InDataAnchorPtr in_data_anchor = parent_node->GetInDataAnchor(parent_index); + OutDataAnchorPtr out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(out_anchor == nullptr, + GELOGE(INTERNAL_ERROR, "Cannot find out_anchor of %s.", parent_node->GetName().c_str()); + return ); + NodePtr in_node = out_anchor->GetOwnerNode(); + GELOGI("in_node of parent_node is %s.", in_node->GetName().c_str()); + if (in_node->GetType() == DATA) { + if (in_node->GetOwnerComputeGraph()->GetParentGraph() != nullptr) { + // DATA is in sub graph again, update user_data of known firstly + user_data_for_known_ = in_node; + out_of_user_data_for_known_ = parent_node; + peer_out_anchor_for_known_ = out_anchor; + in_anchor_for_known_ = in_data_anchor; + NodePtr pre_in_node = in_node->GetOwnerComputeGraph()->GetParentNode(); + if (!AttrUtils::GetInt(in_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(INTERNAL_ERROR, "Failed to refresh parent index of %s", in_node->GetName().c_str()); + return; + } + FindUserData(pre_in_node, parent_index); + } else { + // DATA is in parent graph and not has input + user_data_ = in_node; + out_of_user_data_ = parent_node; + peer_out_anchor_ = out_anchor; + in_anchor_ = in_data_anchor; + find_user_data_ = true; + GELOGI("%s connect with %s, will insert memcpyaddr.", user_data_->GetName().c_str(), + out_of_user_data_->GetName().c_str()); + } + } else if (in_node->GetType() == IF || in_node->GetType() == WHILE || in_node->GetType() == CASE) { + if (!AttrUtils::GetInt(parent_node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { + GELOGE(INTERNAL_ERROR, "Failed to refresh parent index of %s", in_node->GetName().c_str()); + return; + } + FindUserData(in_node, parent_index); + } else { + GELOGI("%s connect with %s, which is not user_data.", parent_node->GetName().c_str(), in_node->GetName().c_str()); + find_user_data_ = false; + } +} + +void MemcpyAddrAsyncPass::FindUserData(const NodePtr &parent_node, uint32_t &parent_index) { + auto parent_op_desc = parent_node->GetOpDesc(); + if (parent_op_desc == nullptr) { + GELOGI("Cannot get op_desc of %s.", parent_node->GetName().c_str()); + return; + } + bool is_unknown_shape = false; + if (parent_node->GetType() == PARTITIONEDCALL && + AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape) && !is_unknown_shape) { + FindUserDataForKnown(parent_node, parent_index); + } else { + FindUserDataForNonDynamic(parent_node, parent_index); + } +} + +NodePtr MemcpyAddrAsyncPass::CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, + const OutDataAnchorPtr &out_data_anchor, + const NodePtr &out_of_user_data) { + GELOGI("Start CreateMemcpyAddrAsyncNode."); + OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "Op_desc of pre node is invalid."); + std::string node_name = pre_op_desc->GetName() + "_" + MEMCPYADDRASYNC; + + OpDescPtr op_desc = MakeShared(node_name, MEMCPYADDRASYNC); + GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); + + if (op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add memcpy_addr_async input desc failed."); + return nullptr; + } + + if (op_desc->AddOutputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add memcpy_addr_async output desc failed."); + return nullptr; + } + + int64_t stream_id = out_of_user_data->GetOpDesc()->GetStreamId(); + op_desc->SetStreamId(stream_id); + GELOGI("SetStreamId: Node %s assign stream is %ld.", op_desc->GetName().c_str(), stream_id); + bool labeled_input = false; + (void)ge::AttrUtils::GetBool(out_of_user_data->GetOpDesc(), ATTR_NAME_NODE_CONNECT_INPUT, labeled_input); + if (labeled_input) { + if (!ge::AttrUtils::SetBool(out_of_user_data->GetOpDesc(), ATTR_NAME_NODE_CONNECT_INPUT, false)) { + GELOGE(FAILED, "Failed to unset attr %s for node %s.", ATTR_NAME_NODE_CONNECT_INPUT.c_str(), + out_of_user_data->GetName().c_str()); + return nullptr; + } + if (!ge::AttrUtils::SetBool(op_desc, ATTR_NAME_NODE_CONNECT_INPUT, true)) { + GELOGE(FAILED, "Failed to set attr %s for node %s.", ATTR_NAME_NODE_CONNECT_INPUT.c_str(), + op_desc->GetName().c_str()); + return nullptr; + } + } + + NodePtr memcpy_addr_async_node = graph->AddNodeAfter(op_desc, out_data_anchor->GetOwnerNode()); + GE_CHECK_NOTNULL_EXEC(memcpy_addr_async_node, return nullptr); + + return memcpy_addr_async_node; +} + +Status MemcpyAddrAsyncPass::InsertMemcpyAddrAsyncNode(const OutDataAnchorPtr &out_anchor, + const InDataAnchorPtr &in_anchor, const NodePtr &node) { + // insert memcpy_addr of each user_data and out_of_user_data + if (GraphUtils::RemoveEdge(out_anchor, in_anchor) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Remove edge of %s and %s failed.", out_anchor->GetOwnerNode()->GetName().c_str(), + in_anchor->GetOwnerNode()->GetName().c_str()); + return INTERNAL_ERROR; + } + if (GraphUtils::AddEdge(out_anchor, node->GetInDataAnchor(0)) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add edge of %s and %s failed.", out_anchor->GetOwnerNode()->GetName().c_str(), + node->GetName().c_str()); + return INTERNAL_ERROR; + } + if (GraphUtils::AddEdge(node->GetOutDataAnchor(0), in_anchor) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add edge of %s and %s failed.", node->GetName().c_str(), + in_anchor->GetOwnerNode()->GetName().c_str()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + +} // namespace ge diff --git a/src/ge/graph/passes/memcpy_addr_async_pass.h b/src/ge/graph/passes/memcpy_addr_async_pass.h new file mode 100644 index 00000000..9d99e505 --- /dev/null +++ b/src/ge/graph/passes/memcpy_addr_async_pass.h @@ -0,0 +1,51 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_MEMCPY_ADDR_ASYNC_PASS_H_ +#define GE_GRAPH_PASSES_MEMCPY_ADDR_ASYNC_PASS_H_ + +#include "inc/graph_pass.h" + +namespace ge { + +class MemcpyAddrAsyncPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph); + + private: + Status AddMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, const NodePtr &node); + void FindUserData(const NodePtr &node, uint32_t &parent_index); + void FindUserDataForKnown(const NodePtr &parent_node, uint32_t &parent_index); + void FindUserDataForNonDynamic(const ge::NodePtr &parent_node, uint32_t &parent_index); + + NodePtr CreateMemcpyAddrAsyncNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, + const NodePtr &out_of_user_data); + Status InsertMemcpyAddrAsyncNode(const OutDataAnchorPtr &out_anchor, const InDataAnchorPtr &in_anchor, + const NodePtr &node); + + NodePtr user_data_; + NodePtr out_of_user_data_; + OutDataAnchorPtr peer_out_anchor_; + InDataAnchorPtr in_anchor_; + bool find_user_data_ = false; + NodePtr user_data_for_known_; + NodePtr out_of_user_data_for_known_; + OutDataAnchorPtr peer_out_anchor_for_known_; + InDataAnchorPtr in_anchor_for_known_; + bool find_user_data_for_known_ = false; +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_MEMCPY_ADDR_ASYNC_PASS_H_ diff --git a/src/ge/graph/passes/merge_pass.cc b/src/ge/graph/passes/merge_pass.cc index f4114474..8e691518 100644 --- a/src/ge/graph/passes/merge_pass.cc +++ b/src/ge/graph/passes/merge_pass.cc @@ -66,7 +66,7 @@ Status MergePass::Run(NodePtr &node) { AddRePassNode(end_node); } for (const auto &delete_node : del_nodes) { - AddNodeDeleted(delete_node.get()); + AddNodeDeleted(delete_node); } return ret; } diff --git a/src/ge/graph/passes/merge_to_stream_merge_pass.cc b/src/ge/graph/passes/merge_to_stream_merge_pass.cc new file mode 100644 index 00000000..b785ddfa --- /dev/null +++ b/src/ge/graph/passes/merge_to_stream_merge_pass.cc @@ -0,0 +1,234 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/merge_to_stream_merge_pass.h" +#include "common/ge/ge_util.h" +#include "ge/ge_api_types.h" +#include "graph/common/omg_util.h" + +namespace ge { +Status MergeToStreamMergePass::Run(ComputeGraphPtr graph) { + GELOGD("MergeToStreamMergePass Enter"); + + bypass_nodes_.clear(); + for (const auto &node : graph->GetDirectNode()) { + if ((node->GetType() != MERGE) && (node->GetType() != REFMERGE)) { + continue; + } + + OpDescPtr merge_op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(merge_op_desc); + if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { + GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, node, true), "Merge add memcpy node failed."); + GE_CHK_STATUS_RET(SetStreamLabel(node, node->GetName()), "Set stream label failed"); + } else { + GE_CHK_STATUS_RET(ReplaceMergeNode(graph, node), "Add StreamMerge node failed."); + } + } + + for (const auto &node : bypass_nodes_) { + GE_CHK_BOOL_EXEC(GraphUtils::RemoveNodeWithoutRelink(graph, node) == GRAPH_SUCCESS, return FAILED, + "Remove merge node failed."); + } + + GELOGD("MergeToStreamMergePass Leave"); + return SUCCESS; +} + +/// +/// @brief Replace Merge Op +/// @param [in] graph +/// @param [in] merge_node +/// @return Status +/// +Status MergeToStreamMergePass::ReplaceMergeNode(const ComputeGraphPtr &graph, const NodePtr &merge_node) { + OpDescPtr merge_op_desc = merge_node->GetOpDesc(); + GE_CHECK_NOTNULL(merge_op_desc); + + const std::string &node_name = merge_node->GetName(); + GELOGI("Create StreamMerge Op, name=%s.", node_name.c_str()); + OpDescPtr op_desc = MakeShared(node_name, STREAMMERGE); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, StreamMerge:%s.", node_name.c_str()); + return FAILED; + } + + for (const InDataAnchorPtr &in_anchor : merge_node->GetAllInDataAnchors()) { + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(merge_op_desc->GetInputDesc(in_anchor->GetIdx())) == GRAPH_SUCCESS, + return FAILED, "Create StreamMerge op: add input desc failed."); + } + + for (const OutDataAnchorPtr &out_anchor : merge_node->GetAllOutDataAnchors()) { + GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(merge_op_desc->GetOutputDesc(out_anchor->GetIdx())) == GRAPH_SUCCESS, + return FAILED, "Create StreamMerge op: add output desc failed."); + } + + NodePtr stream_merge = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(stream_merge != nullptr, return FAILED, "Insert StreamMerge node failed."); + GE_CHK_STATUS_RET(MoveEdges(merge_node, stream_merge), "Move edges failed."); + bypass_nodes_.insert(merge_node); + + if (merge_op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { + std::string next_iteration_name; + GE_IF_BOOL_EXEC(!AttrUtils::GetStr(merge_op_desc, ATTR_NAME_NEXT_ITERATION, next_iteration_name), + GELOGE(INTERNAL_ERROR, "Get ATTR_NAME_NEXT_ITERATION failed"); + return INTERNAL_ERROR); + GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "Set next iteration failed"); + } + + return AddMemcpyAsyncNodes(graph, stream_merge, false); +} + +/// +/// @brief Add MemcpyAsync Op as StreamMerge in_node +/// @param [in] graph +/// @param [in] node +/// @param [in] multi_batch_flag +/// @return Status +/// +Status MergeToStreamMergePass::AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, const NodePtr &node, + bool multi_batch_flag) { + GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); + for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + NodePtr in_node = peer_out_anchor->GetOwnerNode(); + const std::string &type = in_node->GetType(); + // For WhileLoop no need memcpy & active for merge. + GE_IF_BOOL_EXEC((type == ENTER) || (type == REFENTER) || (type == NEXTITERATION) || (type == REFNEXTITERATION), + continue); + + const std::string &memcpy_name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()); + NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, memcpy_name, peer_out_anchor, multi_batch_flag); + GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return FAILED, "Create MemcpyAsync node failed."); + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "MemcpyAsync node remove edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, memcpy_node->GetInDataAnchor(0)), + "MemcpyAsync node add edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(memcpy_node->GetOutDataAnchor(0), in_data_anchor), + "MemcpyAsync node add edge failed."); + + NodePtr active_node = CreateActiveNode(graph, memcpy_node); + GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), node->GetInControlAnchor()), + "StreamActive add ctrl edge failed."); + if (SetActiveLabelList(active_node, {node->GetName()}) != SUCCESS) { + GELOGE(FAILED, "SetActiveLabelList for node %s failed.", active_node->GetName().c_str()); + return FAILED; + } + } + + return SUCCESS; +} + +/// +/// @brief Add MemcpyAsync Node +/// @param [in] graph +/// @param [in] name +/// @param [in] out_data_anchor +/// @param [in] multi_batch_flag +/// @return ge::NodePtr +/// +NodePtr MergeToStreamMergePass::CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, + const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag) { + GE_CHK_BOOL_EXEC(out_data_anchor != nullptr, return nullptr, "Param of input node is null."); + OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); + + const std::string &memcpy_type = multi_batch_flag ? MEMCPYADDRASYNC : MEMCPYASYNC; + const std::string &node_name = name + "_" + memcpy_type; + GELOGI("Create MemcpyAsync op:%s.", node_name.c_str()); + OpDescPtr op_desc = MakeShared(node_name, memcpy_type); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, MemcpyAsync:%s.", node_name.c_str()); + return nullptr; + } + + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, + return nullptr, "Create MemcpyAsync op: add input desc failed."); + GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, + return nullptr, "Create MemcpyAsync op: add output desc failed."); + + return graph->AddNode(op_desc); +} + +/// +/// @brief Create Active Op +/// @param [in] graph +/// @param [in] node +/// @return ge::NodePtr +/// +NodePtr MergeToStreamMergePass::CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node) { + const std::string &node_name = node->GetName() + "_" + STREAMACTIVE; + GELOGI("Create StreamActive op:%s.", node_name.c_str()); + OpDescPtr op_desc = MakeShared(node_name, STREAMACTIVE); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, StreamActive:%s.", node_name.c_str()); + return nullptr; + } + + NodePtr active_node = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(active_node != nullptr, return nullptr, "Create StreamActive node failed."); + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(node->GetOutControlAnchor(), active_node->GetInControlAnchor()) != SUCCESS, + GELOGE(INTERNAL_ERROR, "add edge failed"); + return nullptr); + GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS, + GELOGE(INTERNAL_ERROR, "set switch branch node label failed"); + return nullptr); + + return active_node; +} + +/// +/// @brief move edges from old_node to new_node +/// @param [in] old_node +/// @param [in] new_node +/// @return Status +/// +Status MergeToStreamMergePass::MoveEdges(const NodePtr &old_node, const NodePtr &new_node) { + for (const InDataAnchorPtr &in_data_anchor : old_node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "Merge remove in data edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, new_node->GetInDataAnchor(in_data_anchor->GetIdx())), + "StreamMerge add in data edge failed."); + } + + for (const OutDataAnchorPtr &out_data_anchor : old_node->GetAllOutDataAnchors()) { + for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor), "Merge remove out data edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutDataAnchor(out_data_anchor->GetIdx()), peer_in_anchor), + "StreamMerge add out data edge failed."); + } + } + + for (const NodePtr &in_ctrl_node : old_node->GetInControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctrl_node->GetOutControlAnchor(), old_node->GetInControlAnchor()), + "Merge remove in ctrl edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(in_ctrl_node->GetOutControlAnchor(), new_node->GetInControlAnchor()), + "StreamMerge add in ctrl edge failed."); + } + + for (const NodePtr &out_ctrl_node : old_node->GetOutControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()), + "Merge remove out ctrl edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_ctrl_node->GetInControlAnchor()), + "StreamMerge add out ctrl edge failed."); + } + + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/merge_to_stream_merge_pass.h b/src/ge/graph/passes/merge_to_stream_merge_pass.h new file mode 100644 index 00000000..9f713989 --- /dev/null +++ b/src/ge/graph/passes/merge_to_stream_merge_pass.h @@ -0,0 +1,75 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_MERGE_TO_STREAM_MERGE_PASS_H_ +#define GE_GRAPH_PASSES_MERGE_TO_STREAM_MERGE_PASS_H_ + +#include "inc/graph_pass.h" + +namespace ge { +class MergeToStreamMergePass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph); + + private: + /// + /// @brief Replace Merge Op + /// @param [in] graph + /// @param [in] merge_node + /// @return Status + /// + Status ReplaceMergeNode(const ComputeGraphPtr &graph, const NodePtr &merge_node); + + /// + /// @brief Add MemcpyAsync Op as StreamMerge in_node + /// @param [in] graph + /// @param [in] node + /// @param [in] multi_batch_flag + /// @return Status + /// + Status AddMemcpyAsyncNodes(const ComputeGraphPtr &graph, const NodePtr &node, bool multi_batch_flag); + + /// + /// @brief Add MemcpyAsync Node + /// @param [in] graph + /// @param [in] name + /// @param [in] out_data_anchor + /// @param [in] multi_batch_flag + /// @return ge::NodePtr + /// + NodePtr CreateMemcpyAsyncNode(const ComputeGraphPtr &graph, const std::string &name, + const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); + + /// + /// @brief Create Active Op + /// @param [in] graph + /// @param [in] node + /// @return ge::NodePtr + /// + NodePtr CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node); + + /// + /// @brief move edges from old_node to new_node + /// @param [in] old_node + /// @param [in] new_node + /// @return Status + /// + Status MoveEdges(const NodePtr &old_node, const NodePtr &new_node); + + std::set bypass_nodes_; +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_MERGE_TO_STREAM_MERGE_PASS_H_ diff --git a/src/ge/graph/passes/multi_batch_pass.cc b/src/ge/graph/passes/multi_batch_pass.cc index bb0050be..7d484a25 100644 --- a/src/ge/graph/passes/multi_batch_pass.cc +++ b/src/ge/graph/passes/multi_batch_pass.cc @@ -29,10 +29,13 @@ #include "graph/debug/ge_attr_define.h" #include "graph/utils/type_utils.h" +using std::string; +using std::vector; + namespace ge { Status MultiBatchPass::Run(ComputeGraphPtr graph) { GELOGD("MultiBatchPass Enter"); - GE_CHECK_NOTNULL(graph); + if (graph->GetParentGraph() != nullptr) { GELOGI("Subgraph %s skip the MultiBatchPass.", graph->GetName().c_str()); return SUCCESS; @@ -44,26 +47,32 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { return SUCCESS; } if (ret != SUCCESS) { - GELOGE(FAILED, "FindPredValue fail."); + GELOGE(FAILED, "FindPredValue failed."); + return FAILED; + } + + if (GetDynamicType() != SUCCESS) { + GELOGE(FAILED, "Get dynamic type failed."); return FAILED; } std::vector> batch_shape; - if (!CheckSwitchN(batch_shape)) { - GELOGE(FAILED, "CheckSwitchN fail."); + vector> combined_batch; + if (!CheckSwitchN(batch_shape, combined_batch)) { + GELOGE(FAILED, "CheckSwitchN failed."); return FAILED; } FindSwitchOutNodes(batch_shape.size()); - if (ReplaceSwitchN(graph, pred_value, batch_shape) != SUCCESS) { - GELOGE(FAILED, "Replace SwitchN nodes fail."); + if (ReplaceSwitchN(graph, pred_value, batch_shape, combined_batch) != SUCCESS) { + GELOGE(FAILED, "Replace SwitchN nodes failed."); return FAILED; } - for (NodePtr &node : bypass_nodes_) { - if (graph->RemoveNode(node) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Remove SwitchN nodes %s fail.", node->GetName().c_str()); + for (const NodePtr &node : bypass_nodes_) { + if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Remove SwitchN nodes %s failed.", node->GetName().c_str()); return FAILED; } } @@ -79,19 +88,19 @@ Status MultiBatchPass::Run(ComputeGraphPtr graph) { /// @return Status /// Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value) { - for (NodePtr &node : graph->GetDirectNode()) { + for (const NodePtr &node : graph->GetDirectNode()) { if (node->GetType() != SWITCHN) { continue; } InDataAnchorPtr in_data_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); if (in_data_anchor == nullptr) { - GELOGE(FAILED, "FindPredInput fail, in_data_anchor is null, node:%s.", node->GetName().c_str()); + GELOGE(FAILED, "FindPredInput failed, in_data_anchor is null, node:%s.", node->GetName().c_str()); return FAILED; } OutDataAnchorPtr pred_input = in_data_anchor->GetPeerOutAnchor(); if (pred_input == nullptr) { - GELOGE(FAILED, "FindPredInput fail, pred_input is null, node:%s.", node->GetName().c_str()); + GELOGE(FAILED, "FindPredInput failed, pred_input is null, node:%s.", node->GetName().c_str()); return FAILED; } @@ -110,7 +119,7 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor } if (pred_value == nullptr) { - GELOGE(FAILED, "FindPredInput fail, pred_value is null."); + GELOGE(FAILED, "FindPredInput failed, pred_value is null."); return FAILED; } @@ -118,15 +127,48 @@ Status MultiBatchPass::FindPredValue(const ComputeGraphPtr &graph, OutDataAnchor return SUCCESS; } +/// +/// @brief Get dynamic type: dynamic batch size: 1, dynamic image size: 2, dynamic dims: 3 +/// @return Status +/// +Status MultiBatchPass::GetDynamicType() { + for (const auto &switchn : switch_n_nodes_) { + auto switchn_desc = switchn->GetOpDesc(); + GE_CHECK_NOTNULL(switchn_desc); + int32_t dynamic_type = static_cast(FIXED); + if (!AttrUtils::GetInt(switchn_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) { + GELOGE(FAILED, "Get attr ATTR_DYNAMIC_TYPE of node: %s failed.", switchn->GetName().c_str()); + return FAILED; + } + if (dynamic_type == static_cast(FIXED)) { + GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE shouldn't be 0."); + return FAILED; + } + if (dynamic_type_ != static_cast(FIXED) && dynamic_type_ != dynamic_type) { + GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE of all switchn node should be same, while one is %d and another is %d.", + dynamic_type, dynamic_type_); + return FAILED; + } + dynamic_type_ = dynamic_type; + } + if (dynamic_type_ == static_cast(FIXED)) { + GELOGE(FAILED, "Attr ATTR_DYNAMIC_TYPE shouldn't be 0."); + return FAILED; + } + + return SUCCESS; +} + /// /// @brief Check SwitchN nodes /// @param [out] batch_shape +/// @param [out] combined_batch /// @return bool /// -bool MultiBatchPass::CheckSwitchN(std::vector> &batch_shape) { +bool MultiBatchPass::CheckSwitchN(vector> &batch_shape, vector> &combined_batch) { // Check if output_num of different SwitchN is same uint32_t batch_num = 0; - for (NodePtr &node : switch_n_nodes_) { + for (const NodePtr &node : switch_n_nodes_) { uint32_t tmp_num = node->GetAllOutDataAnchorsSize(); if (batch_num == 0) { batch_num = tmp_num; @@ -136,45 +178,79 @@ bool MultiBatchPass::CheckSwitchN(std::vector> &batch_shape } } + if (!GetBatchInfo(batch_num, batch_shape, combined_batch)) { + GELOGE(FAILED, "Get batch info failed."); + return false; + } + + if (batch_shape.empty()) { + GELOGE(FAILED, "batch_shape is empty."); + return false; + } + if (combined_batch.empty()) { + GELOGE(FAILED, "combined_batch is empty."); + return false; + } + size_t dim_num = batch_shape[0].size(); + size_t combined_dim_num = combined_batch[0].size(); + for (uint32_t i = 1; i < batch_num; i++) { + size_t tmp_dim_num = batch_shape[i].size(); + if (dim_num != tmp_dim_num) { + GELOGE(FAILED, "Dim num of batch_shape not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num); + return false; + } + size_t tmp_combined_dim_num = combined_batch[i].size(); + if (combined_dim_num != tmp_combined_dim_num) { + GELOGE(FAILED, "Dim num of combined_batch not equal, batch_0:%zu, batch_%u:%zu.", dim_num, i, tmp_dim_num); + return false; + } + } + + return true; +} + +/// +/// @brief Check SwitchN nodes +/// @param [in] batch_num +/// @param [out] batch_shape +/// @param [out] combined_batch +/// @return bool +/// +bool MultiBatchPass::GetBatchInfo(uint32_t batch_num, vector> &batch_shape, + vector> &combined_batch) { // Check if output_shape of different SwitchN is same - std::vector> idx_batch_shape; + vector> idx_batch_shape; + vector> idx_combined_batch; for (uint32_t i = 0; i < batch_num; i++) { idx_batch_shape.clear(); - for (NodePtr &node : switch_n_nodes_) { - std::vector output_dims; + idx_combined_batch.clear(); + for (const NodePtr &node : switch_n_nodes_) { OpDescPtr op_desc = node->GetOpDesc(); if (op_desc == nullptr) { - GELOGE(FAILED, "CheckDims fail, get op_desc fail, node: %s.", node->GetName().c_str()); + GELOGE(FAILED, "CheckDims failed, get op_desc failed, node: %s.", node->GetName().c_str()); return false; } + vector output_dims; if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_SWITCHN_PRED_VALUE, output_dims)) { - GELOGE(FAILED, "CheckDims fail, get attr ATTR_NAME_SWITCHN_PRED_VALUE fail, batch_index=%u.", i); + GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_SWITCHN_PRED_VALUE failed, batch_index=%u.", i); return false; } idx_batch_shape.emplace_back(output_dims); + output_dims.clear(); + if (!AttrUtils::GetListInt(op_desc->GetOutputDesc(i), ATTR_NAME_COMBINED_DYNAMIC_DIMS, output_dims)) { + GELOGE(FAILED, "CheckDims failed, get attr ATTR_NAME_COMBINED_DYNAMIC_DIMS failed, batch_index=%u.", i); + return false; + } + idx_combined_batch.emplace_back(output_dims); } if (!CheckDims(idx_batch_shape)) { - GELOGE(FAILED, "CheckDims fail, batch_index=%u.", i); + GELOGE(FAILED, "CheckDims failed, batch_index=%u.", i); return false; } batch_shape.emplace_back(idx_batch_shape[0]); + combined_batch.emplace_back(idx_combined_batch[0]); } - - // Check if dim_num of different batch is same - if (batch_shape.empty()) { - GELOGE(FAILED, "batch_shape is empty."); - return false; - } - uint32_t dim_num = batch_shape[0].size(); - for (uint32_t i = 1; i < batch_num; i++) { - uint32_t tmp_dim_num = batch_shape[i].size(); - if (dim_num != tmp_dim_num) { - GELOGE(FAILED, "dim_num not equal, batch_0:%u, batch_%u:%u.", dim_num, i, tmp_dim_num); - return false; - } - } - return true; } @@ -187,11 +263,11 @@ void MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { std::vector output_nodes; for (uint32_t i = 0; i < batch_num; i++) { output_nodes.clear(); - for (NodePtr &node : switch_n_nodes_) { + for (const NodePtr &node : switch_n_nodes_) { // idx is promised to be valid OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(i); GE_CHECK_NOTNULL_JUST_RETURN(out_data_anchor); - for (InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { output_nodes.emplace_back(peer_in_anchor->GetOwnerNode()); } } @@ -206,35 +282,37 @@ void MultiBatchPass::FindSwitchOutNodes(uint32_t batch_num) { /// @param [in] graph /// @param [in] pred_value /// @param [in] batch_shape +/// @param [in] combined_batch /// @return Status /// -Status MultiBatchPass::ReplaceSwitchN(ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value, - const std::vector> &batch_shape) { +Status MultiBatchPass::ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, + const vector> &batch_shape, + const vector> &combined_batch) { NodePtr pred_value_node = pred_value->GetOwnerNode(); // Create SwitchCase node - const std::string switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; - NodePtr switch_case = CreateSwitchCaseNode(graph, switch_case_name, pred_value, batch_shape); + const std::string &switch_case_name = pred_value_node->GetName() + "_" + STREAMSWITCHN; + NodePtr switch_case = CreateSwitchCaseNode(graph, switch_case_name, pred_value, batch_shape, combined_batch); if (switch_case == nullptr) { - GELOGE(FAILED, "CreateSwitchCaseNode %s fail.", switch_case_name.c_str()); + GELOGE(FAILED, "CreateSwitchCaseNode %s failed.", switch_case_name.c_str()); return FAILED; } - for (NodePtr &switch_n_node : switch_n_nodes_) { + for (const NodePtr &switch_n_node : switch_n_nodes_) { if (BypassSwitchN(switch_n_node, switch_case) != SUCCESS) { - GELOGE(FAILED, "Bypass SwitchN %s fail.", switch_case_name.c_str()); + GELOGE(FAILED, "Bypass SwitchN %s failed.", switch_case_name.c_str()); return FAILED; } } // Add switchCase input edge if (GraphUtils::AddEdge(pred_value, switch_case->GetInDataAnchor(0)) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add SwitchCase in_data_edge fail, %s->%s.", pred_value_node->GetName().c_str(), + GELOGE(FAILED, "Add SwitchCase in_data_edge failed, %s->%s.", pred_value_node->GetName().c_str(), switch_case->GetName().c_str()); return FAILED; } if (AttachLabel(switch_case) != SUCCESS) { - GELOGE(FAILED, "AttachLabel fail."); + GELOGE(FAILED, "AttachLabel failed."); return FAILED; } @@ -248,7 +326,7 @@ Status MultiBatchPass::ReplaceSwitchN(ComputeGraphPtr &graph, OutDataAnchorPtr & /// bool MultiBatchPass::CheckDims(const std::vector> &output_shape) const { if (output_shape.empty()) { - GELOGE(FAILED, "CheckDims fail: output_shape is empty."); + GELOGE(FAILED, "CheckDims failed: output_shape is empty."); return false; } @@ -257,7 +335,7 @@ bool MultiBatchPass::CheckDims(const std::vector> &output_s for (size_t i = 1; i < num; i++) { size_t tmp_dim_num = output_shape[i].size(); if (dim_num != tmp_dim_num) { - GELOGE(FAILED, "CheckDims fail: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num); + GELOGE(FAILED, "CheckDims failed: dim_num not equal, output_0:%zu, output_%zu:%zu.", dim_num, i, tmp_dim_num); return false; } } @@ -271,7 +349,7 @@ bool MultiBatchPass::CheckDims(const std::vector> &output_s for (size_t j = 1; j < num; j++) { int64_t tmp_dim_value = output_shape[j][i]; if (dim_value != tmp_dim_value) { - GELOGE(FAILED, "CheckDims fail: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i, + GELOGE(FAILED, "CheckDims failed: dim_value not equal, dim_index=%zu, dim_value_0:%ld, dim_value_%zu:%ld.", i, dim_value, j, tmp_dim_value); return false; } @@ -287,43 +365,54 @@ bool MultiBatchPass::CheckDims(const std::vector> &output_s /// @param [in] name /// @param [in] pred_value /// @param [in] batch_shape +/// @param [in] combined_batch /// @return ge::NodePtr /// -NodePtr MultiBatchPass::CreateSwitchCaseNode(ComputeGraphPtr &graph, const std::string &name, +NodePtr MultiBatchPass::CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, const OutDataAnchorPtr &pred_value, - const std::vector> &batch_shape) { + const vector> &batch_shape, + const vector> &combined_batch) { OpDescPtr op_desc = MakeShared(name, STREAMSWITCHN); if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc fail, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "Create op_desc failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } GELOGI("Create StreamSwitchN op:%s.", name.c_str()); OpDescPtr pred_desc = pred_value->GetOwnerNode()->GetOpDesc(); if (pred_desc == nullptr) { - GELOGE(FAILED, "Get pred_desc fail, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "Get pred_desc failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } if (op_desc->AddInputDesc(pred_desc->GetOutputDesc(pred_value->GetIdx())) != GRAPH_SUCCESS) { - GELOGE(FAILED, "AddInputDesc fail, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "AddInputDesc failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } NodePtr switch_case_node = graph->AddNode(op_desc); if (switch_case_node == nullptr) { - GELOGE(FAILED, "Create node fail, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "Create node failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } uint32_t batch_num = static_cast(batch_shape.size()); if (!AttrUtils::SetInt(op_desc, ATTR_NAME_BATCH_NUM, batch_num)) { - GELOGE(FAILED, "set attr ATTR_NAME_BATCH_NUM fail, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "set attr ATTR_NAME_BATCH_NUM failed, StreamSwitchN:%s.", name.c_str()); + return nullptr; + } + if (!AttrUtils::SetInt(op_desc, ATTR_DYNAMIC_TYPE, dynamic_type_)) { + GELOGE(FAILED, "Set attr ATTR_DYNAMIC_TYPE failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } for (uint32_t i = 0; i < batch_num; i++) { - const std::string attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); + const std::string &attr_name = ATTR_NAME_PRED_VALUE + "_" + std::to_string(i); if (!AttrUtils::SetListInt(op_desc, attr_name, batch_shape[i])) { - GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE fail, StreamSwitchN:%s.", name.c_str()); + GELOGE(FAILED, "set attr ATTR_NAME_PRED_VALUE failed, StreamSwitchN:%s.", name.c_str()); + return nullptr; + } + const string &attr_combined_batch = ATTR_NAME_COMBINED_BATCH + "_" + std::to_string(i); + if (!AttrUtils::SetListInt(op_desc, attr_combined_batch, combined_batch[i])) { + GELOGE(FAILED, "set attr ATTR_NAME_COMBINED_BATCH failed, StreamSwitchN:%s.", name.c_str()); return nullptr; } } @@ -337,43 +426,43 @@ NodePtr MultiBatchPass::CreateSwitchCaseNode(ComputeGraphPtr &graph, const std:: /// @param [in] switch_case /// @return Status /// -Status MultiBatchPass::BypassSwitchN(NodePtr &switch_n_node, NodePtr &switch_case) { +Status MultiBatchPass::BypassSwitchN(const NodePtr &switch_n_node, const NodePtr &switch_case) { InDataAnchorPtr in_data_anchor = switch_n_node->GetInDataAnchor(SWITCH_DATA_INPUT); if (in_data_anchor == nullptr) { - GELOGE(FAILED, "Check in_data_anchor fail, SwitchN:%s.", switch_n_node->GetName().c_str()); + GELOGE(FAILED, "Check in_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str()); return FAILED; } OutDataAnchorPtr peer_data_anchor = in_data_anchor->GetPeerOutAnchor(); if (peer_data_anchor == nullptr) { - GELOGE(FAILED, "Check peer_data_anchor fail, SwitchN:%s.", switch_n_node->GetName().c_str()); + GELOGE(FAILED, "Check peer_data_anchor failed, SwitchN:%s.", switch_n_node->GetName().c_str()); return FAILED; } NodePtr data_input = peer_data_anchor->GetOwnerNode(); // Remove SwitchN data input if (GraphUtils::RemoveEdge(peer_data_anchor, in_data_anchor) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Remove SwitchN in_data_edge fail, %s->%s.", data_input->GetName().c_str(), + GELOGE(FAILED, "Remove SwitchN in_data_edge failed, %s->%s.", data_input->GetName().c_str(), switch_n_node->GetName().c_str()); return FAILED; } if (GraphUtils::AddEdge(data_input->GetOutControlAnchor(), switch_case->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add StreamSwitchN in_control_edge fail, %s->%s.", data_input->GetName().c_str(), + GELOGE(FAILED, "Add StreamSwitchN in_control_edge failed, %s->%s.", data_input->GetName().c_str(), switch_case->GetName().c_str()); return FAILED; } // Add SwitchCase control output - for (OutDataAnchorPtr &out_data_anchor : switch_n_node->GetAllOutDataAnchors()) { - for (InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { + for (const OutDataAnchorPtr &out_data_anchor : switch_n_node->GetAllOutDataAnchors()) { + for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { NodePtr data_output = peer_in_anchor->GetOwnerNode(); if ((GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor) != GRAPH_SUCCESS) || (GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor) != GRAPH_SUCCESS)) { - GELOGE(FAILED, "Bypass SwitchN data_edge fail, %s->%s->%s.", data_input->GetName().c_str(), + GELOGE(FAILED, "Bypass SwitchN data_edge failed, %s->%s->%s.", data_input->GetName().c_str(), switch_n_node->GetName().c_str(), data_output->GetName().c_str()); return FAILED; } if (GraphUtils::AddEdge(switch_case->GetOutControlAnchor(), data_output->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add SwitchCase out_control_edge fail, %s->%s.", switch_case->GetName().c_str(), + GELOGE(FAILED, "Add SwitchCase out_control_edge failed, %s->%s.", switch_case->GetName().c_str(), data_output->GetName().c_str()); return FAILED; } @@ -390,17 +479,17 @@ Status MultiBatchPass::BypassSwitchN(NodePtr &switch_n_node, NodePtr &switch_cas /// @param [in] switch_case_node /// @return Status /// -Status MultiBatchPass::AttachLabel(NodePtr &switch_case_node) { +Status MultiBatchPass::AttachLabel(const NodePtr &switch_case_node) { std::vector stream_label_list; for (uint32_t i = 0; i < static_cast(batch_head_nodes_.size()); i++) { if (AttachBatchLabel(i) != SUCCESS) { - GELOGE(FAILED, "AttachBatchLabel fail, batch_idx=%u", i); + GELOGE(FAILED, "AttachBatchLabel failed, batch_idx=%u", i); return FAILED; } - const std::string stream_label = "stream_label_batch_" + std::to_string(i); + const std::string &stream_label = "stream_label_batch_" + std::to_string(i); if (AttachStreamLabel(i, stream_label) != SUCCESS) { - GELOGE(FAILED, "AttachStreamLabel fail, stream_label=%s", stream_label.c_str()); + GELOGE(FAILED, "AttachStreamLabel failed, stream_label=%s", stream_label.c_str()); return FAILED; } stream_label_list.emplace_back(stream_label); @@ -416,11 +505,11 @@ Status MultiBatchPass::AttachLabel(NodePtr &switch_case_node) { /// Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { std::stack nodes; - for (auto &node : batch_head_nodes_[batch_idx]) { + for (const auto &node : batch_head_nodes_[batch_idx]) { nodes.push(node); } - const std::string batch_label = "Batch_" + std::to_string(batch_idx); + const std::string &batch_label = "Batch_" + std::to_string(batch_idx); std::unordered_set handled_nodes; while (!nodes.empty()) { NodePtr cur_node = nodes.top(); @@ -434,7 +523,7 @@ Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { if (cur_desc->HasAttr(ATTR_NAME_BATCH_LABEL)) { std::string tmp_label; if (!AttrUtils::GetStr(cur_desc, ATTR_NAME_BATCH_LABEL, tmp_label)) { - GELOGE(FAILED, "get attr ATTR_NAME_BATCH_LABEL fail, node: %s.", cur_desc->GetName().c_str()); + GELOGE(FAILED, "get attr ATTR_NAME_BATCH_LABEL failed, node: %s.", cur_desc->GetName().c_str()); return FAILED; } if (tmp_label != batch_label) { @@ -445,14 +534,14 @@ Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { } GELOGD("Attach batch_label %s to node %s.", batch_label.c_str(), cur_desc->GetName().c_str()); if (!AttrUtils::SetStr(cur_desc, ATTR_NAME_BATCH_LABEL, batch_label)) { - GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL fail, node:%s.", cur_desc->GetName().c_str()); + GELOGE(FAILED, "set attr ATTR_NAME_BATCH_LABEL failed, node:%s.", cur_desc->GetName().c_str()); return FAILED; } - for (auto &out_node : cur_node->GetOutAllNodes()) { + for (const auto &out_node : cur_node->GetOutAllNodes()) { OpDescPtr op_desc = out_node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - const std::string type = op_desc->GetType(); + const std::string &type = op_desc->GetType(); if ((type == MERGE) && (op_desc->HasAttr(ATTR_INSERT_BY_MBATCH))) { continue; } @@ -476,7 +565,7 @@ Status MultiBatchPass::AttachBatchLabel(uint32_t batch_idx) { /// Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string &stream_label) { std::stack nodes; - for (auto &node : batch_head_nodes_[batch_idx]) { + for (const auto &node : batch_head_nodes_[batch_idx]) { nodes.push(node); } @@ -493,11 +582,11 @@ Status MultiBatchPass::AttachStreamLabel(uint32_t batch_idx, const std::string & GELOGD("Attach stream_label %s to node %s.", stream_label.c_str(), cur_desc->GetName().c_str()); if (SetStreamLabel(cur_node, stream_label) != SUCCESS) { - GELOGE(FAILED, "SetStreamLabel fail, node:%s.", cur_node->GetName().c_str()); + GELOGE(FAILED, "Set stream_label failed, node:%s.", cur_node->GetName().c_str()); return FAILED; } - for (auto &out_node : cur_node->GetOutAllNodes()) { + for (const auto &out_node : cur_node->GetOutAllNodes()) { nodes.push(out_node); } diff --git a/src/ge/graph/passes/multi_batch_pass.h b/src/ge/graph/passes/multi_batch_pass.h index 6e3f5e46..8f14ec0a 100644 --- a/src/ge/graph/passes/multi_batch_pass.h +++ b/src/ge/graph/passes/multi_batch_pass.h @@ -29,22 +29,28 @@ class MultiBatchPass : public GraphPass { private: Status FindPredValue(const ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value); - bool CheckSwitchN(std::vector> &batch_shape); + Status GetDynamicType(); + bool CheckSwitchN(std::vector> &batch_shape, std::vector> &combined_batch); + bool GetBatchInfo(uint32_t batch_num, std::vector> &batch_shape, + std::vector> &combined_batch); void FindSwitchOutNodes(uint32_t batch_num); - Status ReplaceSwitchN(ComputeGraphPtr &graph, OutDataAnchorPtr &pred_value, - const std::vector> &batch_shape); + Status ReplaceSwitchN(const ComputeGraphPtr &graph, const OutDataAnchorPtr &pred_value, + const std::vector> &batch_shape, + const std::vector> &combined_batch); bool CheckDims(const std::vector> &output_shape) const; - NodePtr CreateSwitchCaseNode(ComputeGraphPtr &graph, const std::string &name, const OutDataAnchorPtr &pred_value, - const std::vector> &batch_shape); - Status BypassSwitchN(NodePtr &switch_n_node, NodePtr &switch_case_node); - Status AttachLabel(NodePtr &switch_case_node); + NodePtr CreateSwitchCaseNode(const ComputeGraphPtr &graph, const std::string &name, + const OutDataAnchorPtr &pred_value, const std::vector> &batch_shape, + const std::vector> &combined_batch); + Status BypassSwitchN(const NodePtr &switch_n_node, const NodePtr &switch_case_node); + Status AttachLabel(const NodePtr &switch_case_node); Status AttachBatchLabel(uint32_t batch_idx); Status AttachStreamLabel(uint32_t batch_idx, const std::string &stream_label); std::vector switch_n_nodes_; std::vector bypass_nodes_; std::vector> batch_head_nodes_; + int32_t dynamic_type_ = 0; }; } // namespace ge #endif // GE_GRAPH_PASSES_MULTI_BATCH_PASS_H_ diff --git a/src/ge/graph/passes/net_output_pass.cc b/src/ge/graph/passes/net_output_pass.cc index 3c83d8ac..dd17f99c 100644 --- a/src/ge/graph/passes/net_output_pass.cc +++ b/src/ge/graph/passes/net_output_pass.cc @@ -22,15 +22,21 @@ #include #include +#include "common/ge/ge_util.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" -#include "common/ge/ge_util.h" +#include "framework/omg/omg_inner_types.h" +#include "graph/debug/ge_attr_define.h" #include "graph/passes/pass_utils.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/type_utils.h" -#include "graph/debug/ge_attr_define.h" namespace ge { +static std::map output_type_str_to_datatype = { + {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"INT8", ge::DT_INT8}, {"INT16", ge::DT_INT16}, + {"UINT16", ge::DT_UINT16}, {"UINT8", ge::DT_UINT8}, {"INT32", ge::DT_INT32}, {"INT64", ge::DT_INT64}, + {"UINT32", ge::DT_UINT32}, {"UINT64", ge::DT_UINT64}, {"DOUBLE", ge::DT_DOUBLE}}; + Status NetOutputPass::GetRetvalOutputInfo(const ge::NodePtr &node, std::map &retval_node_index_map) { GE_CHECK_NOTNULL(node); @@ -135,18 +141,6 @@ Status NetOutputPass::CheckOutputNodeInfo(const ComputeGraphPtr &graph, const st return SUCCESS; } -void NetOutputPass::AddInOutForNetOutputOp(const ge::ComputeGraphPtr &graph, const ge::OpDescPtr &net_output_desc, - const ge::NodePtr &src_node, int32_t src_index) { - /// Get the output attribute of src_node, - /// and set to the input/output of net_out_node. - if (src_node == nullptr || src_node->GetOpDesc() == nullptr || net_output_desc == nullptr) { - GELOGE(INTERNAL_ERROR, "src node or net output desc is null."); - return; - } - ge::GeTensorDesc out_desc = src_node->GetOpDesc()->GetOutputDesc(src_index); - GE_IF_BOOL_EXEC(net_output_desc->AddInputDesc(out_desc) != SUCCESS, GELOGW("add input desc failed"); return ); -} - Status NetOutputPass::RemoveUnusedNode(const ge::ComputeGraphPtr &graph) { std::vector node_to_delete; // Delete _Retval operator. @@ -401,6 +395,7 @@ Status NetOutputPass::ProcessWithNetoutput(const ge::ComputeGraphPtr &graph, con GELOGE(INTERNAL_ERROR, "Update net output desc failed."); return INTERNAL_ERROR; } + if (UnLink(graph, output_node) != SUCCESS) { GELOGE(INTERNAL_ERROR, "UnLink connection between netoutput node and user set target node"); return INTERNAL_ERROR; @@ -415,6 +410,10 @@ Status NetOutputPass::ProcessWithNetoutput(const ge::ComputeGraphPtr &graph, con Status NetOutputPass::AddCtrlEdgesBetweenLeafAndNetOutput(const ge::ComputeGraphPtr &graph, const ge::NodePtr &net_out_node) { GE_CHECK_NOTNULL(net_out_node); + if (!domi::GetContext().user_out_nodes.empty()) { + GELOGI("No need to add ctrl edge to netoutput because user out nodes have been set."); + return SUCCESS; + } for (const auto &node : graph->GetDirectNode()) { if (node == nullptr || node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() == NETOUTPUT) { continue; @@ -430,7 +429,7 @@ Status NetOutputPass::AddCtrlEdgesBetweenLeafAndNetOutput(const ge::ComputeGraph return SUCCESS; } -Status NetOutputPass::CreateNetOutputNode(OpDescPtr &net_output_desc, ge::ComputeGraphPtr &graph) { +Status NetOutputPass::CreateNetOutputNode(OpDescPtr &net_output_desc, const ge::ComputeGraphPtr &graph) { // Only flush subgraph name string node_name = (graph->GetParentGraph() != nullptr) ? (graph->GetName() + "_" + NODE_NAME_NET_OUTPUT) : NODE_NAME_NET_OUTPUT; @@ -451,83 +450,185 @@ Status NetOutputPass::Run(ge::ComputeGraphPtr graph) { } GELOGI("NetOutputPass Run."); NodePtr output_node = graph->FindFirstNodeMatchType(NETOUTPUT); - OpDescPtr net_output_desc = nullptr; - std::vector output_nodes_info; - // save user targets node SaveAndRemoveTargets(graph); // If graph already has a netoutput node, doesn't need to create it again. if (output_node != nullptr) { (void)AttrUtils::SetListStr(output_node->GetOpDesc(), ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, std::move(std::vector())); - return ProcessWithNetoutput(graph, output_node); - } else { - if (CreateNetOutputNode(net_output_desc, graph) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get net output nodes failed."); + if (ProcessWithNetoutput(graph, output_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Process with netoutput node failed."); return INTERNAL_ERROR; } - Status ret = GetOutputNode(graph, output_nodes_info); - if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get net output nodes failed."); + } else { + if (AddNetOutputNodeToGraph(graph, output_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Set user define dtype and format for netoutput failed."); return INTERNAL_ERROR; } - GELOGI("[NETOUTPUT PASS] OutNodesInfo size:%zu, Targets Size:%zu, is_include_special_node_:%d", - graph->GetGraphOutNodesInfo().size(), graph->GetGraphTargetNodesInfo().size(), is_include_special_node_); - // If user does not set out nodes and targets and no retval node, return false - bool is_valid = (graph->GetGraphOutNodesInfo().size() == 0) && (graph->GetGraphTargetNodesInfo().size() == 0) && - (is_include_special_node_ == false); - if (is_valid) { - GELOGI("[NETOUTPUT PASS] output_nodes and target_nodes and special nodes is empty!It means no need netoutput!"); - return SUCCESS; - } - GELOGI("[NETOUTPUT PASS] Output node size:%lu.", output_nodes_info.size()); - if (output_nodes_info.empty()) { - // because retval node is contained by output_nodes_info, here means targets is non-empty - auto net_output_node = graph->AddNode(net_output_desc); - if (net_output_node == nullptr) { - GELOGE(INTERNAL_ERROR, "Add output node failed."); - return INTERNAL_ERROR; - } - GE_CHK_STATUS_RET(AddCtrlEdgeForTargets(net_output_node), "add ctrl edge for targets failed"); - // Add true stream, netoutput is 0 - GE_IF_BOOL_EXEC(!ge::AttrUtils::SetInt(net_output_node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, 0), - GELOGE(INTERNAL_ERROR, "set ATTR_NAME_TRUE_BRANCH_STREAM failed"); - return INTERNAL_ERROR); - return SUCCESS; - } - std::vector is_input_const; - for (auto iter = output_nodes_info.begin(); iter != output_nodes_info.end();) { - ge::NodePtr src_node = iter->output_node; - if (src_node == nullptr) { - continue; - } - int32_t src_index = iter->node_output_index; - // if src_node is in targets_, no need to Add in and out for netoutput - auto it = targets_.find(src_node); - if (it != targets_.end()) { - iter = output_nodes_info.erase(iter); - GELOGD("node [%s] is in processed targets, do not add inout for netoutput!", src_node->GetName().c_str()); - continue; - } - AddInOutForNetOutputOp(graph, net_output_desc, src_node, src_index); - is_input_const.push_back(PassUtils::IsConstant(src_node)); - ++iter; - } - net_output_desc->SetIsInputConst(is_input_const); + } + // Add userdef attrs to netoutput node + return SetUserDefDTypeAndFormatFromAtcParams(output_node); +} + +Status NetOutputPass::AddNetOutputNodeToGraph(const ge::ComputeGraphPtr &graph, NodePtr &output_node) { + OpDescPtr net_output_desc = nullptr; + if (CreateNetOutputNode(net_output_desc, graph) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get net output nodes failed."); + return INTERNAL_ERROR; + } + std::vector output_nodes_info; + if (GetOutputNode(graph, output_nodes_info) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get net output nodes failed."); + return INTERNAL_ERROR; + } + GELOGI("[NETOUTPUT PASS] OutNodesInfo size:%zu, Targets Size:%zu, is_include_special_node_:%d", + graph->GetGraphOutNodesInfo().size(), graph->GetGraphTargetNodesInfo().size(), is_include_special_node_); + // If user does not set out nodes and targets and no retval node, return false + if ((graph->GetGraphOutNodesInfo().empty()) && (graph->GetGraphTargetNodesInfo().empty()) && + !is_include_special_node_) { + GELOGI("[NETOUTPUT PASS] output_nodes and target_nodes and special nodes is empty!It means no need netoutput!"); + return SUCCESS; + } + GELOGI("[NETOUTPUT PASS] Output node size:%lu.", output_nodes_info.size()); + if (output_nodes_info.empty()) { + // because retval node is contained by output_nodes_info, here means targets is non-empty output_node = graph->AddNode(net_output_desc); if (output_node == nullptr) { GELOGE(INTERNAL_ERROR, "Add output node failed."); return INTERNAL_ERROR; } - if (AddEdgesForNetOutput(graph, output_node, output_nodes_info) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add edges for net output node failed."); - return INTERNAL_ERROR; + GE_CHK_STATUS_RET(AddCtrlEdgeForTargets(output_node), "add ctrl edge for targets failed"); + // Add true stream, netoutput is 0 + GE_IF_BOOL_EXEC(!ge::AttrUtils::SetInt(output_node->GetOpDesc(), ATTR_NAME_TRUE_BRANCH_STREAM, 0), + GELOGE(INTERNAL_ERROR, "set ATTR_NAME_TRUE_BRANCH_STREAM failed"); + return INTERNAL_ERROR); + return SUCCESS; + } + + AddInOutForNetOutputOp(graph, net_output_desc, output_nodes_info); + output_node = graph->AddNode(net_output_desc); + if (output_node == nullptr) { + GELOGE(INTERNAL_ERROR, "Add output node failed."); + return INTERNAL_ERROR; + } + if (AddEdgesForNetOutput(graph, output_node, output_nodes_info) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add edges for net output node failed."); + return INTERNAL_ERROR; + } + if (AddCtrlEdgesBetweenLeafAndNetOutput(graph, output_node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Add control edges between leaf and netoutput failed."); + return INTERNAL_ERROR; + } + GELOGI("Add NetOutput node success."); + return SUCCESS; +} +void NetOutputPass::AddInOutForNetOutputOp(const ComputeGraphPtr &graph, OpDescPtr &net_output_desc, + vector &output_nodes_info) { + std::vector is_input_const; + for (auto iter = output_nodes_info.begin(); iter != output_nodes_info.end();) { + NodePtr src_node = iter->output_node; + if (src_node == nullptr) { + continue; } - if (AddCtrlEdgesBetweenLeafAndNetOutput(graph, output_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edges between leaf and netoutput failed."); - return INTERNAL_ERROR; + int32_t src_index = iter->node_output_index; + // if src_node is in targets_, no need to Add in and out for netoutput + auto it = targets_.find(src_node); + if (it != targets_.end()) { + iter = output_nodes_info.erase(iter); + GELOGD("node [%s] is in processed targets, do not add inout for netoutput!", src_node->GetName().c_str()); + continue; + } + /// Get the output attribute of src_node, + /// and set to the input/output of net_out_node. + if (src_node == nullptr || src_node->GetOpDesc() == nullptr || net_output_desc == nullptr) { + GELOGE(INTERNAL_ERROR, "src node or net output desc is null."); + return; + } + ge::GeTensorDesc out_desc = src_node->GetOpDesc()->GetOutputDesc(src_index); + GE_IF_BOOL_EXEC(net_output_desc->AddInputDesc(out_desc) != SUCCESS, GELOGW("add input desc failed"); return ); + is_input_const.push_back(PassUtils::IsConstant(src_node)); + ++iter; + } + net_output_desc->SetIsInputConst(is_input_const); +} + +bool NeedUpdateOutputByOutputTypeParm(std::string &output_type, NodePtr &src_node, uint32_t src_index, + ge::DataType &dt) { + if (output_type_str_to_datatype.find(output_type) != output_type_str_to_datatype.end()) { + dt = output_type_str_to_datatype[output_type]; + return true; + } + + auto op_desc = src_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + vector output_data_type_vec; + vector index_vec; + if ((ge::AttrUtils::GetListDataType(op_desc, "_output_dt_list", output_data_type_vec)) && + (ge::AttrUtils::GetListInt(op_desc, "_output_dt_index", index_vec))) { + if (output_data_type_vec.size() != index_vec.size()) { + GELOGW("output_dt_list size is not match output_dt_index size"); + return false; + } + for (uint32_t i = 0; i < index_vec.size(); ++i) { + if (index_vec[i] == src_index) { + dt = output_data_type_vec[i]; + return true; + } } - GELOGI("Add NetOutput node success."); + } + return false; +} + +Status NetOutputPass::SetUserDefDTypeAndFormatFromAtcParams(const NodePtr &output_node) { + if (output_node == nullptr) { + GELOGI("[NETOUTPUT PASS] The graph no need netoutput node!"); + return SUCCESS; + } + auto output_type = domi::GetContext().output_type; + auto op_desc = output_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + std::vector userdef_dtypes; + std::vector userdef_formats; + + ge::DataType output_data_type = ge::DT_FLOAT; + for (const auto &in_anchor : output_node->GetAllInDataAnchors()) { + auto index = static_cast(in_anchor->GetIdx()); + auto peer_out = in_anchor->GetPeerOutAnchor(); + if (peer_out == nullptr) { + // If user set target, peer_out anchor will be unlinked. + continue; + } + auto src_index = static_cast(peer_out->GetIdx()); + auto src_node = peer_out->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + + // Update datatype + if (NeedUpdateOutputByOutputTypeParm(output_type, src_node, src_index, output_data_type)) { + GELOGD("Add user-define datatype:%s to netoutput node.", + TypeUtils::DataTypeToSerialString(output_data_type).c_str()); + userdef_dtypes.push_back( + std::to_string(index).append(":").append(TypeUtils::DataTypeToSerialString(output_data_type))); + continue; + } + // Output_node is not set,check if is_output_adjust_hw_layout is set + OpDescPtr src_op_desc = src_node->GetOpDesc(); + GE_CHECK_NOTNULL(src_op_desc); + bool set_fp16_nc1hwc0 = false; + (void)AttrUtils::GetBool(src_op_desc, "output_set_fp16_nc1hwc0", set_fp16_nc1hwc0); + if (set_fp16_nc1hwc0) { + // Set DT_FLOAT16 & FORMAT_NC1HWC0 + userdef_dtypes.push_back(std::to_string(index).append(":").append(TypeUtils::DataTypeToSerialString(DT_FLOAT16))); + userdef_formats.push_back( + std::to_string(index).append(":").append(TypeUtils::FormatToSerialString(FORMAT_NC1HWC0))); + } + } + if (!userdef_dtypes.empty() && !ge::AttrUtils::SetListStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, userdef_dtypes)) { + GELOGE(INTERNAL_ERROR, "Set user_define_dtype attr list for netoutput failed."); + return INTERNAL_ERROR; + } + if (!userdef_formats.empty() && !ge::AttrUtils::SetListStr(op_desc, ATTR_ATC_USER_DEFINE_FORMAT, userdef_formats)) { + GELOGE(INTERNAL_ERROR, "Set user_define_format attr list for netoutput failed."); + return INTERNAL_ERROR; } return SUCCESS; } diff --git a/src/ge/graph/passes/net_output_pass.h b/src/ge/graph/passes/net_output_pass.h index 5edf24fc..567d1246 100644 --- a/src/ge/graph/passes/net_output_pass.h +++ b/src/ge/graph/passes/net_output_pass.h @@ -73,7 +73,7 @@ class NetOutputPass : public GraphPass { /// @return OTHERS: Execution failed /// @author /// - Status CreateNetOutputNode(OpDescPtr &net_output_desc, ge::ComputeGraphPtr &graph); + Status CreateNetOutputNode(OpDescPtr &net_output_desc, const ge::ComputeGraphPtr &graph); /// /// Check if the network output node is legal @@ -89,13 +89,12 @@ class NetOutputPass : public GraphPass { /// Set input and output for the NetOutput node /// @param [in] graph: Input ComputeGraph /// @param [in] net_output_desc: OpDesc of the NetOutput node - /// @param [in] src_node: Source node of the NetOutput - /// @param [in] src_index: Output index of the Source node + /// @param [in] output_nodes_info: RetvalInfos of the NetOutput /// @return void /// @author /// - void AddInOutForNetOutputOp(const ge::ComputeGraphPtr &graph, const ge::OpDescPtr &net_output_desc, - const ge::NodePtr &src_node, int32_t src_index); + void AddInOutForNetOutputOp(const ComputeGraphPtr &graph, OpDescPtr &net_output_desc, + vector &output_nodes_info); /// /// Delete unwanted _Retval/Save/Summary nodes @@ -199,6 +198,25 @@ class NetOutputPass : public GraphPass { /// bool CheckNodeIsInOutputNodes(const ge::ComputeGraphPtr &graph, const ge::NodePtr &node); + /// + /// Add netoutput node to graph with output node infos + /// @param [in] graph: ComputeGraph + /// @param [in] output_node: shared_ptr to netoutput node + /// @return SUCCESS: Execution succeed + /// @return OTHERS: Execution failed + /// @author + /// + Status AddNetOutputNodeToGraph(const ge::ComputeGraphPtr &graph, NodePtr &output_node); + + /// + /// Add user_def_dtype & format for netoutput node + /// @param [in] output_node: The netOutput node + /// @return SUCCESS: Execution succeed + /// @return OTHERS: Execution failed + /// @author + /// + Status SetUserDefDTypeAndFormatFromAtcParams(const ge::NodePtr &output_node); + bool is_include_special_node_ = false; std::set targets_; friend class ReUpdateNetOutputPass; diff --git a/src/ge/graph/passes/next_iteration_pass.cc b/src/ge/graph/passes/next_iteration_pass.cc index 138ad86b..12cde11e 100644 --- a/src/ge/graph/passes/next_iteration_pass.cc +++ b/src/ge/graph/passes/next_iteration_pass.cc @@ -16,19 +16,8 @@ #include "graph/passes/next_iteration_pass.h" -#include -#include -#include -#include -#include - #include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/debug/log.h" -#include "framework/common/ge_inner_error_codes.h" -#include "framework/common/types.h" #include "graph/common/omg_util.h" -#include "graph/debug/ge_attr_define.h" namespace ge { Status NextIterationPass::Run(ComputeGraphPtr graph) { @@ -41,24 +30,24 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) { if ((type != ENTER) && (type != REFENTER)) { continue; } - if (HandleEnterNode(node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "HandleEnterNode for node %s fail.", node->GetName().c_str()); + if (GroupEnterNode(node) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Group enter_node %s failed.", node->GetName().c_str()); return INTERNAL_ERROR; } } if (FindWhileGroups() != SUCCESS) { - GELOGE(INTERNAL_ERROR, "FindWhileGroups fail"); + GELOGE(INTERNAL_ERROR, "Find while groups failed."); return INTERNAL_ERROR; } if (!VerifyWhileGroup()) { - GELOGE(INTERNAL_ERROR, "VerifyWhileGroup fail"); + GELOGE(INTERNAL_ERROR, "Verify while groups failed."); return INTERNAL_ERROR; } if (HandleWhileGroup(graph) != SUCCESS) { - GELOGE(FAILED, "HandleWhileGroup fail"); + GELOGE(FAILED, "Handle while groups failed."); return FAILED; } @@ -67,16 +56,16 @@ Status NextIterationPass::Run(ComputeGraphPtr graph) { } /// -/// @brief Handle Enter node +/// @brief Group Enter node /// @param [in] enter_node /// @return Status /// -Status NextIterationPass::HandleEnterNode(const NodePtr &enter_node) { +Status NextIterationPass::GroupEnterNode(const NodePtr &enter_node) { OpDescPtr enter_desc = enter_node->GetOpDesc(); GE_CHECK_NOTNULL(enter_desc); std::string frame_name; if (!ge::AttrUtils::GetStr(enter_desc, ENTER_ATTR_FRAME_NAME, frame_name) || frame_name.empty()) { - GELOGE(FAILED, "Get attr ENTER_ATTR_FRAME_NAME fail, node: %s", enter_desc->GetName().c_str()); + GELOGE(FAILED, "Get attr ENTER_ATTR_FRAME_NAME failed, node: %s", enter_desc->GetName().c_str()); return FAILED; } @@ -84,7 +73,7 @@ Status NextIterationPass::HandleEnterNode(const NodePtr &enter_node) { if (iter == loop_group_map_.end()) { LoopCondGroupPtr loop_group = MakeShared(); if (loop_group == nullptr) { - GELOGE(FAILED, "MakeShared for LoopCondGroup fail."); + GELOGE(FAILED, "MakeShared for LoopCondGroup failed."); return FAILED; } loop_group->enter_nodes.emplace_back(enter_node); @@ -101,40 +90,42 @@ Status NextIterationPass::HandleEnterNode(const NodePtr &enter_node) { /// @return Status /// Status NextIterationPass::FindWhileGroups() { - for (auto &loop_group_iter : loop_group_map_) { - const std::string frame_name = loop_group_iter.first; - for (auto &enter_node : loop_group_iter.second->enter_nodes) { - for (auto &out_node : enter_node->GetOutAllNodes()) { - const std::string type = out_node->GetType(); + for (const auto &loop_group_iter : loop_group_map_) { + const std::string &frame_name = loop_group_iter.first; + for (const auto &enter_node : loop_group_iter.second->enter_nodes) { + for (const auto &out_node : enter_node->GetOutAllNodes()) { + const std::string &type = out_node->GetType(); if ((type != MERGE) && (type != REFMERGE)) { continue; } NodePtr next_node = nullptr; if (FindTargetNode(out_node, NEXTITERATION, true, next_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get NextIteration node fail, frame_name: %s.", frame_name.c_str()); + GELOGE(INTERNAL_ERROR, "Get NextIteration node failed, frame_name: %s.", frame_name.c_str()); return INTERNAL_ERROR; } + loop_group_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); NodePtr switch_node = nullptr; if (FindTargetNode(out_node, SWITCH, false, switch_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get Switch node fail, frame_name: %s.", frame_name.c_str()); + GELOGE(INTERNAL_ERROR, "Get Switch node failed, frame_name: %s.", frame_name.c_str()); return INTERNAL_ERROR; } + if (switch_node == nullptr) { + continue; + } NodePtr loop_cond = nullptr; if (FindTargetNode(switch_node, LOOPCOND, true, loop_cond) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Get LoopCond node fail, frame_name: %s.", frame_name.c_str()); + GELOGE(INTERNAL_ERROR, "Get LoopCond node failed, frame_name: %s.", frame_name.c_str()); return INTERNAL_ERROR; } - if (loop_group_iter.second->loop_cond == nullptr) { loop_group_iter.second->loop_cond = loop_cond; } else if (loop_group_iter.second->loop_cond != loop_cond) { GELOGE(FAILED, "Multi LoopCond nodes exist, frame_name: %s.", frame_name.c_str()); return FAILED; } - loop_group_iter.second->merge_next_pairs.emplace_back(std::make_pair(out_node, next_node)); } } } @@ -148,21 +139,21 @@ Status NextIterationPass::FindWhileGroups() { /// bool NextIterationPass::VerifyWhileGroup() { // map - for (auto &loop_group_iter : loop_group_map_) { - const std::string frame_name = loop_group_iter.first; + for (const auto &loop_group_iter : loop_group_map_) { + const std::string &frame_name = loop_group_iter.first; if (frame_name.empty()) { - GELOGE(INTERNAL_ERROR, "VerifyWhileGroup fail, frame_name is empty."); + GELOGE(INTERNAL_ERROR, "Verify while group failed, frame_name is empty."); return false; } if (loop_group_iter.second->loop_cond == nullptr) { - GELOGE(INTERNAL_ERROR, "VerifyWhileGroup fail, LoopCond is null, frame_name: %s.", frame_name.c_str()); + GELOGE(INTERNAL_ERROR, "Verify while group failed, LoopCond is null, frame_name: %s.", frame_name.c_str()); return false; } - for (auto &pair_iter : loop_group_iter.second->merge_next_pairs) { + for (const auto &pair_iter : loop_group_iter.second->merge_next_pairs) { if ((pair_iter.first == nullptr) || (pair_iter.second == nullptr)) { - GELOGE(INTERNAL_ERROR, "VerifyWhileGroup fail, merge_node/next_node is null, frame_name: %s.", + GELOGE(INTERNAL_ERROR, "Verify while group failed, merge_node/next_node is null, frame_name: %s.", frame_name.c_str()); return false; } @@ -178,51 +169,51 @@ bool NextIterationPass::VerifyWhileGroup() { /// @return Status /// Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { - for (auto &loop_cond_iter : loop_group_map_) { - std::string cond_name = loop_cond_iter.second->loop_cond->GetName(); - GELOGI("HandleWhileGroup, LoopCond node: %s.", cond_name.c_str()); + for (const auto &loop_cond_iter : loop_group_map_) { + const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); + GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); - // Create Active node, Enter->Active->Merge, NextItaration->Active->Merge + // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge NodePtr enter_active = CreateActiveNode(graph, cond_name + "_Enter_" + STREAMACTIVE); NodePtr next_active = CreateActiveNode(graph, cond_name + "_Next_" + STREAMACTIVE); if ((enter_active == nullptr) || (next_active == nullptr)) { - GELOGE(INTERNAL_ERROR, "CreateActiveNode fail, cond_name: %s.", cond_name.c_str()); + GELOGE(INTERNAL_ERROR, "Create active node failed, cond_name: %s.", cond_name.c_str()); return INTERNAL_ERROR; } - for (auto &enter_node : loop_cond_iter.second->enter_nodes) { + for (const auto &enter_node : loop_cond_iter.second->enter_nodes) { // Enter --> Active if (GraphUtils::AddEdge(enter_node->GetOutControlAnchor(), enter_active->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge fail"); + GELOGE(INTERNAL_ERROR, "Add control edge failed."); return INTERNAL_ERROR; } } - for (auto &pair : loop_cond_iter.second->merge_next_pairs) { + for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { NodePtr merge_node = pair.first; NodePtr next_node = pair.second; // Active --> Merge if (GraphUtils::AddEdge(enter_active->GetOutControlAnchor(), merge_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge fail"); + GELOGE(INTERNAL_ERROR, "Add control edge failed."); return INTERNAL_ERROR; } // NextIteration --> Active if (GraphUtils::AddEdge(next_node->GetOutControlAnchor(), next_active->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "Add control edge fail"); + GELOGE(INTERNAL_ERROR, "Add control edge failed."); return INTERNAL_ERROR; } // break link between NextIteration and Merge if (BreakNextIteration(next_node, merge_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "BreakNextIteration failed"); + GELOGE(INTERNAL_ERROR, "Break NextIteration failed"); return INTERNAL_ERROR; } } if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || (SetActiveLabelList(next_active, {cond_name}) != SUCCESS)) { - GELOGE(INTERNAL_ERROR, "SetActiveLabelList failed"); + GELOGE(INTERNAL_ERROR, "Set attr ACTIVE_LABEL_LIST failed."); return INTERNAL_ERROR; } } @@ -245,12 +236,12 @@ NodePtr NextIterationPass::CreateActiveNode(ComputeGraphPtr &graph, const std::s GELOGI("Create StreamActive op:%s.", op_desc->GetName().c_str()); NodePtr active_node = graph->AddNode(op_desc); if (active_node == nullptr) { - GELOGE(INTERNAL_ERROR, "Create node[%s] fail.", name.c_str()); + GELOGE(INTERNAL_ERROR, "Create node[%s] failed.", name.c_str()); return nullptr; } if (SetSwitchBranchNodeLabel(active_node, name) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "SetSwitchBranchNodeLabel for node: %s failed.", active_node->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "Set attr SWITCH_BRANCH_NODE_LABEL for node: %s failed.", active_node->GetName().c_str()); return nullptr; } @@ -268,18 +259,18 @@ Status NextIterationPass::BreakNextIteration(const NodePtr &next_node, NodePtr & GELOGE(PARAM_INVALID, "merge node or next node is null."); return PARAM_INVALID; } - for (auto &in_anchor : merge_node->GetAllInDataAnchors()) { + for (const auto &in_anchor : merge_node->GetAllInDataAnchors()) { OutDataAnchorPtr out_anchor = in_anchor->GetPeerOutAnchor(); if ((out_anchor == nullptr) || (out_anchor->GetOwnerNode() != next_node)) { continue; } if (GraphUtils::RemoveEdge(out_anchor, in_anchor) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Remove data edge fail, %s->%s.", next_node->GetName().c_str(), + GELOGE(INTERNAL_ERROR, "Remove data edge failed, %s->%s.", next_node->GetName().c_str(), merge_node->GetName().c_str()); return INTERNAL_ERROR; } if (SetNextIteration(merge_node, next_node->GetName()) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "SetNextIteration for node %s fail.", merge_node->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "Set attr NEXT_ITERATION for node %s failed.", merge_node->GetName().c_str()); return INTERNAL_ERROR; } } @@ -302,16 +293,16 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string } std::vector nodes; if (is_input) { - for (auto &tmp_node : node->GetInDataNodes()) { + for (const auto &tmp_node : node->GetInDataNodes()) { nodes.emplace_back(tmp_node); } } else { - for (auto &tmp_node : node->GetOutDataNodes()) { + for (const auto &tmp_node : node->GetOutDataNodes()) { nodes.emplace_back(tmp_node); } } - for (auto &tmp_node : nodes) { + for (const auto &tmp_node : nodes) { const std::string type = tmp_node->GetType(); if ((target_type == LOOPCOND) && (type == target_type)) { target_node = tmp_node; @@ -322,14 +313,15 @@ Status NextIterationPass::FindTargetNode(const NodePtr &node, const std::string } } - if (target_node == nullptr) { - GELOGE(INTERNAL_ERROR, "Find node %s fail", target_type.c_str()); + if ((target_type != SWITCH) && (target_node == nullptr)) { + GELOGE(INTERNAL_ERROR, "Find node %s failed.", target_type.c_str()); return INTERNAL_ERROR; } return SUCCESS; } + /// -/// @brief Clear Status, uesd for subgraph pass +/// @brief Clear Status, used for subgraph pass /// @return SUCCESS /// Status NextIterationPass::ClearStatus() { diff --git a/src/ge/graph/passes/next_iteration_pass.h b/src/ge/graph/passes/next_iteration_pass.h index 4bbced4f..4cdf4b51 100644 --- a/src/ge/graph/passes/next_iteration_pass.h +++ b/src/ge/graph/passes/next_iteration_pass.h @@ -17,12 +17,6 @@ #ifndef GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ #define GE_GRAPH_PASSES_NEXT_ITERATION_PASS_H_ -#include -#include -#include -#include -#include - #include "inc/graph_pass.h" struct LoopCondGroup { @@ -37,15 +31,64 @@ namespace ge { class NextIterationPass : public GraphPass { public: Status Run(ComputeGraphPtr graph); + + /// + /// @brief Clear Status, used for subgraph pass + /// @return SUCCESS + /// Status ClearStatus() override; private: - Status HandleEnterNode(const NodePtr &enter_node); + /// + /// @brief Group Enter node + /// @param [in] enter_node + /// @return Status + /// + Status GroupEnterNode(const NodePtr &enter_node); + + /// + /// @brief Find while groups + /// @return Status + /// Status FindWhileGroups(); + + /// + /// @brief Verify if valid + /// @return bool + /// bool VerifyWhileGroup(); + + /// + /// @brief Handle while group + /// @param [in] graph + /// @return Status + /// Status HandleWhileGroup(ComputeGraphPtr &graph); + + /// + /// @brief Create Active Node + /// @param [in] graph + /// @param [in] name + /// @return ge::NodePtr + /// NodePtr CreateActiveNode(ComputeGraphPtr &graph, const std::string &name); + + /// + /// @brief Break NextIteration Link & add name to merge attr + /// @param [in] next_node + /// @param [in] merge_node + /// @return Status + /// Status BreakNextIteration(const NodePtr &next_node, NodePtr &merge_node); + + /// + /// @brief find target node + /// @param [in] node + /// @param [in] target_type + /// @param [in] is_input + /// @param [out] target_node + /// @return Status + /// Status FindTargetNode(const NodePtr &node, const std::string &target_type, bool is_input, NodePtr &target_node); // map diff --git a/src/ge/graph/passes/pass_manager.cc b/src/ge/graph/passes/pass_manager.cc index eec33eef..5be54f0a 100644 --- a/src/ge/graph/passes/pass_manager.cc +++ b/src/ge/graph/passes/pass_manager.cc @@ -19,6 +19,7 @@ #include "common/types.h" #include "common/util.h" #include "graph/utils/node_utils.h" +#include "graph/common/ge_call_wrapper.h" #include "omg/omg_inner_types.h" namespace ge { diff --git a/src/ge/graph/passes/permute_pass.cc b/src/ge/graph/passes/permute_pass.cc index f5fd9dc5..3c0dfd4e 100644 --- a/src/ge/graph/passes/permute_pass.cc +++ b/src/ge/graph/passes/permute_pass.cc @@ -33,7 +33,6 @@ using domi::TENSORFLOW; namespace ge { Status PermutePass::Run(ComputeGraphPtr graph) { - GE_TIMESTAMP_START(PermutePass); GE_CHECK_NOTNULL(graph); std::vector isolate_nodes; for (NodePtr &node : graph->GetDirectNode()) { @@ -116,8 +115,6 @@ Status PermutePass::Run(ComputeGraphPtr graph) { GE_RETURN_WITH_LOG_IF_ERROR(graph->RemoveNode(node), "[%s]:remove permute node failed", node->GetOpDesc()->GetName().c_str()); }); - - GE_TIMESTAMP_END(PermutePass, "GraphManager::PermutePass"); return SUCCESS; } } // namespace ge diff --git a/src/ge/graph/passes/print_op_pass.h b/src/ge/graph/passes/print_op_pass.h index 64bf6573..15b0badc 100644 --- a/src/ge/graph/passes/print_op_pass.h +++ b/src/ge/graph/passes/print_op_pass.h @@ -31,6 +31,6 @@ class PrintOpPass : public BaseNodePass { public: Status Run(ge::NodePtr &node) override; }; -}; // namespace ge +} // namespace ge #endif // GE_GRAPH_PASSES_PRINT_OP_PASS_H_ diff --git a/src/ge/graph/passes/ref_identity_delete_op_pass.cc b/src/ge/graph/passes/ref_identity_delete_op_pass.cc new file mode 100644 index 00000000..5bc0fad6 --- /dev/null +++ b/src/ge/graph/passes/ref_identity_delete_op_pass.cc @@ -0,0 +1,225 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "ref_identity_delete_op_pass.h" +#include +#include +#include "graph/common/transop_util.h" + +namespace ge { +Status RefIdentityDeleteOpPass::Run(ComputeGraphPtr graph) { + GE_CHECK_NOTNULL(graph); + for (auto &node : graph->GetAllNodes()) { + if (node->GetType() != REFIDENTITY) { + continue; + } + int input_index = 0; + NodePtr ref_node = GetRefNode(node, input_index); + CHECK_FALSE_EXEC(GetRefNode(node, input_index) != nullptr, + GELOGE(FAILED, "Ref node of RefIdentity[%s] not found", node->GetName().c_str()); + return FAILED); + CHECK_FALSE_EXEC(DealNoOutputRef(ref_node, node, input_index, graph) == SUCCESS, + GELOGE(FAILED, "Ref identity [%s] delete failed", node->GetName().c_str()); + return FAILED); + } + return SUCCESS; +} + +NodePtr RefIdentityDeleteOpPass::GetRefNode(const NodePtr &node, int &input_index) { + OutDataAnchorPtr out_anchor = node->GetOutDataAnchor(0); + CHECK_FALSE_EXEC(out_anchor != nullptr, return nullptr); + for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { + CHECK_FALSE_EXEC(peer_in_anchor != nullptr, continue); + auto peer_node = peer_in_anchor->GetOwnerNode(); + CHECK_FALSE_EXEC(peer_node != nullptr, continue); + const auto &peer_op_desc = peer_node->GetOpDesc(); + CHECK_FALSE_EXEC(peer_op_desc != nullptr, return nullptr); + const auto &peer_input_desc = peer_op_desc->GetInputDescPtr(static_cast(peer_in_anchor->GetIdx())); + if (!peer_input_desc->GetRefPortIndex().empty()) { + input_index = peer_in_anchor->GetIdx(); + return peer_node; + } + } + return nullptr; +} + +Status RefIdentityDeleteOpPass::DealNoOutputRef(const NodePtr &node, const NodePtr &ref_identity, int input_index, + const ComputeGraphPtr &graph) { + NodePtr first_node = nullptr; + NodePtr variable_ref = GetVariableRef(node, ref_identity, first_node); + if (variable_ref == nullptr) { + GELOGE(FAILED, "[RefIdentityDeleteOpPass]Can not find variable ref for %s:%d", node->GetName().c_str(), + input_index); + return FAILED; + } + if (first_node->GetName() != variable_ref->GetName()) { + // Remove the control edge between ref node and variable ref + // Add a control edge between ref node and trans node + // +-----------+ +-----------+ + // +---------+RefIdentity| +-----------+RefIdentity| + // | +-----+-----+ | +-----+-----+ + // | | | | + // | v | v + // +-----v-----+ +----+----+ +-----v-----+ +----+----+ + // | TransNode | | RefNode | ==> | TransNode +<--C--+ RefNode | + // +-----+-----+ +----+----+ +-----+-----+ +---------+ + // | | | + // v C v + // +-----+-----+ | +-----+-----+ + // |VariableRef+<--------+ |VariableRef| + // +-----------+ +-----------+ + auto ret = ge::GraphUtils::AddEdge(node->GetOutControlAnchor(), first_node->GetInControlAnchor()); + if (ret != SUCCESS) { + GELOGE(FAILED, "Add control edge between ref node and trans node failed"); + return FAILED; + } + ret = ge::GraphUtils::RemoveEdge(node->GetOutControlAnchor(), variable_ref->GetInControlAnchor()); + if (ret != SUCCESS) { + GELOGE(FAILED, "Remove control edge between ref node and its peer node failed"); + return FAILED; + } + } else { + // +-----------+ +-----------+ + // +-----------+RefIdentity| +-----------+RefIdentity| + // | +-----+-----+ | +-----+-----+ + // | | | | + // | v | v + // +-----v-----+ +----+----+ +-----v-----+ +----+----+ + // |VariableRef+<--C--+ RefNode | ==> |VariableRef+<--C--+ RefNode | + // +-----+-----+ +----+----+ +-----------+ +----+----+ + // | | | + // | v v + // | +---+----+ +---+----+ + // +-----C------>+ | | | + // +--------+ +--------+ + auto ret = RemoveUselessControlEdge(node, variable_ref); + if (ret != SUCCESS) { + GELOGE(FAILED, "Remove useless control edge failed."); + return FAILED; + } + } + // remove ref identity + if (GraphUtils::IsolateNode(ref_identity, {0}) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Isolate removed node: %s, type: %s failed", ref_identity->GetName().c_str(), + variable_ref->GetType().c_str()); + return FAILED; + } + if (GraphUtils::RemoveNodeWithoutRelink(graph, ref_identity) != GRAPH_SUCCESS) { + GELOGE(INTERNAL_ERROR, "Remove node: %s, type: %s without relink failed", ref_identity->GetName().c_str(), + ref_identity->GetType().c_str()); + return FAILED; + } + return SUCCESS; +} + +ge::NodePtr RefIdentityDeleteOpPass::GetVariableRef(const NodePtr &ref, const NodePtr &ref_identity, + NodePtr &first_node) { + const auto &ref_identity_out_anchor = ref_identity->GetOutDataAnchor(0); + if (ref_identity_out_anchor == nullptr) { + return nullptr; + } + for (auto &peer_in_anchor : ref_identity_out_anchor->GetPeerInDataAnchors()) { + const auto &peer_node = peer_in_anchor->GetOwnerNode(); + if (peer_node == nullptr || peer_node->GetName() == ref->GetName()) { + continue; + } + // DFS to find variable ref node. + std::stack nodes_to_check; + nodes_to_check.push(peer_node); + GELOGI("[RefIdentityDeleteOpPass]Start to search variable ref node from %s.", peer_node->GetName().c_str()); + NodePtr cur_node = nullptr; + while (!nodes_to_check.empty()) { + cur_node = nodes_to_check.top(); + nodes_to_check.pop(); + const auto &type = cur_node->GetType(); + if (type == VARIABLE && CheckControlEdge(ref, cur_node)) { + // Target variable ref node found. + GELOGI("[RefIdentityDeleteOpPass]variable ref node[%s] found.", cur_node->GetName().c_str()); + first_node = peer_node; + return cur_node; + } + + int data_index = TransOpUtil::GetTransOpDataIndex(type); + if (data_index < 0) { + GELOGI("[RefIdentityDeleteOpPass]Find node[%s] that is not trans op[%s], stop to search its output.", + cur_node->GetName().c_str(), type.c_str()); + continue; + } + const auto &cur_out_anchor = cur_node->GetOutDataAnchor(0); + if (cur_out_anchor == nullptr) { + GELOGI("[RefIdentityDeleteOpPass]Get out anchor of [%s] failed, stop to search its output.", + cur_node->GetName().c_str()); + continue; + } + for (const auto &cur_peer_in_anchor : cur_out_anchor->GetPeerInDataAnchors()) { + const auto &cur_peer_node = cur_peer_in_anchor->GetOwnerNode(); + if (cur_peer_node == nullptr) { + continue; + } + nodes_to_check.push(cur_peer_node); + } + } + GELOGI("[RefIdentityDeleteOpPass]Can not find variable ref node from %s.", peer_node->GetName().c_str()); + } + GELOGI("[RefIdentityDeleteOpPass]Can not find variable ref node, return nullptr."); + return nullptr; +} + +bool RefIdentityDeleteOpPass::CheckControlEdge(const NodePtr &ref, const NodePtr &variable_ref) { + const auto &control_out_anchor = ref->GetOutControlAnchor(); + if (control_out_anchor == nullptr) { + return false; + } + const string &variable_ref_name = variable_ref->GetName(); + for (const auto &peer_in_control_anchor : control_out_anchor->GetPeerInControlAnchors()) { + const auto &node = peer_in_control_anchor->GetOwnerNode(); + if (node != nullptr && node->GetName() == variable_ref_name) { + return true; + } + } + return false; +} + +Status RefIdentityDeleteOpPass::RemoveUselessControlEdge(const NodePtr &ref, const NodePtr &variable_ref) { + map out_nodes_map; + for (const auto &out_anchor : ref->GetAllOutDataAnchors()) { + for (const auto &peer_in_anchor : out_anchor->GetPeerAnchors()) { + const auto &peer_node = peer_in_anchor->GetOwnerNode(); + if (peer_node == nullptr) { + continue; + } + out_nodes_map[peer_node->GetName()] = peer_node; + } + } + const auto &out_control_anchor = variable_ref->GetOutControlAnchor(); + GE_CHECK_NOTNULL(out_control_anchor); + for (const auto &peer_in_control_anchor : out_control_anchor->GetPeerInControlAnchors()) { + const auto &peer_node = peer_in_control_anchor->GetOwnerNode(); + if (peer_node == nullptr) { + continue; + } + if (out_nodes_map.find(peer_node->GetName()) != out_nodes_map.end()) { + auto ret = ge::GraphUtils::RemoveEdge(out_control_anchor, peer_in_control_anchor); + if (ret != SUCCESS) { + GELOGE(FAILED, "Remove control edge between variable ref node[%s] and ref node's peer node[%s] failed", + variable_ref->GetName().c_str(), peer_node->GetName().c_str()); + return FAILED; + } + } + } + return SUCCESS; +} +} // namespace ge diff --git a/src/ge/graph/passes/ref_identity_delete_op_pass.h b/src/ge/graph/passes/ref_identity_delete_op_pass.h new file mode 100644 index 00000000..3e42def4 --- /dev/null +++ b/src/ge/graph/passes/ref_identity_delete_op_pass.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_REF_IDENTITY_DELETE_OP_PASS_H_ +#define GE_GRAPH_PASSES_REF_IDENTITY_DELETE_OP_PASS_H_ + +#include +#include +#include "framework/common/ge_inner_error_codes.h" +#include "inc/graph_pass.h" + +namespace ge { +class RefIdentityDeleteOpPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph); + + private: + Status DealNoOutputRef(const NodePtr &node, const NodePtr &ref_identity, int input_index, + const ComputeGraphPtr &graph); + NodePtr GetVariableRef(const NodePtr &ref, const NodePtr &ref_identity, NodePtr &first_node); + bool CheckControlEdge(const NodePtr &ref, const NodePtr &variable_ref); + Status RemoveUselessControlEdge(const NodePtr &ref, const NodePtr &variable_ref); + NodePtr GetRefNode(const NodePtr &node, int &input_index); +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_REF_IDENTITY_DELETE_OP_PASS_H_ diff --git a/src/ge/graph/passes/reshape_recovery_pass.cc b/src/ge/graph/passes/reshape_recovery_pass.cc index 787c8d83..07b08de9 100644 --- a/src/ge/graph/passes/reshape_recovery_pass.cc +++ b/src/ge/graph/passes/reshape_recovery_pass.cc @@ -30,6 +30,10 @@ NodePtr CreateReshape(const ConstGeTensorDescPtr &src, const ConstGeTensorDescPt if (ret != GRAPH_SUCCESS) { return nullptr; } + ret = reshape->AddInputDesc("shape", GeTensorDesc(GeShape(), Format(), DT_INT32)); + if (ret != GRAPH_SUCCESS) { + return nullptr; + } ret = reshape->AddOutputDesc("y", *dst); if (ret != GRAPH_SUCCESS) { return nullptr; @@ -49,7 +53,10 @@ Status InsertReshapeIfNeed(const NodePtr &node) { GE_CHECK_NOTNULL(dst_node); GE_CHECK_NOTNULL(dst_node->GetOpDesc()); auto dst_tensor = dst_node->GetOpDesc()->GetInputDescPtr(dst_anchor->GetIdx()); - if (src_tensor->GetShape().GetDims() != dst_tensor->GetShape().GetDims()) { + bool is_need_insert_reshape = src_tensor->GetShape().GetDims() != UNKNOWN_RANK && + dst_tensor->GetShape().GetDims() != UNKNOWN_RANK && + src_tensor->GetShape().GetDims() != dst_tensor->GetShape().GetDims(); + if (is_need_insert_reshape) { auto reshape = CreateReshape(src_tensor, dst_tensor, node->GetOwnerComputeGraph()); GE_CHECK_NOTNULL(reshape); auto ret = GraphUtils::InsertNodeBetweenDataAnchors(src_anchor, dst_anchor, reshape); diff --git a/src/ge/graph/passes/resource_pair_add_control_pass.cc b/src/ge/graph/passes/resource_pair_add_control_pass.cc index c5be9600..bba8ee71 100644 --- a/src/ge/graph/passes/resource_pair_add_control_pass.cc +++ b/src/ge/graph/passes/resource_pair_add_control_pass.cc @@ -28,7 +28,6 @@ #include "graph/utils/tensor_adapter.h" namespace { -const char *const kSeparate = "/"; const std::map kResourcePairType = {{"StackPush", "StackPop"}}; const std::set kResourceTypes = {"StackPush", "StackPop"}; } // namespace @@ -41,15 +40,16 @@ Status ResourcePairAddControlPass::Run(ComputeGraphPtr graph) { // find all node of condition type, store with type and scope prefix key for (auto &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node); - if (kResourceTypes.find(node->GetType()) != kResourceTypes.end()) { + auto node_type = node->GetType(); + if (kResourceTypes.find(node_type) != kResourceTypes.end()) { std::string node_name = node->GetName(); - std::string node_prefix; - size_t last_separate_index = node_name.find_last_of(kSeparate); - if (last_separate_index != std::string::npos) { - node_prefix = node_name.substr(0, last_separate_index); + std::string node_key(node_name); + std::size_t found = node_name.rfind(node_type); + if (found != std::string::npos) { + node_key.replace(found, node_type.size(), ""); } - prefix_2_node_per_type[node->GetType()][node_prefix] = node; - GELOGD("ResourcePairAddControlPass insert prefix:%s, op_name:%s, op_type:%s", node_prefix.c_str(), + prefix_2_node_per_type[node_type][node_key] = node; + GELOGD("ResourcePairAddControlPass insert node_key:%s, op_name:%s, op_type:%s", node_key.c_str(), node_name.c_str(), node->GetType().c_str()); } } diff --git a/src/ge/graph/passes/resource_pair_remove_control_pass.cc b/src/ge/graph/passes/resource_pair_remove_control_pass.cc index de3537f0..00d97798 100644 --- a/src/ge/graph/passes/resource_pair_remove_control_pass.cc +++ b/src/ge/graph/passes/resource_pair_remove_control_pass.cc @@ -28,7 +28,6 @@ #include "graph/utils/tensor_adapter.h" namespace { -const char *const kSeparate = "/"; const std::map kResourcePairType = {{"StackPush", "StackPop"}}; const std::set kResourceTypes = {"StackPush", "StackPop"}; } // namespace @@ -41,15 +40,16 @@ Status ResourcePairRemoveControlPass::Run(ComputeGraphPtr graph) { // find all node of condition type, store with type and scope prefix key for (auto &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node); - if (kResourceTypes.find(node->GetType()) != kResourceTypes.end()) { + auto node_type = node->GetType(); + if (kResourceTypes.find(node_type) != kResourceTypes.end()) { std::string node_name = node->GetName(); - std::string node_prefix; - size_t last_separate_index = node_name.find_last_of(kSeparate); - if (last_separate_index != std::string::npos) { - node_prefix = node_name.substr(0, last_separate_index); + std::string node_key(node_name); + std::size_t found = node_name.rfind(node_type); + if (found != std::string::npos) { + node_key.replace(found, node_type.size(), ""); } - prefix_2_node_per_type[node->GetType()][node_prefix] = node; - GELOGD("ResourcePairRemoveControlPass insert prefix:%s, op_name:%s, op_type:%s", node_prefix.c_str(), + prefix_2_node_per_type[node_type][node_key] = node; + GELOGD("ResourcePairRemoveControlPass insert node_key:%s, op_name:%s, op_type:%s", node_key.c_str(), node_name.c_str(), node->GetType().c_str()); } } diff --git a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc index 3b4e4c19..d51f52e1 100644 --- a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc +++ b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.cc @@ -22,7 +22,6 @@ #include #include "common/ge_inner_error_codes.h" #include "common/types.h" -#include "framework/common/debug/ge_log.h" #include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" #include "graph/utils/op_desc_utils.h" @@ -117,20 +116,44 @@ void SameTransdataBreadthFusionPass::InsertSameTransdataNodeIndex(int anchors_in same_transdata_nodes.push_back(anchors_index); } +std::set SameTransdataBreadthFusionPass::GetInControlIdentityNodes(const NodePtr &node, + int subgraph_index) { + std::set in_node_names; + for (const auto &in_node : node->GetInControlNodes()) { + if (in_node->GetType() == IDENTITY) { + in_node_names.insert(in_node->GetName()); + } + } + for (const auto &subgraph_node : before_transdata_nodes_[subgraph_index]) { + for (const auto &in_node : subgraph_node->GetInControlNodes()) { + if (in_node->GetType() == IDENTITY) { + in_node_names.insert(in_node->GetName()); + } + } + } + GELOGD("control in nodes for %s(%d): %zu", node->GetName().c_str(), subgraph_index, in_node_names.size()); + return in_node_names; +} + void SameTransdataBreadthFusionPass::GetSameTransdataNode(vector &same_transdata_nodes) { auto iter = all_transdata_nodes_.begin(); same_transdata_nodes.push_back(iter->first); + auto node_for_compare_in_anchor = iter->second; GE_CHECK_NOTNULL_JUST_RETURN(node_for_compare_in_anchor); auto node_for_compare = node_for_compare_in_anchor->GetOwnerNode(); + + // Get op-desc, input/output desc, in-control-edges-from-identity, as the compare-key auto op_desc_for_compare = node_for_compare->GetOpDesc(); GE_CHECK_NOTNULL_JUST_RETURN(op_desc_for_compare); string op_compare_stream_label; (void)AttrUtils::GetStr(op_desc_for_compare, ATTR_NAME_STREAM_LABEL, op_compare_stream_label); + auto op_compare_in_ctrl_nodes = GetInControlIdentityNodes(node_for_compare, iter->first); auto input_desc_for_compare = op_desc_for_compare->GetInputDescPtr(node_for_compare_in_anchor->GetIdx()); GE_CHECK_NOTNULL_JUST_RETURN(input_desc_for_compare); auto output_desc_for_compare = op_desc_for_compare->GetOutputDescPtr(0); GE_CHECK_NOTNULL_JUST_RETURN(output_desc_for_compare); + iter = all_transdata_nodes_.erase(iter); while (iter != all_transdata_nodes_.end()) { auto in_anchor = iter->second; @@ -149,12 +172,14 @@ void SameTransdataBreadthFusionPass::GetSameTransdataNode(vector &same_tran auto output_desc_tmp = op_desc_tmp->GetOutputDescPtr(0); string op_tmp_stream_label; (void)AttrUtils::GetStr(op_desc_tmp, ATTR_NAME_STREAM_LABEL, op_tmp_stream_label); + auto op_tmp_in_ctrl_nodes = GetInControlIdentityNodes(node_tmp, iter->first); GE_CHECK_NOTNULL_JUST_RETURN(input_desc_tmp); GE_CHECK_NOTNULL_JUST_RETURN(output_desc_tmp); if ((op_compare_stream_label == op_tmp_stream_label) && (input_desc_tmp->GetFormat() == input_desc_for_compare->GetFormat()) && - (output_desc_tmp->GetFormat() == output_desc_for_compare->GetFormat())) { + (output_desc_tmp->GetFormat() == output_desc_for_compare->GetFormat()) && + (op_compare_in_ctrl_nodes == op_tmp_in_ctrl_nodes)) { GELOGD("same transdata node:%s, src node:%s", node_tmp->GetName().c_str(), node_for_compare->GetName().c_str()); InsertSameTransdataNodeIndex(iter->first, same_transdata_nodes); iter = all_transdata_nodes_.erase(iter); @@ -339,14 +364,13 @@ graphStatus SameTransdataBreadthFusionPass::ReLinkTransdataControlOutput2PreNode } graphStatus SameTransdataBreadthFusionPass::Run(ComputeGraphPtr graph) { - GE_TIMESTAMP_START(SameTransdataBreadthFusionPass); GELOGI("[SameTransdataBreadthFusionPass]: optimize begin."); if (graph == nullptr) { return GRAPH_SUCCESS; } for (auto &node : graph->GetDirectNode()) { - if (IsTransOp(node) || node->GetOutDataNodes().size() <= 1) { + if (IsTransOp(node) || node->GetOutDataNodesSize() <= 1) { continue; } @@ -374,7 +398,6 @@ graphStatus SameTransdataBreadthFusionPass::Run(ComputeGraphPtr graph) { } GELOGI("[SameTransdataBreadthFusionPass]: Optimize success."); - GE_TIMESTAMP_END(SameTransdataBreadthFusionPass, "GraphManager::SameTransdataBreadthFusionPass"); return GRAPH_SUCCESS; } diff --git a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.h b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.h index f4b44a59..a6a3bb26 100644 --- a/src/ge/graph/passes/same_transdata_breadth_fusion_pass.h +++ b/src/ge/graph/passes/same_transdata_breadth_fusion_pass.h @@ -42,7 +42,7 @@ class SameTransdataBreadthFusionPass : public GraphPass { void GetSubGraphNodesInfo(); void EraseInvalidAnchorsPair(); - + std::set GetInControlIdentityNodes(const NodePtr &node, int subgraph_index); OpDescPtr GetCastOp(const GeTensorDesc &in_desc, const GeTensorDesc &out_desc); graphStatus AddCastNode(const ComputeGraphPtr &graph, int anchors_index, OutDataAnchorPtr &pre_out_anchor, diff --git a/src/ge/graph/passes/set_input_output_offset_pass.cc b/src/ge/graph/passes/set_input_output_offset_pass.cc new file mode 100644 index 00000000..58c3be85 --- /dev/null +++ b/src/ge/graph/passes/set_input_output_offset_pass.cc @@ -0,0 +1,285 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/set_input_output_offset_pass.h" + +#include "runtime/mem.h" + +namespace ge { +Status SetInputOutputOffsetPass::Run(ComputeGraphPtr graph) { + GE_CHECK_NOTNULL(graph); + for (auto &node : graph->GetDirectNode()) { + GE_CHECK_NOTNULL(node->GetOpDesc()); + vector connect_input; + (void)AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_NODE_CONNECT_INPUT, connect_input); + if (!connect_input.empty()) { + Status ret = SetInputOffset(node, connect_input); + if (ret != SUCCESS) { + GELOGE(ret, "SetInputOffset failed."); + return ret; + } + } + vector connect_output; + (void)AttrUtils::GetListInt(node->GetOpDesc(), ATTR_NAME_NODE_CONNECT_OUTPUT, connect_output); + if (!connect_output.empty()) { + Status ret = SetOutputOffset(node, connect_output); + if (ret != SUCCESS) { + GELOGE(ret, "SetOutputOffset failed."); + return ret; + } + } + } + return SUCCESS; +} + +Status SetInputOutputOffsetPass::SetInputOffsetForFusion(const std::vector &memory_type, + const ge::NodePtr &node) { + GELOGI("Start to SetInputOffsetForFusion for %s", node->GetName().c_str()); + auto op_desc = node->GetOpDesc(); + for (size_t i = 0; i < memory_type.size(); ++i) { + if (memory_type.at(i) != RT_MEMORY_L1) { + std::vector input_offset_of_node; + input_offset_of_node = op_desc->GetInputOffset(); + if (input_offset_of_node.size() < i) { + GELOGE(PARAM_INVALID, "not get input_offset of %zu", i); + return PARAM_INVALID; + } + int64_t input_offset = input_offset_of_node.at(i); + GELOGI("input_offset of %s is %ld.", node->GetName().c_str(), input_offset); + auto in_anchor = node->GetInDataAnchor(i); + GE_IF_BOOL_EXEC(in_anchor == nullptr, continue); + auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + int out_index = peer_out_anchor->GetIdx(); + auto data_op_desc = peer_out_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHECK_NOTNULL(data_op_desc); + int64_t out_offset = data_op_desc->GetOutputOffset().at(out_index); + GELOGI("output_offset of %s is %ld.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), out_offset); + vector zero_copy_basic_offset; + vector zero_copy_relative_offset; + + (void)ge::AttrUtils::GetListInt(data_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset); + (void)ge::AttrUtils::GetListInt(data_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset); + zero_copy_basic_offset.emplace_back(out_offset); + int64_t relative_offset = input_offset - out_offset; + zero_copy_relative_offset.emplace_back(relative_offset); + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(data_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset), + GELOGE(FAILED, "SetListInt of zero_copy_basic_offset failed."); + return FAILED); + GE_CHK_BOOL_EXEC( + ge::AttrUtils::SetListInt(data_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset), + GELOGE(FAILED, "SetListInt of zero_copy_relative_offset failed."); + return FAILED); + } + } + return SUCCESS; +} + +Status SetInputOutputOffsetPass::SetInputOffsetForHcom(const ge::NodePtr &node, const vector &connect_input) { + GELOGI("Start SetInputOffsetForHcom for %s.", node->GetName().c_str()); + + auto op_desc = node->GetOpDesc(); + vector input_offset_of_node; + input_offset_of_node = node->GetOpDesc()->GetInputOffset(); + for (size_t input_index = 0; input_index < connect_input.size(); ++input_index) { + int connect_input_index = connect_input.at(input_index); + int64_t input_offset = input_offset_of_node.at(connect_input_index); + NodePtr in_data = node->GetInDataNodes().at(connect_input_index); + auto in_op_desc = in_data->GetOpDesc(); + GE_CHECK_NOTNULL(in_op_desc); + if (in_op_desc->GetType() == DATA) { + int64_t output_offset = in_op_desc->GetOutputOffset().at(0); + if (output_offset == input_offset) { + continue; + } else { + vector zero_copy_basic_offset; + vector zero_copy_relative_offset; + (void)ge::AttrUtils::GetListInt(in_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset); + (void)ge::AttrUtils::GetListInt(in_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset); + GELOGI("input offset from %s to %s is %ld to %ld.", in_op_desc->GetName().c_str(), op_desc->GetName().c_str(), + output_offset, input_offset); + int64_t relative_offset = input_offset - output_offset; + zero_copy_basic_offset.emplace_back(output_offset); + zero_copy_relative_offset.emplace_back(relative_offset); + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(in_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset), + GELOGE(FAILED, "SetListInt of zero_copy_basic_offset failed."); + return FAILED); + GE_CHK_BOOL_EXEC( + ge::AttrUtils::SetListInt(in_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset), + GELOGE(FAILED, "SetListInt of zero_copy_relative_offset failed."); + return FAILED); + } + } + } + return SUCCESS; +} + +Status SetInputOutputOffsetPass::SetInputOffset(const NodePtr &node, const vector &connect_input) { + GELOGI("Start to SetInputOffset for %s.", node->GetName().c_str()); + std::vector memory_type; + auto op_desc = node->GetOpDesc(); + (void)ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_INPUT_MEM_TYPE_LIST, memory_type); + if (!memory_type.empty()) { + Status ret = SetInputOffsetForFusion(memory_type, node); + if (ret != SUCCESS) { + GELOGE(ret, "SetInputOffsetForFusion failed."); + return ret; + } + } + // Data->Hcom + bool is_input_continuous = false; + (void)ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous); + if (is_input_continuous) { + Status ret = SetInputOffsetForHcom(node, connect_input); + if (ret != SUCCESS) { + GELOGE(ret, "SetInputOffsetForHcom failed."); + return ret; + } + } + return SUCCESS; +} + +Status SetInputOutputOffsetPass::SetOutputOffsetForConcat(const NodePtr &node) { + GELOGI("Start SetOutputOffsetForConcat for %s.", node->GetName().c_str()); + auto op_desc = node->GetOpDesc(); + std::vector output_offset_of_concat; + output_offset_of_concat = op_desc->GetOutputOffset(); + // phony_concat has one output + GE_IF_BOOL_EXEC(output_offset_of_concat.size() != 1, + GELOGE(PARAM_INVALID, "%s should has one output.", node->GetName().c_str()); + return PARAM_INVALID); + NodePtr net_output = node->GetOutDataNodes().at(0); + auto out_op_desc = net_output->GetOpDesc(); + GE_CHECK_NOTNULL(out_op_desc); + vector zero_copy_basic_offset; + vector zero_copy_relative_offset; + (void)ge::AttrUtils::GetListInt(out_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset); + (void)ge::AttrUtils::GetListInt(out_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset); + + int64_t basic_offset = output_offset_of_concat.at(0); + GELOGI("output_offset of %s is %ld.", op_desc->GetName().c_str(), basic_offset); + for (InDataAnchorPtr &in_anchor : node->GetAllInDataAnchors()) { + OutDataAnchorPtr peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); + NodePtr in_node = peer_out_anchor->GetOwnerNode(); + auto out_index = peer_out_anchor->GetIdx(); + std::vector output_offset_of_in_node; + GE_CHECK_NOTNULL(in_node->GetOpDesc()); + output_offset_of_in_node = in_node->GetOpDesc()->GetOutputOffset(); + GELOGI("input offset from %s to %s is %ld.", in_node->GetName().c_str(), op_desc->GetName().c_str(), + output_offset_of_in_node.at(out_index)); + int64_t relative_offset = output_offset_of_in_node.at(out_index) - basic_offset; + zero_copy_basic_offset.emplace_back(basic_offset); + zero_copy_relative_offset.emplace_back(relative_offset); + } + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(out_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset), + GELOGE(FAILED, "SetListInt of zero_copy_basic_offset failed."); + return FAILED); + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(out_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset), + GELOGE(FAILED, "SetListInt of zero_copy_relative_offset failed."); + return FAILED); + return SUCCESS; +} + +Status SetInputOutputOffsetPass::SetOutputOffsetForHcom(const NodePtr &node, const vector &connect_output) { + GELOGI("Start SetOutputOffsetForHcom, %s connect with %zu output.", node->GetName().c_str(), connect_output.size()); + vector output_offset_of_node; + output_offset_of_node = node->GetOpDesc()->GetOutputOffset(); + int connect_output_index = connect_output.at(0); + int64_t basic_offset = output_offset_of_node.at(connect_output_index); + GELOGI("basic_offset of %s is %ld.", node->GetName().c_str(), basic_offset); + + NodePtr net_output = node->GetOutDataNodes().at(connect_output_index); + auto out_op_desc = net_output->GetOpDesc(); + GE_CHECK_NOTNULL(out_op_desc); + vector zero_copy_basic_offset; + vector zero_copy_relative_offset; + (void)ge::AttrUtils::GetListInt(out_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset); + (void)ge::AttrUtils::GetListInt(out_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset); + + for (auto &out_anchor : node->GetAllOutDataAnchors()) { + GE_IF_BOOL_EXEC(out_anchor == nullptr, continue); + for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_IF_BOOL_EXEC(in_anchor == nullptr, continue); + if (in_anchor->GetOwnerNode()->GetType() == NETOUTPUT && out_anchor->GetIdx() != connect_output_index) { + continue; + } else { + NodePtr out_node = in_anchor->GetOwnerNode(); + auto in_index = in_anchor->GetIdx(); + std::vector input_offset_of_out_node; + GE_CHECK_NOTNULL(out_node->GetOpDesc()); + input_offset_of_out_node = out_node->GetOpDesc()->GetInputOffset(); + GELOGI("input offset from %s to %s is %ld.", node->GetName().c_str(), out_node->GetName().c_str(), + input_offset_of_out_node.at(in_index)); + int64_t relative_offset = input_offset_of_out_node.at(in_index) - basic_offset; + zero_copy_basic_offset.emplace_back(basic_offset); + zero_copy_relative_offset.emplace_back(relative_offset); + } + } + } + + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(out_op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset), + GELOGE(FAILED, "SetListInt of zero_copy_basic_offset failed."); + return FAILED); + GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(out_op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset), + GELOGE(FAILED, "SetListInt of zero_copy_relative_offset failed."); + return FAILED); + return SUCCESS; +} + +Status SetInputOutputOffsetPass::SetOutputOffset(const NodePtr &node, const vector &connect_output) { + GELOGI("Start SetOutputOffset of %s.", node->GetName().c_str()); + bool attr_no_task = false; + bool get_attr_no_task = ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_NOTASK, attr_no_task); + if (get_attr_no_task && attr_no_task) { + bool is_input_continuous = false; + (void)ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CONTINUOUS_INPUT, is_input_continuous); + bool buffer_fusion = CheckBufferFusion(node); + // A/B/C -> Phony_concat -> Netoutput : input_continuous + if (is_input_continuous || buffer_fusion) { + Status ret = SetOutputOffsetForConcat(node); + if (ret != SUCCESS) { + GELOGE(ret, "SetOutputOffsetForConcat failed."); + return ret; + } + } + } + // allreduce->netoutput : output_continuous + bool is_output_continuous = false; + (void)ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CONTINUOUS_OUTPUT, is_output_continuous); + if (is_output_continuous) { + Status ret = SetOutputOffsetForHcom(node, connect_output); + if (ret != SUCCESS) { + GELOGE(ret, "SetOutputOffsetForHcom failed."); + return ret; + } + } + return SUCCESS; +} + +bool SetInputOutputOffsetPass::CheckBufferFusion(const NodePtr &node) { + for (auto &in_node : node->GetInDataNodes()) { + GE_CHECK_NOTNULL(in_node); + auto op_desc = in_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (!op_desc->HasAttr(ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION)) { + GELOGI("The node: %s not have ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION.", node->GetName().c_str()); + return false; + } + } + return true; +} +} // namespace ge \ No newline at end of file diff --git a/src/ge/graph/passes/set_input_output_offset_pass.h b/src/ge/graph/passes/set_input_output_offset_pass.h new file mode 100644 index 00000000..24f9f6c4 --- /dev/null +++ b/src/ge/graph/passes/set_input_output_offset_pass.h @@ -0,0 +1,38 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_SET_INPUT_OUTPUT_OFFSET_PASS_H_ +#define GE_GRAPH_PASSES_SET_INPUT_OUTPUT_OFFSET_PASS_H_ + +#include "inc/graph_pass.h" + +namespace ge { +class SetInputOutputOffsetPass : public GraphPass { + public: + Status Run(ComputeGraphPtr graph) override; + + private: + Status SetInputOffset(const NodePtr &node, const vector &connect_input); + Status SetOutputOffset(const NodePtr &node, const vector &connect_output); + Status SetInputOffsetForFusion(const std::vector &memory_type, const ge::NodePtr &node); + Status SetInputOffsetForHcom(const NodePtr &node, const vector &connect_input); + Status SetOutputOffsetForConcat(const NodePtr &node); + Status SetOutputOffsetForHcom(const NodePtr &node, const vector &connect_output); + + bool CheckBufferFusion(const NodePtr &node); +}; +} // namespace ge +#endif // GE_GRAPH_PASSES_SET_INPUT_OUTPUT_OFFSET_PASS_H_ diff --git a/src/ge/graph/passes/subgraph_pass.cc b/src/ge/graph/passes/subgraph_pass.cc index d759aa12..80ce995a 100644 --- a/src/ge/graph/passes/subgraph_pass.cc +++ b/src/ge/graph/passes/subgraph_pass.cc @@ -15,7 +15,6 @@ */ #include "graph/passes/subgraph_pass.h" -#include #include "graph/utils/node_utils.h" #include "graph/utils/op_desc_utils.h" #include "graph/utils/tensor_utils.h" @@ -67,13 +66,13 @@ Status SubgraphPass::Run(ComputeGraphPtr graph) { /** * @ingroup ge - * @brief Check Subgraph NetOutput node + * @brief Check Subgraph Input node * @param [in] graph: ComputeGraph. - * @param [in] node: NetOutput node in Subgraph. + * @param [in] node: Data node in Subgraph. * @return: 0 for SUCCESS / others for FAILED */ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodePtr &node) { - GELOGD("Hadle input_node %s for graph %s.", node->GetName().c_str(), graph->GetName().c_str()); + GELOGD("Handle input_node %s for graph %s.", node->GetName().c_str(), graph->GetName().c_str()); // Data has and only has one output bool input_continues_required_flag = false; OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(0); @@ -86,7 +85,7 @@ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodeP // Data->InputContinuesRequiredOp in subgraph need memcpy. if (input_continues_required_flag) { GELOGD("Data %s output_node required continues input.", node->GetName().c_str()); - std::string name = node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + std::string name = node->GetName() + "_output_0_Memcpy"; if (InsertMemcpyNode(graph, out_data_anchor, in_anchors, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy after %s failed.", node->GetName().c_str()); return FAILED; @@ -123,7 +122,7 @@ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodeP GE_CHECK_NOTNULL(peer_out_anchor); GELOGD("Constant input %s links to While %s.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), parent_node->GetName().c_str()); - std::string name = in_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + std::string name = parent_node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; if (InsertMemcpyNode(parent_graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy between %s and %s failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), parent_node->GetName().c_str()); @@ -136,7 +135,7 @@ Status SubgraphPass::SubgraphInputNode(const ComputeGraphPtr &graph, const NodeP /** * @ingroup ge - * @brief Check Subgraph NetOutput node + * @brief Check Subgraph Output node * @param [in] graph: ComputeGraph. * @param [in] node: NetOutput node in Subgraph. * @return: 0 for SUCCESS / others for FAILED @@ -153,14 +152,14 @@ Status SubgraphPass::SubgraphOutputNode(const ComputeGraphPtr &graph, const Node // 1. Const->NetOutput in subgraph // 2. AtomicOp->NetOutput in subgraph // 3. OutputContinuesRequiredOp->NetOutput in subgraph - // 4. Data->NetOutput in subgraph but not while body + // 4. Data->NetOutput in subgraph but parent_node is not while std::string op_type; bool insert_flag = NodeUtils::GetConstOpType(in_node, op_type) || IsAtomicRequired(in_node, peer_out_anchor->GetIdx()) || IsOutputContinuesRequired(in_node) || - ((in_node->GetType() == DATA) && !IsWhileBodyOutput(in_data_anchor)); + ((in_node->GetType() == DATA) && (kWhileOpTypes.count(graph->GetParentNode()->GetType()) == 0)); if (insert_flag) { - GELOGI("Insert MemcpyAsync node between %s and %s.", node->GetName().c_str(), in_node->GetName().c_str()); - std::string name = in_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + GELOGD("Insert MemcpyAsync node between %s and %s.", in_node->GetName().c_str(), node->GetName().c_str()); + std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str()); return FAILED; @@ -186,8 +185,8 @@ Status SubgraphPass::WhileInputNodes(const ComputeGraphPtr &graph, const NodePtr GE_CHECK_NOTNULL(in_node); // Input->While and Input link to other nodes need insert memcpy if (peer_out_anchor->GetPeerInDataAnchors().size() > 1) { - GELOGI("Input %s of While %s links to other nodes.", in_node->GetName().c_str(), node->GetName().c_str()); - std::string name = in_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); + GELOGD("Input %s of While %s links to other nodes.", in_node->GetName().c_str(), node->GetName().c_str()); + std::string name = node->GetName() + "_input_" + std::to_string(in_data_anchor->GetIdx()) + "_Memcpy"; if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { GELOGE(FAILED, "Insert memcpy between %s and %s failed.", in_node->GetName().c_str(), node->GetName().c_str()); return FAILED; @@ -206,231 +205,121 @@ Status SubgraphPass::WhileInputNodes(const ComputeGraphPtr &graph, const NodePtr * @return: 0 for SUCCESS / others for FAILED */ Status SubgraphPass::WhileBodySubgraph(const ComputeGraphPtr &graph, const NodePtr &node) { - ComputeGraphPtr while_body = GetWhileBodySubgraph(graph, node); + // index of body_subgraph is 1 + ComputeGraphPtr while_body = NodeUtils::GetSubgraph(*node, 1); if (while_body == nullptr) { GELOGE(FAILED, "while_body of %s is NULL.", node->GetName().c_str()); return FAILED; } - NodePtr output_node = while_body->FindFirstNodeMatchType(NETOUTPUT); - if (output_node == nullptr) { - GELOGE(FAILED, "net_output_node not exist in graph %s.", while_body->GetName().c_str()); - return FAILED; - } - OpDescPtr output_desc = output_node->GetOpDesc(); - GE_CHECK_NOTNULL(output_desc); - std::unordered_map> node_to_attr_index; - for (const InDataAnchorPtr &in_data_anchor : output_node->GetAllInDataAnchors()) { - uint32_t index = 0; - if (!AttrUtils::GetInt(output_desc->GetInputDesc(in_data_anchor->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, index)) { - GELOGE(FAILED, "Get attr PARENT_NODE_INDEX failed, node %s:%u.", output_node->GetName().c_str(), - in_data_anchor->GetIdx()); - return FAILED; + std::vector data_nodes; + std::set bypass_index; + NodePtr output_node = nullptr; + for (const auto &n : while_body->GetDirectNode()) { + const std::string &type = n->GetType(); + if (type == DATA) { + if (CheckInsertInputMemcpy(n, bypass_index)) { + data_nodes.emplace_back(n); + } + } else if (type == NETOUTPUT) { + if (output_node == nullptr) { + output_node = n; + } else { + GELOGE(FAILED, "while_body %s exists multi NetOutput nodes.", while_body->GetName().c_str()); + return FAILED; + } } - MarkOutputIndex(in_data_anchor->GetPeerOutAnchor(), index, node_to_attr_index); } - - std::set data_nodes; - std::set netoutput_input_indexes; - GetExchangeInOut(node_to_attr_index, data_nodes, netoutput_input_indexes); - return InsertMemcpyInWhileBody(while_body, data_nodes, output_node, netoutput_input_indexes); -} - -/** - * @ingroup ge - * @brief Get body subgraph of While op - * @param [in] graph: ComputeGraph. - * @param [in] node: While node. - * @return: body subgraph - */ -ComputeGraphPtr SubgraphPass::GetWhileBodySubgraph(const ComputeGraphPtr &graph, const NodePtr &node) { - OpDescPtr op_desc = node->GetOpDesc(); - if (op_desc == nullptr) { - GELOGE(FAILED, "op_desc is NULL."); - return nullptr; - } - - const std::vector &subgraph_instance_names = op_desc->GetSubgraphInstanceNames(); - std::string body_instance_name; - for (const std::string &instance_name : subgraph_instance_names) { - std::string subgraph_name; - if (op_desc->GetSubgraphNameByInstanceName(instance_name, subgraph_name) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Get subgraph_name by instance_name %s failed, node:%s.", instance_name.c_str(), - node->GetName().c_str()); - return nullptr; - } - if (subgraph_name == ATTR_NAME_WHILE_BODY) { - body_instance_name = instance_name; - break; - } + if (output_node == nullptr) { + GELOGE(FAILED, "while_body %s has no output.", while_body->GetName().c_str()); + return FAILED; } - ComputeGraphPtr root_graph = GraphUtils::FindRootGraph(graph); - if (root_graph == nullptr) { - GELOGE(FAILED, "root_graph is NULL."); - return nullptr; + if ((InsertInputMemcpy(while_body, data_nodes) != SUCCESS) || + (InsertOutputMemcpy(while_body, output_node, bypass_index) != SUCCESS)) { + GELOGE(FAILED, "Insert memcpy node in while_body %s failed.", while_body->GetName().c_str()); + return FAILED; } - return root_graph->GetSubgraph(body_instance_name); + return SUCCESS; } /** * @ingroup ge - * @brief Mark output parent_node_index - * @param [in] peer_out_anchor: peer_out_anchor of NetOutput - * @param [in] index: parent_node_index of NetOutput - * @param [out] node_to_attr_index: key for node in subgraph, value for parent_node_index - * @return: void + * @brief Insert input memcpy node in while_body + * @param [in] graph: while_body + * @param [in] data_nodes: data_nodes + * @return: 0 for SUCCESS / others for FAILED */ -void SubgraphPass::MarkOutputIndex(const OutDataAnchorPtr &peer_out_anchor, uint32_t index, - std::unordered_map> &node_to_attr_index) { - if (peer_out_anchor == nullptr) { - return; - } - std::set visited_nodes; - std::stack nodes; - nodes.emplace(peer_out_anchor->GetOwnerNode()); - while (!nodes.empty()) { - NodePtr cur_node = nodes.top(); - nodes.pop(); - if (visited_nodes.count(cur_node) > 0) { - continue; - } - node_to_attr_index[cur_node].emplace_back(index); - for (const NodePtr &in_node : cur_node->GetInDataNodes()) { - nodes.emplace(in_node); - } - visited_nodes.emplace(cur_node); +Status SubgraphPass::InsertInputMemcpy(const ComputeGraphPtr &graph, const std::vector &data_nodes) { + if (data_nodes.empty()) { + GELOGD("No need to insert input memcpy node in while_body %s.", graph->GetName().c_str()); + return SUCCESS; } -} - -/** - * @ingroup ge - * @brief Get data_nodes / input_indexes of netoutput if need insert memcpy - * @param [in] node_to_attr_index: key for node in subgraph, value for parent_node_index - * @param [out] data_nodes: data_nodes need insert memcpy - * @param [out] netoutput_input_indexes: input_indexes of netoutput need insert memcpy - * @return: void - */ -void SubgraphPass::GetExchangeInOut(const std::unordered_map> &node_to_attr_index, - std::set &data_nodes, std::set &netoutput_input_indexes) { - for (const auto &item : node_to_attr_index) { - NodePtr node = item.first; - uint32_t input_index = 0; - if ((node->GetType() != DATA) || !AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, input_index)) { - continue; - } - if (item.second.empty() || ((item.second.size() == 1) && (item.second[0] == input_index))) { - continue; - } - data_nodes.emplace(node); + std::string in_name = graph->GetName() + "_input_Memcpy"; + OpDescBuilder in_builder(in_name, MEMCPYASYNC); + for (size_t i = 0; i < data_nodes.size(); i++) { // Data node has and only has one output - OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(0); - if (out_data_anchor == nullptr) { - continue; - } - for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - NodePtr out_node = peer_in_anchor->GetOwnerNode(); - if ((out_node->GetType() != NETOUTPUT) || (out_node->GetOpDesc() == nullptr)) { - continue; - } - uint32_t output_index = 0; - GeTensorDesc input_tensor = out_node->GetOpDesc()->GetInputDesc(peer_in_anchor->GetIdx()); - if (!AttrUtils::GetInt(input_tensor, ATTR_NAME_PARENT_NODE_INDEX, output_index)) { - continue; - } - if (input_index != output_index) { - netoutput_input_indexes.emplace(peer_in_anchor->GetIdx()); - } - } + in_builder.AddInput("x" + std::to_string(i), data_nodes[i]->GetOpDesc()->GetOutputDesc(0)) + .AddOutput("y" + std::to_string(i), data_nodes[i]->GetOpDesc()->GetOutputDesc(0)); } -} - -/** - * @ingroup ge - * @brief Insert memcpy node in while_body - * @param [in] graph: while_body - * @param [in] data_nodes: data_nodes need insert memcpy - * @param [in] output_node: NetOutput in while_body - * @param [in] netoutput_input_indexes: input_indexes of netoutput need insert memcpy - * @return: 0 for SUCCESS / others for FAILED - */ -Status SubgraphPass::InsertMemcpyInWhileBody(const ComputeGraphPtr &graph, const std::set &data_nodes, - const NodePtr &output_node, - const std::set &netoutput_input_indexes) { - for (const NodePtr &data_node : data_nodes) { + GELOGD("Insert memcpy after data_nodes of while_body %s.", graph->GetName().c_str()); + NodePtr in_memcpy = graph->AddNode(in_builder.Build()); + GE_CHECK_NOTNULL(in_memcpy); + for (size_t i = 0; i < data_nodes.size(); i++) { // Data node has and only has one output - OutDataAnchorPtr out_data_anchor = data_node->GetOutDataAnchor(0); + OutDataAnchorPtr out_data_anchor = data_nodes[i]->GetOutDataAnchor(0); std::vector in_anchors; for (const InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { in_anchors.emplace_back(peer_in_anchor); } - std::string name = data_node->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); - GELOGD("Insert memcpy after while_body %s input_node %s.", graph->GetName().c_str(), data_node->GetName().c_str()); - if (InsertMemcpyNode(graph, out_data_anchor, in_anchors, name) != SUCCESS) { - GELOGE(FAILED, "Insert MemcpyAsync node %s after %s failed.", name.c_str(), data_node->GetName().c_str()); - return FAILED; - } - } - - for (uint32_t index : netoutput_input_indexes) { - InDataAnchorPtr in_data_anchor = output_node->GetInDataAnchor(index); - GE_CHECK_NOTNULL(in_data_anchor); - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_anchor); - std::string name = - peer_out_anchor->GetOwnerNode()->GetName() + "_" + MEMCPYASYNC + "_" + std::to_string(memcpy_num_++); - GELOGD("Insert memcpy after while_body %s output %u.", graph->GetName().c_str(), index); - if (InsertMemcpyNode(graph, peer_out_anchor, {in_data_anchor}, name) != SUCCESS) { - GELOGE(FAILED, "Insert MemcpyAsync node %s after %s failed.", name.c_str(), - peer_out_anchor->GetOwnerNode()->GetName().c_str()); + if (InsertNodeBetween(out_data_anchor, in_anchors, in_memcpy, i, i) != SUCCESS) { + GELOGE(FAILED, "Insert MemcpyAsync %s in while_body %s failed.", in_name.c_str(), graph->GetName().c_str()); return FAILED; } } - std::set memcpy_nodes; - std::set loop_body_nodes; - for (const NodePtr &data_node : data_nodes) { - // data_node has only one output node - NodePtr memcpy_node = data_node->GetOutDataNodes().at(0); - GE_CHECK_NOTNULL(memcpy_node); - memcpy_nodes.emplace(memcpy_node); - for (const NodePtr &out_node : memcpy_node->GetOutDataNodes()) { - loop_body_nodes.insert(out_node); - } - } - return InsertNoOp(graph, memcpy_nodes, loop_body_nodes); + return SUCCESS; } /** * @ingroup ge - * @brief Insert NoOp node between memcpy_nodes and loop_body_nodes + * @brief Insert output memcpy node in while_body * @param [in] graph: while_body - * @param [in] memcpy_nodes - * @param [in] loop_body_nodes + * @param [in] output_node: NetOutput + * @param [in] bypass_index * @return: 0 for SUCCESS / others for FAILED */ -Status SubgraphPass::InsertNoOp(const ComputeGraphPtr &graph, const std::set &memcpy_nodes, - const std::set &loop_body_nodes) { - if (memcpy_nodes.empty() || loop_body_nodes.empty()) { +Status SubgraphPass::InsertOutputMemcpy(const ComputeGraphPtr &graph, const NodePtr &output_node, + const std::set &bypass_index) { + if (output_node->GetAllInDataAnchorsSize() == bypass_index.size()) { + GELOGD("No need to insert output memcpy node in while_body %s, output_size=%zu, bypass_num=%zu.", + graph->GetName().c_str(), output_node->GetAllInDataAnchorsSize(), bypass_index.size()); return SUCCESS; } - OpDescBuilder noop_desc_builder("NoOp_for_Control", NOOP); - OpDescPtr noop_desc = noop_desc_builder.Build(); - NodePtr noop_node = graph->AddNode(noop_desc); - GE_CHECK_NOTNULL(noop_node); - for (const NodePtr &memcpy_node : memcpy_nodes) { - if (GraphUtils::AddEdge(memcpy_node->GetOutControlAnchor(), noop_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add ctrl edge %s->%s failed.", memcpy_node->GetName().c_str(), noop_node->GetName().c_str()); - return FAILED; + std::string out_name = graph->GetName() + "_output_Memcpy"; + OpDescBuilder out_builder(out_name, MEMCPYASYNC); + for (size_t i = 0; i < output_node->GetAllInDataAnchorsSize(); i++) { + if (bypass_index.count(i) == 0) { + out_builder.AddInput("x" + std::to_string(i), output_node->GetOpDesc()->GetInputDesc(i)) + .AddOutput("y" + std::to_string(i), output_node->GetOpDesc()->GetInputDesc(i)); } } - for (const NodePtr &loop_body_node : loop_body_nodes) { - if (GraphUtils::AddEdge(noop_node->GetOutControlAnchor(), loop_body_node->GetInControlAnchor()) != GRAPH_SUCCESS) { - GELOGE(FAILED, "Add ctrl edge %s->%s failed.", noop_node->GetName().c_str(), loop_body_node->GetName().c_str()); - return FAILED; + GELOGD("Insert memcpy before NetOutput of while_body %s.", graph->GetName().c_str()); + NodePtr out_memcpy = graph->AddNode(out_builder.Build()); + GE_CHECK_NOTNULL(out_memcpy); + size_t cnt = 0; + for (size_t i = 0; i < output_node->GetAllInDataAnchorsSize(); i++) { + if (bypass_index.count(i) == 0) { + InDataAnchorPtr in_data_anchor = output_node->GetInDataAnchor(i); + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + if (InsertNodeBetween(peer_out_anchor, {in_data_anchor}, out_memcpy, cnt, cnt) != SUCCESS) { + GELOGE(FAILED, "Insert MemcpyAsync %s in while_body %s failed.", out_name.c_str(), graph->GetName().c_str()); + return FAILED; + } + cnt++; } } @@ -439,28 +328,39 @@ Status SubgraphPass::InsertNoOp(const ComputeGraphPtr &graph, const std::setnetoutput in while body - * @param [in] in_data_anchor - * @return: true for data->netoutput in while body / for false for others + * @brief Check is data->netoutput without change in while body + * @param [in] node: data node + * @param [out] bypass_index + * @return: false for data->netoutput without change in while body / for true for others */ -bool SubgraphPass::IsWhileBodyOutput(const InDataAnchorPtr &in_data_anchor) { - // Check is subgraph - NodePtr parent_node = in_data_anchor->GetOwnerNode()->GetOwnerComputeGraph()->GetParentNode(); - if (parent_node == nullptr) { - return false; +bool SubgraphPass::CheckInsertInputMemcpy(const NodePtr &node, std::set &bypass_index) { + uint32_t input_index = 0; + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, input_index)) { + return true; } - // Check if parent_node is While - if (kWhileOpTypes.count(parent_node->GetType()) == 0) { - return false; + // Data node has and only has one output + OutDataAnchorPtr out_data_anchor = node->GetOutDataAnchor(0); + if ((out_data_anchor == nullptr) || (out_data_anchor->GetPeerInDataAnchors().size() != 1)) { + return true; + } + InDataAnchorPtr peer_in_anchor = out_data_anchor->GetPeerInDataAnchors().at(0); + if (peer_in_anchor->GetOwnerNode()->GetType() != NETOUTPUT) { + return true; } - // While cond / body - OpDescPtr op_desc = in_data_anchor->GetOwnerNode()->GetOpDesc(); - if (op_desc == nullptr) { - return false; + OpDescPtr op_desc = peer_in_anchor->GetOwnerNode()->GetOpDesc(); + uint32_t output_index = 0; + if ((op_desc == nullptr) || + !AttrUtils::GetInt(op_desc->GetInputDesc(peer_in_anchor->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX, output_index)) { + return true; } - return AttrUtils::HasAttr(op_desc->GetInputDesc(in_data_anchor->GetIdx()), ATTR_NAME_PARENT_NODE_INDEX); + + if (input_index != output_index) { + return true; + } + bypass_index.insert(peer_in_anchor->GetIdx()); + return false; } /** @@ -542,7 +442,7 @@ Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDat OpDescPtr op_desc = op_desc_builder.AddInput("x", in_node->GetOpDesc()->GetOutputDesc(0)) .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(0)) .Build(); - if (GraphUtils::InsertNodeBefore(out_anchor, in_anchors, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { + if (GraphUtils::InsertNodeAfter(out_anchor, in_anchors, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { GELOGE(FAILED, "Insert MemcpyAsync node %s after %s failed.", name.c_str(), in_node->GetName().c_str()); return FAILED; } @@ -550,4 +450,33 @@ Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDat return SUCCESS; } +/// +/// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst +/// @param [in] src +/// @param [in] dsts +/// @param [in] insert_node +/// @param [in] input_index +/// @param [in] output_index +/// @return Status +/// +Status SubgraphPass::InsertNodeBetween(const OutDataAnchorPtr &src, const std::vector &dsts, + const NodePtr &insert_node, uint32_t input_index, uint32_t output_index) { + if (GraphUtils::AddEdge(src, insert_node->GetInDataAnchor(input_index)) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add data_edge %s:%d->%s:%u failed.", src->GetOwnerNode()->GetName().c_str(), src->GetIdx(), + insert_node->GetName().c_str(), input_index); + return FAILED; + } + for (const auto &dst : dsts) { + GELOGD("Insert node %s between %s->%s.", insert_node->GetName().c_str(), src->GetOwnerNode()->GetName().c_str(), + dst->GetOwnerNode()->GetName().c_str()); + if ((GraphUtils::RemoveEdge(src, dst) != GRAPH_SUCCESS) || + (GraphUtils::AddEdge(insert_node->GetOutDataAnchor(output_index), dst) != GRAPH_SUCCESS)) { + GELOGE(FAILED, "Replace data_edge %s:%d->%s:%d by %s:%u->%s:%d failed.", src->GetOwnerNode()->GetName().c_str(), + src->GetIdx(), dst->GetOwnerNode()->GetName().c_str(), dst->GetIdx(), insert_node->GetName().c_str(), + output_index, dst->GetOwnerNode()->GetName().c_str(), dst->GetIdx()); + return FAILED; + } + } + return SUCCESS; +} } // namespace ge diff --git a/src/ge/graph/passes/subgraph_pass.h b/src/ge/graph/passes/subgraph_pass.h index 2308b1bd..7ff2019f 100644 --- a/src/ge/graph/passes/subgraph_pass.h +++ b/src/ge/graph/passes/subgraph_pass.h @@ -17,12 +17,6 @@ #ifndef GE_GRAPH_PASSES_SUBGRAPH_PASS_H_ #define GE_GRAPH_PASSES_SUBGRAPH_PASS_H_ -#include -#include -#include -#include - -#include "graph/types.h" #include "inc/graph_pass.h" namespace ge { @@ -75,65 +69,32 @@ class SubgraphPass : public GraphPass { /** * @ingroup ge - * @brief Get body subgraph of While op - * @param [in] graph: ComputeGraph. - * @param [in] node: While node. - * @return: body subgraph - */ - ComputeGraphPtr GetWhileBodySubgraph(const ComputeGraphPtr &graph, const NodePtr &node); - - /** - * @ingroup ge - * @brief Mark output parent_node_index - * @param [in] peer_out_anchor: peer_out_anchor of NetOutput - * @param [in] index: parent_node_index of NetOutput - * @param [out] node_to_attr_index: key for node in subgraph, value for parent_node_index - * @return: void - */ - void MarkOutputIndex(const OutDataAnchorPtr &peer_out_anchor, uint32_t index, - std::unordered_map> &node_to_attr_index); - - /** - * @ingroup ge - * @brief Get data_nodes / input_indexes of netoutput if need insert memcpy - * @param [in] node_to_attr_index: key for node in subgraph, value for parent_node_index - * @param [out] data_nodes: data_nodes need insert memcpy - * @param [out] netoutput_input_indexes: input_indexes of netoutput need insert memcpy - * @return: void - */ - void GetExchangeInOut(const std::unordered_map> &node_to_attr_index, - std::set &data_nodes, std::set &netoutput_input_indexes); - - /** - * @ingroup ge - * @brief Insert memcpy node in while_body + * @brief Insert input memcpy node in while_body * @param [in] graph: while_body - * @param [in] data_nodes: data_nodes need insert memcpy - * @param [in] output_node: NetOutput in while_body - * @param [in] netoutput_input_indexes: input_indexes of netoutput need insert memcpy + * @param [in] data_nodes: data_nodes * @return: 0 for SUCCESS / others for FAILED */ - Status InsertMemcpyInWhileBody(const ComputeGraphPtr &graph, const std::set &data_nodes, - const NodePtr &output_node, const std::set &netoutput_input_indexes); + Status InsertInputMemcpy(const ComputeGraphPtr &graph, const std::vector &data_nodes); /** * @ingroup ge - * @brief Insert NoOp node between memcpy_nodes and loop_body_nodes + * @brief Insert output memcpy node in while_body * @param [in] graph: while_body - * @param [in] memcpy_nodes - * @param [in] loop_body_nodes + * @param [in] output_node: NetOutput + * @param [in] bypass_index * @return: 0 for SUCCESS / others for FAILED */ - Status InsertNoOp(const ComputeGraphPtr &graph, const std::set &memcpy_nodes, - const std::set &loop_body_nodes); + Status InsertOutputMemcpy(const ComputeGraphPtr &graph, const NodePtr &output_node, + const std::set &bypass_index); /** * @ingroup ge - * @brief Check is Data->NetOutput in while body - * @param [in] in_data_anchor - * @return: true for Data->NetOutput in while body / false for others + * @brief Check is data->netoutput without change in while body + * @param [in] node: data node + * @param [out] bypass_index + * @return: false for data->netoutput without change in while body / for true for others */ - bool IsWhileBodyOutput(const InDataAnchorPtr &in_data_anchor); + bool CheckInsertInputMemcpy(const NodePtr &node, std::set &bypass_index); /** * @ingroup ge @@ -172,8 +133,17 @@ class SubgraphPass : public GraphPass { Status InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, const std::vector &in_anchors, const std::string &name); - // Append index for new memcpy node. - uint32_t memcpy_num_{0}; + /// + /// @brief Insert node: src->insert_node:input_index, insert_node:output_index->dst + /// @param [in] src + /// @param [in] dsts + /// @param [in] insert_node + /// @param [in] input_index + /// @param [in] output_index + /// @return Status + /// + Status InsertNodeBetween(const OutDataAnchorPtr &src, const std::vector &dsts, + const NodePtr &insert_node, uint32_t input_index, uint32_t output_index); }; } // namespace ge #endif // GE_GRAPH_PASSES_SUBGRAPH_PASS_H_ diff --git a/src/ge/graph/passes/switch_dead_branch_elimination.cc b/src/ge/graph/passes/switch_dead_branch_elimination.cc index f398d8df..dd7ace60 100644 --- a/src/ge/graph/passes/switch_dead_branch_elimination.cc +++ b/src/ge/graph/passes/switch_dead_branch_elimination.cc @@ -171,7 +171,7 @@ Status SwitchDeadBranchElimination::Run(NodePtr &node) { AddRePassNode(end_node); } for (const auto &delete_node : del_nodes) { - AddNodeDeleted(delete_node.get()); + AddNodeDeleted(delete_node); } } diff --git a/src/ge/graph/passes/switch_logic_remove_pass.cc b/src/ge/graph/passes/switch_logic_remove_pass.cc index be84a582..dafa3ae1 100644 --- a/src/ge/graph/passes/switch_logic_remove_pass.cc +++ b/src/ge/graph/passes/switch_logic_remove_pass.cc @@ -145,7 +145,7 @@ Status SwitchLogicRemovePass::RemoveSwitchNodeLogically(int parent_index, NodePt GE_CHECK_NOTNULL(node); GELOGD("Remove node %s from inactivate branch from switch %s", node->GetName().c_str(), switch_node->GetName().c_str()); - AddNodeDeleted(node.get()); + AddNodeDeleted(node); } for (auto &node : end_nodes) { GE_CHECK_NOTNULL(node); diff --git a/src/ge/graph/passes/switch_op_pass.cc b/src/ge/graph/passes/switch_op_pass.cc deleted file mode 100644 index ed3e9b36..00000000 --- a/src/ge/graph/passes/switch_op_pass.cc +++ /dev/null @@ -1,1227 +0,0 @@ -/** - * Copyright 2019-2020 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include "graph/passes/switch_op_pass.h" -#include -#include -#include -#include -#include -#include -#include -#include "common/ge/ge_util.h" -#include "framework/common/debug/ge_log.h" -#include "framework/common/debug/log.h" -#include "framework/common/ge_inner_error_codes.h" -#include "framework/common/types.h" -#include "ge/ge_api_types.h" -#include "graph/common/omg_util.h" -#include "graph/debug/ge_attr_define.h" -#include "graph/ge_context.h" -#include "graph/utils/type_utils.h" - -namespace ge { -Status SwitchOpPass::Run(ComputeGraphPtr graph) { - GELOGD("SwitchOpPass Enter"); - GE_CHK_STATUS_RET(CheckCycleDependence(graph), "CheckCycleDependence fail."); - - for (auto &switch_node : switch_nodes_) { - GE_CHK_STATUS_RET(ReplaceSwitchNode(graph, switch_node), "Add StreamSwitch node fail."); - } - - for (auto &merge_node : merge_nodes_) { - OpDescPtr merge_op_desc = merge_node->GetOpDesc(); - GE_CHECK_NOTNULL(merge_op_desc); - if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { - GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, merge_node, true), "Merge add memcpy node fail."); - GE_CHK_STATUS_RET(SetStreamLabel(merge_node, merge_node->GetName()), "Set stream label failed"); - } else { - GE_CHK_STATUS_RET(ReplaceMergeNode(graph, merge_node), "Add StreamMerge node fail."); - } - } - - GE_CHK_STATUS_RET(CombineSwitchNode(graph), "Combine StreamSwitch nodes fail."); - - for (auto &node : bypass_nodes_) { - GE_CHK_BOOL_EXEC(graph->RemoveNode(node) == GRAPH_SUCCESS, return FAILED, "Remove switch node fail."); - } - - for (auto &node : stream_switch_nodes_) { - for (auto &out_ctrl_node : node->GetOutControlNodes()) { - MarkHeadNodes(out_ctrl_node, node); - } - } - - for (auto &node : need_label_nodes_) { - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (!op_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { - GE_CHK_STATUS_RET(UpdateCondBranch(node), "Set cond branch fail, start node:%s", node->GetName().c_str()); - } - } - - GE_CHK_STATUS_RET(UpdateEnterNode(), "UpdateEnterNode fail."); - - GELOGD("SwitchOpPass Leave"); - return SUCCESS; -} - -/// -/// @brief Replace Switch Op -/// @param [in] graph -/// @param [in] switch_node -/// @return Status -/// -Status SwitchOpPass::ReplaceSwitchNode(ComputeGraphPtr &graph, NodePtr &switch_node) { - std::string type; - GE_CHK_STATUS_RET(GetOriginalType(switch_node, type), "Get node type fail."); - GE_CHK_BOOL_EXEC((type == SWITCH) || (type == REFSWITCH), return FAILED, "Type of input node is not switch."); - - OutDataAnchorPtr peer_data_anchor = nullptr; - OutDataAnchorPtr peer_cond_anchor = nullptr; - GE_CHK_BOOL_EXEC(BypassSwitchNode(switch_node, peer_data_anchor, peer_cond_anchor) == SUCCESS, return FAILED, - "Bypass switch node %s fail.", switch_node->GetName().c_str()); - GE_CHECK_NOTNULL(peer_data_anchor); - GE_CHECK_NOTNULL(peer_cond_anchor); - OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHECK_NOTNULL(cond_desc); - DataType cond_data_type = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx()).GetDataType(); - GE_CHK_BOOL_EXEC(cond_data_type == DT_BOOL, return FAILED, - "SwitchNode not support datatype %s, datatype of cond_input should be bool", - TypeUtils::DataTypeToSerialString(cond_data_type).c_str()); - - OpDescPtr switch_desc = switch_node->GetOpDesc(); - GE_CHECK_NOTNULL(switch_desc); - bool cyclic_flag = switch_desc->HasAttr(ATTR_NAME_CYCLIC_DEPENDENCE_FLAG); - - std::set out_node_list; - for (OutDataAnchorPtr &out_data_anchor : switch_node->GetAllOutDataAnchors()) { - bool true_branch_flag = (static_cast(out_data_anchor->GetIdx()) == SWITCH_TRUE_OUTPUT); - NodePtr stream_switch = nullptr; - out_node_list.clear(); - for (auto &peer_in_anchor : out_data_anchor->GetPeerAnchors()) { - GE_IF_BOOL_EXEC(stream_switch == nullptr, { - std::string suffix = (true_branch_flag ? "_t" : "_f"); - stream_switch = CreateStreamSwitchNode(graph, switch_node, suffix, peer_cond_anchor); - GE_CHK_BOOL_EXEC(stream_switch != nullptr, return FAILED, "Create stream_switch node fail."); - if (SetSwitchTrueBranchFlag(stream_switch, true_branch_flag) != SUCCESS) { - GELOGE(FAILED, "SetSwitchTrueBranchFlag for node %s fail.", stream_switch->GetName().c_str()); - return FAILED; - } - if (MarkBranchs(peer_cond_anchor, stream_switch, true_branch_flag) != SUCCESS) { - GELOGE(FAILED, "MarkBranchs for stream_switch %s fail.", stream_switch->GetName().c_str()); - return FAILED; - } - - if (!cyclic_flag) { - GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor->GetOwnerNode()->GetOutControlAnchor(), - stream_switch->GetInControlAnchor()), - "StreamSwitch node add ctl edge fail."); - } - }); - - GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor), "Remove Switch data output fail."); - - NodePtr out_node = peer_in_anchor->GetOwnerNode(); - GE_CHK_STATUS_RET(GetOriginalType(out_node, type), "Get node type fail."); - if ((type == MERGE) || (type == REFMERGE)) { - NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, peer_data_anchor, false); - GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return FAILED, "Create memcpy_async node fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, memcpy_node->GetInDataAnchor(0)), - "MemcpyAsync node add edge fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(memcpy_node->GetOutDataAnchor(0), peer_in_anchor), - "MemcpyAsync node add edge fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(stream_switch->GetOutControlAnchor(), memcpy_node->GetInControlAnchor()), - "MemcpyAsync node add ctl edge fail."); - out_node_list.insert(memcpy_node->GetName()); - } else { - GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor), "StreamSwitch node add edge fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(stream_switch->GetOutControlAnchor(), out_node->GetInControlAnchor()), - "StreamSwitch node add ctl edge fail."); - out_node_list.insert(out_node->GetName()); - } - } - GE_IF_BOOL_EXEC(stream_switch != nullptr, { - CopyControlEdges(switch_node, stream_switch, true); - switch_node_map_[stream_switch] = out_node_list; - if (SetOriginalNodeName(stream_switch, switch_node->GetName()) != SUCCESS) { - GELOGE(FAILED, "SetOriginalNodeName for node %s fail.", stream_switch->GetName().c_str()); - return FAILED; - } - }); - } - - RemoveControlEdges(switch_node); - (void)bypass_nodes_.insert(switch_node); - - return SUCCESS; -} - -/// -/// @brief Replace Merge Op -/// @param [in] graph -/// @param [in] merge_node -/// @return Status -/// -Status SwitchOpPass::ReplaceMergeNode(ComputeGraphPtr &graph, NodePtr &merge_node) { - std::string type; - GE_CHK_STATUS_RET(GetOriginalType(merge_node, type), "Get node type fail."); - GE_CHK_BOOL_EXEC((type == MERGE) || (type == REFMERGE), return FAILED, "Type of input node is not merge."); - - OpDescPtr merge_op_desc = merge_node->GetOpDesc(); - GE_CHECK_NOTNULL(merge_op_desc); - - const std::string node_name = merge_node->GetName(); - GELOGI("Create StreamMerge Op, name=%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, STREAMMERGE); - if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc fail, StreamMerge:%s.", node_name.c_str()); - return FAILED; - } - - for (InDataAnchorPtr &in_anchor : merge_node->GetAllInDataAnchors()) { - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(merge_op_desc->GetInputDesc(in_anchor->GetIdx())) == GRAPH_SUCCESS, - return FAILED, "Create StreamMerge op: add input desc fail."); - } - - for (OutDataAnchorPtr &out_anchor : merge_node->GetAllOutDataAnchors()) { - GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(merge_op_desc->GetOutputDesc(out_anchor->GetIdx())) == GRAPH_SUCCESS, - return FAILED, "Create StreamMerge op: add output desc fail."); - } - - NodePtr stream_merge = graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(stream_merge != nullptr, return FAILED, "Insert StreamMerge node fail."); - - for (InDataAnchorPtr &in_data_anchor : merge_node->GetAllInDataAnchors()) { - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); - - GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "Remove Merge data input fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, stream_merge->GetInDataAnchor(in_data_anchor->GetIdx())), - "StreamMerge node add edge fail."); - } - - for (OutDataAnchorPtr &out_data_anchor : merge_node->GetAllOutDataAnchors()) { - for (InDataAnchorPtr &peer_in_anchor : out_data_anchor->GetPeerInDataAnchors()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor), "Remove Merge data output fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(stream_merge->GetOutDataAnchor(out_data_anchor->GetIdx()), peer_in_anchor), - "StreamMerge node add edge fail."); - } - } - - ReplaceControlEdges(merge_node, stream_merge); - - if (merge_op_desc->HasAttr(ATTR_NAME_NEXT_ITERATION)) { - std::string next_iteration_name; - GE_IF_BOOL_EXEC(!AttrUtils::GetStr(merge_op_desc, ATTR_NAME_NEXT_ITERATION, next_iteration_name), - GELOGE(INTERNAL_ERROR, "get ATTR_NAME_NEXT_ITERATION failed"); - return INTERNAL_ERROR); - - GE_CHK_STATUS_RET(SetNextIteration(stream_merge, next_iteration_name), "set next iteration failed"); - } else { - need_label_nodes_.emplace_back(stream_merge); - } - - (void)bypass_nodes_.insert(merge_node); - - GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, stream_merge, false), "StreamMerge add memcpy node fail."); - - return SUCCESS; -} - -/// -/// @brief Create StreamSwitch Node -/// @param [in] graph -/// @param [in] switch_node -/// @param [in] suffix -/// @param [in] peer_cond_anchor -/// @return ge::NodePtr -/// -NodePtr SwitchOpPass::CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodePtr &switch_node, - const std::string &suffix, OutDataAnchorPtr &peer_cond_anchor) { - GE_CHK_BOOL_EXEC(switch_node != nullptr, return nullptr, "Param of merge node is null."); - OpDescPtr switch_op_desc = switch_node->GetOpDesc(); - GE_CHK_BOOL_EXEC(switch_op_desc != nullptr, return nullptr, "OpDesc of Switch node is invalid."); - GE_IF_BOOL_EXEC(switch_op_desc->GetInputsSize() != SWITCH_INPUT_NUM, { - GELOGE(FAILED, "Switch input param invalid, input_size=%lu, should be %u", switch_op_desc->GetInputsSize(), - SWITCH_INPUT_NUM); - return nullptr; - }); - - const std::string node_name = switch_node->GetName() + "_" + STREAMSWITCH + suffix; - GELOGI("Create StreamSwitch, name=%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, STREAMSWITCH); - if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc fail, StreamSwitch:%s.", node_name.c_str()); - return nullptr; - } - // mark hccl group id - std::string hccl_group_id; - if (AttrUtils::GetStr(switch_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { - (void)AttrUtils::SetStr(op_desc, ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id); - GELOGI("Set attr ATTR_NAME_HCCL_FUSED_GROUP for Stream_Switch%s, value is %s.", node_name.c_str(), - hccl_group_id.c_str()); - } else { - GELOGI("Can not find attr ATTR_NAME_HCCL_FUSED_GROUP for node %s.", switch_node->GetName().c_str()); - } - - if (!AttrUtils::SetInt(op_desc, ATTR_NAME_SWITCH_DATA_TYPE, RT_SWITCH_INT32) || - !AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, (int64_t)RT_EQUAL)) { - GELOGE(INTERNAL_ERROR, "set int failed"); - return nullptr; - } - - // Already checked, first input is Variable will passed, second is condition will checked. - GeTensorDesc cond_input_desc = switch_op_desc->GetInputDesc(SWITCH_PRED_INPUT); - GeTensorDesc input_desc(GeShape(cond_input_desc.GetShape().GetDims()), cond_input_desc.GetFormat(), DT_INT32); - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS, return nullptr, - "Create StreamSwitch node: add input desc fail."); - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS, return nullptr, - "Create StreamSwitch node: add input desc fail."); - - NodePtr stream_switch = graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(stream_switch != nullptr, return nullptr, "Insert StreamSwitch node fail."); - - GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), - "StreamSwitch node add cond edge fail."); - - return stream_switch; -} - -/// -/// @brief Add MemcpyAsync Node -/// @param [in] graph -/// @param [in] in_node -/// @param [in] multi_batch_flag -/// @return ge::NodePtr -/// -NodePtr SwitchOpPass::CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, - bool multi_batch_flag) { - GE_CHK_BOOL_EXEC(out_data_anchor != nullptr, return nullptr, "Param of input node is null."); - OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); - - std::string memcpy_type = multi_batch_flag ? MEMCPYADDRASYNC : MEMCPYASYNC; - std::string node_name = pre_op_desc->GetName() + "_" + memcpy_type; - node_name = CheckDuplicateName(node_name); - GELOGI("Create MemcpyAsync op:%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, memcpy_type); - if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc fail, MemcpyAsync:%s.", node_name.c_str()); - return nullptr; - } - - GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, - return nullptr, "Create MemcpyAsync op: add input desc fail."); - GE_CHK_BOOL_EXEC(op_desc->AddOutputDesc(pre_op_desc->GetOutputDesc(out_data_anchor->GetIdx())) == GRAPH_SUCCESS, - return nullptr, "Create MemcpyAsync op: add output desc fail."); - - NodePtr memcpy_node = graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return nullptr, "Insert MemcpyAsync node fail."); - - return memcpy_node; -} - -/// -/// @brief Combine switch nodes link to same cond -/// @param [in] graph -/// @return Status -/// -Status SwitchOpPass::CombineSwitchNode(ComputeGraphPtr &graph) { - for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { - for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { - OutDataAnchorPtr peer_cond_anchor = iter->first; - GE_CHECK_NOTNULL(peer_cond_anchor); - std::list false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; - std::list true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; - std::set same_cond_switch; - same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); - same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); - - NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); - GELOGI("CombineSwitchNode: cond_node=%s", cond_node->GetName().c_str()); - - NodePtr cast_node = CreateCastOp(graph, peer_cond_anchor); - GE_CHK_BOOL_EXEC(cast_node != nullptr, return FAILED, "Create cast_node fail."); - - NodePtr active_node = CreateActiveNode(graph, cond_node); - GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutControlAnchor(), active_node->GetInControlAnchor()), - "StreamActive add ctl edge fail."); - if (SetActiveLabelList(active_node, {cast_node->GetName()}) != SUCCESS) { - GELOGE(FAILED, "SetActiveLabelList for node %s fail.", active_node->GetName().c_str()); - return FAILED; - } - - const std::string cond_group = cond_node->GetName(); - for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { - bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); - std::list &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); - GE_IF_BOOL_EXEC(switch_list.empty(), continue); - - // select first stream_switch - NodePtr stream_switch = switch_list.front(); - OpDescPtr switch_desc = stream_switch->GetOpDesc(); - GE_CHECK_NOTNULL(switch_desc); - std::string node_name = cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f"); - node_name = CheckDuplicateName(node_name); - switch_desc->SetName(node_name); - stream_switch_nodes_.emplace_back(stream_switch); - need_label_nodes_.emplace_back(stream_switch); - - // 0_input: original pred input, 1_input: constant node - GE_CHK_STATUS_RET(AddConstNode(graph, stream_switch), "Add const node fail"); - GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), - "StreamSwitch remove data edge fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), - "Cast add data edge fail."); - - for (NodePtr &node : switch_list) { - GE_CHECK_NOTNULL(node); - GE_IF_BOOL_EXEC(node != stream_switch, { - GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), - "StreamSwitch remove data edge fail."); - }); - GE_CHK_STATUS(ModifySwitchInCtlEdges(node, cast_node, same_cond_switch), "ModifySwitchInCtlEdges fail"); - GE_CHK_STATUS(ModifySwitchOutCtlEdges(node, stream_switch, active_node), "ModifySwitchOutCtlEdges fail"); - } - - GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), stream_switch->GetInControlAnchor()), - "StreamActive add ctl edge fail."); - } - } - } - return SUCCESS; -} - -/// -/// @brief Create Active Op -/// @param [in] graph -/// @param [in] cond_node -/// @return ge::NodePtr -/// -NodePtr SwitchOpPass::CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node) { - GE_CHK_BOOL_EXEC(node != nullptr, return nullptr, "Param of pre cond_node is null."); - std::string node_name = node->GetName() + "_" + STREAMACTIVE; - node_name = CheckDuplicateName(node_name); - GELOGI("Create StreamActive op:%s.", node_name.c_str()); - OpDescPtr op_desc = MakeShared(node_name, STREAMACTIVE); - if (op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc fail, StreamActive:%s.", node_name.c_str()); - return nullptr; - } - - NodePtr active_node = graph->AddNode(op_desc); - GE_CHK_BOOL_EXEC(active_node != nullptr, return nullptr, "Create StreamActive node fail."); - - GE_IF_BOOL_EXEC(GraphUtils::AddEdge(node->GetOutControlAnchor(), active_node->GetInControlAnchor()) != SUCCESS, - GELOGE(INTERNAL_ERROR, "add edge failed"); - return nullptr); - - GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS, - GELOGE(INTERNAL_ERROR, "set switch branch node label failed"); - return nullptr); - - return active_node; -} - -/// -/// @brief Add MemcpyAsync Op as StreamMerge in_node -/// @param [in] graph -/// @param [in] node -/// @param [in] multi_batch_flag -/// @return Status -/// -Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node, bool multi_batch_flag) { - GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); - for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); - NodePtr in_node = peer_out_anchor->GetOwnerNode(); - - const std::string type = in_node->GetType(); - // For WhileLoop no need memcpy & active for merge. - GE_IF_BOOL_EXEC((type == ENTER) || (type == REFENTER) || (type == NEXTITERATION) || (type == REFNEXTITERATION), - continue); - - GE_IF_BOOL_EXEC(type != MEMCPYASYNC, { - in_node = CreateMemcpyAsyncNode(graph, peer_out_anchor, multi_batch_flag); - GE_CHK_BOOL_EXEC(in_node != nullptr, return FAILED, "Create MemcpyAsync node fail."); - GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "MemcpyAsync node remove edge fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, in_node->GetInDataAnchor(0)), - "MemcpyAsync node add edge fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(in_node->GetOutDataAnchor(0), in_data_anchor), - "MemcpyAsync node add edge fail."); - }); - - NodePtr active_node = CreateActiveNode(graph, in_node); - GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), node->GetInControlAnchor()), - "StreamActive add ctl edge fail."); - if (SetActiveLabelList(active_node, {node->GetName()}) != SUCCESS) { - GELOGE(FAILED, "SetActiveLabelList for node %s fail.", active_node->GetName().c_str()); - return FAILED; - } - } - - return SUCCESS; -} - -/// -/// @brief Bypass Switch Node -/// @param [in] switch_node -/// @param [out] peer_data_anchor -/// @param [out] peer_cond_anchor -/// @return Status -/// -Status SwitchOpPass::BypassSwitchNode(NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, - OutDataAnchorPtr &peer_cond_anchor) { - GE_CHK_BOOL_EXEC(switch_node != nullptr, return FAILED, "Switch_node is null."); - for (uint32_t idx = 0; idx < SWITCH_INPUT_NUM; ++idx) { - InDataAnchorPtr in_data_anchor = switch_node->GetInDataAnchor(idx); - GE_CHK_BOOL_EXEC(in_data_anchor != nullptr, return FAILED, "node[%s]Check Switch input anchor fail.", - switch_node->GetName().c_str()); - OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHK_BOOL_EXEC(peer_out_anchor != nullptr, return FAILED, "node[%s]Check Pre node output anchor fail.", - switch_node->GetName().c_str()); - // Remove Switch data input. - GE_CHK_STATUS_RET(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "remove edge failed"); - - if (idx == SWITCH_DATA_INPUT) { - peer_data_anchor = peer_out_anchor; - } else { - if (FindSwitchCondInput(false, peer_out_anchor) != SUCCESS) { - GELOGE(FAILED, "FindSwitchCondInput fail, switch=%s", switch_node->GetName().c_str()); - return FAILED; - } - peer_cond_anchor = peer_out_anchor; - } - } - - return SUCCESS; -} - -/// -/// @brief Find Switch cond input -/// @param [in] pass_switch_flag -/// @param [out] peer_cond_anchor -/// @return Status -/// -Status SwitchOpPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor) { - NodePtr tmp_node = nullptr; - string type; - bool need_pass_type = true; - while (need_pass_type) { - if (tmp_node == nullptr) { - GE_CHECK_NOTNULL(peer_cond_anchor); - tmp_node = peer_cond_anchor->GetOwnerNode(); - } else { - InDataAnchorPtr in_data_anchor = tmp_node->GetInDataAnchor(SWITCH_DATA_INPUT); - GE_CHECK_NOTNULL(in_data_anchor); - peer_cond_anchor = in_data_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_cond_anchor); - tmp_node = peer_cond_anchor->GetOwnerNode(); - } - - GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type fail"); - need_pass_type = (pass_switch_flag && ((type == SWITCH) || (type == REFSWITCH))); - } - - return SUCCESS; -} - -int64_t SwitchOpPass::GetGroupId(const NodePtr &node) { - string tailing_optimization_option; - bool is_tailing_optimization = false; - auto ret = GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option); - if (ret == GRAPH_SUCCESS) { - // "1" means it's True from frontend option - is_tailing_optimization = (tailing_optimization_option == "1"); - GELOGI("Option ge.exec.isTailingOptimization is %s", tailing_optimization_option.c_str()); - } - if (!is_tailing_optimization) { - return 0; - } - - string hccl_group_id; - if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { - GELOGI("Node is %s, can not find hccl group id", node->GetName().c_str()); - return 0; - } - auto key_index = hccl_group_id.find_last_of('_'); - auto key_num = hccl_group_id.substr(key_index + 1, hccl_group_id.length() - key_index); - GELOGI("Node is %s,Hccl group id is %s, key_num is %s", node->GetName().c_str(), hccl_group_id.c_str(), - key_num.c_str()); - int64_t num = atoi(key_num.c_str()); - if (num == 0) { - return 0; - } - GELOGI("Hccl group id is %s, group id is %ld", hccl_group_id.c_str(), num); - return num; -} - -/// -/// @brief Mark Switch Branch -/// @param [in] peer_cond_anchor -/// @param [in] stream_switch -/// @param [in] true_branch_flag -/// @return Status -/// -Status SwitchOpPass::MarkBranchs(OutDataAnchorPtr &peer_cond_anchor, NodePtr &stream_switch, bool true_branch_flag) { - uint32_t index = true_branch_flag ? SWITCH_TRUE_OUTPUT : SWITCH_FALSE_OUTPUT; - GE_CHECK_NOTNULL(stream_switch); - auto it = cond_node_map_.find(peer_cond_anchor); - if (it != cond_node_map_.end()) { - int64_t switch_group_id = GetGroupId(stream_switch); - auto switch_group_it = it->second.find(switch_group_id); - if (switch_group_it == it->second.end()) { - std::list false_node_list; - std::list true_node_list; - std::list &node_list = true_branch_flag ? true_node_list : false_node_list; - node_list.emplace_back(stream_switch); - std::vector> switch_list; - switch_list.emplace_back(false_node_list); - switch_list.emplace_back(true_node_list); - (void)it->second.emplace(switch_group_id, switch_list); - } else { - GE_IF_BOOL_EXEC(switch_group_it->second.size() != SWITCH_OUTPUT_NUM, { - GELOGE(INTERNAL_ERROR, "cond_node_map_ check size fail, node: %s", stream_switch->GetName().c_str()); - return FAILED; - }); - switch_group_it->second[index].emplace_back(stream_switch); - } - } else { - int64_t switch_group_id = GetGroupId(stream_switch); - map>> switch_group_map; - std::list false_node_list; - std::list true_node_list; - std::list &node_list = true_branch_flag ? true_node_list : false_node_list; - node_list.emplace_back(stream_switch); - std::vector> switch_list; - switch_list.emplace_back(false_node_list); - switch_list.emplace_back(true_node_list); - (void)switch_group_map.emplace(switch_group_id, switch_list); - auto result = cond_node_map_.insert( - std::pair>>>(peer_cond_anchor, switch_group_map)); - GE_IF_BOOL_EXEC(!result.second, { - GELOGE(INTERNAL_ERROR, "cond_node_map_ insert fail, node: %s", stream_switch->GetName().c_str()); - return FAILED; - }); - } - return SUCCESS; -} - -/// -/// @brief Create cast node -/// @param [in] graph -/// @param [in] peer_cond_anchor -/// @return NodePtr -/// -NodePtr SwitchOpPass::CreateCastOp(ComputeGraphPtr &graph, OutDataAnchorPtr &peer_cond_anchor) { - GE_CHK_BOOL_EXEC(peer_cond_anchor != nullptr, return nullptr, "Param of pre cond_node is null."); - OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc(); - GE_CHK_BOOL_EXEC(cond_desc != nullptr, return nullptr, "Get cond_desc fail."); - - std::string cast_name = cond_desc->GetName() + "_" + CAST; - cast_name = CheckDuplicateName(cast_name); - GELOGI("Create cast_node: %s, input datatype:DT_BOOL, out datatype:DT_INT32", cast_name.c_str()); - OpDescPtr cast_desc = MakeShared(cast_name, CAST); - if (cast_desc == nullptr) { - GELOGE(FAILED, "Create op_desc fail, Cast:%s.", cast_name.c_str()); - return nullptr; - } - if (!(AttrUtils::SetInt(cast_desc, CAST_ATTR_SRCT, (int64_t)DT_BOOL) && - AttrUtils::SetInt(cast_desc, CAST_ATTR_DSTT, (int64_t)DT_INT32) && - AttrUtils::SetInt(cast_desc, CAST_ATTR_DST_TYPE, (int64_t)DT_INT32) && - AttrUtils::SetBool(cast_desc, CAST_ATTR_TRUNCATE, false))) { - GELOGE(FAILED, "Set CAST_ATTR_SRCT or CAST_ATTR_DSTT or CAST_ATTR_DST_TYPE or CAST_ATTR_TRUNCATE fail, node: %s.", - cast_name.c_str()); - return nullptr; - } - GeTensorDesc tensor_desc = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx()); - tensor_desc.SetDataType(DT_BOOL); - GE_CHK_BOOL_EXEC(cast_desc->AddInputDesc(tensor_desc) == SUCCESS, return nullptr, "Cast_node add input desc fail."); - tensor_desc.SetDataType(DT_INT32); - GE_CHK_BOOL_EXEC(cast_desc->AddOutputDesc(tensor_desc) == SUCCESS, return nullptr, "Cast_node add output desc fail."); - - NodePtr cast_node = graph->AddNode(cast_desc); - GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node fail."); - - GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge fail."); - - return cast_node; -} - -/// -/// @brief Add const node as switch input1 -/// @param [in] graph -/// @param [in] stream_switch -/// @return Status -/// -Status SwitchOpPass::AddConstNode(ComputeGraphPtr &graph, NodePtr &stream_switch) { - GE_CHK_BOOL_EXEC(stream_switch != nullptr, return FAILED, "stream_switch is null."); - OpDescPtr op_desc = stream_switch->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - bool value = false; - GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, - "StreamSwitch get attr TRUE_BRANCH_STREAM fail."); - - const std::string const_node_name = op_desc->GetName() + "_Constant_" + (value ? "t" : "f"); - GELOGI("Create const op: %s", const_node_name.c_str()); - OpDescPtr const_op_desc = MakeShared(const_node_name, CONSTANT); - if (const_op_desc == nullptr) { - GELOGE(FAILED, "Create op_desc fail, Constant:%s.", const_node_name.c_str()); - return FAILED; - } - - auto resize_value = (int32_t)value; - GeTensorDesc data_desc = op_desc->GetInputDesc(1); - GeTensorPtr const_value = - MakeShared(data_desc, reinterpret_cast(&resize_value), sizeof(int32_t)); - if (const_value == nullptr) { - GELOGE(FAILED, "Create tensor fail."); - return FAILED; - } - GE_CHK_BOOL_EXEC(AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value), return FAILED); - GE_CHK_BOOL_EXEC(const_op_desc->AddOutputDesc(data_desc) == GRAPH_SUCCESS, return FAILED, - "Create Const op: add output desc fail."); - - NodePtr const_node = graph->AddNode(const_op_desc); - GE_CHK_BOOL_EXEC(const_node != nullptr, return FAILED, "Insert Const node fail."); - GE_CHK_STATUS(GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(1)), - "StreamSwitch node add ctl edge fail."); - - return SUCCESS; -} - -/// -/// @brief update cond branch -/// @param [in] node -/// @return Status -/// -Status SwitchOpPass::UpdateCondBranch(NodePtr &node) { - std::string stream_label; - std::unordered_set branch_nodes; - std::unordered_set handled_set; - std::stack nodes; - nodes.push(node); - - static const std::set end_type_set = {STREAMSWITCH, STREAMMERGE, MERGE}; - bool merge_flag = false; - bool exit_flag = false; - bool net_output_flag = false; - - while (!nodes.empty()) { - NodePtr cur_node = nodes.top(); - nodes.pop(); - if (handled_set.count(cur_node) > 0) { - continue; - } - GE_CHECK_NOTNULL(cur_node); - if (UpdateAttachFlag(cur_node, stream_label, merge_flag, exit_flag, net_output_flag) != SUCCESS) { - GELOGE(FAILED, "UpdateAttachFlag fail, cur_node: %s.", cur_node->GetName().c_str()); - return FAILED; - } - - const std::string type = cur_node->GetType(); - for (auto &out_node : cur_node->GetOutAllNodes()) { - const std::string out_type = out_node->GetType(); - bool stop_flag = (end_type_set.count(out_type) > 0) || - ((branch_head_nodes_.count(out_node) > 0) && (branch_head_nodes_[out_node] != node)) || - (((type == ENTER) || (type == REFENTER)) && (out_type != STREAMACTIVE)); - if (!stop_flag) { - nodes.push(out_node); - GELOGD("branch_nodes insert %s", out_node->GetName().c_str()); - branch_nodes.insert(out_node); - } - } - handled_set.insert(cur_node); - } - - if (node->GetType() == STREAMSWITCH) { - GE_CHK_STATUS_RET(SetActiveLabelList(node, {stream_label}), "set active_label_list failed"); - } - - bool attach_flag = (merge_flag || exit_flag) && net_output_flag; - if (attach_flag) { - GELOGI("No need to keep on attaching label."); - return SUCCESS; - } - - for (NodePtr tmp_node : branch_nodes) { - GELOGD("Attach label %s to node: %s", stream_label.c_str(), tmp_node->GetName().c_str()); - GE_CHK_STATUS_RET(SetStreamLabel(tmp_node, stream_label), "set stream label failed"); - } - - return SUCCESS; -} - -/// -/// @brief update attach flag -/// @param [in] node -/// @param [out] stream_label -/// @param [out] merge_flag -/// @param [out] exit_flag -/// @param [out] net_output_flag -/// @return Status -/// -Status SwitchOpPass::UpdateAttachFlag(const NodePtr &node, std::string &stream_label, bool &merge_flag, bool &exit_flag, - bool &net_output_flag) { - const std::string type = node->GetType(); - if (type == STREAMSWITCH) { - if (node->GetInDataNodes().empty()) { - GELOGE(INTERNAL_ERROR, "cur_node %s has no input_data_node", node->GetName().c_str()); - return INTERNAL_ERROR; - } - stream_label = node->GetInDataNodes().at(0)->GetName(); - GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "set stream label failed"); - bool value = false; - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, - "StreamSwitch get attr TRUE_BRANCH_STREAM fail."); - stream_label += (value ? "_t" : "_f"); - } else if (type == STREAMMERGE) { - stream_label = node->GetName(); - GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "set stream label failed"); - merge_flag = true; - } else if ((type == EXIT) || (type == REFEXIT)) { - GE_CHK_STATUS_RET(SetStreamLabel(node, stream_label), "set stream label failed"); - exit_flag = true; - } else if (type == NETOUTPUT) { - net_output_flag = true; - } - - return SUCCESS; -} - -/// -/// @brief update loop branch -/// @param [in] enter_nodes -/// @param [in] stream_label -/// @return Status -/// -Status SwitchOpPass::UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label) { - std::stack nodes(enter_nodes); - NodePtr cur_node = nullptr; - while (!nodes.empty()) { - cur_node = nodes.top(); - nodes.pop(); - for (NodePtr &out_node : cur_node->GetOutAllNodes()) { - OpDescPtr out_desc = out_node->GetOpDesc(); - GE_CHECK_NOTNULL(out_desc); - if (out_desc->HasAttr(ATTR_NAME_STREAM_LABEL)) { - continue; - } - GELOGD("Attach label %s to node: %s", stream_label.c_str(), out_node->GetName().c_str()); - GE_CHK_STATUS_RET(SetStreamLabel(out_node, stream_label), "set stream label failed"); - nodes.push(out_node); - } - } - - return SUCCESS; -} - -/// -/// @brief update enter nodes -/// @return Status -/// -Status SwitchOpPass::UpdateEnterNode() { - std::unordered_map> enter_active_map; - for (auto &enter_node : enter_nodes_) { - for (auto &out_ctrl_node : enter_node->GetOutControlNodes()) { - if (out_ctrl_node->GetType() != STREAMACTIVE) { - continue; - } - auto iter = enter_active_map.find(out_ctrl_node); - if (iter == enter_active_map.end()) { - enter_active_map[out_ctrl_node] = {enter_node}; - } else { - iter->second.emplace_back(enter_node); - } - } - } - - for (auto &pair : enter_active_map) { - std::string stream_label; - NodePtr active_node = pair.first; - GE_CHECK_NOTNULL(active_node); - OpDescPtr active_desc = active_node->GetOpDesc(); - GE_CHECK_NOTNULL(active_desc); - (void)AttrUtils::GetStr(active_desc, ATTR_NAME_STREAM_LABEL, stream_label); - if (stream_label.empty()) { - stream_label = active_desc->GetName(); - GE_CHK_STATUS_RET(SetStreamLabel(active_node, stream_label), "set stream label failed"); - } - std::stack enter_nodes; - for (auto &enter_node : pair.second) { - GE_CHK_STATUS_RET(SetStreamLabel(enter_node, stream_label), "set stream label failed"); - enter_nodes.emplace(enter_node); - } - - std::vector active_label_list; - if (!AttrUtils::GetListStr(active_desc, ATTR_NAME_ACTIVE_LABEL_LIST, active_label_list) || - (active_label_list.size() != 1) || active_label_list[0].empty()) { - GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ACTIVE_LABEL_LIST fail, node: %s", active_desc->GetName().c_str()); - return INTERNAL_ERROR; - } - if (UpdateLoopBranch(enter_nodes, active_label_list[0]) != SUCCESS) { - GELOGE(FAILED, "UpdateLoopBranch fail."); - return FAILED; - } - } - - return SUCCESS; -} - -/// -/// @brief Check duplicate node_name -/// @param [in] node_name -/// @return std::string -/// -std::string SwitchOpPass::CheckDuplicateName(const std::string &node_name) { - std::string tmp_name = node_name; - auto iter = node_num_map_.find(tmp_name); - if (iter != node_num_map_.end()) { - tmp_name = tmp_name + "_" + std::to_string(iter->second); - (iter->second)++; - } else { - node_num_map_[tmp_name] = 1; - } - return tmp_name; -} - -/// -/// @brief Check cyclic dependence -/// @param [in] graph -/// @return Status -/// -Status SwitchOpPass::CheckCycleDependence(ComputeGraphPtr &graph) { - std::string type; - std::unordered_map> cond_switch_map; - for (NodePtr &node : graph->GetDirectNode()) { - GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type fail"); - if ((type == SWITCH) || (type == REFSWITCH)) { - InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); - GE_CHK_BOOL_EXEC(in_cond_anchor != nullptr, return INTERNAL_ERROR, "Check Switch in_cond_anchor fail."); - OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); - GE_CHK_BOOL_EXEC(peer_out_anchor != nullptr, return INTERNAL_ERROR, "Check Switch peer_out_anchor fail."); - if (FindSwitchCondInput(true, peer_out_anchor) != SUCCESS) { - GELOGE(FAILED, "FindSwitchCondInput fail, switch=%s", node->GetName().c_str()); - return FAILED; - } - - NodePtr cond_node = peer_out_anchor->GetOwnerNode(); - auto iter = cond_switch_map.find(cond_node); - if (iter == cond_switch_map.end()) { - cond_switch_map[cond_node] = {node}; - } else { - iter->second.emplace_back(node); - } - - switch_nodes_.emplace_back(node); - } else if ((type == MERGE) || (type == REFMERGE)) { - merge_nodes_.emplace_back(node); - } else if ((type == ENTER) || (type == REFENTER)) { - enter_nodes_.emplace_back(node); - } - } - - MarkCycleDependence(cond_switch_map); - - return SUCCESS; -} - -/// -/// @brief Mark cyclic dependence -/// @param [in] graph -/// @param [in] cond_switch_map -/// @return void -/// -void SwitchOpPass::MarkCycleDependence(const std::unordered_map> &cond_switch_map) { - std::stack out_nodes; - NodePtr tmp_node = nullptr; - std::unordered_set handled_set; - for (auto &iter : cond_switch_map) { - std::set switch_nodes(iter.second.begin(), iter.second.end()); - for (auto &switch_node : switch_nodes) { - GE_CHECK_NOTNULL_JUST_RETURN(switch_node); - GELOGD("CheckCycleDependence: cond_node=%s, switch=%s", iter.first->GetName().c_str(), - switch_node->GetName().c_str()); - for (const NodePtr &node : switch_node->GetOutAllNodes()) { - out_nodes.push(node); - } - } - handled_set.clear(); - while (!out_nodes.empty()) { - tmp_node = out_nodes.top(); - GE_CHECK_NOTNULL_JUST_RETURN(tmp_node); - out_nodes.pop(); - if (handled_set.count(tmp_node) > 0) { - continue; - } - GELOGD("CheckCycleDependence: tmp_node=%s", tmp_node->GetName().c_str()); - for (NodePtr &out_node : tmp_node->GetOutAllNodes()) { - if (switch_nodes.find(out_node) == switch_nodes.end()) { - out_nodes.push(out_node); - continue; - } - GE_IF_BOOL_EXEC(SetCyclicDependenceFlag(out_node) != SUCCESS, GELOGW("set cyclic dependence failed"); return ); - auto map_iter = switch_cyclic_map_.find(out_node); - if (map_iter == switch_cyclic_map_.end()) { - switch_cyclic_map_[out_node] = {tmp_node->GetName()}; - } else { - map_iter->second.insert(tmp_node->GetName()); - } - } - handled_set.insert(tmp_node); - } - } - - return; -} - -/// -/// @brief Modify in ctl edge for switch_node -/// @param [in] switch_node -/// @param [in] cast_node -/// @param [in] same_cond_switch -/// @return Status -/// -Status SwitchOpPass::ModifySwitchInCtlEdges(NodePtr &switch_node, NodePtr &cast_node, - const std::set &same_cond_switch) { - GE_CHECK_NOTNULL(switch_node); - GE_CHECK_NOTNULL(cast_node); - GELOGI("ModifySwitchInCtlEdges: switch_node=%s, active_node=%s", switch_node->GetName().c_str(), - cast_node->GetName().c_str()); - - std::string orig_switch_name = switch_node->GetName(); - OpDescPtr switch_desc = switch_node->GetOpDesc(); - GE_CHECK_NOTNULL(switch_desc); - if (!AttrUtils::GetStr(switch_desc, ATTR_NAME_ORIG_NODE_NAME, orig_switch_name) || orig_switch_name.empty()) { - GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ORIG_NODE_NAME fail, node: %s", switch_desc->GetName().c_str()); - return INTERNAL_ERROR; - } - - for (NodePtr &in_ctl_node : switch_node->GetInControlNodes()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), - "Remove ctl edge fail."); - GE_IF_BOOL_EXEC(!in_ctl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), - "Add ctl edge fail."); - }); - - GE_IF_BOOL_EXEC(in_ctl_node->GetType() != STREAMSWITCH, continue); - if (same_cond_switch.count(in_ctl_node) > 0) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), - "Remove ctl edge fail."); - continue; - } - auto find_res1 = switch_node_map_.find(in_ctl_node); - GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { - GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctl_node->GetName().c_str()); - return INTERNAL_ERROR; - }); - auto find_res2 = find_res1->second.find(orig_switch_name); - auto find_res3 = find_res1->second.find(cast_node->GetName()); - GE_IF_BOOL_EXEC((find_res2 != find_res1->second.end()) && (find_res3 == find_res1->second.end()), { - find_res1->second.erase(find_res2); - find_res1->second.insert(cast_node->GetName()); - continue; - }); - } - - return SUCCESS; -} - -/// -/// @brief Modify out ctl edge for switch_node -/// @param [in] switch_node -/// @param [in] stream_switch -/// @param [in] active_node -/// @return Status -/// -Status SwitchOpPass::ModifySwitchOutCtlEdges(NodePtr &switch_node, NodePtr &stream_switch, NodePtr &active_node) { - GE_CHECK_NOTNULL(switch_node); - GE_CHECK_NOTNULL(stream_switch); - GE_CHECK_NOTNULL(active_node); - GELOGI("ModifySwitchOutCtlEdges: switch_node=%s, stream_switch=%s, active_node=%s", switch_node->GetName().c_str(), - stream_switch->GetName().c_str(), active_node->GetName().c_str()); - auto find_res = switch_node_map_.find(switch_node); - GE_IF_BOOL_EXEC(find_res == switch_node_map_.end(), { - GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", switch_node->GetName().c_str()); - return INTERNAL_ERROR; - }); - GE_IF_BOOL_EXEC(find_res->second.empty(), { - GELOGE(INTERNAL_ERROR, "true_nodes of StreamSwitch node %s is empty.", switch_node->GetName().c_str()); - return INTERNAL_ERROR; - }); - - for (NodePtr &node : switch_node->GetOutControlNodes()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(switch_node->GetOutControlAnchor(), node->GetInControlAnchor()), - "Remove ctl edge fail."); - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - std::string orig_name = op_desc->GetName(); - GE_IF_BOOL_EXEC(op_desc->HasAttr(ATTR_NAME_ORIG_NODE_NAME), { - if (!AttrUtils::GetStr(op_desc, ATTR_NAME_ORIG_NODE_NAME, orig_name) || orig_name.empty()) { - GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ORIG_NODE_NAME fail, node: %s.", op_desc->GetName().c_str()); - return INTERNAL_ERROR; - } - }); - if (find_res->second.find(orig_name) == find_res->second.end()) { - auto active_out_control_anchor = active_node->GetOutControlAnchor(); - GE_CHECK_NOTNULL(active_out_control_anchor); - GE_IF_BOOL_EXEC(!active_out_control_anchor->IsLinkedWith(node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(active_out_control_anchor, node->GetInControlAnchor()), "Add ctl edge fail."); - }); - } else { - auto stream_switch_out_control_anchor = stream_switch->GetOutControlAnchor(); - GE_CHECK_NOTNULL(stream_switch_out_control_anchor); - GE_IF_BOOL_EXEC(!stream_switch_out_control_anchor->IsLinkedWith(node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(stream_switch_out_control_anchor, node->GetInControlAnchor()), - "Add ctl edge fail."); - }); - } - } - - GE_IF_BOOL_EXEC(switch_node != stream_switch, (void)bypass_nodes_.insert(switch_node)); - - return SUCCESS; -} - -/// -/// @brief Copy Control Edges -/// @param [in] old_node -/// @param [in] new_node -/// @param [in] input_check_flag -/// @return void -/// -void SwitchOpPass::CopyControlEdges(NodePtr &old_node, NodePtr &new_node, bool input_check_flag) { - GE_CHECK_NOTNULL_JUST_RETURN(old_node); - GE_CHECK_NOTNULL_JUST_RETURN(new_node); - GE_IF_BOOL_EXEC(old_node == new_node, return ); - auto iter = switch_cyclic_map_.find(old_node); - bool check_flag = input_check_flag && (iter != switch_cyclic_map_.end()); - for (NodePtr &node : old_node->GetInControlNodes()) { - if (check_flag && (iter->second.count(node->GetName()) > 0)) { - for (auto &out_node : old_node->GetOutAllNodes()) { - auto out_control_anchor = node->GetOutControlAnchor(); - GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor); - GE_IF_BOOL_EXEC(!out_control_anchor->IsLinkedWith(out_node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(out_control_anchor, out_node->GetInControlAnchor()), "Add ctl edge fail."); - }); - } - } else { - auto out_control_anchor = node->GetOutControlAnchor(); - GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor); - GE_IF_BOOL_EXEC(!out_control_anchor->IsLinkedWith(new_node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(out_control_anchor, new_node->GetInControlAnchor()), "Add in ctl edge fail."); - }); - } - } - - for (NodePtr &node : old_node->GetOutControlNodes()) { - GE_IF_BOOL_EXEC(!new_node->GetOutControlAnchor()->IsLinkedWith(node->GetInControlAnchor()), { - GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), node->GetInControlAnchor()), - "Add out ctl edge fail."); - }); - } -} - -/// -/// @brief Remove Control Edges -/// @param [in] node -/// @return void -/// -void SwitchOpPass::RemoveControlEdges(NodePtr &node) { - GE_CHECK_NOTNULL_JUST_RETURN(node); - for (NodePtr &in_node : node->GetInControlNodes()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(in_node->GetOutControlAnchor(), node->GetInControlAnchor()), - "Remove in ctl edge fail."); - } - - for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { - for (auto &in_ctrl_anchor : out_data_anchor->GetPeerInControlAnchors()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, in_ctrl_anchor), "Remove in ctl edge fail."); - } - } - - auto out_control_anchor = node->GetOutControlAnchor(); - GE_CHECK_NOTNULL_JUST_RETURN(out_control_anchor); - for (auto &peer_anchor : out_control_anchor->GetPeerAnchors()) { - GE_CHK_STATUS(GraphUtils::RemoveEdge(out_control_anchor, peer_anchor), "Remove out ctl edge fail."); - } -} - -/// -/// @brief Replace Control Edges -/// @param [in] old_node -/// @param [in] new_node -/// @return void -/// -void SwitchOpPass::ReplaceControlEdges(NodePtr &old_node, NodePtr &new_node) { - GE_IF_BOOL_EXEC(old_node == new_node, return ); - CopyControlEdges(old_node, new_node); - RemoveControlEdges(old_node); -} - -/// -/// @brief Mark node as head_node of stream_switch -/// @param [in] node -/// @param [in] stream_switch -/// @return void -/// -void SwitchOpPass::MarkHeadNodes(const NodePtr &node, const NodePtr &stream_switch) { - std::stack nodes; - nodes.push(node); - std::set visited; - while (!nodes.empty()) { - NodePtr cur_node = nodes.top(); - nodes.pop(); - if (visited.count(cur_node) > 0) { - continue; - } - GELOGD("branch_head_node %s of stream_switch %s", cur_node->GetName().c_str(), stream_switch->GetName().c_str()); - branch_head_nodes_[cur_node] = stream_switch; - if ((cur_node->GetType() == IDENTITY) || (cur_node->GetType() == IDENTITYN)) { - for (auto &out_node : cur_node->GetOutAllNodes()) { - nodes.push(out_node); - } - } - visited.insert(cur_node); - } -} - -/// -/// @brief Clear Status, uesd for subgraph pass -/// @return -/// -Status SwitchOpPass::ClearStatus() { - switch_nodes_.clear(); - merge_nodes_.clear(); - enter_nodes_.clear(); - switch_cyclic_map_.clear(); - bypass_nodes_.clear(); - branch_head_nodes_.clear(); - stream_switch_nodes_.clear(); - need_label_nodes_.clear(); - cond_node_map_.clear(); - switch_node_map_.clear(); - node_num_map_.clear(); - return SUCCESS; -} -} // namespace ge diff --git a/src/ge/graph/passes/switch_to_stream_switch_pass.cc b/src/ge/graph/passes/switch_to_stream_switch_pass.cc new file mode 100644 index 00000000..ef8879dd --- /dev/null +++ b/src/ge/graph/passes/switch_to_stream_switch_pass.cc @@ -0,0 +1,755 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "graph/passes/switch_to_stream_switch_pass.h" +#include +#include "common/ge/ge_util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/debug/log.h" +#include "framework/common/ge_inner_error_codes.h" +#include "framework/common/types.h" +#include "ge/ge_api_types.h" +#include "graph/common/omg_util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/ge_context.h" +#include "graph/utils/type_utils.h" + +namespace ge { +Status SwitchToStreamSwitchPass::Run(ComputeGraphPtr graph) { + GELOGD("SwitchToStreamSwitchPass Enter"); + + GE_CHK_STATUS_RET(CheckCycleDependence(graph), "Check cyclic dependence failed."); + for (const auto &switch_node : switch_nodes_) { + GE_CHK_STATUS_RET(ReplaceSwitchNode(graph, switch_node), "Replace Switch by StreamSwitch failed."); + } + GE_CHK_STATUS_RET(CombineSwitchNode(graph), "Combine StreamSwitch nodes failed."); + + for (const auto &node : bypass_nodes_) { + GE_CHK_BOOL_EXEC(graph->IsolateNode(node) == GRAPH_SUCCESS, return FAILED, "Isolate node failed."); + GE_CHK_BOOL_EXEC(GraphUtils::RemoveNodeWithoutRelink(graph, node) == GRAPH_SUCCESS, return FAILED, + "Remove switch node failed."); + } + + GELOGD("SwitchToStreamSwitchPass Leave"); + return SUCCESS; +} + +/// +/// @brief Clear Status, used for subgraph pass +/// @return +/// +Status SwitchToStreamSwitchPass::ClearStatus() { + switch_nodes_.clear(); + switch_cyclic_map_.clear(); + bypass_nodes_.clear(); + stream_switch_nodes_.clear(); + cond_node_map_.clear(); + switch_node_map_.clear(); + node_num_map_.clear(); + return SUCCESS; +} + +/// +/// @brief Check cyclic dependence +/// @param [in] graph +/// @return Status +/// +Status SwitchToStreamSwitchPass::CheckCycleDependence(const ComputeGraphPtr &graph) { + std::string type; + std::unordered_map> cond_switch_map; + for (const NodePtr &node : graph->GetDirectNode()) { + GE_CHK_STATUS_RET(GetOriginalType(node, type), "Get node type failed."); + if ((type == SWITCH) || (type == REFSWITCH)) { + InDataAnchorPtr in_cond_anchor = node->GetInDataAnchor(SWITCH_PRED_INPUT); + GE_CHECK_NOTNULL(in_cond_anchor); + OutDataAnchorPtr peer_out_anchor = in_cond_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + if (FindSwitchCondInput(true, peer_out_anchor) != SUCCESS) { + GELOGE(FAILED, "Find pred_input for switch_node %s failed.", node->GetName().c_str()); + return FAILED; + } + + NodePtr cond_node = peer_out_anchor->GetOwnerNode(); + auto iter = cond_switch_map.find(cond_node); + if (iter == cond_switch_map.end()) { + cond_switch_map[cond_node] = {node}; + } else { + iter->second.emplace_back(node); + } + switch_nodes_.emplace_back(node); + } + } + + MarkCycleDependence(cond_switch_map); + return SUCCESS; +} + +/// +/// @brief Mark cyclic dependence +/// @param [in] graph +/// @param [in] cond_switch_map +/// @return void +/// +void SwitchToStreamSwitchPass::MarkCycleDependence( + const std::unordered_map> &cond_switch_map) { + std::stack out_nodes; + NodePtr tmp_node = nullptr; + std::unordered_set visited; + for (const auto &iter : cond_switch_map) { + std::set switch_nodes(iter.second.begin(), iter.second.end()); + for (const auto &switch_node : switch_nodes) { + GELOGD("MarkCycleDependence: cond_node=%s, switch=%s.", iter.first->GetName().c_str(), + switch_node->GetName().c_str()); + for (const auto &node : switch_node->GetOutAllNodes()) { + out_nodes.push(node); + } + } + visited.clear(); + while (!out_nodes.empty()) { + tmp_node = out_nodes.top(); + out_nodes.pop(); + if (visited.count(tmp_node) > 0) { + continue; + } + GELOGD("MarkCycleDependence: tmp_node=%s.", tmp_node->GetName().c_str()); + for (const NodePtr &out_node : tmp_node->GetOutAllNodes()) { + if (switch_nodes.find(out_node) == switch_nodes.end()) { + out_nodes.push(out_node); + continue; + } + GE_IF_BOOL_EXEC(SetCyclicDependenceFlag(out_node) != SUCCESS, GELOGW("set cyclic dependence attr failed."); + return ); + auto map_iter = switch_cyclic_map_.find(out_node); + if (map_iter == switch_cyclic_map_.end()) { + switch_cyclic_map_[out_node] = {tmp_node->GetName()}; + } else { + map_iter->second.insert(tmp_node->GetName()); + } + } + visited.insert(tmp_node); + } + } + + return; +} + +/// +/// @brief Replace Switch Op +/// @param [in] graph +/// @param [in] switch_node +/// @return Status +/// +Status SwitchToStreamSwitchPass::ReplaceSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node) { + OutDataAnchorPtr peer_data_anchor = nullptr; + OutDataAnchorPtr peer_cond_anchor = nullptr; + GE_CHK_BOOL_EXEC(BypassSwitchNode(switch_node, peer_data_anchor, peer_cond_anchor) == SUCCESS, return FAILED, + "Bypass switch node %s failed.", switch_node->GetName().c_str()); + GE_CHECK_NOTNULL(peer_data_anchor); + GE_CHECK_NOTNULL(peer_cond_anchor); + OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHECK_NOTNULL(cond_desc); + DataType cond_data_type = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx()).GetDataType(); + GE_CHK_BOOL_EXEC(cond_data_type == DT_BOOL, return FAILED, + "pred_input of Switch only support DT_BOOL data_type, but %s exactly.", + TypeUtils::DataTypeToSerialString(cond_data_type).c_str()); + + OpDescPtr switch_desc = switch_node->GetOpDesc(); + GE_CHECK_NOTNULL(switch_desc); + bool cyclic_flag = switch_desc->HasAttr(ATTR_NAME_CYCLIC_DEPENDENCE_FLAG); + std::set out_node_list; + for (const auto &out_data_anchor : switch_node->GetAllOutDataAnchors()) { + bool true_branch_flag = (static_cast(out_data_anchor->GetIdx()) == SWITCH_TRUE_OUTPUT); + NodePtr stream_switch = nullptr; + out_node_list.clear(); + for (const auto &peer_in_anchor : out_data_anchor->GetPeerAnchors()) { + GE_IF_BOOL_EXEC(stream_switch == nullptr, { + stream_switch = CreateStreamSwitchNode(graph, switch_node, true_branch_flag ? "_t" : "_f", peer_cond_anchor); + GE_CHK_BOOL_EXEC(stream_switch != nullptr, return FAILED, "Create stream_switch node failed."); + if (SetSwitchTrueBranchFlag(stream_switch, true_branch_flag) != SUCCESS) { + GELOGE(FAILED, "SetSwitchTrueBranchFlag for node %s failed.", stream_switch->GetName().c_str()); + return FAILED; + } + if (MarkBranches(peer_cond_anchor, stream_switch, true_branch_flag) != SUCCESS) { + GELOGE(FAILED, "Mark branches for stream_switch %s failed.", stream_switch->GetName().c_str()); + return FAILED; + } + + if (!cyclic_flag) { + GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor->GetOwnerNode()->GetOutControlAnchor(), + stream_switch->GetInControlAnchor()), + "StreamSwitch node add ctl edge failed."); + } + }); + + GE_CHK_STATUS(GraphUtils::RemoveEdge(out_data_anchor, peer_in_anchor), "Remove Switch data output failed."); + + NodePtr out_node = peer_in_anchor->GetOwnerNode(); + GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, peer_in_anchor), "StreamSwitch node add edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(stream_switch->GetOutControlAnchor(), out_node->GetInControlAnchor()), + "StreamSwitch node add ctl edge failed."); + out_node_list.insert(out_node->GetName()); + } + + GE_IF_BOOL_EXEC(stream_switch != nullptr, { + MoveCtrlEdges(switch_node, stream_switch); + switch_node_map_[stream_switch] = out_node_list; + if (SetOriginalNodeName(stream_switch, switch_node->GetName()) != SUCCESS) { + GELOGE(FAILED, "SetOriginalNodeName for node %s failed.", stream_switch->GetName().c_str()); + return FAILED; + } + }); + } + + (void)bypass_nodes_.insert(switch_node); + return SUCCESS; +} + +/// +/// @brief Bypass Switch Node +/// @param [in] switch_node +/// @param [out] peer_data_anchor +/// @param [out] peer_cond_anchor +/// @return Status +/// +Status SwitchToStreamSwitchPass::BypassSwitchNode(const NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, + OutDataAnchorPtr &peer_cond_anchor) { + for (uint32_t idx = 0; idx < SWITCH_INPUT_NUM; ++idx) { + InDataAnchorPtr in_data_anchor = switch_node->GetInDataAnchor(idx); + GE_CHECK_NOTNULL(in_data_anchor); + OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + // Remove Switch data input. + if (GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Remove data edge %s->%s failed.", peer_out_anchor->GetOwnerNode()->GetName().c_str(), + switch_node->GetName().c_str()); + return FAILED; + } + + if (idx == SWITCH_DATA_INPUT) { + peer_data_anchor = peer_out_anchor; + } else { + if (FindSwitchCondInput(false, peer_out_anchor) != SUCCESS) { + GELOGE(FAILED, "Find pred_input for switch_node %s failed.", switch_node->GetName().c_str()); + return FAILED; + } + peer_cond_anchor = peer_out_anchor; + } + } + + return SUCCESS; +} + +/// +/// @brief Find Switch cond input +/// @param [in] pass_switch_flag +/// @param [out] peer_cond_anchor +/// @return Status +/// +Status SwitchToStreamSwitchPass::FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor) { + NodePtr tmp_node = nullptr; + string type; + bool need_pass_type = true; + while (need_pass_type) { + if (tmp_node == nullptr) { + tmp_node = peer_cond_anchor->GetOwnerNode(); + } else { + InDataAnchorPtr in_data_anchor = tmp_node->GetInDataAnchor(SWITCH_DATA_INPUT); + GE_CHECK_NOTNULL(in_data_anchor); + peer_cond_anchor = in_data_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_cond_anchor); + tmp_node = peer_cond_anchor->GetOwnerNode(); + } + + GE_CHK_STATUS_RET(GetOriginalType(tmp_node, type), "Get node type failed."); + need_pass_type = (pass_switch_flag && ((type == SWITCH) || (type == REFSWITCH))); + } + + return SUCCESS; +} + +/// +/// @brief Create StreamSwitch Node +/// @param [in] graph +/// @param [in] switch_node +/// @param [in] suffix +/// @param [in] peer_cond_anchor +/// @return ge::NodePtr +/// +NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node, + const std::string &suffix, + const OutDataAnchorPtr &peer_cond_anchor) { + OpDescPtr switch_op_desc = switch_node->GetOpDesc(); + GE_CHK_BOOL_EXEC(switch_op_desc != nullptr, return nullptr, "OpDesc of Switch node is invalid."); + GE_IF_BOOL_EXEC(switch_op_desc->GetInputsSize() != SWITCH_INPUT_NUM, { + GELOGE(FAILED, "Switch input param invalid, input_size=%lu, should be %u.", switch_op_desc->GetInputsSize(), + SWITCH_INPUT_NUM); + return nullptr; + }); + + const std::string &node_name = switch_node->GetName() + "_" + STREAMSWITCH + suffix; + GELOGI("Create StreamSwitch, name=%s.", node_name.c_str()); + OpDescPtr op_desc = MakeShared(node_name, STREAMSWITCH); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, StreamSwitch:%s.", node_name.c_str()); + return nullptr; + } + + // mark hccl group id + std::string hccl_group_id; + if (AttrUtils::GetStr(switch_node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { + (void)AttrUtils::SetStr(op_desc, ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id); + GELOGD("Set attr ATTR_NAME_HCCL_FUSED_GROUP for Stream_Switch %s, value is %s.", node_name.c_str(), + hccl_group_id.c_str()); + } + + if (!AttrUtils::SetInt(op_desc, ATTR_NAME_SWITCH_DATA_TYPE, RT_SWITCH_INT32) || + !AttrUtils::SetInt(op_desc, ATTR_NAME_STREAM_SWITCH_COND, (int64_t)RT_EQUAL)) { + GELOGE(INTERNAL_ERROR, "set int failed"); + return nullptr; + } + + // Already checked, first input is Variable will passed, second is condition will checked. + GeTensorDesc cond_input_desc = switch_op_desc->GetInputDesc(SWITCH_PRED_INPUT); + GeTensorDesc input_desc(GeShape(cond_input_desc.GetShape().GetDims()), cond_input_desc.GetFormat(), DT_INT32); + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS, return nullptr, + "Create StreamSwitch node: add input desc failed."); + GE_CHK_BOOL_EXEC(op_desc->AddInputDesc(input_desc) == GRAPH_SUCCESS, return nullptr, + "Create StreamSwitch node: add input desc failed."); + + NodePtr stream_switch = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(stream_switch != nullptr, return nullptr, "Insert StreamSwitch node failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), + "StreamSwitch node add cond edge failed."); + + return stream_switch; +} + +/// +/// @brief Mark Switch Branch +/// @param [in] peer_cond_anchor +/// @param [in] stream_switch +/// @param [in] true_branch_flag +/// @return Status +/// +Status SwitchToStreamSwitchPass::MarkBranches(const OutDataAnchorPtr &peer_cond_anchor, const NodePtr &stream_switch, + bool true_branch_flag) { + uint32_t index = true_branch_flag ? SWITCH_TRUE_OUTPUT : SWITCH_FALSE_OUTPUT; + auto it = cond_node_map_.find(peer_cond_anchor); + if (it != cond_node_map_.end()) { + int64_t switch_group_id = GetGroupId(stream_switch); + auto switch_group_it = it->second.find(switch_group_id); + if (switch_group_it == it->second.end()) { + std::list false_node_list; + std::list true_node_list; + std::list &node_list = true_branch_flag ? true_node_list : false_node_list; + node_list.emplace_back(stream_switch); + std::vector> switch_list; + switch_list.emplace_back(false_node_list); + switch_list.emplace_back(true_node_list); + it->second[switch_group_id] = switch_list; + } else { + GE_IF_BOOL_EXEC(switch_group_it->second.size() != SWITCH_OUTPUT_NUM, { + GELOGE(INTERNAL_ERROR, "Check size failed, node: %s", stream_switch->GetName().c_str()); + return FAILED; + }); + switch_group_it->second[index].emplace_back(stream_switch); + } + } else { + int64_t switch_group_id = GetGroupId(stream_switch); + map>> switch_group_map; + std::list false_node_list; + std::list true_node_list; + std::list &node_list = true_branch_flag ? true_node_list : false_node_list; + node_list.emplace_back(stream_switch); + std::vector> switch_list; + switch_list.emplace_back(false_node_list); + switch_list.emplace_back(true_node_list); + switch_group_map[switch_group_id] = switch_list; + cond_node_map_[peer_cond_anchor] = switch_group_map; + } + return SUCCESS; +} + +/// +/// @brief Get group_id for switch_node +/// @param [in] node +/// @return group_id +/// +int64_t SwitchToStreamSwitchPass::GetGroupId(const NodePtr &node) { + string tailing_optimization_option; + bool is_tailing_optimization = false; + if (GetContext().GetOption(OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, tailing_optimization_option) == GRAPH_SUCCESS) { + // "1" means it's True from frontend option + is_tailing_optimization = (tailing_optimization_option == "1"); + GELOGI("Option ge.exec.isTailingOptimization is %s", tailing_optimization_option.c_str()); + } + if (!is_tailing_optimization) { + return 0; + } + + string hccl_group_id; + if (!AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, hccl_group_id)) { + GELOGI("Node %s can not find hccl group id.", node->GetName().c_str()); + return 0; + } + auto key_index = hccl_group_id.find_last_of('_'); + auto key_num = hccl_group_id.substr(key_index + 1, hccl_group_id.length() - key_index); + GELOGI("Node:%s, hccl_group_id=%s, key_num=%s", node->GetName().c_str(), hccl_group_id.c_str(), key_num.c_str()); + int64_t num = atoi(key_num.c_str()); + if (num == 0) { + return 0; + } + + GELOGI("Hccl_group_id is %s, group_id is %ld", hccl_group_id.c_str(), num); + return num; +} + +/// +/// @brief Combine switch nodes link to same cond +/// @param [in] graph +/// @return Status +/// +Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) { + for (auto iter = cond_node_map_.begin(); iter != cond_node_map_.end(); ++iter) { + for (auto group_iter = iter->second.begin(); group_iter != iter->second.end(); ++group_iter) { + std::list false_switch_list = group_iter->second[SWITCH_FALSE_OUTPUT]; + std::list true_switch_list = group_iter->second[SWITCH_TRUE_OUTPUT]; + std::set same_cond_switch; + same_cond_switch.insert(false_switch_list.begin(), false_switch_list.end()); + same_cond_switch.insert(true_switch_list.begin(), true_switch_list.end()); + + OutDataAnchorPtr peer_cond_anchor = iter->first; + NodePtr cond_node = peer_cond_anchor->GetOwnerNode(); + GELOGI("CombineSwitchNode: cond_node=%s.", cond_node->GetName().c_str()); + + NodePtr cast_node = CreateCastOp(graph, peer_cond_anchor); + GE_CHK_BOOL_EXEC(cast_node != nullptr, return FAILED, "Create cast_node failed."); + + NodePtr active_node = CreateActiveNode(graph, cond_node); + GE_CHK_BOOL_EXEC(active_node != nullptr, return FAILED, "Create StreamActive node failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutControlAnchor(), active_node->GetInControlAnchor()), + "StreamActive add ctl edge failed."); + if (SetActiveLabelList(active_node, {cast_node->GetName()}) != SUCCESS) { + GELOGE(FAILED, "Set active_label_list attr for node %s failed.", active_node->GetName().c_str()); + return FAILED; + } + + const std::string &cond_group = cond_node->GetName(); + for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { + bool true_branch_flag = (i == SWITCH_TRUE_OUTPUT); + std::list &switch_list = (true_branch_flag ? true_switch_list : false_switch_list); + GE_IF_BOOL_EXEC(switch_list.empty(), continue); + + // select first stream_switch + NodePtr stream_switch = switch_list.front(); + OpDescPtr switch_desc = stream_switch->GetOpDesc(); + GE_CHECK_NOTNULL(switch_desc); + switch_desc->SetName(CheckDuplicateName(cond_group + "/" + STREAMSWITCH + (true_branch_flag ? "_t" : "_f"))); + stream_switch_nodes_.emplace_back(stream_switch); + + // 0_input: original pred input, 1_input: constant node + GE_CHK_STATUS_RET(AddConstNode(graph, stream_switch), "Add const node failed."); + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), + "StreamSwitch remove data edge failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), + "Cast add data edge failed."); + + for (const NodePtr &node : switch_list) { + GE_IF_BOOL_EXEC(node != stream_switch, { + GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), + "StreamSwitch remove data edge failed."); + }); + GE_CHK_STATUS(ModifySwitchInCtlEdges(node, cast_node, same_cond_switch), "ModifySwitchInCtlEdges failed."); + GE_CHK_STATUS(ModifySwitchOutCtlEdges(node, stream_switch, active_node), "ModifySwitchOutCtlEdges failed."); + } + + GE_CHK_STATUS(GraphUtils::AddEdge(active_node->GetOutControlAnchor(), stream_switch->GetInControlAnchor()), + "StreamActive add ctl edge failed."); + } + } + } + return SUCCESS; +} + +/// +/// @brief Create Active Op +/// @param [in] graph +/// @param [in] cond_node +/// @return ge::NodePtr +/// +NodePtr SwitchToStreamSwitchPass::CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node) { + const std::string &node_name = CheckDuplicateName(node->GetName() + "_" + STREAMACTIVE); + GELOGI("Create StreamActive op:%s.", node_name.c_str()); + OpDescPtr op_desc = MakeShared(node_name, STREAMACTIVE); + if (op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, StreamActive:%s.", node_name.c_str()); + return nullptr; + } + + NodePtr active_node = graph->AddNode(op_desc); + GE_CHK_BOOL_EXEC(active_node != nullptr, return nullptr, "Create StreamActive node failed."); + + GE_IF_BOOL_EXEC(GraphUtils::AddEdge(node->GetOutControlAnchor(), active_node->GetInControlAnchor()) != SUCCESS, + GELOGE(INTERNAL_ERROR, "add edge failed"); + return nullptr); + + GE_IF_BOOL_EXEC(SetSwitchBranchNodeLabel(active_node, node_name) != SUCCESS, + GELOGE(INTERNAL_ERROR, "set switch branch node label failed"); + return nullptr); + + return active_node; +} + +/// +/// @brief Create cast node +/// @param [in] graph +/// @param [in] peer_cond_anchor +/// @return NodePtr +/// +NodePtr SwitchToStreamSwitchPass::CreateCastOp(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_cond_anchor) { + OpDescPtr cond_desc = peer_cond_anchor->GetOwnerNode()->GetOpDesc(); + GE_CHK_BOOL_EXEC(cond_desc != nullptr, return nullptr, "Get cond_desc failed."); + + const std::string &cast_name = CheckDuplicateName(cond_desc->GetName() + "_" + CAST); + GELOGI("Create cast_node: %s, input datatype:DT_BOOL, out datatype:DT_INT32", cast_name.c_str()); + OpDescPtr cast_desc = MakeShared(cast_name, CAST); + if (cast_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, Cast:%s.", cast_name.c_str()); + return nullptr; + } + if (!(AttrUtils::SetInt(cast_desc, CAST_ATTR_SRCT, (int64_t)DT_BOOL) && + AttrUtils::SetInt(cast_desc, CAST_ATTR_DSTT, (int64_t)DT_INT32) && + AttrUtils::SetInt(cast_desc, CAST_ATTR_DST_TYPE, (int64_t)DT_INT32) && + AttrUtils::SetBool(cast_desc, CAST_ATTR_TRUNCATE, false))) { + GELOGE(FAILED, "Set CAST_ATTR_SRCT or CAST_ATTR_DSTT or CAST_ATTR_DST_TYPE or CAST_ATTR_TRUNCATE failed, node: %s.", + cast_name.c_str()); + return nullptr; + } + + GeTensorDesc tensor_desc = cond_desc->GetOutputDesc(peer_cond_anchor->GetIdx()); + tensor_desc.SetDataType(DT_BOOL); + GE_CHK_BOOL_EXEC(cast_desc->AddInputDesc(tensor_desc) == SUCCESS, return nullptr, "Cast_node add input desc failed."); + tensor_desc.SetDataType(DT_INT32); + GE_CHK_BOOL_EXEC(cast_desc->AddOutputDesc(tensor_desc) == SUCCESS, return nullptr, + "Cast_node add output desc failed."); + + NodePtr cast_node = graph->AddNode(cast_desc); + GE_CHK_BOOL_EXEC(cast_node != nullptr, return nullptr, "Create cast_node failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, cast_node->GetInDataAnchor(0)), "Cast add data edge failed."); + + return cast_node; +} + +/// +/// @brief Add const node as switch input1 +/// @param [in] graph +/// @param [in] stream_switch +/// @return Status +/// +Status SwitchToStreamSwitchPass::AddConstNode(const ComputeGraphPtr &graph, const NodePtr &stream_switch) { + OpDescPtr op_desc = stream_switch->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + bool value = false; + GE_CHK_BOOL_EXEC(AttrUtils::GetBool(op_desc, ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG, value), return FAILED, + "StreamSwitch get attr TRUE_BRANCH_STREAM failed."); + + const std::string &const_node_name = op_desc->GetName() + "_Constant_" + (value ? "t" : "f"); + GELOGI("Create const op: %s", const_node_name.c_str()); + OpDescPtr const_op_desc = MakeShared(const_node_name, CONSTANT); + if (const_op_desc == nullptr) { + GELOGE(FAILED, "Create op_desc failed, Constant:%s.", const_node_name.c_str()); + return FAILED; + } + + auto resize_value = (int32_t)value; + GeTensorDesc data_desc = op_desc->GetInputDesc(1); + GeTensorPtr const_value = + MakeShared(data_desc, reinterpret_cast(&resize_value), sizeof(int32_t)); + if (const_value == nullptr) { + GELOGE(FAILED, "Create tensor failed."); + return FAILED; + } + GE_CHK_BOOL_EXEC(AttrUtils::SetTensor(const_op_desc, ATTR_NAME_WEIGHTS, const_value), return FAILED); + GE_CHK_BOOL_EXEC(const_op_desc->AddOutputDesc(data_desc) == GRAPH_SUCCESS, return FAILED, + "Create Const op: add output desc failed."); + + NodePtr const_node = graph->AddNode(const_op_desc); + GE_CHK_BOOL_EXEC(const_node != nullptr, return FAILED, "Insert Const node failed."); + GE_CHK_STATUS(GraphUtils::AddEdge(const_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(1)), + "StreamSwitch node add ctl edge failed."); + + return SUCCESS; +} + +/// +/// @brief Modify in ctl edge for switch_node +/// @param [in] switch_node +/// @param [in] cast_node +/// @param [in] same_cond_switch +/// @return Status +/// +Status SwitchToStreamSwitchPass::ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node, + const std::set &same_cond_switch) { + GELOGI("ModifySwitchInCtlEdges: switch_node=%s, active_node=%s", switch_node->GetName().c_str(), + cast_node->GetName().c_str()); + std::string orig_switch_name = switch_node->GetName(); + OpDescPtr switch_desc = switch_node->GetOpDesc(); + GE_CHECK_NOTNULL(switch_desc); + if (!AttrUtils::GetStr(switch_desc, ATTR_NAME_ORIG_NODE_NAME, orig_switch_name) || orig_switch_name.empty()) { + GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ORIG_NODE_NAME failed, node: %s", switch_desc->GetName().c_str()); + return INTERNAL_ERROR; + } + + for (const NodePtr &in_ctl_node : switch_node->GetInControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), switch_node->GetInControlAnchor()), + "Remove ctl edge failed."); + GE_IF_BOOL_EXEC(!in_ctl_node->GetOutControlAnchor()->IsLinkedWith(cast_node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), + "Add ctl edge failed."); + }); + + GE_IF_BOOL_EXEC(in_ctl_node->GetType() != STREAMSWITCH, continue); + if (same_cond_switch.count(in_ctl_node) > 0) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(in_ctl_node->GetOutControlAnchor(), cast_node->GetInControlAnchor()), + "Remove ctl edge failed."); + continue; + } + + auto find_res1 = switch_node_map_.find(in_ctl_node); + GE_IF_BOOL_EXEC(find_res1 == switch_node_map_.end(), { + GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", in_ctl_node->GetName().c_str()); + return INTERNAL_ERROR; + }); + auto find_res2 = find_res1->second.find(orig_switch_name); + auto find_res3 = find_res1->second.find(cast_node->GetName()); + GE_IF_BOOL_EXEC((find_res2 != find_res1->second.end()) && (find_res3 == find_res1->second.end()), { + find_res1->second.erase(find_res2); + find_res1->second.insert(cast_node->GetName()); + continue; + }); + } + + return SUCCESS; +} + +/// +/// @brief Modify out ctl edge for switch_node +/// @param [in] switch_node +/// @param [in] stream_switch +/// @param [in] active_node +/// @return Status +/// +Status SwitchToStreamSwitchPass::ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch, + const NodePtr &active_node) { + GELOGI("ModifySwitchOutCtlEdges: switch_node=%s, stream_switch=%s, active_node=%s", switch_node->GetName().c_str(), + stream_switch->GetName().c_str(), active_node->GetName().c_str()); + auto find_res = switch_node_map_.find(switch_node); + GE_IF_BOOL_EXEC(find_res == switch_node_map_.end(), { + GELOGE(INTERNAL_ERROR, "StreamSwitch node %s not found in switch_node_map_.", switch_node->GetName().c_str()); + return INTERNAL_ERROR; + }); + GE_IF_BOOL_EXEC(find_res->second.empty(), { + GELOGE(INTERNAL_ERROR, "true_nodes of StreamSwitch node %s is empty.", switch_node->GetName().c_str()); + return INTERNAL_ERROR; + }); + + for (const NodePtr &node : switch_node->GetOutControlNodes()) { + GE_CHK_STATUS(GraphUtils::RemoveEdge(switch_node->GetOutControlAnchor(), node->GetInControlAnchor()), + "Remove ctl edge failed."); + OpDescPtr op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + std::string orig_name = op_desc->GetName(); + GE_IF_BOOL_EXEC(op_desc->HasAttr(ATTR_NAME_ORIG_NODE_NAME), { + if (!AttrUtils::GetStr(op_desc, ATTR_NAME_ORIG_NODE_NAME, orig_name) || orig_name.empty()) { + GELOGE(INTERNAL_ERROR, "Get attr ATTR_NAME_ORIG_NODE_NAME failed, node: %s.", op_desc->GetName().c_str()); + return INTERNAL_ERROR; + } + }); + if (find_res->second.find(orig_name) == find_res->second.end()) { + auto active_out_ctrl_anchor = active_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL(active_out_ctrl_anchor); + GE_IF_BOOL_EXEC(!active_out_ctrl_anchor->IsLinkedWith(node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(active_out_ctrl_anchor, node->GetInControlAnchor()), "Add ctl edge failed."); + }); + } else { + auto switch_out_ctrl_anchor = stream_switch->GetOutControlAnchor(); + GE_CHECK_NOTNULL(switch_out_ctrl_anchor); + GE_IF_BOOL_EXEC(!switch_out_ctrl_anchor->IsLinkedWith(node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(switch_out_ctrl_anchor, node->GetInControlAnchor()), "Add ctl edge failed."); + }); + } + } + + GE_IF_BOOL_EXEC(switch_node != stream_switch, (void)bypass_nodes_.insert(switch_node)); + return SUCCESS; +} + +/// +/// @brief Check duplicate node_name +/// @param [in] node_name +/// @return std::string +/// +std::string SwitchToStreamSwitchPass::CheckDuplicateName(const std::string &node_name) { + std::string tmp_name = node_name; + auto iter = node_num_map_.find(tmp_name); + if (iter != node_num_map_.end()) { + tmp_name = tmp_name + "_" + std::to_string(iter->second); + (iter->second)++; + } else { + node_num_map_[tmp_name] = 1; + } + return tmp_name; +} + +/// +/// @brief Move Control Edges +/// @param [in] old_node +/// @param [in] new_node +/// @return void +/// +void SwitchToStreamSwitchPass::MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node) { + GE_IF_BOOL_EXEC(old_node == new_node, return ); + auto iter = switch_cyclic_map_.find(old_node); + bool check_flag = (iter != switch_cyclic_map_.end()); + for (const NodePtr &in_node : old_node->GetInControlNodes()) { + auto out_ctrl_anchor = in_node->GetOutControlAnchor(); + GE_CHECK_NOTNULL_JUST_RETURN(out_ctrl_anchor); + if (check_flag && (iter->second.count(in_node->GetName()) > 0)) { + for (const auto &out_node : old_node->GetOutAllNodes()) { + GE_IF_BOOL_EXEC(!out_ctrl_anchor->IsLinkedWith(out_node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(out_ctrl_anchor, out_node->GetInControlAnchor()), + "Add in ctrl edge failed."); + }); + } + } else { + GE_IF_BOOL_EXEC(!out_ctrl_anchor->IsLinkedWith(new_node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(out_ctrl_anchor, new_node->GetInControlAnchor()), "Add in ctrl edge failed."); + }); + } + GE_CHK_STATUS(GraphUtils::RemoveEdge(out_ctrl_anchor, old_node->GetInControlAnchor()), + "Remove in ctrl edge failed."); + } + + for (const NodePtr &out_node : old_node->GetOutControlNodes()) { + GE_IF_BOOL_EXEC(!new_node->GetOutControlAnchor()->IsLinkedWith(out_node->GetInControlAnchor()), { + GE_CHK_STATUS(GraphUtils::AddEdge(new_node->GetOutControlAnchor(), out_node->GetInControlAnchor()), + "Add out ctrl edge failed."); + }); + GE_CHK_STATUS(GraphUtils::RemoveEdge(old_node->GetOutControlAnchor(), out_node->GetInControlAnchor()), + "Remove out ctrl edge failed."); + } +} +} // namespace ge diff --git a/src/ge/graph/passes/switch_op_pass.h b/src/ge/graph/passes/switch_to_stream_switch_pass.h similarity index 61% rename from src/ge/graph/passes/switch_op_pass.h rename to src/ge/graph/passes/switch_to_stream_switch_pass.h index 202b919c..15fe9dce 100644 --- a/src/ge/graph/passes/switch_op_pass.h +++ b/src/ge/graph/passes/switch_to_stream_switch_pass.h @@ -14,15 +14,9 @@ * limitations under the License. */ -#ifndef GE_GRAPH_PASSES_SWITCH_OP_PASS_H_ -#define GE_GRAPH_PASSES_SWITCH_OP_PASS_H_ - -#include -#include -#include -#include -#include -#include +#ifndef GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_ +#define GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_ + #include "inc/graph_pass.h" namespace ge { @@ -91,78 +85,158 @@ namespace ge { +-----------+ +-----------+ +-----------+ +-----| Less |----+ +-----------+ */ -class SwitchOpPass : public GraphPass { +class SwitchToStreamSwitchPass : public GraphPass { public: Status Run(ComputeGraphPtr graph); + + /// + /// @brief Clear Status, used for subgraph pass + /// @return + /// Status ClearStatus() override; private: - Status ReplaceSwitchNode(ComputeGraphPtr &graph, NodePtr &switch_node); - - Status ReplaceMergeNode(ComputeGraphPtr &graph, NodePtr &merge_node); - - NodePtr CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix, - OutDataAnchorPtr &peer_cond_anchor); - - NodePtr CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); - - Status CombineSwitchNode(ComputeGraphPtr &graph); - - NodePtr CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node); - - Status AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &stream_merge_node, bool multi_batch_flag); - - Status BypassSwitchNode(NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, OutDataAnchorPtr &peer_cond_anchor); + /// + /// @brief Check cyclic dependence + /// @param [in] graph + /// @return Status + /// + Status CheckCycleDependence(const ComputeGraphPtr &graph); + + /// + /// @brief Mark cyclic dependence + /// @param [in] graph + /// @param [in] cond_switch_map + /// @return void + /// + void MarkCycleDependence(const std::unordered_map> &cond_switch_map); + /// + /// @brief Replace Switch Op + /// @param [in] graph + /// @param [in] switch_node + /// @return Status + /// + Status ReplaceSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node); + + /// + /// @brief Bypass Switch Node + /// @param [in] switch_node + /// @param [out] peer_data_anchor + /// @param [out] peer_cond_anchor + /// @return Status + /// + Status BypassSwitchNode(const NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, + OutDataAnchorPtr &peer_cond_anchor); + + /// + /// @brief Find Switch cond input + /// @param [in] pass_switch_flag + /// @param [out] peer_cond_anchor + /// @return Status + /// Status FindSwitchCondInput(bool pass_switch_flag, OutDataAnchorPtr &peer_cond_anchor); - Status MarkBranchs(OutDataAnchorPtr &peer_cond_anchor, NodePtr &stream_switch_node, bool true_branch_flag); - - NodePtr CreateCastOp(ComputeGraphPtr &graph, OutDataAnchorPtr &peer_cond_anchor); - - Status AddConstNode(ComputeGraphPtr &graph, NodePtr &stream_switch_node); - - Status UpdateCondBranch(NodePtr &node); - - Status UpdateAttachFlag(const NodePtr &node, std::string &stream_label, bool &merge_flag, bool &exit_flag, - bool &net_output_flag); - - Status UpdateLoopBranch(const std::stack &enter_nodes, const std::string &stream_label); - - Status UpdateEnterNode(); + /// + /// @brief Create StreamSwitch Node + /// @param [in] graph + /// @param [in] switch_node + /// @param [in] suffix + /// @param [in] peer_cond_anchor + /// @return ge::NodePtr + /// + NodePtr CreateStreamSwitchNode(const ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix, + const OutDataAnchorPtr &peer_cond_anchor); + + /// + /// @brief Mark Switch Branch + /// @param [in] peer_cond_anchor + /// @param [in] stream_switch + /// @param [in] true_branch_flag + /// @return Status + /// + Status MarkBranches(const OutDataAnchorPtr &peer_cond_anchor, const NodePtr &stream_switch_node, + bool true_branch_flag); + + /// + /// @brief Get group_id for switch_node + /// @param [in] node + /// @return group_id + /// + int64_t GetGroupId(const NodePtr &node); + /// + /// @brief Combine switch nodes link to same cond + /// @param [in] graph + /// @return Status + /// + Status CombineSwitchNode(const ComputeGraphPtr &graph); + + /// + /// @brief Create cast node + /// @param [in] graph + /// @param [in] peer_cond_anchor + /// @return NodePtr + /// + NodePtr CreateCastOp(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_cond_anchor); + + /// + /// @brief Create Active Op + /// @param [in] graph + /// @param [in] cond_node + /// @return ge::NodePtr + /// + NodePtr CreateActiveNode(const ComputeGraphPtr &graph, const NodePtr &node); + + /// + /// @brief Add const node as switch input1 + /// @param [in] graph + /// @param [in] stream_switch + /// @return Status + /// + Status AddConstNode(const ComputeGraphPtr &graph, const NodePtr &stream_switch_node); + + /// + /// @brief Modify in ctl edge for switch_node + /// @param [in] switch_node + /// @param [in] cast_node + /// @param [in] same_cond_switch + /// @return Status + /// + Status ModifySwitchInCtlEdges(const NodePtr &switch_node, const NodePtr &cast_node, + const std::set &same_cond_switch); + + /// + /// @brief Modify out ctl edge for switch_node + /// @param [in] switch_node + /// @param [in] stream_switch + /// @param [in] active_node + /// @return Status + /// + Status ModifySwitchOutCtlEdges(const NodePtr &switch_node, const NodePtr &stream_switch, const NodePtr &active_node); + + /// + /// @brief Check duplicate node_name + /// @param [in] node_name + /// @return std::string + /// std::string CheckDuplicateName(const std::string &node_name); - Status CheckCycleDependence(ComputeGraphPtr &graph); - - void MarkCycleDependence(const std::unordered_map> &cond_switch_map); - - Status ModifySwitchInCtlEdges(NodePtr &switch_node, NodePtr &cast_node, const std::set &same_cond_switch); - - Status ModifySwitchOutCtlEdges(NodePtr &switch_node, NodePtr &stream_switch, NodePtr &active_node); - - void CopyControlEdges(NodePtr &old_node, NodePtr &new_node, bool input_check_flag = false); - - void RemoveControlEdges(NodePtr &node); - - void ReplaceControlEdges(NodePtr &old_node, NodePtr &new_node); - - int64_t GetGroupId(const NodePtr &node); - - void MarkHeadNodes(const NodePtr &node, const NodePtr &stream_switch); + /// + /// @brief Move Control Edges + /// @param [in] old_node + /// @param [in] new_node + /// @return void + /// + void MoveCtrlEdges(const NodePtr &old_node, const NodePtr &new_node); std::vector switch_nodes_; - std::vector merge_nodes_; - std::vector enter_nodes_; std::unordered_map> switch_cyclic_map_; - std::set bypass_nodes_; - std::unordered_map branch_head_nodes_; std::vector stream_switch_nodes_; - std::vector need_label_nodes_; std::unordered_map>>> cond_node_map_; std::unordered_map> switch_node_map_; std::unordered_map node_num_map_; }; } // namespace ge -#endif // GE_GRAPH_PASSES_SWITCH_OP_PASS_H_ +#endif // GE_GRAPH_PASSES_SWITCH_TO_STREAM_SWITCH_PASS_H_ diff --git a/src/ge/graph/passes/transop_breadth_fusion_pass.cc b/src/ge/graph/passes/transop_breadth_fusion_pass.cc index 53f9e825..d8df4a22 100644 --- a/src/ge/graph/passes/transop_breadth_fusion_pass.cc +++ b/src/ge/graph/passes/transop_breadth_fusion_pass.cc @@ -19,14 +19,12 @@ #include #include -#include "framework/common/debug/ge_log.h" #include "common/types.h" #include "graph/common/transop_util.h" #include "graph/utils/node_utils.h" namespace ge { Status TransOpBreadthFusionPass::Run(ge::ComputeGraphPtr graph) { - GE_TIMESTAMP_START(TransOpBreadthFusionPass); if (graph == nullptr) { return SUCCESS; } @@ -47,7 +45,6 @@ Status TransOpBreadthFusionPass::Run(ge::ComputeGraphPtr graph) { } } } - GE_TIMESTAMP_END(TransOpBreadthFusionPass, "GraphManager::TransOpBreadthFusionPass"); return SUCCESS; } diff --git a/src/ge/graph/passes/transop_depth_fusion_pass.cc b/src/ge/graph/passes/transop_depth_fusion_pass.cc index c0c854b6..afeca3c4 100644 --- a/src/ge/graph/passes/transop_depth_fusion_pass.cc +++ b/src/ge/graph/passes/transop_depth_fusion_pass.cc @@ -17,7 +17,6 @@ #include "graph/passes/transop_depth_fusion_pass.h" #include -#include "framework/common/debug/ge_log.h" #include "common/ge_inner_error_codes.h" #include "common/types.h" #include "graph/compute_graph.h" @@ -29,7 +28,6 @@ namespace ge { graphStatus TransOpDepthFusionPass::Run(ComputeGraphPtr graph) { - GE_TIMESTAMP_START(TransOpDepthFusionPass); GELOGI("[TransOpDepthFusionPass]: optimize in depth begin..."); if (graph == nullptr) { return GRAPH_SUCCESS; @@ -53,7 +51,6 @@ graphStatus TransOpDepthFusionPass::Run(ComputeGraphPtr graph) { } } GELOGI("[TransOpDepthFusionPass]: Optimize in depth success..."); - GE_TIMESTAMP_END(TransOpDepthFusionPass, "GraphManager::TransOpDepthFusionPass"); return GRAPH_SUCCESS; } diff --git a/src/ge/graph/passes/transop_symmetry_elimination_pass.cc b/src/ge/graph/passes/transop_symmetry_elimination_pass.cc index 38b6684b..887079f8 100644 --- a/src/ge/graph/passes/transop_symmetry_elimination_pass.cc +++ b/src/ge/graph/passes/transop_symmetry_elimination_pass.cc @@ -15,22 +15,26 @@ */ #include "transop_symmetry_elimination_pass.h" +#include "common/formats/utils/formats_trans_utils.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" #include "graph/common/transop_util.h" +#include "graph/debug/ge_attr_define.h" #include "graph/utils/graph_utils.h" +#include "graph/utils/node_utils.h" #include "graph/utils/type_utils.h" +#include "types.h" namespace { const int kTransOpOutIndex = 0; -static std::map precision_loss_transfer_map = {{ge::DT_FLOAT, ge::DT_BOOL}}; - +const std::set white_list_op{ge::TRANSPOSED, ge::RESHAPE, ge::REFORMAT, ge::CAST, ge::TRANSDATA}; +std::map precision_loss_transfer_map = {{ge::DT_FLOAT, ge::DT_BOOL}}; } // namespace namespace ge { Status TransOpSymmetryEliminationPass::Run(NodePtr &node) { GE_CHECK_NOTNULL(node); GE_CHECK_NOTNULL(node->GetOpDesc()); - if (!TransOpUtil::IsTransOp(node)) { + if (white_list_op.find(node->GetType()) == white_list_op.end()) { return SUCCESS; } GELOGD("Symmetry Elimination Pass in."); @@ -41,9 +45,8 @@ Status TransOpSymmetryEliminationPass::Run(NodePtr &node) { GE_CHECK_NOTNULL(peer_in_anchor->GetOwnerNode()); GE_CHECK_NOTNULL(peer_in_anchor->GetOwnerNode()->GetOpDesc()); if (!CheckCanBeEliminated(node, peer_in_anchor)) { - break; + continue; } - auto dst_node = peer_in_anchor->GetOwnerNode(); Status ret = EliminateTransOp(node, out_anchor, dst_node, peer_in_anchor); if (ret != SUCCESS) { @@ -71,12 +74,33 @@ bool TransOpSymmetryEliminationPass::CheckCanBeEliminated(const ge::NodePtr &src dst_node->GetType().c_str(), dst_in_anchor->GetIdx()); return false; } - if (!DescAreSymmetry(src_node, dst_node) || !CheckPrecisionLoss(src_node)) { - GELOGD("Not satisfied symmetry or has precision loss, ignore pass."); - return false; + if (src_node->GetType() == ge::RESHAPE) { + GE_CHECK_NOTNULL(src_node->GetOpDesc()); + auto unknown_dims_num = GetUnknownDimsNum(src_node->GetOpDesc()->GetInputDesc(0)); + if (unknown_dims_num != 0 && (unknown_dims_num == UNKNOWN_DIM_NUM || unknown_dims_num > 1)) { + GELOGD( + "Pre node %s is reshape op which input is dynamic shape and has more than one unknown dimension. " + "Ignore pass.", + src_node->GetName().c_str()); + return false; + } + } else if (src_node->GetType() == ge::TRANSPOSED) { + if (!JudgeTransposeDBack2Raw(src_node, dst_node)) { + GELOGD("Two Transpose op src node %s dst node %s will change the raw data. Ignore pass.", + src_node->GetName().c_str(), dst_node->GetName().c_str()); + return false; + } + } else if (src_node->GetType() == ge::TRANSDATA) { + auto unknown_dims_num = GetUnknownDimsNum(src_node->GetOpDesc()->GetInputDesc(0)); + if (unknown_dims_num == UNKNOWN_DIM_NUM) { + GELOGD("Pre node %s is transdata op which input is dynamic shape and all dimension are unknown(-2). Ignore pass.", + src_node->GetName().c_str()); + return false; + } } - return true; + return CheckPrecisionLoss(src_node) && DescAreSymmetry(src_node, dst_node); } + bool TransOpSymmetryEliminationPass::DescAreSymmetry(const NodePtr &src_node, const NodePtr &dst_node) { const auto &src_input_desc = src_node->GetOpDesc()->MutableInputDesc(0); const auto &dst_output_desc = dst_node->GetOpDesc()->MutableOutputDesc(0); @@ -89,15 +113,28 @@ bool TransOpSymmetryEliminationPass::DescAreSymmetry(const NodePtr &src_node, co const auto &dst_output_format = dst_output_desc->GetFormat(); const auto &dst_output_shape = dst_output_desc->GetShape().GetDims(); + bool is_symmetry = true; if (src_node->GetType() == CAST && dst_node->GetType() == CAST) { bool is_format_symmetry = (src_input_format == dst_output_format) || (dst_output_format == FORMAT_ND) || (src_input_format == FORMAT_ND); - return (src_input_dtype == dst_output_dtype) && is_format_symmetry; + is_symmetry = (src_input_dtype == dst_output_dtype) && is_format_symmetry; } else { - return (src_input_dtype == dst_output_dtype) && (src_input_shape == dst_output_shape) && - (src_input_format == dst_output_format); + is_symmetry = (src_input_dtype == dst_output_dtype) && (src_input_shape == dst_output_shape) && + (src_input_format == dst_output_format); } + if (!is_symmetry) { + GELOGD( + "Not satisfied symmetry. ignore pass.\n" + "Src node %s input type: %s format: %s shape: %s, " + "dst node %s output type: %s format: %s shape: %s. ", + src_node->GetName().c_str(), TypeUtils::DataTypeToSerialString(src_input_dtype).c_str(), + TypeUtils::FormatToSerialString(src_input_format).c_str(), formats::ShapeToString(src_input_shape).c_str(), + dst_node->GetName().c_str(), TypeUtils::DataTypeToSerialString(dst_output_dtype).c_str(), + TypeUtils::FormatToSerialString(dst_output_format).c_str(), formats::ShapeToString(dst_output_shape).c_str()); + } + return is_symmetry; } + bool TransOpSymmetryEliminationPass::CheckPrecisionLoss(const ge::NodePtr &src_node) { auto idx = TransOpUtil::GetTransOpDataIndex(src_node); auto input_desc = src_node->GetOpDesc()->GetInputDesc(idx); @@ -106,13 +143,62 @@ bool TransOpSymmetryEliminationPass::CheckPrecisionLoss(const ge::NodePtr &src_n auto dst_dtype = output_desc.GetDataType(); auto iter = precision_loss_transfer_map.find(src_dtype); if (iter != precision_loss_transfer_map.end() && iter->second == dst_dtype) { - GELOGW("Node %s transfer data type from %s to %s ,it will cause precision loss.", src_node->GetName().c_str(), - TypeUtils::DataTypeToSerialString(src_dtype).c_str(), TypeUtils::DataTypeToSerialString(dst_dtype).c_str()); + GELOGW("Node %s transfer data type from %s to %s ,it will cause precision loss. ignore pass.", + src_node->GetName().c_str(), TypeUtils::DataTypeToSerialString(src_dtype).c_str(), + TypeUtils::DataTypeToSerialString(dst_dtype).c_str()); return false; } return true; } +int TransOpSymmetryEliminationPass::GetUnknownDimsNum(const GeTensorDesc &node_desc) { + // + // unknown_dims_num != 0 , is dynamic shape + // unknown_dims_num = UNKNOWN_DIM_NUM , all dims are unknown + // unknown_dims_num = n , n > 0 , has n dims unknown + // + int unknown_dims_num = 0; + auto ge_shape = node_desc.GetShape(); + for (const auto dim : ge_shape.GetDims()) { + if (dim == UNKNOWN_DIM_NUM) { + return UNKNOWN_DIM_NUM; + } + if (dim == UNKNOWN_DIM) { + ++unknown_dims_num; + } + } + return unknown_dims_num; +} + +bool TransOpSymmetryEliminationPass::JudgeTransposeDBack2Raw(const NodePtr &src_node, const NodePtr &dst_node) { + // + // A transpose to C : A---->(perm_1)---->B---->(perm_2)---->C + // we want to judge A is equal with C or not + // suppose A = C then: + // 1. B[i] = A[perm_1[i]] + // 2. C[i] = B[perm_2[i]] + // 3. combine 1 and 2 then: C[i] = A[perm_1[perm_2[i]]] + // which we get through 3: i = perm_1[perm_2[i]] + // + vector src_node_perm; + AttrUtils::GetListInt(src_node->GetOpDesc(), ge::PERMUTE_ATTR_PERM, src_node_perm); + vector dst_node_perm; + AttrUtils::GetListInt(dst_node->GetOpDesc(), ge::PERMUTE_ATTR_PERM, dst_node_perm); + + if (src_node_perm.size() != dst_node_perm.size()) { + return false; + } + for (size_t src_index = 0; src_index < src_node_perm.size(); ++src_index) { + if (dst_node_perm[src_index] >= static_cast(src_node_perm.size())) { + return false; + } + if (static_cast(src_index) != src_node_perm[dst_node_perm[src_index]]) { + return false; + } + } + return true; +} + Status TransOpSymmetryEliminationPass::EliminateTransOp(NodePtr &src_node, const OutDataAnchorPtr &src_out_anchor, NodePtr &dst_node, const InDataAnchorPtr &dst_in_anchor) { // Two transform nodes can be offset like A->T1->T2->B @@ -140,7 +226,18 @@ Status TransOpSymmetryEliminationPass::EliminateTransOp(NodePtr &src_node, const GELOGE(FAILED, "Copy control edge from %s to %s failed.", src_node->GetName().c_str(), dst_node->GetName().c_str()); return ret; } - // 4.IsolateAndDelete T2, A will link to B automatically, and all control edge will also relink. + // 4.Add control edge from T1 other input to T2, like reshape second input + for (const auto &in_node : src_node->GetInDataNodes()) { + if (in_node->GetName() == pre_normal_node->GetName()) { + continue; + } + ret = GraphUtils::AddEdge(in_node->GetOutControlAnchor(), dst_node->GetInControlAnchor()); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "Add control edge from %s to %s failed.", in_node->GetName().c_str(), dst_node->GetName().c_str()); + return ret; + } + } + // 5.IsolateAndDelete T2, A will link to B automatically, and all control edge will also relink. ret = IsolateAndDeleteNode(dst_node, {0}); if (ret != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Isolate removed node: %s, type: %s failed", dst_node->GetName().c_str(), @@ -148,16 +245,16 @@ Status TransOpSymmetryEliminationPass::EliminateTransOp(NodePtr &src_node, const return ret; } GELOGI("Trans op symmetry eliminate successfully. Node %s has been removed.", dst_node->GetName().c_str()); - // 5.If T1 has no data out, isolate and deleted it. + // 6.If T1 has no data out, isolate and deleted it. if (src_node->GetOutDataNodesSize() == 0) { - // 5.1 Copy out control to pre normal node + // 6.1 Copy out control to pre normal node ret = GraphUtils::CopyOutCtrlEdges(src_node, pre_normal_node); if (ret != GRAPH_SUCCESS) { GELOGE(FAILED, "Copy control edge from %s to %s failed.", src_node->GetName().c_str(), dst_node->GetName().c_str()); return ret; } - // 5.2 Isolate and delete T1 + // 6.2 Isolate and delete T1 ret = IsolateAndDeleteNode(src_node, {}); if (ret != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Isolate removed node: %s, type: %s failed", src_node->GetName().c_str(), diff --git a/src/ge/graph/passes/transop_symmetry_elimination_pass.h b/src/ge/graph/passes/transop_symmetry_elimination_pass.h index b0cff0c9..7f7409b7 100644 --- a/src/ge/graph/passes/transop_symmetry_elimination_pass.h +++ b/src/ge/graph/passes/transop_symmetry_elimination_pass.h @@ -43,6 +43,21 @@ class TransOpSymmetryEliminationPass : public BaseNodePass { /// static bool DescAreSymmetry(const NodePtr &src_node, const NodePtr &dst_node); + /// + /// get the number of unknown shape of node + /// @param node_desc: node to be checked + /// @return 0 , is not dynamic shape; UNKNOWN_DIM_NUM , all dims are unknown; n , n > 0 , has n dims unknown + /// + static int GetUnknownDimsNum(const GeTensorDesc &node_desc); + + /// + /// judge after two transposed op transform the raw data will be the same + /// @param src_node: first transposed op + /// @param dst_node: second transposed op + /// @return True or False, same or not + /// + static bool JudgeTransposeDBack2Raw(const NodePtr &src_node, const NodePtr &dst_node); + /// /// two transform nodes can not be offset if there is precision loss, like FP32->BOOL BOOL->FP32. /// keep this pair of transform nodes if it has precision loss. diff --git a/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc b/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc index ba4cd031..1d97d9a1 100644 --- a/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc +++ b/src/ge/graph/passes/transop_without_reshape_fusion_pass.cc @@ -22,7 +22,6 @@ #include "common/ge/ge_util.h" #include "common/ge_inner_error_codes.h" #include "common/types.h" -#include "framework/common/debug/ge_log.h" #include "graph/compute_graph.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_tensor.h" @@ -733,7 +732,6 @@ void TransOpWithoutReshapeFusionPass::RemoveNousedNodes(const ComputeGraphPtr &g } graphStatus TransOpWithoutReshapeFusionPass::Run(ComputeGraphPtr graph) { - GE_TIMESTAMP_START(TransOpWithoutReshapeFusionPass); GELOGI("[TransOpWithoutReshapeFusionPass]: optimize begin."); if (graph == nullptr) { return GRAPH_SUCCESS; @@ -786,7 +784,6 @@ graphStatus TransOpWithoutReshapeFusionPass::Run(ComputeGraphPtr graph) { } } GELOGI("[TransOpWithoutReshapeFusionPass]: Optimize end."); - GE_TIMESTAMP_END(TransOpWithoutReshapeFusionPass, "GraphManager::TransOpWithoutReshapeFusionPass"); return GRAPH_SUCCESS; } diff --git a/src/ge/graph/passes/transpose_transdata_pass.cc b/src/ge/graph/passes/transpose_transdata_pass.cc index 7ac7b7a3..3ac6dea5 100644 --- a/src/ge/graph/passes/transpose_transdata_pass.cc +++ b/src/ge/graph/passes/transpose_transdata_pass.cc @@ -135,7 +135,7 @@ Status TransposeTransDataPass::RemoveTranspose(NodePtr &node) { GE_CHECK_NOTNULL(anchor); anchor->UnlinkAll(); } - AddNodeDeleted(node.get()); + AddNodeDeleted(node); if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { GELOGE(FAILED, "[%s] RemoveNodeWithoutRelink failed.", node->GetName().c_str()); return FAILED; diff --git a/src/ge/graph/passes/var_is_initialized_op_pass.cc b/src/ge/graph/passes/var_is_initialized_op_pass.cc index c88db80c..73456a7b 100644 --- a/src/ge/graph/passes/var_is_initialized_op_pass.cc +++ b/src/ge/graph/passes/var_is_initialized_op_pass.cc @@ -191,7 +191,7 @@ Status VarIsInitializedOpPass::ChangeNodeToConstant(NodePtr &node, bool inited) AddRePassNodesWithInOut(const_node); // delete VarIsInitializedOp node from the graph - AddNodeDeleted(node.get()); + AddNodeDeleted(node); return SUCCESS; } diff --git a/src/ge/graph/passes/variable_op_pass.cc b/src/ge/graph/passes/variable_op_pass.cc index 175a049a..8c34cd36 100644 --- a/src/ge/graph/passes/variable_op_pass.cc +++ b/src/ge/graph/passes/variable_op_pass.cc @@ -20,7 +20,6 @@ #include "common/formats/formats.h" #include "common/formats/utils/formats_trans_utils.h" -#include "framework/common/debug/ge_log.h" #include "graph/ge_context.h" #include "graph/graph.h" #include "graph/manager/graph_var_manager.h" @@ -115,7 +114,6 @@ bool IsTransSupport(const TransNodeInfo &trans_info) { } // namespace Status VariableOpPass::Run(ge::ComputeGraphPtr graph) { - GE_TIMESTAMP_START(VariableOpPass); if (graph == nullptr) { GELOGE(INTERNAL_ERROR, "Failed to run variable op pass, null graph"); return INTERNAL_ERROR; @@ -190,9 +188,15 @@ Status VariableOpPass::Run(ge::ComputeGraphPtr graph) { if (UpdateIOFormatInfo(end_iter->output, node_set) != SUCCESS) { return GE_GRAPH_VARIABLE_OP_PASS_FAILED; } + + // renew var desc if the trans_road is all reshape or reformat + ret = RenewVarDesc(graph->GetSessionID(), node, fusion_road); + if (ret != SUCCESS) { + GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str()); + return FAILED; + } } - GE_TIMESTAMP_END(VariableOpPass, "GraphManager::VariableOpPass"); return SUCCESS; } @@ -604,4 +608,28 @@ Status VariableOpPass::RenewVarDesc(ge::ComputeGraphPtr &graph) { } return SUCCESS; } + +Status VariableOpPass::RenewVarDesc(uint64_t session_id, const NodePtr &node, const VarTransRoad &fusion_road) { + // renew var desc if the trans_road is all reshape or reformat + for (auto &road : fusion_road) { + if (road.node_type != RESHAPE && road.node_type != REFORMAT) { + return SUCCESS; + } + } + + if (!ge::VarManager::Instance(session_id)->IsVarExist(node->GetName())) { + GELOGD("var manager does not exist var node[%s]", node->GetName().c_str()); + return SUCCESS; + } + GELOGD("var manager exist var node[%s]", node->GetName().c_str()); + GE_CHECK_NOTNULL(node->GetOpDesc()); + Status ret = ge::VarManager::Instance(session_id)->RenewCurVarDesc(node->GetName(), node->GetOpDesc()); + if (ret != SUCCESS) { + GELOGE(FAILED, "var manager renew var[%s] descriptor failed!", node->GetName().c_str()); + return FAILED; + } + + return SUCCESS; +} + } // namespace ge diff --git a/src/ge/graph/passes/variable_op_pass.h b/src/ge/graph/passes/variable_op_pass.h index 4e194a0c..e17980e9 100644 --- a/src/ge/graph/passes/variable_op_pass.h +++ b/src/ge/graph/passes/variable_op_pass.h @@ -66,6 +66,7 @@ class VariableOpPass : public GraphPass { Status UpdateIOFormatInfo(const GeTensorDesc &final_output, std::set &nodes); Status RenewVarDesc(ge::ComputeGraphPtr &graph); + Status RenewVarDesc(uint64_t session_id, const NodePtr &node, const VarTransRoad &fusion_road); std::map> var_and_var_ref_map_; diff --git a/src/ge/graph/passes/variable_prepare_op_pass.cc b/src/ge/graph/passes/variable_prepare_op_pass.cc index 4db78a46..d93e1003 100644 --- a/src/ge/graph/passes/variable_prepare_op_pass.cc +++ b/src/ge/graph/passes/variable_prepare_op_pass.cc @@ -30,6 +30,7 @@ namespace ge { std::map> VariablePrepareOpPass::ref_node_without_prototype_map_{ {REFSWITCH, {{0, 0}, {0, 1}}}}; + Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); for (const auto &node : graph->GetDirectNode()) { @@ -62,7 +63,6 @@ Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { GELOGI("{ %d : %d }", index_iter->first, index_iter->second); } } - return SUCCESS; } @@ -73,10 +73,13 @@ Status VariablePrepareOpPass::DealVariableNode(NodePtr &var_node) { GE_CHECK_NOTNULL(dst_node); InDataAnchorPtr dst_in_data_anchor = dst_node_and_inanchor.second; GE_CHECK_NOTNULL(dst_in_data_anchor); - int out_index = GetWritableNodeOutIndex(dst_node, dst_in_data_anchor->GetIdx()); + auto input_index = dst_in_data_anchor->GetIdx(); + int out_index = GetWritableNodeOutIndex(dst_node, input_index); if (out_index >= 0) { - Status ret = DealWritableNode(dst_node, var_node, out_index); + Status ret = DealWritableNode(dst_node, input_index, var_node); if (ret != SUCCESS) { + GELOGE(FAILED, "Deal writable node[%s] failed, input index: %d, var: %s.", dst_node->GetName().c_str(), + input_index, var_node->GetName().c_str()); return FAILED; } } @@ -84,84 +87,97 @@ Status VariablePrepareOpPass::DealVariableNode(NodePtr &var_node) { return SUCCESS; } -Status VariablePrepareOpPass::DealWritableNode(ge::NodePtr &writable_node, ge::NodePtr &var_node, int out_index) { - GE_CHECK_NOTNULL(writable_node); - GE_CHECK_NOTNULL(var_node); - NodePtr final_writable_node = writable_node; - bool is_have_peer_node = false; - for (auto &dst_node_and_inanchor : writable_node->GetOutDataNodesAndAnchors()) { - NodePtr dst_node = dst_node_and_inanchor.first; - GE_CHECK_NOTNULL(dst_node); - InDataAnchorPtr dst_in_data_anchor = dst_node_and_inanchor.second; - GE_CHECK_NOTNULL(dst_in_data_anchor); - is_have_peer_node = true; - int current_out_index = GetWritableNodeOutIndex(dst_node, dst_in_data_anchor->GetIdx()); - if (current_out_index >= 0) { - final_writable_node = GetFinalWritableNode(dst_node, current_out_index); - out_index = current_out_index; - } - - GE_CHECK_NOTNULL(final_writable_node); - Status ret = AddVariableRef(final_writable_node, var_node, out_index); - if (ret != SUCCESS) { - GELOGE(FAILED, "add variable ref failed"); - return FAILED; +Status VariablePrepareOpPass::DealWritableNode(const ge::NodePtr &writable_node, int input_index, + const ge::NodePtr &var_node) { + // Find the last ref node: + // If the ref input has corresponding output, add variable ref after it. + // If the ref input has no corresponding output, insert RefIdentity and variable ref before it. + // If ref node with control output was found while finding the last ref node, add variable ref after it. + std::stack> nodes_to_check; + nodes_to_check.push({writable_node, input_index}); + while (!nodes_to_check.empty()) { + auto node_index = nodes_to_check.top(); + nodes_to_check.pop(); + auto cur_node = node_index.first; + int cur_input_index = node_index.second; + // Collect ref node after cur node + const auto nodes_size = nodes_to_check.size(); + // Add peer ref output node of current node to stack + CHECK_FALSE_EXEC(GetPeerNodeOfRefInput(cur_node, cur_input_index, nodes_to_check) == SUCCESS, + GELOGE(FAILED, "GetPeerNodeOfRefInput for node[%s] failed.", cur_node->GetName().c_str()); + return FAILED); + auto output_index = GetWritableNodeOutIndex(cur_node, cur_input_index); + CHECK_FALSE_EXEC(output_index >= 0, + GELOGE(FAILED, "Get writable node[%s] ref input[%d]'s corresponding out index failed: %d.", + cur_node->GetName().c_str(), cur_input_index, output_index); + return FAILED); + if (nodes_size == nodes_to_check.size()) { + const auto &op_desc = cur_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + // No need to add variable_ref for frameworkop + if (op_desc->GetType() == FRAMEWORKOP) { + GELOGD("No need to add variable_ref for frameworkop"); + continue; + } + if (static_cast(output_index) < op_desc->GetOutputsSize()) { + // Add variable ref node after ref output for final ref node + CHECK_FALSE_EXEC(AddVariableRef(cur_node, var_node, output_index) == SUCCESS, + GELOGE(FAILED, "Add variable ref failed"); + return FAILED); + } else { + // Insert variable ref node before ref input without corresponding ref output + CHECK_FALSE_EXEC(InsertVariableRef(cur_node, cur_input_index, var_node) == SUCCESS, + GELOGE(FAILED, "Insert variable ref and ref identity failed"); + return FAILED); + } + continue; } - } - if (final_writable_node->GetName() == writable_node->GetName() && !is_have_peer_node) { - Status ret = AddVariableRef(final_writable_node, var_node, out_index); - if (ret != SUCCESS) { - return FAILED; + if (HasControlOut(cur_node)) { + // Add variable ref node after ref output for ref node has control output. + CHECK_FALSE_EXEC(AddVariableRef(cur_node, var_node, output_index) == SUCCESS, + GELOGE(FAILED, "Add variable ref failed"); + return FAILED); } } return SUCCESS; } -NodePtr VariablePrepareOpPass::GetFinalWritableNode(ge::NodePtr &writable_node, int &out_index) { - NodePtr current_node = writable_node; - std::unordered_set seen_node; - while (true) { - if (seen_node.count(current_node.get())) { - GELOGE(FAILED, "There is a ring structure in the graph"); - return nullptr; - } - seen_node.insert(current_node.get()); - OutDataAnchorPtr out_anchor = current_node->GetOutDataAnchor(out_index); - if (out_anchor == nullptr) { - GELOGE(FAILED, "Failed to get data anchor by index %d", out_index); - return nullptr; - } - bool found_writeable_node = false; - auto peer_in_anchors = out_anchor->GetPeerInDataAnchors(); - for (auto &peer_in_anchor : peer_in_anchors) { - if (peer_in_anchor == nullptr) { - GELOGE(FAILED, "peer in data anchor is nullptr, node %s:%s", current_node->GetType().c_str(), - current_node->GetName().c_str()); - continue; - } - - NodePtr peer_node = peer_in_anchor->GetOwnerNode(); - int current_out_index = GetWritableNodeOutIndex(peer_node, peer_in_anchor->GetIdx()); - if (current_out_index >= 0) { - current_node = peer_node; - out_index = current_out_index; - found_writeable_node = true; - break; - } +Status VariablePrepareOpPass::GetPeerNodeOfRefInput(const ge::NodePtr &node, int input_index, + std::stack> &nodes) { + auto output_index = GetWritableNodeOutIndex(node, input_index); + if (output_index == -1) { + GELOGE(PARAM_INVALID, "Node[%s] is not a ref node.", node->GetName().c_str()); + return PARAM_INVALID; + } + const auto &op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + if (static_cast(output_index) == op_desc->GetOutputsSize()) { + return SUCCESS; + } + if (output_index >= static_cast(node->GetAllOutDataAnchorsSize())) { + GELOGW("Can not get %d th output anchor of %s", output_index, node->GetName().c_str()); + return SUCCESS; + } + const auto &out_anchor = node->GetOutDataAnchor(output_index); + GE_CHECK_NOTNULL(out_anchor); + for (const auto &peer_in_anchor : out_anchor->GetPeerInDataAnchors()) { + auto peer_node = peer_in_anchor->GetOwnerNode(); + if (peer_node == nullptr) { + continue; } - if (!found_writeable_node) { - GELOGD("final writable node is %s", current_node->GetName().c_str()); - return current_node; + const int peer_in_index = peer_in_anchor->GetIdx(); + if (GetWritableNodeOutIndex(peer_node, peer_in_index) != -1) { + nodes.push({peer_node, peer_in_index}); } } + return SUCCESS; } -Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, ge::NodePtr &var_node, int index) { +Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, const ge::NodePtr &var_node, int index) { GE_CHECK_NOTNULL(final_writable_node); GE_CHECK_NOTNULL(var_node); - - if (final_writable_node->GetType() == FRAMEWORKOP) { - GELOGD("No need to add variable_ref for frameworkop"); + if (index >= static_cast(final_writable_node->GetAllOutDataAnchorsSize())) { + GELOGW("Can not get %d th output anchor of %s", index, final_writable_node->GetName().c_str()); return SUCCESS; } // Check for duplicate creation @@ -181,7 +197,8 @@ Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, g // creat variable_ref std::stringstream variable_ref_name; variable_ref_name << "_TO_" << final_writable_node->GetName() << "_REF_" << index; - NodePtr variable_ref_node = CreatVariableRef(var_node->GetName() + variable_ref_name.str(), var_node); + NodePtr variable_ref_node = CreateVariableRef(var_node->GetName() + variable_ref_name.str(), var_node); + GE_CHECK_NOTNULL(variable_ref_node); Status ret_check = CheckStreamLabel(variable_ref_node, final_writable_node); if (ret_check != SUCCESS) { GELOGE(FAILED, "check stream lable failed"); @@ -189,23 +206,12 @@ Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, g } GELOGI("Add variable_ref between [%s] and [%s]", var_node->GetName().c_str(), variable_ref_node->GetName().c_str()); - GE_CHECK_NOTNULL(variable_ref_node); - // add control anchor between variable_ref and final peer node + // add control anchor between variable_ref and final peer node // variable_ref_node need to execute before other nodes - auto final_writable_outAnchors = final_writable_node->GetAllOutAnchors(); - for (auto &final_writable_outAnchor : final_writable_outAnchors) { - GE_CHECK_NOTNULL(final_writable_outAnchor); - for (auto &final_writable_peerAnchor : final_writable_outAnchor->GetPeerAnchors()) { - GE_CHECK_NOTNULL(final_writable_peerAnchor); - NodePtr peer_node = final_writable_peerAnchor->GetOwnerNode(); - graphStatus ret = - ge::GraphUtils::AddEdge(variable_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()); - if (ret != GRAPH_SUCCESS) { - GELOGE(FAILED, "add control anchor between variable_ref and final_writable peer node failed"); - return FAILED; - } - } - } + CHECK_FALSE_EXEC(AddControlEdge(final_writable_node, variable_ref_node) == SUCCESS, + GELOGE(FAILED, "Add control edges between variable ref node and output nodes of ref node failed"); + return FAILED); + graphStatus ret = ge::GraphUtils::AddEdge(out_anchor, variable_ref_node->GetInDataAnchor(0)); if (ret != GRAPH_SUCCESS) { GELOGE(FAILED, "add data anchor between variable_ref and final_writable peer node failed"); @@ -214,7 +220,110 @@ Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, g return SUCCESS; } -ge::NodePtr VariablePrepareOpPass::CreatVariableRef(const std::string &variable_ref_name, ge::NodePtr &var_node) { +Status VariablePrepareOpPass::InsertVariableRef(ge::NodePtr &node, int in_index, const ge::NodePtr &var_node) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(var_node); + // Check connection between two nodes + const auto in_anchor = node->GetInDataAnchor(in_index); + GE_CHECK_NOTNULL(in_anchor); + auto peer_out_anchor = in_anchor->GetPeerOutAnchor(); + GE_CHECK_NOTNULL(peer_out_anchor); + auto peer_in_node = peer_out_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_in_node); + + // Create ref_identity + std::stringstream ref_identity_name; + ref_identity_name << "RefIdentity_" << peer_in_node->GetName() << "_" << peer_out_anchor->GetIdx() << "_TO_" + << node->GetName() << "_" << in_index; + NodePtr ref_identity_node = CreateRefIdentity(ref_identity_name.str(), node, static_cast(in_index)); + GE_CHECK_NOTNULL(ref_identity_node); + + // Create variable_ref + std::stringstream variable_ref_name; + variable_ref_name << "_TO_" << node->GetName() << "_REF_" << in_index; + NodePtr variable_ref_node = CreateVariableRef(var_node->GetName() + variable_ref_name.str(), var_node); + GE_CHECK_NOTNULL(variable_ref_node); + Status ret_check = CheckStreamLabel(variable_ref_node, node); + if (ret_check != SUCCESS) { + GELOGE(FAILED, "check stream lable failed"); + return FAILED; + } + + GELOGI("Insert variable_ref of [%s] between [%s] and [%s]", var_node->GetName().c_str(), + peer_in_node->GetName().c_str(), node->GetName().c_str()); + // add control anchor between variable_ref and node + // variable_ref_node need to execute before other nodes + CHECK_FALSE_EXEC(AddControlEdge(node, variable_ref_node) == SUCCESS, + GELOGE(FAILED, "Add control edges between variable ref node and output nodes of ref node failed"); + return FAILED); + + // Insert variable ref node between two nodes and remove the original edge. + CHECK_FALSE_EXEC(ge::GraphUtils::RemoveEdge(peer_out_anchor, in_anchor) == SUCCESS, + GELOGE(FAILED, "Remove edge between ref node and its peer node failed"); + return FAILED); + CHECK_FALSE_EXEC(ge::GraphUtils::AddEdge(peer_out_anchor, ref_identity_node->GetInDataAnchor(0)) == SUCCESS, + GELOGE(FAILED, "Add data edge between pre node and ref_identity failed"); + return FAILED); + CHECK_FALSE_EXEC(ge::GraphUtils::AddEdge(ref_identity_node->GetOutDataAnchor(0), in_anchor) == SUCCESS, + GELOGE(FAILED, "Add data edge between ref_identity and ref node failed"); + return FAILED); + + // Add edge from ref identity node to variable ref node. + CHECK_FALSE_EXEC( + ge::GraphUtils::AddEdge(ref_identity_node->GetOutDataAnchor(0), variable_ref_node->GetInDataAnchor(0)) == SUCCESS, + GELOGE(FAILED, "Add data edge between ref_identity and variable_ref failed"); + return FAILED); + CHECK_FALSE_EXEC( + ge::GraphUtils::AddEdge(node->GetOutControlAnchor(), variable_ref_node->GetInControlAnchor()) == SUCCESS, + GELOGE(FAILED, "Add control edge between ref_identity and variable_ref failed"); + return FAILED); + return SUCCESS; +} + +Status VariablePrepareOpPass::AddControlEdge(const ge::NodePtr &node, const ge::NodePtr &variable_ref_node) { + auto out_anchors = node->GetAllOutAnchors(); + for (auto &out_anchor : out_anchors) { + GE_CHECK_NOTNULL(out_anchor); + for (auto &peer_in_anchor : out_anchor->GetPeerAnchors()) { + GE_CHECK_NOTNULL(peer_in_anchor); + NodePtr peer_node = peer_in_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(peer_node); + CHECK_FALSE_EXEC( + ge::GraphUtils::AddEdge(variable_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()) == SUCCESS, + GELOGE(FAILED, "Add control edge between variable_ref and ref node's peer node failed"); + return FAILED); + } + } + return SUCCESS; +} + +ge::NodePtr VariablePrepareOpPass::CreateRefIdentity(const std::string &ref_identity_name, const ge::NodePtr &node, + uint32_t input_index) { + OpDescPtr op_desc = node->GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(FAILED, "opdesc is nullptr"); + return nullptr; + } + + OpDescPtr ref_identity_op_desc = MakeShared(ref_identity_name.c_str(), REFIDENTITY); + if (ref_identity_op_desc == nullptr) { + GELOGE(FAILED, "ref_identity op desc is nullptr"); + return nullptr; + } + + GE_IF_BOOL_EXEC(ref_identity_op_desc->AddOutputDesc(op_desc->GetInputDesc(input_index)) != SUCCESS, + GELOGW("add output desc edge failed"); + return nullptr); + GE_IF_BOOL_EXEC(ref_identity_op_desc->AddInputDesc(op_desc->GetInputDesc(input_index)) != SUCCESS, + GELOGW("add input desc edge failed"); + return nullptr); + NodePtr ref_identity_node = node->GetOwnerComputeGraph()->AddNode(ref_identity_op_desc); + GE_IF_BOOL_EXEC(ref_identity_node == nullptr, GELOGW("ref_identity_node is null"); return nullptr); + return ref_identity_node; +} + +ge::NodePtr VariablePrepareOpPass::CreateVariableRef(const std::string &variable_ref_name, + const ge::NodePtr &var_node) { OpDescPtr var_op_desc = var_node->GetOpDesc(); if (var_op_desc == nullptr) { GELOGE(FAILED, "get var opdesc is nullptr"); @@ -250,7 +359,6 @@ int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int inpu } GELOGD("get writable node and input index %s:%d", node->GetName().c_str(), input_index); auto node_type = node->GetType(); - if (node_type == FRAMEWORKOP) { std::string original_type; GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, GELOGW("Get node original type fail")); @@ -266,25 +374,17 @@ void VariablePrepareOpPass::GenerateRefTypeAndInputOutputMap(const NodePtr &node GELOGW("op_desc in null, please check node:[%s]", node->GetName().c_str()); return; } - for (const auto &out_ancohor : node->GetAllOutDataAnchors()) { - int output_index = out_ancohor->GetIdx(); - string output_name = op_desc->GetOutputNameByIndex(output_index); - GELOGD("output name:[%s]", output_name.c_str()); - - int input_index = op_desc->GetInputIndexByName(output_name); - if (input_index == -1) { + for (const auto &name_index : op_desc->GetAllInputName()) { + // Record the index of output with the same name as input, thinking of them as a pair of ref input and output. + const int out_index = op_desc->GetOutputIndexByName(name_index.first); + if (out_index != -1) { + ref_input_output_map_[node->GetType()][name_index.second] = out_index; continue; } - auto ref_type_and_input_output_iter = ref_input_output_map_.find(node->GetType()); - if (ref_type_and_input_output_iter != ref_input_output_map_.end()) { - auto &input_output_index_map = ref_type_and_input_output_iter->second; - if (input_output_index_map.find(input_index) == input_output_index_map.end()) { - input_output_index_map.emplace(input_index, output_index); - GELOGD("Add RefInputOutputMap %s:{ %d, %d }", node->GetType().c_str(), input_index, output_index); - } - } else { - ref_input_output_map_.insert({node->GetType(), {{input_index, output_index}}}); - GELOGD("Create RefInputOutputMap { %s:{ %d, %d } }", node->GetType().c_str(), input_index, output_index); + // Record the ref input without corresponding output. + const auto &input_desc = op_desc->GetInputDesc(name_index.second); + if (!input_desc.GetRefPortIndex().empty()) { + ref_input_output_map_[node->GetType()][name_index.second] = static_cast(op_desc->GetOutputsSize()); } } } @@ -317,4 +417,15 @@ Status VariablePrepareOpPass::CheckStreamLabel(const ge::NodePtr &var_ref_node, } return SUCCESS; } + +bool VariablePrepareOpPass::HasControlOut(const ge::NodePtr &node) { + const auto &out_control_anchor = node->GetOutControlAnchor(); + for (const auto &peer_in_control_anchor : out_control_anchor->GetPeerInControlAnchors()) { + if (peer_in_control_anchor == nullptr || peer_in_control_anchor->GetOwnerNode() == nullptr) { + continue; + } + return true; + } + return false; +} } // namespace ge diff --git a/src/ge/graph/passes/variable_prepare_op_pass.h b/src/ge/graph/passes/variable_prepare_op_pass.h index c8b9883e..f024a464 100644 --- a/src/ge/graph/passes/variable_prepare_op_pass.h +++ b/src/ge/graph/passes/variable_prepare_op_pass.h @@ -18,6 +18,7 @@ #define GE_GRAPH_PASSES_VARIABLE_PREPARE_OP_PASS_H_ #include +#include #include #include "framework/common/ge_inner_error_codes.h" @@ -30,15 +31,19 @@ class VariablePrepareOpPass : public GraphPass { private: Status DealVariableNode(ge::NodePtr &node); - Status DealWritableNode(ge::NodePtr &writable_node, ge::NodePtr &var_node, int out_index); - NodePtr GetFinalWritableNode(ge::NodePtr &writable_node, int &out_index); - Status AddVariableRef(ge::NodePtr &node, ge::NodePtr &var_node, int index); - NodePtr CreatVariableRef(const std::string &variable_ref_name, ge::NodePtr &var_node); + Status DealWritableNode(const ge::NodePtr &writable_node, int input_index, const ge::NodePtr &var_node); + Status GetPeerNodeOfRefInput(const ge::NodePtr &node, int input_index, std::stack> &nodes); + Status AddVariableRef(ge::NodePtr &node, const ge::NodePtr &var_node, int index); + Status InsertVariableRef(ge::NodePtr &node, int in_index, const ge::NodePtr &var_node); + Status AddControlEdge(const ge::NodePtr &node, const ge::NodePtr &variable_ref_node); + NodePtr CreateVariableRef(const std::string &variable_ref_name, const ge::NodePtr &var_node); + NodePtr CreateRefIdentity(const std::string &ref_identity_name, const ge::NodePtr &node, uint32_t input_index); int GetWritableNodeOutIndex(const NodePtr &node, int input_index); void GenerateRefTypeAndInputOutputMap(const NodePtr &node); int FindRefOutIndex(const std::string &node_type, int input_index, const std::map> &ref_map); Status CheckStreamLabel(const ge::NodePtr &var_ref_node, const ge::NodePtr &final_writable_node); + bool HasControlOut(const ge::NodePtr &node); std::map> ref_input_output_map_; static std::map> ref_node_without_prototype_map_; diff --git a/src/ge/graph/passes/variable_ref_delete_op_pass.cc b/src/ge/graph/passes/variable_ref_delete_op_pass.cc index cd5b9fe9..3487df47 100644 --- a/src/ge/graph/passes/variable_ref_delete_op_pass.cc +++ b/src/ge/graph/passes/variable_ref_delete_op_pass.cc @@ -16,18 +16,16 @@ #include "graph/passes/variable_ref_delete_op_pass.h" #include -#include "framework/common/debug/ge_log.h" namespace ge { Status VariableRefDeleteOpPass::Run(ge::ComputeGraphPtr graph) { - GE_TIMESTAMP_START(VariableRefDeleteOpPass); GE_CHECK_NOTNULL(graph); - - for (auto &node : graph->GetDirectNode()) { - GELOGD("before VariableRefDeleteOpPass, graph has node: %s, and node name: %s", node->GetType().c_str(), - node->GetName().c_str()); + std::set all_var_names; + auto root_graph = GraphUtils::FindRootGraph(graph); + GE_CHECK_NOTNULL(root_graph); + for (const auto &n : root_graph->GetAllNodes()) { + all_var_names.insert(n->GetName()); } - for (auto &node : graph->GetDirectNode()) { GE_CHECK_NOTNULL(node->GetOpDesc()); std::string ref_var_src_var_name; @@ -36,19 +34,17 @@ Status VariableRefDeleteOpPass::Run(ge::ComputeGraphPtr graph) { if (!is_variable_ref) { continue; } + if (all_var_names.count(ref_var_src_var_name) == 0) { + GELOGE(FAILED, "Can not find source variable[%s] of variable ref[%s]", ref_var_src_var_name.c_str(), + node->GetName().c_str()); + return FAILED; + } Status ret = DealVariableRef(graph, node, ref_var_src_var_name); if (ret != SUCCESS) { GELOGE(ret, "variable ref [%s] delete failed", node->GetName().c_str()); return FAILED; } } - - for (auto &node : graph->GetDirectNode()) { - GELOGD("after VariableRefDeleteOpPass, graph has node: %s, and node name: %s", node->GetType().c_str(), - node->GetName().c_str()); - } - GE_TIMESTAMP_END(VariableRefDeleteOpPass, "GraphManager::VariableRefDeleteOpPass"); - return SUCCESS; } @@ -68,23 +64,15 @@ Status VariableRefDeleteOpPass::DealVariableRef(ge::ComputeGraphPtr &graph, ge:: // get previous node of variable_ref NodePtr peer_node = inAnchor0->GetPeerOutAnchor()->GetOwnerNode(); - // add attr [REF_VAR_SRC_VAR_NAME] to the previous node of the variable_ref - GE_CHECK_NOTNULL(peer_node->GetOpDesc()); - bool is_set_str = ge::AttrUtils::SetStr(peer_node->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); - - ge::NodePtr ref_var_src_var = GraphUtils::FindNodeFromAllNodes(graph, ref_var_src_var_name); - if (ref_var_src_var == nullptr) { - GELOGE(FAILED, "get ref_var_src_var failed"); - return FAILED; + // add attr [REF_VAR_SRC_VAR_NAME] to the previous op output desc of the variable_ref + auto op_desc = peer_node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + auto out_desc = op_desc->MutableOutputDesc(static_cast(index)); + bool is_set_str = ge::AttrUtils::SetStr(out_desc, REF_VAR_SRC_VAR_NAME, ref_var_src_var_name); + if (is_set_str) { + GELOGI("[%s-%d]: add attr [REF_VAR_SRC_VAR_NAME: %s ] ", peer_node->GetName().c_str(), index, + ref_var_src_var_name.c_str()); } - - GE_CHECK_NOTNULL(ref_var_src_var->GetOpDesc()); - bool is_set_index = ge::AttrUtils::SetInt(ref_var_src_var->GetOpDesc(), REF_VAR_PRE_PEER_OUT_INDEX, index); - if (is_set_str && is_set_index) { - GELOGI("[%s]: add attr [REF_VAR_SRC_VAR_NAME: %s ] ", peer_node->GetName().c_str(), ref_var_src_var_name.c_str()); - GELOGI("[%s]: add attr [REF_VAR_PRE_PEER_OUT_INDEX: %d]", ref_var_src_var->GetName().c_str(), index); - } - // remove variable_ref if (GraphUtils::IsolateNode(variable_ref, {0}) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Isolate removed node: %s, type: %s failed", variable_ref->GetName().c_str(), diff --git a/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc b/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc index bd153184..1321cf20 100644 --- a/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc +++ b/src/ge/graph/passes/variable_ref_useless_control_out_delete_pass.cc @@ -17,7 +17,6 @@ #include "variable_ref_useless_control_out_delete_pass.h" namespace ge { - Status VariableRefUselessControlOutDeletePass::Run(ge::ComputeGraphPtr graph) { GE_CHECK_NOTNULL(graph); for (const auto &node : graph->GetDirectNode()) { diff --git a/src/ge/graph/preprocess/graph_preprocess.cc b/src/ge/graph/preprocess/graph_preprocess.cc index 9c82a06d..3d0f1514 100644 --- a/src/ge/graph/preprocess/graph_preprocess.cc +++ b/src/ge/graph/preprocess/graph_preprocess.cc @@ -19,9 +19,12 @@ #include #include #include +#include "common/formats/format_transfers/format_transfer_fractal_nz.h" +#include "common/formats/format_transfers/format_transfer_fractal_z.h" #include "common/formats/format_transfers/format_transfer_nchw_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_nhwc_nc1hwc0.h" #include "common/formats/format_transfers/format_transfer_transpose.h" +#include "common/formats/utils/formats_trans_utils.h" #include "common/helper/model_helper.h" #include "common/math/math_util.h" #include "common/op/ge_op_utils.h" @@ -32,6 +35,7 @@ #include "graph/common/transop_util.h" #include "graph/debug/ge_attr_define.h" #include "graph/ge_context.h" +#include "graph/shape_refiner.h" #include "graph/manager/graph_var_manager.h" #include "graph/manager/util/rt_context_util.h" #include "graph/optimize/graph_optimize.h" @@ -80,7 +84,9 @@ #include "graph/passes/switch_dead_branch_elimination.h" #include "graph/passes/switch_fusion_pass.h" #include "graph/passes/switch_logic_remove_pass.h" -#include "graph/passes/switch_op_pass.h" +#include "graph/passes/merge_to_stream_merge_pass.h" +#include "graph/passes/switch_to_stream_switch_pass.h" +#include "graph/passes/attach_stream_label_pass.h" #include "graph/passes/switch_split_pass.h" #include "graph/passes/unused_const_pass.h" #include "graph/passes/unused_op_remove_pass.h" @@ -96,7 +102,6 @@ #include "runtime/dev.h" #include "graph/passes/dimension_adjust_pass.h" -#include "graph/passes/identify_reference_pass.h" #include "graph/passes/link_gen_mask_nodes_pass.h" #include "graph/passes/permute_pass.h" #include "graph/passes/reshape_remove_pass.h" @@ -134,14 +139,14 @@ OpDescPtr CreateTensorShape(const GeTensorDesc &data_tensor) { auto dim_cnt = static_cast(dst_ge_shape.GetDimNum()); if (dim_cnt == 0) { // if the dim_cnt is 0, the tensor is a scalar tensor->MutableTensorDesc().SetShape(GeShape()); - int64_t dst_shape = 1; - if (tensor->SetData(reinterpret_cast(&dst_shape), sizeof(int64_t)) != GRAPH_SUCCESS) { + int32_t dst_shape = 1; + if (tensor->SetData(reinterpret_cast(&dst_shape), sizeof(int32_t)) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "tensor set data failed"); return nullptr; } } else { tensor->MutableTensorDesc().SetShape(GeShape(std::vector({dim_cnt}))); - unique_ptr dst_shape(new (std::nothrow) int64_t[dim_cnt]()); + unique_ptr dst_shape(new (std::nothrow) int32_t[dim_cnt]()); if (dst_shape == nullptr) { GELOGE(INTERNAL_ERROR, "Create unique ptr failed"); return nullptr; @@ -151,7 +156,7 @@ OpDescPtr CreateTensorShape(const GeTensorDesc &data_tensor) { } GE_IF_BOOL_EXEC( - tensor->SetData(reinterpret_cast(dst_shape.get()), dim_cnt * sizeof(int64_t)) != GRAPH_SUCCESS, + tensor->SetData(reinterpret_cast(dst_shape.get()), dim_cnt * sizeof(int32_t)) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "tensor set data failed"); return nullptr;) } @@ -451,135 +456,6 @@ VarNamesToRefs CollectVarNamesToRefs(const ComputeGraphPtr &graph) { } return names_to_refs; } -Status AddTransNodeBetweenTwoNodes(OutDataAnchorPtr &src_out, InDataAnchorPtr &insert_in, - OutDataAnchorPtr &insert_out) { - if ((src_out == nullptr) || (insert_in == nullptr) || (insert_out == nullptr)) { - GELOGE(INTERNAL_ERROR, "anchor is nullptr"); - return FAILED; - } - auto vistor = src_out->GetPeerInDataAnchors(); - for (auto it = vistor.begin(); it != vistor.end(); ++it) { - InDataAnchorPtr dst_in = *it; - GE_CHK_STATUS_RET(src_out->Unlink(dst_in), "Unlink the anchor failed"); - GE_CHK_STATUS_RET(insert_out->LinkTo(dst_in), "Link the anchor failed"); - } - GE_CHK_STATUS_RET(src_out->LinkTo(insert_in), "Link the anchor failed"); - return SUCCESS; -} - -NodePtr CreateCastOp(const ge::GeShape &shape, const ge::DataType input_data_type, const ge::DataType output_data_type, - const ge::Format format, NodePtr &node) { - static uint32_t transop_count = 0; - std::string name = std::string("cast_node").append(std::to_string(transop_count++)); - - GELOGI("create cast op:%s, input datatype:%s, out datatype:%s.", name.c_str(), - TypeUtils::DataTypeToSerialString(input_data_type).c_str(), - TypeUtils::DataTypeToSerialString(output_data_type).c_str()); - GeTensorDesc input(shape, format, input_data_type); - input.SetOriginFormat(format); - input.SetOriginShape(shape); - input.SetOriginDataType(input_data_type); - ge::TensorUtils::SetRealDimCnt(input, static_cast(shape.GetDims().size())); - - GeTensorDesc output(shape, format, output_data_type); - output.SetOriginFormat(format); - output.SetOriginShape(shape); - output.SetOriginDataType(output_data_type); - ge::TensorUtils::SetRealDimCnt(output, static_cast(shape.GetDims().size())); - - auto cast_node = CreateTransNode(name, CAST, input, output, node); - GELOGD("Create cast node success."); - return cast_node; -} - -Status ProcessInputFP16(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr &switchn_node) { - GE_CHECK_NOTNULL(node_ptr); - auto op_desc = node_ptr->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - const GeTensorDescPtr &input = op_desc->MutableInputDesc(0); - GE_CHECK_NOTNULL(input); - ge::DataType src_dtype = input->GetDataType(); - if (src_dtype == DT_FLOAT16) { - GELOGI("The node name, %s dtype is fp16", node_ptr->GetName().c_str()); - return SUCCESS; - } - input->SetDataType(DT_FLOAT16); - input->SetOriginDataType(DT_FLOAT16); - int64_t input_shape_size = 0; - int64_t output_shape_size = 0; - ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size); - ge::graphStatus output_graph_status = ge::TensorUtils::GetTensorMemorySizeInBytes(*input, output_shape_size); - if (input_graph_status != ge::GRAPH_SUCCESS && output_graph_status != ge::GRAPH_SUCCESS) { - GELOGE(GRAPH_FAILED, "GetTensorSize failed!"); - return FAILED; - } - ge::TensorUtils::SetSize(*input, input_shape_size); - const GeTensorDescPtr &output = op_desc->MutableOutputDesc(0); - GE_CHECK_NOTNULL(output); - output->SetDataType(DT_FLOAT16); - output->SetOriginDataType(DT_FLOAT16); - ge::TensorUtils::SetSize(*output, output_shape_size); - - if (!is_dynamic_batch) { - NodePtr cast_node = CreateCastOp(output->GetShape(), DT_FLOAT16, src_dtype, output->GetFormat(), node_ptr); - GE_CHECK_NOTNULL(cast_node); - OutDataAnchorPtr src_out = node_ptr->GetOutDataAnchor(0); - InDataAnchorPtr cast_in = cast_node->GetInDataAnchor(0); - OutDataAnchorPtr cast_out = cast_node->GetOutDataAnchor(0); - if (AddTransNodeBetweenTwoNodes(src_out, cast_in, cast_out) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "add node between two nodes failed, src name:%s, cast node name:%s.", - node_ptr->GetName().c_str(), cast_node->GetName().c_str()); - return FAILED; - } - } else { - auto switchn_op_desc = switchn_node->GetOpDesc(); - GE_CHECK_NOTNULL(switchn_op_desc); - const GeTensorDescPtr &switchn_input = switchn_op_desc->MutableInputDesc(0); - GE_CHECK_NOTNULL(switchn_input); - switchn_input->SetDataType(DT_FLOAT16); - switchn_input->SetOriginDataType(DT_FLOAT16); - for (uint32_t i = 0; i < switchn_node->GetAllOutDataAnchorsSize(); ++i) { - const GeTensorDescPtr &switchn_output = switchn_op_desc->MutableOutputDesc(i); - GE_CHECK_NOTNULL(switchn_output); - switchn_output->SetDataType(DT_FLOAT16); - switchn_output->SetOriginDataType(DT_FLOAT16); - NodePtr cast_node = - CreateCastOp(switchn_output->GetShape(), DT_FLOAT16, src_dtype, switchn_output->GetFormat(), node_ptr); - GE_CHECK_NOTNULL(cast_node); - OutDataAnchorPtr src_out = switchn_node->GetOutDataAnchor(i); - InDataAnchorPtr cast_in = cast_node->GetInDataAnchor(0); - OutDataAnchorPtr cast_out = cast_node->GetOutDataAnchor(0); - if (AddTransNodeBetweenTwoNodes(src_out, cast_in, cast_out) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "add node between two nodes failed, src name:%s, cast node name:%s.", - switchn_node->GetName().c_str(), cast_node->GetName().c_str()); - return FAILED; - } - } - } - return SUCCESS; -} - -NodePtr CreateTransdataNode(const ge::GeShape &in_shape, const ge::Format input_format, const ge::GeShape &out_shape, - const ge::Format output_format, const ge::DataType dt, NodePtr &node) { - static uint32_t transop_count = 0; - // Does not involve multithreading. - std::string name = std::string("transdata_node").append(std::to_string(transop_count++)); - - GELOGI("create trandata op:%s, input format:%s, out format:%s.", name.c_str(), - TypeUtils::FormatToSerialString(input_format).c_str(), TypeUtils::FormatToSerialString(output_format).c_str()); - - GeTensorDesc input(in_shape, input_format, dt); - input.SetOriginFormat(input_format); - input.SetOriginShape(in_shape); - input.SetOriginDataType(dt); - - GeTensorDesc output(out_shape, output_format, dt); - output.SetOriginFormat(output_format); - output.SetOriginShape(out_shape); - output.SetOriginDataType(dt); - - return CreateTransNode(name, TRANSDATA, input, output, node); -} Status TransferShape2NC1HWC0(Format src_format, const std::vector &src_shape, DataType dt, Format dst_format, std::vector &dst_shape) { @@ -649,68 +525,40 @@ Status ModifyFormatAndShapeForSingleTensor(const GeTensorDescPtr &input_output) return SUCCESS; } -Status ProcessInputNC1HWC0(NodePtr &node_ptr, bool &is_dynamic_batch, NodePtr &switchn_node) { - GE_CHECK_NOTNULL(node_ptr); - auto op_desc = node_ptr->GetOpDesc(); +Status ModifyDataNetOutputFormatAndShape(OpDescPtr &op_desc, uint32_t index, Format storage_format, + vector &dst_shape_dims) { GE_CHECK_NOTNULL(op_desc); - const GeTensorDescPtr &input = op_desc->MutableInputDesc(0); + const GeTensorDescPtr &input = op_desc->MutableInputDesc(index); GE_CHECK_NOTNULL(input); ge::Format old_format = input->GetFormat(); - ge::GeShape old_shape = input->GetShape(); - bool support = ((old_format == FORMAT_NC1HWC0) || (old_format == FORMAT_NCHW) || (old_format == FORMAT_NHWC)); - if (!support) { - GELOGE(INTERNAL_ERROR, "The format [%s] is unsupported", TypeUtils::FormatToSerialString(old_format).c_str()); - return FAILED; - } - if (old_format == FORMAT_NC1HWC0) { - GELOGI("No need to transfer format"); - return SUCCESS; - } - if (ModifyInputFormatAndShape(node_ptr) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "modify format and shape failed"); - return FAILED; - } - if (!is_dynamic_batch) { - NodePtr trans_node = - CreateTransdataNode(input->GetShape(), FORMAT_NC1HWC0, old_shape, old_format, input->GetDataType(), node_ptr); - GE_CHECK_NOTNULL(trans_node); - OutDataAnchorPtr src_out = node_ptr->GetOutDataAnchor(0); - InDataAnchorPtr trans_in = trans_node->GetInDataAnchor(0); - OutDataAnchorPtr trans_out = trans_node->GetOutDataAnchor(0); - if (AddTransNodeBetweenTwoNodes(src_out, trans_in, trans_out) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "add node between two nodes failed"); - return FAILED; - } - } else { - auto switchn_op_desc = switchn_node->GetOpDesc(); - GE_CHECK_NOTNULL(switchn_op_desc); - const GeTensorDescPtr &switchn_input = switchn_op_desc->MutableInputDesc(0); - if (ModifyFormatAndShapeForSingleTensor(switchn_input) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "modify format and shape failed"); + std::vector old_shape = input->GetShape().GetDims(); + + input->SetShape(ge::GeShape(dst_shape_dims)); + input->SetFormat(storage_format); + + auto output = op_desc->MutableOutputDesc(index); + GE_CHECK_NOTNULL(output); + output->SetShape(ge::GeShape(dst_shape_dims)); + output->SetFormat(storage_format); + + if (!output->MutableShape().IsUnknownShape()) { + int64_t size = 0; + graphStatus graph_status = TensorUtils::GetTensorMemorySizeInBytes(*output, size); + if (graph_status != ge::GRAPH_SUCCESS) { + GELOGE(graph_status, "GetTensorSizeInBytes failed!"); return FAILED; } - for (uint32_t i = 0; i < switchn_node->GetAllOutDataAnchorsSize(); ++i) { - const GeTensorDescPtr &switchn_output = switchn_op_desc->MutableOutputDesc(i); - GE_CHECK_NOTNULL(switchn_output); - old_format = switchn_output->GetFormat(); - old_shape = switchn_output->GetShape(); - if (ModifyFormatAndShapeForSingleTensor(switchn_output) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "modify format and shape failed"); - return FAILED; - } - NodePtr trans_node = CreateTransdataNode(switchn_output->GetShape(), FORMAT_NC1HWC0, old_shape, old_format, - switchn_output->GetDataType(), node_ptr); - GE_CHECK_NOTNULL(trans_node); - OutDataAnchorPtr src_out = switchn_node->GetOutDataAnchor(i); - InDataAnchorPtr cast_in = trans_node->GetInDataAnchor(0); - OutDataAnchorPtr cast_out = trans_node->GetOutDataAnchor(0); - if (AddTransNodeBetweenTwoNodes(src_out, cast_in, cast_out) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "add node between two nodes failed, src name:%s, cast node name:%s.", - switchn_node->GetName().c_str(), trans_node->GetName().c_str()); - return FAILED; - } - } + ge::TensorUtils::SetSize(*input, size); + ge::TensorUtils::SetSize(*output, size); + + GELOGI( + "Modify Data NetOutput format and shape success, node:%s, index:%d, old_shape:%s, old_Format:%s, " + "new_shape:%s, new_format:%s, new_size:%lu", + op_desc->GetName().c_str(), index, formats::JoinToString(old_shape).c_str(), + ge::TypeUtils::FormatToSerialString(old_format).c_str(), formats::JoinToString(dst_shape_dims).c_str(), + ge::TypeUtils::FormatToSerialString(storage_format).c_str(), size); } + return SUCCESS; } @@ -739,44 +587,6 @@ Status CheckIfDynamicBatchScene(NodePtr &data_node, bool &is_dynamic_batch, Node return SUCCESS; } -Status ProcessDataNode(NodePtr &node_ptr) { - bool set_fp16 = false; - if (!ge::AttrUtils::GetBool(node_ptr->GetOpDesc(), "input_fp16", set_fp16) || !set_fp16) { - return SUCCESS; - } - for (auto const &next_node : node_ptr->GetOutNodes()) { - if (next_node->GetType() == AIPP) { - GELOGE(INTERNAL_ERROR, - "This input node [%s] is linked to aipp, can not be set to fp16," - "please check your atc parma insert_op_conf, input_fp16_nodes.", - node_ptr->GetName().c_str()); - return FAILED; - } - } - GELOGI("input_fp16 is found, the node name is %s.", node_ptr->GetName().c_str()); - bool is_dynamic_batch = false; - NodePtr switchn_node = nullptr; - if (CheckIfDynamicBatchScene(node_ptr, is_dynamic_batch, switchn_node)) { - GELOGE(INTERNAL_ERROR, "CheckIfDynamicBatchScene failed"); - return FAILED; - } - if (ProcessInputFP16(node_ptr, is_dynamic_batch, switchn_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "ProcessInputFP16 failed"); - return FAILED; - } - // check if need to set format - bool set_format = false; - if (!ge::AttrUtils::GetBool(node_ptr->GetOpDesc(), "input_set_nc1hwc0", set_format) || !set_format) { - return SUCCESS; - } - GELOGI("The format of node [%s] should be set NC1HWC0.", node_ptr->GetName().c_str()); - if (ProcessInputNC1HWC0(node_ptr, is_dynamic_batch, switchn_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "ProcessInputNC1HWC0 failed"); - return FAILED; - } - return SUCCESS; -} - bool CheckIfSetOutputType(std::string &output_type, ge::DataType &output_data_type) { if (output_type_str_to_datatype.find(output_type) != output_type_str_to_datatype.end()) { output_data_type = output_type_str_to_datatype[output_type]; @@ -794,221 +604,6 @@ bool CheckOpType(const NodePtr &node, const std::string type) { return false; } -Status ProcessFp16Nc1hwc0Dynamic(const OpDescPtr &src_op_desc, NodePtr &node) { - auto merge_out = src_op_desc->MutableOutputDesc(0); - GE_CHECK_NOTNULL(merge_out); - if (ModifyFormatAndShapeForSingleTensor(merge_out) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "modify format and shape failed"); - return FAILED; - } - for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { - auto merge_in = src_op_desc->MutableInputDesc(i); - GE_CHECK_NOTNULL(merge_in); - ge::Format old_format = merge_in->GetFormat(); - ge::GeShape old_shape = merge_in->GetShape(); - if (ModifyFormatAndShapeForSingleTensor(merge_in) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "modify format and shape failed"); - return FAILED; - } - ge::GeShape new_shape = merge_in->GetShape(); - NodePtr trans_node = CreateTransdataNode(old_shape, old_format, new_shape, FORMAT_NC1HWC0, DT_FLOAT16, node); - GE_CHECK_NOTNULL(trans_node); - const InDataAnchorPtr &dst_in_anchor = node->GetInDataAnchor(i); - GE_CHECK_NOTNULL(dst_in_anchor); - const OutDataAnchorPtr &src_out_anchor = dst_in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(src_out_anchor); - if (GraphUtils::InsertNodeBetweenDataAnchors(src_out_anchor, dst_in_anchor, trans_node) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); - return FAILED; - } - } - return SUCCESS; -} - -Status ProcessNetoutputNodeFp16Nc1hwc0(GeTensorDesc &src_desc, const InDataAnchorPtr &in_anchor, - GeTensorDescPtr &net_output_input_desc, NodePtr &node) { - bool is_dynamic = CheckOpType(node, MERGE); - auto src_op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(src_op_desc); - ge::GeShape src_shape = src_desc.GetShape(); - ge::Format src_format = src_desc.GetFormat(); - ge::DataType src_dtype = src_desc.GetDataType(); - if (src_dtype != DT_FLOAT16) { - if (!is_dynamic) { - auto peer_out = in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out); - NodePtr cast_node = CreateCastOp(src_shape, src_dtype, DT_FLOAT16, src_format, node); - GE_CHECK_NOTNULL(cast_node); - if (GraphUtils::InsertNodeBetweenDataAnchors(peer_out, in_anchor, cast_node) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); - return FAILED; - } - } else { - // Update outputdesc - const GeTensorDescPtr &merge_output = src_op_desc->MutableOutputDesc(0); - GE_CHECK_NOTNULL(merge_output); - merge_output->SetDataType(DT_FLOAT16); - merge_output->SetOriginDataType(DT_FLOAT16); - // Update input - for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { - const GeTensorDescPtr &merge_input = src_op_desc->MutableInputDesc(i); - GE_CHECK_NOTNULL(merge_input); - src_shape = merge_input->GetShape(); - src_format = merge_input->GetFormat(); - src_dtype = merge_input->GetDataType(); - merge_input->SetDataType(DT_FLOAT16); - merge_input->SetOriginDataType(DT_FLOAT16); - const InDataAnchorPtr &dst_in_anchor = node->GetInDataAnchor(i); - const OutDataAnchorPtr &src_out_anchor = dst_in_anchor->GetPeerOutAnchor(); - NodePtr cast_node = CreateCastOp(src_shape, src_dtype, DT_FLOAT16, src_format, node); - if (GraphUtils::InsertNodeBetweenDataAnchors(src_out_anchor, dst_in_anchor, cast_node) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); - return FAILED; - } - } - } - net_output_input_desc->SetDataType(DT_FLOAT16); - net_output_input_desc->SetOriginDataType(DT_FLOAT16); - } - if (src_format == FORMAT_NC1HWC0) { - GELOGI("Format is NC1HWC0, no need to transfer"); - return SUCCESS; - } - std::vector dst_shape_dims; - std::vector src_shape_dims = src_shape.GetDims(); - if (TransferShape2NC1HWC0(src_format, src_shape_dims, DT_FLOAT16, FORMAT_NC1HWC0, dst_shape_dims) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Trans shape failed"); - return FAILED; - } - ge::GeShape dst_shape(dst_shape_dims); - net_output_input_desc->SetFormat(FORMAT_NC1HWC0); - net_output_input_desc->SetOriginFormat(FORMAT_NC1HWC0); - net_output_input_desc->SetShape(dst_shape); - net_output_input_desc->SetOriginShape(dst_shape); - if (!is_dynamic) { - NodePtr trans_node = CreateTransdataNode(src_shape, src_format, dst_shape, FORMAT_NC1HWC0, DT_FLOAT16, node); - GE_CHECK_NOTNULL(trans_node); - auto peer_out_new = in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out_new); - if (GraphUtils::InsertNodeBetweenDataAnchors(peer_out_new, in_anchor, trans_node) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); - return FAILED; - } - } else { - if (ProcessFp16Nc1hwc0Dynamic(src_op_desc, node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "ProcessFp16Nc1hwc0Dynamic failed"); - return FAILED; - } - } - return SUCCESS; -} - -Status ProcessOutputDynamic(const NodePtr &src_node, NodePtr &node, ge::DataType &output_data_type) { - OpDescPtr src_op_desc = src_node->GetOpDesc(); - const GeTensorDescPtr &merge_output = src_op_desc->MutableOutputDesc(0); - GE_CHECK_NOTNULL(merge_output); - merge_output->SetDataType(output_data_type); - merge_output->SetOriginDataType(output_data_type); - // Update input - for (uint32_t i = 0; i < src_node->GetAllInDataAnchorsSize(); ++i) { - const GeTensorDescPtr &merge_input = src_op_desc->MutableInputDesc(i); - GE_CHECK_NOTNULL(merge_input); - ge::GeShape src_shape = merge_input->GetShape(); - ge::Format src_format = merge_input->GetFormat(); - ge::DataType src_dtype = merge_input->GetDataType(); - merge_input->SetDataType(output_data_type); - merge_input->SetOriginDataType(output_data_type); - const InDataAnchorPtr &dst_in_anchor = src_node->GetInDataAnchor(i); - GE_CHECK_NOTNULL(dst_in_anchor); - const OutDataAnchorPtr &src_out_anchor = dst_in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(src_out_anchor); - NodePtr cast_node = CreateCastOp(src_shape, src_dtype, output_data_type, src_format, node); - if (GraphUtils::InsertNodeBetweenDataAnchors(src_out_anchor, dst_in_anchor, cast_node) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); - return FAILED; - } - } - return SUCCESS; -} - -Status ProcessNetoutputNode(NodePtr &node, std::string &output_type) { - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - ge::DataType output_data_type = ge::DT_FLOAT; - bool is_set_output_type = CheckIfSetOutputType(output_type, output_data_type); - - for (const auto &in_anchor : node->GetAllInDataAnchors()) { - auto index = static_cast(in_anchor->GetIdx()); - auto peer_out = in_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(peer_out); - auto src_index = static_cast(peer_out->GetIdx()); - auto src_node = peer_out->GetOwnerNode(); - GE_CHECK_NOTNULL(src_node); - bool is_dynamic = CheckOpType(src_node, MERGE); - - OpDescPtr src_op_desc = src_node->GetOpDesc(); - GE_CHECK_NOTNULL(src_op_desc); - auto net_output_input_desc = op_desc->MutableInputDesc(index); - GE_CHECK_NOTNULL(net_output_input_desc); - - ge::GeShape src_shape = src_op_desc->GetOutputDesc(src_index).GetShape(); - ge::Format src_format = src_op_desc->GetOutputDesc(src_index).GetFormat(); - ge::DataType src_dtype = src_op_desc->GetOutputDesc(src_index).GetDataType(); - // Update datatype - if (is_set_output_type) { - GELOGI("Enter into process output_type schedule"); - if (src_dtype == output_data_type) { - GELOGI("Data type is same ,no need to transfer."); - continue; - } - if (!is_dynamic) { - NodePtr cast_node = CreateCastOp(src_shape, src_dtype, output_data_type, src_format, node); - if (GraphUtils::InsertNodeBetweenDataAnchors(peer_out, in_anchor, cast_node) != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "InsertNodeBetweenDataAnchors failed"); - return FAILED; - } - } else { - // Update outputdesc - if (ProcessOutputDynamic(src_node, node, output_data_type) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "ProcessOutputDynamic failed"); - return FAILED; - } - } - net_output_input_desc->SetDataType(output_data_type); - net_output_input_desc->SetOriginDataType(output_data_type); - continue; - } - // output_node is not set,check if is_output_adjust_hw_layout is set - bool set_fp16_nc1hwc0 = false; - if (!is_dynamic) { - (void)AttrUtils::GetBool(src_op_desc, "output_set_fp16_nc1hwc0", set_fp16_nc1hwc0); - } else { - // need check dynamic scene, graph structure: node->merge->netoutput - const InDataAnchorPtr &merge_input_anchor = src_node->GetInDataAnchor(0); - GE_CHECK_NOTNULL(merge_input_anchor); - const OutDataAnchorPtr &src_out_anchor = merge_input_anchor->GetPeerOutAnchor(); - GE_CHECK_NOTNULL(src_out_anchor); - auto src_merge_node = src_out_anchor->GetOwnerNode(); - GE_CHECK_NOTNULL(src_merge_node); - auto src_merge_node_opdesc = src_merge_node->GetOpDesc(); - (void)AttrUtils::GetBool(src_merge_node_opdesc, "output_set_fp16_nc1hwc0", set_fp16_nc1hwc0); - } - if (set_fp16_nc1hwc0) { - GELOGI("Node [%s] should be set FP16 and NC1HWC0", src_op_desc->GetName().c_str()); - if ((src_format != FORMAT_NCHW) && (src_format != FORMAT_NHWC) && (src_format != FORMAT_NC1HWC0)) { - GELOGE(INTERNAL_ERROR, "Format is not one of NCHW, NHWC, NC1HWC0."); - return FAILED; - } - GeTensorDesc src_desc(src_shape, src_format, src_dtype); - if (ProcessNetoutputNodeFp16Nc1hwc0(src_desc, in_anchor, net_output_input_desc, src_node) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Process netoutput fp16 nc1hwc0."); - return FAILED; - } - } - } - return SUCCESS; -} - Status CheckIfNeedSetNdFormat(const NodePtr &node_ptr) { auto op = node_ptr->GetOpDesc(); GE_CHECK_NOTNULL(op); @@ -1054,7 +649,6 @@ Status ProcessInputFP16DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodeP return SUCCESS; } input->SetDataType(DT_FLOAT16); - input->SetOriginDataType(DT_FLOAT16); int64_t input_shape_size = 0; int64_t output_shape_size = 0; ge::graphStatus input_graph_status = ge::TensorUtils::GetTensorSizeInBytes(*input, input_shape_size); @@ -1067,7 +661,6 @@ Status ProcessInputFP16DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodeP const GeTensorDescPtr &output = op_desc->MutableOutputDesc(0); GE_CHECK_NOTNULL(output); output->SetDataType(DT_FLOAT16); - output->SetOriginDataType(DT_FLOAT16); ge::TensorUtils::SetSize(*output, output_shape_size); if (is_dynamic_batch) { GELOGI("The node [%s] dtype set fp16", switchn_node->GetName().c_str()); @@ -1076,12 +669,10 @@ Status ProcessInputFP16DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, NodeP auto switchn_input = switchn_op_desc->MutableInputDesc(0); GE_CHECK_NOTNULL(switchn_input); switchn_input->SetDataType(DT_FLOAT16); - switchn_input->SetOriginDataType(DT_FLOAT16); for (uint32_t i = 0; i < switchn_node->GetAllOutDataAnchorsSize(); ++i) { const GeTensorDescPtr &switchn_output = switchn_op_desc->MutableOutputDesc(i); GE_CHECK_NOTNULL(switchn_output); switchn_output->SetDataType(DT_FLOAT16); - switchn_output->SetOriginDataType(DT_FLOAT16); } } return SUCCESS; @@ -1100,10 +691,6 @@ Status ProcessInputNC1HWC0DynShape(NodePtr &node_ptr, bool &is_dynamic_batch, No GELOGE(INTERNAL_ERROR, "The format [%s] is unsupported", TypeUtils::FormatToSerialString(old_format).c_str()); return FAILED; } - if (old_format == FORMAT_NC1HWC0) { - GELOGI("No need to transfer format"); - return SUCCESS; - } if (ModifyInputFormatAndShape(node_ptr) != SUCCESS) { GELOGE(INTERNAL_ERROR, "modify format and shape failed"); return FAILED; @@ -1139,7 +726,7 @@ Status ProcessDataNodeDynShape(NodePtr &node_ptr) { } for (auto const &next_node : node_ptr->GetOutNodes()) { if (next_node->GetType() == AIPP) { - ErrorManager::GetInstance().ATCReportErrMessage("E10049", {"opname"}, {node_ptr->GetName()}); + ErrorManager::GetInstance().ATCReportErrMessage("E10034", {"opname"}, {node_ptr->GetName()}); GELOGE(INTERNAL_ERROR, "This input op [%s] is linked to aipp, can not be set to fp16, " "please check your atc parameter --insert_op_conf, --input_fp16_nodes.", @@ -1171,6 +758,42 @@ Status ProcessDataNodeDynShape(NodePtr &node_ptr) { return SUCCESS; } +Status GetStorageFormatAndShape(OpDescPtr &op_desc, const GeTensorDescPtr &tensor_desc_ptr, Format &storage_format, + vector &dst_shape_dims) { + GE_CHECK_NOTNULL(op_desc); + GE_CHECK_NOTNULL(tensor_desc_ptr); + + storage_format = FORMAT_RESERVED; + int64_t format = FORMAT_RESERVED; + dst_shape_dims.clear(); + if (ge::AttrUtils::GetInt(*tensor_desc_ptr, ATTR_NAME_STORAGE_FORMAT, format)) { + storage_format = static_cast(format); + vector storage_shape; + if (ge::AttrUtils::GetListInt(*tensor_desc_ptr, ATTR_NAME_STORAGE_SHAPE, storage_shape)) { + for (auto dim : storage_shape) { + dst_shape_dims.push_back(static_cast(dim)); + } + GELOGI("Update node by storage format, node: [%s], storage_format: [%s], storage_shape:[%s]", + op_desc->GetName().c_str(), TypeUtils::FormatToSerialString(storage_format).c_str(), + formats::JoinToString(storage_shape).c_str()); + } else { + GELOGE(PARAM_INVALID, + "Update node by storage format failed, storage_shape not set. " + "node: [%s], storage_format [%s]", + op_desc->GetName().c_str(), TypeUtils::FormatToSerialString(storage_format).c_str()); + return FAILED; + } + + ge::Format old_format = tensor_desc_ptr->GetFormat(); + auto old_shape = tensor_desc_ptr->GetShape().GetDims(); + if (old_format == storage_format && old_shape == dst_shape_dims) { + GELOGI("Update node by storage format, not changed."); + storage_format = FORMAT_RESERVED; + return SUCCESS; + } + } + return SUCCESS; +} Status ProcessNetoutputNodeFp16Nc1hwc0DynShape(GeTensorDesc &src_desc, GeTensorDescPtr &net_output_input_desc, NodePtr &node) { bool is_dynamic = CheckOpType(node, MERGE); @@ -1180,24 +803,16 @@ Status ProcessNetoutputNodeFp16Nc1hwc0DynShape(GeTensorDesc &src_desc, GeTensorD ge::Format src_format = src_desc.GetFormat(); net_output_input_desc->SetDataType(DT_FLOAT16); - net_output_input_desc->SetOriginDataType(DT_FLOAT16); if (is_dynamic) { auto merge_output = src_op_desc->MutableOutputDesc(0); GE_CHECK_NOTNULL(merge_output); merge_output->SetDataType(DT_FLOAT16); - merge_output->SetOriginDataType(DT_FLOAT16); for (uint32_t i = 0; i < node->GetAllInDataAnchorsSize(); ++i) { auto merge_input = src_op_desc->MutableInputDesc(i); GE_CHECK_NOTNULL(merge_input); merge_input->SetDataType(DT_FLOAT16); - merge_input->SetOriginDataType(DT_FLOAT16); } } - - if (src_format == FORMAT_NC1HWC0) { - GELOGI("Format is NC1HWC0, no need to transfer"); - return SUCCESS; - } std::vector dst_shape_dims; std::vector src_shape_dims = src_shape.GetDims(); if (TransferShape2NC1HWC0(src_format, src_shape_dims, DT_FLOAT16, FORMAT_NC1HWC0, dst_shape_dims) != SUCCESS) { @@ -1291,17 +906,14 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node, std::string &output_type) { if (NeedUpdateOutputByOutputTypeParm(output_type, src_node, src_index, output_data_type)) { GELOGI("Enter into process output_type schedule"); net_output_input_desc->SetDataType(output_data_type); - net_output_input_desc->SetOriginDataType(output_data_type); if (is_dynamic) { auto merge_output = src_op_desc->MutableOutputDesc(0); GE_CHECK_NOTNULL(merge_output); merge_output->SetDataType(output_data_type); - merge_output->SetOriginDataType(output_data_type); for (uint32_t i = 0; i < src_node->GetAllInDataAnchorsSize(); ++i) { auto merge_input = src_op_desc->MutableInputDesc(i); GE_CHECK_NOTNULL(merge_input); merge_input->SetDataType(output_data_type); - merge_input->SetOriginDataType(output_data_type); } } continue; @@ -1337,7 +949,6 @@ Status ProcessNetoutputNodeDynShape(NodePtr &node, std::string &output_type) { } return SUCCESS; } - } // namespace GraphPrepare::GraphPrepare() : compute_graph_(nullptr) {} @@ -1431,6 +1042,8 @@ Status GraphPrepare::Init(const ge::Graph &graph, uint64_t session_id) { if (compute_graph_ != nullptr) { compute_graph_->SetSessionID(session_id); } + session_id_ = session_id; + Status ret = CheckGraph(); if (ret != SUCCESS) { GELOGE(ret, "RunGraph graph check fail, ret:%u", ret); @@ -1442,7 +1055,6 @@ Status GraphPrepare::Init(const ge::Graph &graph, uint64_t session_id) { GELOGE(ret, "RunGraph check ref op fail, ret:%u", ret); return ret; } - return SUCCESS; } @@ -1467,13 +1079,13 @@ Status GraphPrepare::CheckGraph() { } Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &input_name, - const std::unordered_set &ref_nodes) { + const std::set &ref_nodes) { // Acceptable input types should be ref node, variable or Switch operator, which is issued by ME for dynamic - // lossscale and would be optimized in SwitchOpPass. Since ME dont differentiate between RefSwitch and Switch, - // and only issue Switch. - static std::unordered_set acceptable_types = {ge::VARIABLE, ge::VARIABLEV2, ge::VARHANDLEOP, - ge::REFSWITCH, ge::REFMERGE, ge::REFENTER, - ge::REFNEXTITERATION, ge::REFEXIT, ge::SWITCH}; + // lossscale and would be optimized in SwitchToStreamSwitchPass. + // Since ME dont differentiate between RefSwitch and Switch, and only issue Switch. + static std::set acceptable_types = {ge::VARIABLE, ge::VARIABLEV2, ge::VARHANDLEOP, + ge::REFSWITCH, ge::REFMERGE, ge::REFENTER, + ge::REFNEXTITERATION, ge::REFEXIT, ge::SWITCH}; GE_CHECK_NOTNULL(node); const auto &op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); @@ -1499,7 +1111,6 @@ Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &i } } bool is_acceptable = (acceptable_types.find(input_type) != acceptable_types.end()); - if (!is_acceptable) { GELOGE(PARAM_INVALID, "The ref input of ref node %s[%s] must be ref node or variable, but %s[%s]isn't.", node->GetName().c_str(), node->GetType().c_str(), input_op_desc->GetName().c_str(), @@ -1512,7 +1123,7 @@ Status GraphPrepare::CheckRefInputNode(const NodePtr &node, const std::string &i Status GraphPrepare::CheckRefOp() { GE_CHECK_NOTNULL(compute_graph_); - std::unordered_set ref_nodes; + std::set ref_nodes; for (const NodePtr &node : compute_graph_->GetDirectNode()) { if (node == nullptr) { GELOGE(PARAM_INVALID, "param [node] must not be null."); @@ -1524,20 +1135,15 @@ Status GraphPrepare::CheckRefOp() { return PARAM_INVALID; } - auto input_names = op_desc->GetAllInputNames(); + auto input_name_index = op_desc->GetAllInputName(); auto outputs = op_desc->GetAllOutputName(); - std::unordered_set all_output_name; - - for (auto &output : outputs) { - all_output_name.insert(output.first); - } - for (const auto &input_name : input_names) { - if (all_output_name.find(input_name) != all_output_name.end()) { - if (CheckRefInputNode(node, input_name, ref_nodes) != SUCCESS) { + for (const auto &name_index : input_name_index) { + if (op_desc->GetOutputIndexByName(name_index.first) != -1) { + if (CheckRefInputNode(node, name_index.first, ref_nodes) != SUCCESS) { GELOGE(PARAM_INVALID, "CheckRefInputNode failed."); return PARAM_INVALID; } - (void)ref_nodes.insert(node); + (void)ref_nodes.insert(node); // no need to check value } } } @@ -1548,7 +1154,7 @@ Status GraphPrepare::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode) { GELOGI("set rt_context %d, device id:%u.", static_cast(mode), ge::GetContext().DeviceId()); GE_CHK_RT_RET(rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId())); GE_CHK_RT_RET(rtCtxSetCurrent(rt_context)); - RtContextUtil::GetInstance().AddrtContext(rt_context); + RtContextUtil::GetInstance().AddRtContext(session_id_, rt_context); return SUCCESS; } @@ -1566,6 +1172,8 @@ Status GraphPrepare::AdjustDataOpOutput(const NodePtr &node) { int64_t tensor_size = 0; graphStatus graph_status = TensorUtils::GetTensorMemorySizeInBytes(output, tensor_size); if (graph_status != GRAPH_SUCCESS) { + ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, + {"GetTensorMemorySizeInBytes", "opname is " + node->GetName()}); GELOGE(graph_status, "GetTensorMemorySizeInBytes failed!"); return FAILED; } @@ -1599,12 +1207,16 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input) { GeTensorDesc desc(user_input[index].GetTensorDesc()); auto format = desc.GetFormat(); auto origin_format = desc.GetOriginFormat(); - bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); - bool need_check_internal_format = (!options_.is_single_op) && is_internal; + // data maybe internal format [FRACTAL_NZ] at singleop process such as GEMM. + bool need_check_internal_format = (!IsTansDataOpData(input_node)) && (!options_.is_single_op); if (need_check_internal_format) { - GELOGE(PARAM_INVALID, "Input format %s or origin_format %s is not support.", - TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::FormatToSerialString(origin_format).c_str()); - return FAILED; + bool is_internal = TypeUtils::IsInternalFormat(format) || TypeUtils::IsInternalFormat(origin_format); + if (is_internal) { + GELOGE(PARAM_INVALID, "Input format %s or origin_format %s is not support.", + TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::FormatToSerialString(origin_format).c_str()); + return FAILED; + } } auto data_type = desc.GetDataType(); @@ -1623,7 +1235,8 @@ Status GraphPrepare::UpdateInput(const std::vector &user_input) { GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc, size) != GRAPH_SUCCESS, GELOGE(INTERNAL_ERROR, "TensorUtils GetSize failed"); return FAILED); - if ((size != 0) && (shape_size != size)) { + bool size_check = (size != 0 && shape_size != size); + if (size_check) { GELOGE(PARAM_INVALID, "input data size =%ld, shape_size =%ld.", size, shape_size); return FAILED; } @@ -1742,29 +1355,49 @@ Status GraphPrepare::ResourcePairProcess(const std::string &action) { return SUCCESS; } -Status GraphPrepare::OptimizeAfterInfershapeByAtcParams() { - if (options_.train_graph_flag) { - GELOGI("This is train mode, no need to do this schedule."); - return SUCCESS; - } - GE_RETURN_IF_ERROR(InsertNewOpUtil::Instance().UpdateDataNodeByAipp(compute_graph_)); - for (auto &node_ptr : compute_graph_->GetDirectNode()) { +Status GraphPrepare::UpdateDataNetOutputByStorageFormat() { + for (auto &node_ptr : compute_graph_->GetAllNodes()) { GE_CHECK_NOTNULL(node_ptr); - if (CheckIfNeedSetNdFormat(node_ptr) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Set node [%s] format ND failed", node_ptr->GetName().c_str()); - return FAILED; - } if (node_ptr->GetType() == DATA) { - if (ProcessDataNode(node_ptr) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Process data node failed"); + uint32_t index = 0; + auto op_desc = node_ptr->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + const GeTensorDescPtr input = op_desc->MutableInputDesc(index); + Format storage_format = FORMAT_RESERVED; + vector dst_shape_dims; + if (GetStorageFormatAndShape(op_desc, input, storage_format, dst_shape_dims) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get storage format for input failed"); + return FAILED; + } + + if (storage_format == FORMAT_RESERVED) { + continue; + } + + if (ModifyDataNetOutputFormatAndShape(op_desc, index, storage_format, dst_shape_dims) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Modify format and shape for inputfailed"); return FAILED; } } if (node_ptr->GetType() == ge::NETOUTPUT) { - if (ProcessNetoutputNode(node_ptr, options_.output_datatype) != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Process netoutput node failed"); - return FAILED; + auto op_desc = node_ptr->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + for (uint32_t index = 0; index < op_desc->GetOutputsSize(); index++) { + const GeTensorDescPtr output = op_desc->MutableOutputDesc(index); + Format storage_format = FORMAT_RESERVED; + vector dst_shape_dims; + if (GetStorageFormatAndShape(op_desc, output, storage_format, dst_shape_dims) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Get storage format from output failed"); + return FAILED; + } + if (storage_format == FORMAT_RESERVED) { + continue; + } + if (ModifyDataNetOutputFormatAndShape(op_desc, index, storage_format, dst_shape_dims) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Modify format and shape for output failed"); + return FAILED; + } } } } @@ -1908,12 +1541,6 @@ Status GraphPrepare::Preprocess(const std::vector &user_input) { ProcessCCEFormat(); - ret = OptimizeAfterInfershapeByAtcParams(); - if (ret != SUCCESS) { - GELOGE(ret, "Optimize for input if set inputfp16 failed."); - return ret; - } - SaveOriginalGraphToOmModel(); GE_TIMESTAMP_START(OptimizeForPreprocess); @@ -1955,9 +1582,7 @@ Status GraphPrepare::PrepareDynShape(ConstGraphPtr graph, const std::vector(options_.framework_type); const Graph &const_graph = *graph; @@ -1972,7 +1597,6 @@ Status GraphPrepare::PrepareDynShape(ConstGraphPtr graph, const std::vectorInferOriginFormat(); + GE_DUMP(compute_graph_, "after_inferformat"); + if (ret != SUCCESS) { + GELOGE(ret, "Prepare Graph inferformat failed"); + return ret; + } InferShapePass infer_shape_pass; + NamesToPass names_to_passes; names_to_passes.emplace_back("InferShapePass", &infer_shape_pass); + GEPass ge_passes(compute_graph_); ret = ge_passes.Run(names_to_passes); + GE_DUMP(compute_graph_, "after_infershape"); if (ret != SUCCESS) { GELOGE(ret, "Run ge_passes infershape for preprocess failed, ret:%u.", ret); return ret; } + ShapeRefiner::ClearContextMap(); return SUCCESS; } Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &user_input, ge::ComputeGraphPtr &compute_graph, VarAccelerateCtrl &var_acc_ctrl, uint64_t session_id) { - // train graph flag - if (options_.train_graph_flag) { - domi::GetContext().train_flag = true; - } domi::GetContext().type = static_cast(options_.framework_type); if (graph == nullptr) { @@ -2071,7 +1699,7 @@ Status GraphPrepare::Prepare(ConstGraphPtr graph, const std::vector &u } GraphOptimize graph_optimize; - if (!domi::GetContext().train_flag) { + if (!options_.train_graph_flag && !domi::GetContext().train_flag) { GE_DUMP(compute_graph_, "BeforeOriginalGraphForQuantize"); GE_TIMESTAMP_START(OptimizeOriginalGraphForQuantize); ret = graph_optimize.OptimizeOriginalGraphForQuantize(compute_graph_); @@ -2273,6 +1901,7 @@ Status GraphPrepare::InferShapeForPreprocess() { } } } + ShapeRefiner::ClearContextMap(); if (ret != SUCCESS) { GELOGE(ret, "Run ge_passes infershape for preprocess failed, ret:%u.", ret); return ret; @@ -2281,6 +1910,14 @@ Status GraphPrepare::InferShapeForPreprocess() { } Status GraphPrepare::PrepareOptimize() { GELOGI("Start optimize for preprocess."); + // check rw type + GraphOptimize graph_optimize; + bool has_conflict = false; + graph_optimize.CheckRWConflict(compute_graph_, has_conflict); + if (has_conflict) { + GELOGE(GRAPH_PARAM_INVALID, "There has rw conflict.Stop optimize."); + return FAILED; + } PassManager original_graph_passes; // Graph pass try { @@ -2302,10 +1939,10 @@ Status GraphPrepare::PrepareOptimize() { GEPass ge_passes(compute_graph_); NamesToPass names_to_passes; EnterPass enter_pass; - PrintOpPass print_pass; names_to_passes.emplace_back("EnterPass", &enter_pass); CondPass cond_pass; names_to_passes.emplace_back("CondPass", &cond_pass); + PrintOpPass print_pass; if (options_.enable_print_op_pass) { names_to_passes.emplace_back("PrintOpPass", &print_pass); } @@ -2478,7 +2115,9 @@ Status GraphPrepare::OptimizeForPreprocess() { (void)graph_pass.AddPass("OptimizeForPreprocess::PrunePass", new PrunePass); (void)graph_pass.AddPass("OptimizeForPreprocess::NextIterationPass", new NextIterationPass); (void)graph_pass.AddPass("OptimizeForPreprocess::ControlTriggerPass", new ControlTriggerPass); - (void)graph_pass.AddPass("OptimizeForPreprocess::SwitchOpPass", new SwitchOpPass); + (void)graph_pass.AddPass("OptimizeForPreprocess::MergeToStreamMergePass", new MergeToStreamMergePass); + (void)graph_pass.AddPass("OptimizeForPreprocess::SwitchToStreamSwitchPass", new SwitchToStreamSwitchPass); + (void)graph_pass.AddPass("OptimizeForPreprocess::AttachStreamLabelPass", new AttachStreamLabelPass); (void)graph_pass.AddPass("OptimizeForPreprocess::HcclMemcpyPass", new HcclMemcpyPass); GE_IF_BOOL_EXEC(options_.train_graph_flag, (void)graph_pass.AddPass("OptimizeForPreprocess::FlowCtrlPass", new FlowCtrlPass);); @@ -2560,8 +2199,6 @@ Status GraphPrepare::NewOptimizeGraphBeforeSubGraph(VarAccelerateCtrl &var_acc_c GEPass ge_passes_for_shape(compute_graph_); NamesToPass names_to_passes_for_shape; - IdentifyReferencePass identify_reference_pass; - names_to_passes_for_shape.emplace_back("IdentifyReferencePass", &identify_reference_pass); CastRemovePass cast_remove_pass; names_to_passes_for_shape.emplace_back("CastRemovePass", &cast_remove_pass); TransposeTransDataPass transpose_transdata_pass; @@ -2693,6 +2330,12 @@ Status GraphPrepare::CheckAndUpdateInput(const std::vector &user_input return SUCCESS; } Status GraphPrepare::UpdateInputOutputByOptions() { + auto ret = UpdateDataNetOutputByStorageFormat(); + if (ret != SUCCESS) { + GELOGE(ret, "Update format acoording to storage format failed."); + return ret; + } + if (options_.train_graph_flag) { GELOGI("This is train mode, no need to do this schedule."); return SUCCESS; @@ -2736,6 +2379,21 @@ bool GraphPrepare::IsBroadCastOpData(const ge::NodePtr &var_node) { return false; } +bool GraphPrepare::IsTansDataOpData(const ge::NodePtr &var_node) { + for (auto &out_anchor : var_node->GetAllOutDataAnchors()) { + GE_RT_FALSE_CHECK_NOTNULL(out_anchor); + for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { + GE_RT_FALSE_CHECK_NOTNULL(in_anchor); + ge::NodePtr dst_node = in_anchor->GetOwnerNode(); + GE_RT_FALSE_CHECK_NOTNULL(dst_node); + if (dst_node->GetType() == TRANSDATA) { + return true; + } + } + } + return false; +} + bool GraphPrepare::ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, const map> &confirm_ops, ge::NodePtr &use_node) { GE_RT_FALSE_CHECK_NOTNULL(in_anchor); diff --git a/src/ge/graph/preprocess/graph_preprocess.h b/src/ge/graph/preprocess/graph_preprocess.h index b90caa86..343791bd 100644 --- a/src/ge/graph/preprocess/graph_preprocess.h +++ b/src/ge/graph/preprocess/graph_preprocess.h @@ -59,8 +59,7 @@ class GraphPrepare { Status Init(const ge::Graph &graph, uint64_t session_id = 0); Status Preprocess(const std::vector &user_input); Status CheckGraph(); - Status CheckRefInputNode(const NodePtr &node, const std::string &input_name, - const std::unordered_set &ref_nodes); + Status CheckRefInputNode(const NodePtr &node, const std::string &input_name, const std::set &ref_nodes); Status CheckRefOp(); Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode); Status AdjustDataOpOutput(const NodePtr &node); @@ -69,11 +68,11 @@ class GraphPrepare { Status CheckConstOp(); Status VerifyConstOp(const NodePtr &node); Status CheckUserInput(const std::vector &user_input); + Status UpdateDataNetOutputByStorageFormat(); Status OptimizeForPreprocess(); Status PrepareOptimize(); Status InferShapeForPreprocess(); Status TryDoAipp(); - Status OptimizeAfterInfershapeByAtcParams(); Status UpdateVariableFormats(ComputeGraphPtr &graph); Status UpdateVariableFormatsDynShape(ComputeGraphPtr &graph); Status FormatAndShapeProcess(); @@ -88,6 +87,8 @@ class GraphPrepare { Status UpdateInputOutputByOptions(); bool IsBroadCastOpData(const ge::NodePtr &var_node); + bool IsTansDataOpData(const ge::NodePtr &var_node); + void AdjustBroadCastOpData(const ge::NodePtr &var_node); bool IsAssignOpData(const ge::NodePtr &var_node); @@ -104,6 +105,7 @@ class GraphPrepare { ge::ComputeGraphPtr compute_graph_; GraphManagerOptions options_; + uint64_t session_id_ = 0; }; } // namespace ge #endif // GE_GRAPH_PREPROCESS_GRAPH_PREPROCESS_H_ diff --git a/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc b/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc index f35b6d3a..55c7b427 100644 --- a/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc +++ b/src/ge/graph/preprocess/insert_op/ge_aipp_op.cc @@ -24,6 +24,7 @@ #include "common/dynamic_aipp.h" #include "common/ge/ge_util.h" #include "common/util.h" +#include "common/util/error_manager/error_manager.h" #include "external/graph/operator_factory.h" #include "framework/common/debug/ge_log.h" #include "framework/common/ge_inner_error_codes.h" @@ -51,6 +52,16 @@ } \ } while (0) +#define AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(expr, _status, errormsg) \ + do { \ + bool b = (expr); \ + if (!b) { \ + GELOGE(_status, errormsg); \ + ErrorManager::GetInstance().ATCReportErrMessage("E10043", {"reason"}, {errormsg}); \ + return _status; \ + } \ + } while (0) + namespace { const int32_t DEFAULT_MATRIX_R0C0_YUV2RGB = 298; const int32_t DEFAULT_MATRIX_R0C1_YUV2RGB = 0; @@ -411,86 +422,87 @@ Status AippOp::SetDefaultParams() { Status AippOp::ValidateParams() { GE_CHECK_NOTNULL(aipp_params_); - GE_CHK_BOOL_RET_STATUS(aipp_params_->aipp_mode() != domi::AippOpParams::undefined, PARAM_INVALID, - "when insert AIPP op, aipp_mode must be configured as static or dynamic "); - - GE_CHK_BOOL_RET_STATUS(aipp_params_->var_reci_chn_0_size() <= 1, PARAM_INVALID, - "The parameter var_reci_chn_0 can not be configed repeatedly"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->var_reci_chn_1_size() <= 1, PARAM_INVALID, - "The parameter var_reci_chn_1 can not be configed repeatedly"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->var_reci_chn_2_size() <= 1, PARAM_INVALID, - "The parameter var_reci_chn_2 can not be configed repeatedly"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->var_reci_chn_3_size() <= 1, PARAM_INVALID, - "The parameter var_reci_chn_3 can not be configed repeatedly"); - - GE_CHK_BOOL_RET_STATUS(aipp_params_->matrix_r0c0_size() <= 1, PARAM_INVALID, - "The parameter matrix_r0c0 can not be configed repeatedly"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->matrix_r0c1_size() <= 1, PARAM_INVALID, - "The parameter matrix_r0c1 can not be configed repeatedly"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->matrix_r0c2_size() <= 1, PARAM_INVALID, - "The parameter matrix_r0c2 can not be configed repeatedly"); - - GE_CHK_BOOL_RET_STATUS(aipp_params_->matrix_r1c0_size() <= 1, PARAM_INVALID, - "The parameter matrix_r1c0 can not be configed repeatedly"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->matrix_r1c1_size() <= 1, PARAM_INVALID, - "The parameter matrix_r1c1 can not be configed repeatedly"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->matrix_r1c2_size() <= 1, PARAM_INVALID, - "The parameter matrix_r1c2 can not be configed repeatedly"); - - GE_CHK_BOOL_RET_STATUS(aipp_params_->matrix_r2c0_size() <= 1, PARAM_INVALID, - "The parameter matrix_r2c0 can not be configed repeatedly"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->matrix_r2c1_size() <= 1, PARAM_INVALID, - "The parameter matrix_r2c1 can not be configed repeatedly"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->matrix_r2c2_size() <= 1, PARAM_INVALID, - "The parameter matrix_r2c2 can not be configed repeatedly"); - - GE_CHK_BOOL_RET_STATUS(aipp_params_->output_bias_0_size() <= 1, PARAM_INVALID, - "The parameter output_bias_0 can not be configed repeatedly"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->output_bias_1_size() <= 1, PARAM_INVALID, - "The parameter output_bias_1 can not be configed repeatedly"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->output_bias_2_size() <= 1, PARAM_INVALID, - "The parameter output_bias_2 can not be configed repeatedly"); - - GE_CHK_BOOL_RET_STATUS(aipp_params_->input_bias_0_size() <= 1, PARAM_INVALID, - "The parameter input_bias_0 can not be configed repeatedly"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->input_bias_1_size() <= 1, PARAM_INVALID, - "The parameter input_bias_1 can not be configed repeatedly"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->input_bias_2_size() <= 1, PARAM_INVALID, - "The parameter input_bias_2 can not be configed repeatedly"); - - GE_CHK_BOOL_RET_STATUS(aipp_params_->input_edge_idx_size() <= 1, PARAM_INVALID, - "The parameter input_edge_idx can not be configed repeatedly"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->aipp_mode() != domi::AippOpParams::undefined, PARAM_INVALID, + "When insert AIPP op, aipp_mode must be configured as static or dynamic "); + + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->var_reci_chn_0_size() <= 1, PARAM_INVALID, + "The parameter var_reci_chn_0 can not be configed repeatedly"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->var_reci_chn_1_size() <= 1, PARAM_INVALID, + "The parameter var_reci_chn_1 can not be configed repeatedly"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->var_reci_chn_2_size() <= 1, PARAM_INVALID, + "The parameter var_reci_chn_2 can not be configed repeatedly"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->var_reci_chn_3_size() <= 1, PARAM_INVALID, + "The parameter var_reci_chn_3 can not be configed repeatedly"); + + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r0c0_size() <= 1, PARAM_INVALID, + "The parameter matrix_r0c0 can not be configed repeatedly"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r0c1_size() <= 1, PARAM_INVALID, + "The parameter matrix_r0c1 can not be configed repeatedly"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r0c2_size() <= 1, PARAM_INVALID, + "The parameter matrix_r0c2 can not be configed repeatedly"); + + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r1c0_size() <= 1, PARAM_INVALID, + "The parameter matrix_r1c0 can not be configed repeatedly"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r1c1_size() <= 1, PARAM_INVALID, + "The parameter matrix_r1c1 can not be configed repeatedly"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r1c2_size() <= 1, PARAM_INVALID, + "The parameter matrix_r1c2 can not be configed repeatedly"); + + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r2c0_size() <= 1, PARAM_INVALID, + "The parameter matrix_r2c0 can not be configed repeatedly"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r2c1_size() <= 1, PARAM_INVALID, + "The parameter matrix_r2c1 can not be configed repeatedly"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->matrix_r2c2_size() <= 1, PARAM_INVALID, + "The parameter matrix_r2c2 can not be configed repeatedly"); + + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->output_bias_0_size() <= 1, PARAM_INVALID, + "The parameter output_bias_0 can not be configed repeatedly"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->output_bias_1_size() <= 1, PARAM_INVALID, + "The parameter output_bias_1 can not be configed repeatedly"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->output_bias_2_size() <= 1, PARAM_INVALID, + "The parameter output_bias_2 can not be configed repeatedly"); + + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->input_bias_0_size() <= 1, PARAM_INVALID, + "The parameter input_bias_0 can not be configed repeatedly"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->input_bias_1_size() <= 1, PARAM_INVALID, + "The parameter input_bias_1 can not be configed repeatedly"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->input_bias_2_size() <= 1, PARAM_INVALID, + "The parameter input_bias_2 can not be configed repeatedly"); + + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->input_edge_idx_size() <= 1, PARAM_INVALID, + "The parameter input_edge_idx can not be configed repeatedly"); const domi::AippOpParams::AippMode aipp_mode = aipp_params_->aipp_mode(); if (aipp_mode == domi::AippOpParams::dynamic) { - GE_CHK_BOOL_RET_STATUS(aipp_params_->max_src_image_size() > 0, PARAM_INVALID, - "for dynamic AIPP params, max_src_image_size must greater than 0"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG( + aipp_params_->max_src_image_size() > 0, PARAM_INVALID, + "For dynamic AIPP params, max_src_image_size must be set which number should be greater than 0"); } else { - GE_CHK_BOOL_RET_STATUS(aipp_params_->input_format() != domi::AippOpParams::UNDEFINED, PARAM_INVALID, - "Input format of AIPP conf is undefined"); - - GE_CHK_BOOL_RET_STATUS(aipp_params_->src_image_size_w() >= 0, PARAM_INVALID, - "src_image_size_w must not be configed smaller than 0"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->src_image_size_h() >= 0, PARAM_INVALID, - "src_image_size_h must not be configed smaller than 0"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->load_start_pos_w() >= 0, PARAM_INVALID, - "load_start_pos_w must not be configed smaller than 0"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->load_start_pos_h() >= 0, PARAM_INVALID, - "load_start_pos_h must not be configed smaller than 0"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->crop_size_w() >= 0, PARAM_INVALID, - "crop_size_w must not be configed smaller than 0"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->resize_output_w() >= 0, PARAM_INVALID, - "resize_output_w must not be configed smaller than 0"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->resize_output_h() >= 0, PARAM_INVALID, - "resize_output_h must not be configed smaller than 0"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->left_padding_size() >= 0, PARAM_INVALID, - "left_padding_size must not be configed smaller than 0"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->right_padding_size() >= 0, PARAM_INVALID, - "right_padding_size must not be configed smaller than 0"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->top_padding_size() >= 0, PARAM_INVALID, - "top_padding_size must not be configed smaller than 0"); - GE_CHK_BOOL_RET_STATUS(aipp_params_->bottom_padding_size() >= 0, PARAM_INVALID, - "bottom_padding_size must not be configed smaller than 0"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->input_format() != domi::AippOpParams::UNDEFINED, PARAM_INVALID, + "Input format of AIPP conf is undefined"); + + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->src_image_size_w() >= 0, PARAM_INVALID, + "Src_image_size_w must not be configed smaller than 0"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->src_image_size_h() >= 0, PARAM_INVALID, + "Src_image_size_h must not be configed smaller than 0"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->load_start_pos_w() >= 0, PARAM_INVALID, + "Load_start_pos_w must not be configed smaller than 0"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->load_start_pos_h() >= 0, PARAM_INVALID, + "Load_start_pos_h must not be configed smaller than 0"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->crop_size_w() >= 0, PARAM_INVALID, + "Crop_size_w must not be configed smaller than 0"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->resize_output_w() >= 0, PARAM_INVALID, + "Resize_output_w must not be configed smaller than 0"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->resize_output_h() >= 0, PARAM_INVALID, + "Resize_output_h must not be configed smaller than 0"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->left_padding_size() >= 0, PARAM_INVALID, + "Left_padding_size must not be configed smaller than 0"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->right_padding_size() >= 0, PARAM_INVALID, + "Right_padding_size must not be configed smaller than 0"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->top_padding_size() >= 0, PARAM_INVALID, + "Top_padding_size must not be configed smaller than 0"); + AIPP_RETURN_STATUS_AND_REPROT_ERRORMSG(aipp_params_->bottom_padding_size() >= 0, PARAM_INVALID, + "Bottom_padding_size must not be configed smaller than 0"); } return SUCCESS; diff --git a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc index 5fe19869..8bb0c6c4 100644 --- a/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc +++ b/src/ge/graph/preprocess/insert_op/util_insert_aipp_op.cc @@ -45,7 +45,7 @@ static void ConvertShape2Nhwc(Format &format, vector &shape_vec) { return; } if (format != FORMAT_NCHW) { - GELOGW("The format is not NCHW, current format is %s", TypeUtils::FormatToSerialString(format).c_str()); + GELOGW("The format is not NCHW, current format is %s.", TypeUtils::FormatToSerialString(format).c_str()); return; } vector shape_vec_tmp; @@ -245,7 +245,6 @@ Status InsertNewOpUtil::UpdatePrevNodeByAipp(NodePtr &node, std::set &s GELOGE(FAILED, "Can not get size from aipp [%s]", aipp_op_desc->GetName().c_str()); return FAILED; } - // Save the input size of aipp node, which will be used in dumping aipp node or fused aipp node (void)AttrUtils::SetInt(aipp_input, ATTR_NAME_INPUT_ORIGIN_SIZE, size); auto in_data_anchor = node->GetInDataAnchor(0); diff --git a/src/ge/graph/preprocess/multi_batch_copy_graph.cc b/src/ge/graph/preprocess/multi_batch_copy_graph.cc index fbe935ec..d1e9fe62 100644 --- a/src/ge/graph/preprocess/multi_batch_copy_graph.cc +++ b/src/ge/graph/preprocess/multi_batch_copy_graph.cc @@ -35,6 +35,9 @@ #include "graph/utils/graph_utils.h" #include "graph/utils/node_utils.h" +using std::string; +using std::vector; + namespace ge { namespace multibatch { namespace { @@ -45,6 +48,7 @@ const int kDataOutIndex = 0; const int kDataInIndex = 0; const int kMergeDataOutIndex = 0; const int kStaticOutput = -1; +const int kDecimal = 10; const size_t kMaxShapesCount = 100; const size_t kMinShapesCount = 2; @@ -126,8 +130,12 @@ Status CalcShape(const std::vector &batch_shape, GeShape &data_shape) { for (size_t i = 0; i < data_shape.GetDimNum(); ++i) { if (data_shape.GetDim(i) < 0) { if (batch_shape_index >= batch_shape.size()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19012", {"function", "reason"}, + {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) + + " does not match the data shape " + data_shape.ToString()}); GELOGE(PARAM_INVALID, - "Failed to calc tensor shape, the batch shape count %zu, doees not match the data shape %s", + "Failed to calc tensor shape, the batch shape count %zu, does not match the data shape %s", batch_shape.size(), data_shape.ToString().c_str()); return PARAM_INVALID; } @@ -135,6 +143,10 @@ Status CalcShape(const std::vector &batch_shape, GeShape &data_shape) { } } if (batch_shape_index != batch_shape.size()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E19012", {"function", "reason"}, + {"CalcShape", "the batch shape count " + std::to_string(batch_shape.size()) + " does not match the data shape " + + data_shape.ToString()}); GELOGE(PARAM_INVALID, "Failed to calc tensor shape, the batch shape count %zu, does not match the data shape %s", batch_shape.size(), data_shape.ToString().c_str()); return PARAM_INVALID; @@ -199,9 +211,9 @@ Status CheckDataShape(const std::vector &nodes) { } } if (unknown_shape_count == 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10055"); + ErrorManager::GetInstance().ATCReportErrMessage("E10040"); GELOGE(PARAM_INVALID, - "Need unknow shape data when user set --dynamic_batch_size or --dynamic_image_size, please check."); + "Need unknow shape data when user set --dynamic_batch_size, --dynamic_image_size or --dynamic_dims"); return PARAM_INVALID; } @@ -279,6 +291,8 @@ Status MultiBatchGraphCopyer::CreateNewNodes() { case kNodeOutBatchBranch: ret = InsertMergeForEdgeNode(node); break; + case kNodeNotSupportNode: + break; default: GELOGE(INTERNAL_ERROR, "Unexpected status %d on node %s", static_cast(branch_status), node->GetName().c_str()); @@ -291,7 +305,13 @@ Status MultiBatchGraphCopyer::CreateNewNodes() { } return SUCCESS; } + NodeStatus MultiBatchGraphCopyer::GetNodeStatus(const NodePtr &node) { + // node with subgraph is not supported + if (!(node->GetOpDesc()->GetSubgraphInstanceNames().empty())) { + return kNodeNotSupportNode; + } + if (node->GetType() == NETOUTPUT) { return kNodeOutBatchBranch; } @@ -305,6 +325,7 @@ NodeStatus MultiBatchGraphCopyer::GetNodeStatus(const NodePtr &node) { } return kNodeOutBatchBranch; } + NodePtr MultiBatchGraphCopyer::InsertMergeNode(const NodePtr &node, int index) { if (index < 0) { // the merge node must has data inputs, if origin connection is a control @@ -477,38 +498,38 @@ Status MultiBatchGraphCopyer::CheckArguments() { return PARAM_INVALID; } if (shapes_.size() < kMinShapesCount) { - ErrorManager::GetInstance().ATCReportErrMessage("E10050", {"shapesize", "minshapesize"}, - {std::to_string(shapes_.size()), std::to_string(kMinShapesCount)}); + ErrorManager::GetInstance().ATCReportErrMessage( + "E10035", {"shapesize", "minshapesize"}, {std::to_string(shapes_.size()), std::to_string(kMinShapesCount - 1)}); GELOGE(PARAM_INVALID, - "Input parameter[--dynamic_batch_size or --dynamic_image_size]'s " + "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims]'s " "value size [%zu] must be greater than [%zu].", - shapes_.size(), kMinShapesCount); + shapes_.size(), kMinShapesCount - 1); return PARAM_INVALID; } if (shapes_.size() > kMaxShapesCount) { - ErrorManager::GetInstance().ATCReportErrMessage("E10051", {"shapesize", "maxshapesize"}, - {std::to_string(shapes_.size()), std::to_string(kMaxShapesCount)}); + ErrorManager::GetInstance().ATCReportErrMessage( + "E10036", {"shapesize", "maxshapesize"}, {std::to_string(shapes_.size()), std::to_string(kMaxShapesCount + 1)}); GELOGE(PARAM_INVALID, - "Input parameter[--dynamic_batch_size or --dynamic_image_size]'s " + "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims]'s " "value size [%zu] must be less than [%zu].", - shapes_.size(), kMaxShapesCount); + shapes_.size(), kMaxShapesCount + 1); return PARAM_INVALID; } std::set> shapes_set; size_t shape_size = shapes_.at(0).size(); for (auto &shape : shapes_) { if (shape_size != shape.size()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10052", {"shapesize1", "shapesize2"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10037", {"shapesize1", "shapesize2"}, {std::to_string(shape_size), std::to_string(shape.size())}); GELOGE(PARAM_INVALID, - "Input parameter[--dynamic_batch_size or --dynamic_image_size]'s " + "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims]'s " "value size must be same, first group's size is %zu and another's is %zu.", shape_size, shape.size()); return PARAM_INVALID; } for (auto dim : shape) { if (dim <= 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10053", {"dim"}, {std::to_string(dim)}); + ErrorManager::GetInstance().ATCReportErrMessage("E10038", {"dim"}, {std::to_string(dim)}); GELOGE(PARAM_INVALID, "Invalid dim %ld, all dims must be greater than 0", dim); return PARAM_INVALID; } @@ -516,9 +537,9 @@ Status MultiBatchGraphCopyer::CheckArguments() { shapes_set.insert(shape); } if (shapes_set.size() != shapes_.size()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10054"); + ErrorManager::GetInstance().ATCReportErrMessage("E10039"); GELOGE(PARAM_INVALID, - "Input parameter[--dynamic_batch_size or --dynamic_image_size] exist duplicate shapes, please check"); + "Input parameter[--dynamic_batch_size, --dynamic_image_size or --dynamic_dims] exist duplicate shapes."); return PARAM_INVALID; } return SUCCESS; @@ -673,6 +694,10 @@ Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { GELOGE(INTERNAL_ERROR, "Failed to add attr value on output %zu tensor", i); return INTERNAL_ERROR; } + if (!AttrUtils::SetListInt(tensor, ATTR_NAME_COMBINED_DYNAMIC_DIMS, shape.GetDims())) { + GELOGE(INTERNAL_ERROR, "Failed to add attr ATTR_NAME_COMBINED_DYNAMIC_DIMS on output %zu tensor", i); + return INTERNAL_ERROR; + } if (switchn_desc->AddOutputDesc("output" + std::to_string(i), tensor) != GRAPH_SUCCESS) { GELOGE(GRAPH_FAILED, "Opdesc AddOutputDesc failed"); return GRAPH_FAILED; @@ -688,6 +713,10 @@ Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { GELOGE(INTERNAL_ERROR, "Failed to add switchn attr on data node %s", data->GetName().c_str()); return INTERNAL_ERROR; } + if (StampDynamicTypeForSwitchN(switchn_desc) != SUCCESS) { + GELOGE(INTERNAL_ERROR, "Failed to add dynamic type attr on switchn node %s", switchn_desc->GetName().c_str()); + return INTERNAL_ERROR; + } auto switchn = graph_->AddNode(switchn_desc); if (switchn == nullptr) { @@ -697,6 +726,26 @@ Status MultiBatchGraphCopyer::InsertSwitchNForData(const NodePtr &data) { data_nodes_to_switchn_[data.get()] = switchn; return SUCCESS; } + +Status MultiBatchGraphCopyer::StampDynamicTypeForSwitchN(OpDescPtr &switchn_desc) { + GE_CHECK_NOTNULL(switchn_desc); + int32_t dynamic_type = static_cast(FIXED); + if (!domi::GetContext().dynamic_batch_size.empty()) { + dynamic_type = static_cast(DYNAMIC_BATCH); + } + if (!domi::GetContext().dynamic_image_size.empty()) { + dynamic_type = static_cast(DYNAMIC_IMAGE); + } + if (!domi::GetContext().dynamic_dims.empty()) { + dynamic_type = static_cast(DYNAMIC_DIMS); + } + if (!AttrUtils::SetInt(switchn_desc, ATTR_DYNAMIC_TYPE, dynamic_type)) { + GELOGE(INTERNAL_ERROR, "Failed to add dynamic type attr of switchn node %s", switchn_desc->GetName().c_str()); + return INTERNAL_ERROR; + } + return SUCCESS; +} + Status MultiBatchGraphCopyer::InsertMergeForEdgeNode(const NodePtr &node) { for (auto &in_data_anchor : node->GetAllInDataAnchors()) { auto src_out_anchor = in_data_anchor->GetPeerOutAnchor(); @@ -896,7 +945,6 @@ Status MultiBatchGraphCopyer::LinkToNodeOutBranch(const NodePtr &node) { } Status ProcessMultiBatch(ComputeGraphPtr &graph) { - const int kDecimal = 10; std::vector> shapes; if (!domi::GetContext().dynamic_batch_size.empty()) { GELOGD("Found dynamic batch option, value %s", domi::GetContext().dynamic_batch_size.c_str()); @@ -909,25 +957,25 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { GELOGI("Found dynamic batch, shape %s", formats::JoinToString(*shapes.rbegin()).c_str()); } } + if (!domi::GetContext().dynamic_image_size.empty()) { GELOGD("Found dynamic image size option, value %s", domi::GetContext().dynamic_image_size.c_str()); - std::vector shape_strs = ge::StringUtils::Split(domi::GetContext().dynamic_image_size, ';'); - for (const auto &shape_str : shape_strs) { - if (shape_str.empty()) { - continue; - } - std::vector shape; - std::vector dims = ge::StringUtils::Split(shape_str, ','); - for (const auto &dim : dims) { - if (dim.empty()) { - continue; - } - shape.emplace_back(std::strtol(dim.c_str(), nullptr, kDecimal)); - } - shapes.emplace_back(shape); + ParseDynamicSize(domi::GetContext().dynamic_image_size, shapes); + + for (const auto &shape : shapes) { GELOGI("Found dynamic image size, shape %s", formats::JoinToString(shape).c_str()); } } + + if (!domi::GetContext().dynamic_dims.empty()) { + GELOGD("Found dynamic dims option, value %s", domi::GetContext().dynamic_dims.c_str()); + ParseDynamicSize(domi::GetContext().dynamic_dims, shapes); + + for (const auto &shape : shapes) { + GELOGI("Found dynamic dims, shape %s", formats::JoinToString(shape).c_str()); + } + } + if (shapes.empty()) { GELOGD("There is no multi-batch options, no need to process multi-batch copy"); return SUCCESS; @@ -941,6 +989,26 @@ Status ProcessMultiBatch(ComputeGraphPtr &graph) { return copyer.CopyGraph(); } +void ParseDynamicSize(string dynamic_size, vector> &shapes) { + std::vector shape_strs = ge::StringUtils::Split(dynamic_size, ';'); + for (const auto &shape_str : shape_strs) { + if (shape_str.empty()) { + continue; + } + std::vector shape; + std::vector dims = ge::StringUtils::Split(shape_str, ','); + for (const auto &dim : dims) { + if (dim.empty()) { + continue; + } + shape.emplace_back(std::strtol(dim.c_str(), nullptr, kDecimal)); + } + if (!shape.empty()) { + shapes.emplace_back(shape); + } + } +} + Status GetDynamicOutputShape(ComputeGraphPtr &graph) { GELOGI("Start to get dynamic output dynamic batch shape msg"); std::vector dynamic_output_dims; diff --git a/src/ge/graph/preprocess/multi_batch_copy_graph.h b/src/ge/graph/preprocess/multi_batch_copy_graph.h index 2500645f..7e317cb0 100644 --- a/src/ge/graph/preprocess/multi_batch_copy_graph.h +++ b/src/ge/graph/preprocess/multi_batch_copy_graph.h @@ -27,12 +27,15 @@ namespace ge { namespace multibatch { Status ProcessMultiBatch(ComputeGraphPtr &graph); +void ParseDynamicSize(std::string dynamic_size, std::vector> &shapes); + Status GetDynamicOutputShape(ComputeGraphPtr &graph); enum NodeStatus { kNodeInBatchBranch, kNodeOutBatchBranch, kNodeStartNode, + kNodeNotSupportNode, }; class MultiBatchGraphCopyer { @@ -53,6 +56,7 @@ class MultiBatchGraphCopyer { NodePtr InsertShapeDataNode(); Status InsertSwitchNForData(const NodePtr &data); + Status StampDynamicTypeForSwitchN(OpDescPtr &switchn_desc); Status UpdateMaxShapeToData(const NodePtr &data); Status InsertMergeForEdgeNode(const NodePtr &node); diff --git a/src/ge/host_aicpu_engine/common/constant/constant.h b/src/ge/host_aicpu_engine/common/constant/constant.h new file mode 100644 index 00000000..998dc7eb --- /dev/null +++ b/src/ge/host_aicpu_engine/common/constant/constant.h @@ -0,0 +1,30 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_AICPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ +#define GE_HOST_AICPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ + +#include + +namespace ge { +namespace host_aicpu { +// engine name +const char kHostAiCpuEngineName[] = "DNN_VM_HOST_AICPU"; +const char kHostAiCpuOpKernelLibName[] = "DNN_VM_HOST_AICPU_OP_STORE"; +} // namespace host_aicpu +} // namespace ge + +#endif // GE_HOST_AICPU_ENGINE_COMMON_CONSTANT_CONSTANT_H_ diff --git a/src/ge/host_aicpu_engine/engine/host_aicpu_engine.cc b/src/ge/host_aicpu_engine/engine/host_aicpu_engine.cc new file mode 100644 index 00000000..12ec5ede --- /dev/null +++ b/src/ge/host_aicpu_engine/engine/host_aicpu_engine.cc @@ -0,0 +1,74 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "host_aicpu_engine/engine/host_aicpu_engine.h" +#include +#include +#include +#include "framework/common/debug/ge_log.h" +#include "common/ge/ge_util.h" +#include "host_aicpu_engine/common/constant/constant.h" +#include "host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.h" + +namespace ge { +namespace host_aicpu { +HostAiCpuEngine &HostAiCpuEngine::Instance() { + static HostAiCpuEngine instance; + return instance; +} + +Status HostAiCpuEngine::Initialize(const std::map &options) { + if (ops_kernel_store_ == nullptr) { + ops_kernel_store_ = MakeShared(); + if (ops_kernel_store_ == nullptr) { + GELOGE(FAILED, "Make HostAiCpuOpsKernelInfoStore failed."); + return FAILED; + } + } + return SUCCESS; +} + +void HostAiCpuEngine::GetOpsKernelInfoStores(std::map &ops_kernel_map) { + if (ops_kernel_store_ != nullptr) { + // add buildin opsKernel to opsKernelInfoMap + ops_kernel_map[kHostAiCpuOpKernelLibName] = ops_kernel_store_; + } +} + +void HostAiCpuEngine::GetGraphOptimizerObjs(std::map &) { + // no optimizer for host aicpu engine +} + +Status HostAiCpuEngine::Finalize() { + ops_kernel_store_ = nullptr; + return SUCCESS; +} +} // namespace host_aicpu +} // namespace ge + +ge::Status Initialize(const std::map &options) { + return ge::host_aicpu::HostAiCpuEngine::Instance().Initialize(options); +} + +void GetOpsKernelInfoStores(std::map &ops_kernel_map) { + ge::host_aicpu::HostAiCpuEngine::Instance().GetOpsKernelInfoStores(ops_kernel_map); +} + +void GetGraphOptimizerObjs(std::map &graph_optimizers) { + ge::host_aicpu::HostAiCpuEngine::Instance().GetGraphOptimizerObjs(graph_optimizers); +} + +ge::Status Finalize() { return ge::host_aicpu::HostAiCpuEngine::Instance().Finalize(); } diff --git a/src/ge/host_aicpu_engine/engine/host_aicpu_engine.h b/src/ge/host_aicpu_engine/engine/host_aicpu_engine.h new file mode 100644 index 00000000..f8ad71b1 --- /dev/null +++ b/src/ge/host_aicpu_engine/engine/host_aicpu_engine.h @@ -0,0 +1,111 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_AICPU_ENGINE_ENGINE_HOST_AICPU_ENGINE_H_ +#define GE_HOST_AICPU_ENGINE_ENGINE_HOST_AICPU_ENGINE_H_ + +#include +#include +#include +#include "common/opskernel/ops_kernel_info_store.h" +#include "common/optimizer/graph_optimizer.h" + +using OpsKernelInfoStorePtr = std::shared_ptr; +using GraphOptimizerPtr = std::shared_ptr; + +namespace ge { +namespace host_aicpu { +/** + * host aicpu engine. + * Used for the ops which executes on host. + */ +class HostAiCpuEngine { + public: + /** + * get HostAiCpuEngine instance. + * @return HostAiCpuEngine instance. + */ + static HostAiCpuEngine &Instance(); + + virtual ~HostAiCpuEngine() = default; + + /** + * When Ge start, GE will invoke this interface + * @return The status whether initialize successfully + */ + Status Initialize(const std::map &options); + + /** + * After the initialize, GE will invoke this interface + * to get the Ops kernel Store. + * @param ops_kernel_map The host aicpu's ops kernel info + */ + void GetOpsKernelInfoStores(std::map &ops_kernel_map); + + /** + * After the initialize, GE will invoke this interface + * to get the Graph Optimizer. + * @param graph_optimizers The host aicpu's Graph Optimizer objs + */ + void GetGraphOptimizerObjs(std::map &graph_optimizers); + + /** + * When the graph finished, GE will invoke this interface + * @return The status whether initialize successfully + */ + Status Finalize(); + + HostAiCpuEngine(const HostAiCpuEngine &HostAiCpuEngine) = delete; + HostAiCpuEngine(const HostAiCpuEngine &&HostAiCpuEngine) = delete; + HostAiCpuEngine &operator=(const HostAiCpuEngine &HostAiCpuEngine) = delete; + HostAiCpuEngine &operator=(HostAiCpuEngine &&HostAiCpuEngine) = delete; + + private: + HostAiCpuEngine() = default; + + OpsKernelInfoStorePtr ops_kernel_store_ = nullptr; +}; +} // namespace host_aicpu +} // namespace ge + +extern "C" { + +/** + * When Ge start, GE will invoke this interface + * @return The status whether initialize successfully + */ +ge::Status Initialize(const map &options); + +/** + * After the initialize, GE will invoke this interface to get the Ops kernel Store + * @param ops_kernel_map The host aicpu's ops kernel info + */ +void GetOpsKernelInfoStores(std::map &ops_kernel_map); + +/** + * After the initialize, GE will invoke this interface to get the Graph Optimizer + * @param graph_optimizers The host aicpu's Graph Optimizer objs + */ +void GetGraphOptimizerObjs(std::map &graph_optimizers); + +/** + * When the graph finished, GE will invoke this interface + * @return The status whether initialize successfully + */ +ge::Status Finalize(); +} + +#endif // GE_HOST_AICPU_ENGINE_ENGINE_HOST_AICPU_ENGINE_H_ diff --git a/src/ge/host_aicpu_engine/module.mk b/src/ge/host_aicpu_engine/module.mk new file mode 100644 index 00000000..48dd6a87 --- /dev/null +++ b/src/ge/host_aicpu_engine/module.mk @@ -0,0 +1,57 @@ +LOCAL_PATH := $(call my-dir) + + +local_lib_src_files := engine/host_aicpu_engine.cc \ + ops_kernel_store/host_aicpu_ops_kernel_info.cc \ + ops_kernel_store/op/op_factory.cc \ + ops_kernel_store/op/host_op.cc \ + +local_lib_inc_path := proto/task.proto \ + ${LOCAL_PATH} \ + ${TOPDIR}inc \ + ${TOPDIR}inc/external \ + ${TOPDIR}inc/external/graph \ + $(TOPDIR)libc_sec/include \ + ${TOPDIR}third_party/protobuf/include \ + ${TOPDIR}inc/framework \ + $(TOPDIR)framework/domi \ + +#compiler for host +include $(CLEAR_VARS) +LOCAL_MODULE := libhost_aicpu_engine +LOCAL_CFLAGS += -Werror +LOCAL_CFLAGS += -std=c++11 +LOCAL_LDFLAGS := + +LOCAL_STATIC_LIBRARIES := +LOCAL_SHARED_LIBRARIES := libprotobuf \ + libc_sec \ + libslog \ + libgraph \ + libregister \ + libruntime + +LOCAL_SRC_FILES := $(local_lib_src_files) +LOCAL_C_INCLUDES := $(local_lib_inc_path) + +include ${BUILD_HOST_SHARED_LIBRARY} + +#compiler for atc +include $(CLEAR_VARS) +LOCAL_MODULE := atclib/libhost_aicpu_engine +LOCAL_CFLAGS += -Werror +LOCAL_CFLAGS += -std=c++11 +LOCAL_LDFLAGS := + +LOCAL_STATIC_LIBRARIES := +LOCAL_SHARED_LIBRARIES := libprotobuf \ + libc_sec \ + libslog \ + libgraph \ + libregister \ + libruntime_compile + +LOCAL_SRC_FILES := $(local_lib_src_files) +LOCAL_C_INCLUDES := $(local_lib_inc_path) + +include ${BUILD_HOST_SHARED_LIBRARY} diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.cc b/src/ge/host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.cc new file mode 100644 index 00000000..4dbedab1 --- /dev/null +++ b/src/ge/host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.cc @@ -0,0 +1,132 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.h" +#include +#include "common/constant/constant.h" +#include "ge/ge_api_types.h" +#include "common/ge/ge_util.h" +#include "common/ge_inner_error_codes.h" +#include "framework/common/debug/ge_log.h" +#include "graph/utils/node_utils.h" +#include "graph/utils/tensor_utils.h" +#include "graph/utils/type_utils.h" +#include "op/op_factory.h" +#include "proto/task.pb.h" + +namespace ge { +namespace host_aicpu { +using domi::TaskDef; +using std::map; +using std::string; +using std::vector; + +Status HostAiCpuOpsKernelInfoStore::Initialize(const map &options) { + GELOGI("HostAiCpuOpsKernelInfoStore init start."); + OpInfo default_op_info = {.engine = kHostAiCpuEngineName, + .opKernelLib = kHostAiCpuOpKernelLibName, + .computeCost = 0, + .flagPartial = false, + .flagAsync = false, + .isAtomic = false}; + // Init op_info_map_ + auto all_ops = OpFactory::Instance().GetAllOps(); + for (auto &op : all_ops) { + op_info_map_[op] = default_op_info; + } + + GELOGI("HostAiCpuOpsKernelInfoStore inited success. op num=%zu", op_info_map_.size()); + + return SUCCESS; +} + +Status HostAiCpuOpsKernelInfoStore::Finalize() { + op_info_map_.clear(); + return SUCCESS; +} + +Status HostAiCpuOpsKernelInfoStore::CalcOpRunningParam(Node &ge_node) { + OpDescPtr op_desc = ge_node.GetOpDesc(); + if (op_desc == nullptr) { + GELOGE(FAILED, "CalcOpRunningParam failed, as op desc is null"); + return FAILED; + } + + bool is_shape_unknown = false; + if (NodeUtils::GetNodeUnknownShapeStatus(ge_node, is_shape_unknown) == GRAPH_SUCCESS) { + if (is_shape_unknown) { + GELOGI("op:%s is unknown shape, does not need to calc output size.", ge_node.GetName().c_str()); + return SUCCESS; + } + } + + const string name = ge_node.GetName(); + const string type = ge_node.GetType(); + GELOGD("Calc op[%s:%s] running param, output size=%zu.", name.c_str(), type.c_str(), op_desc->GetOutputsSize()); + + for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { + GeTensorDesc output_tensor = op_desc->GetOutputDesc(static_cast(i)); + Format format = output_tensor.GetFormat(); + DataType data_type = output_tensor.GetDataType(); + + int64_t mem_size = 0; + // If mem size has been set, no need reset. + if ((TensorUtils::GetSize(output_tensor, mem_size) == GRAPH_SUCCESS) && (mem_size > 0)) { + GELOGD("Op[%s:%s] out[%zu] mem size has been set, no need calc again, format=%s, data_type=%s, mem_size=%ld.", + name.c_str(), type.c_str(), i, TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str(), mem_size); + continue; + } + + int64_t output_mem_size = 0; + GeShape output_shape = output_tensor.GetShape(); + if ((TensorUtils::CalcTensorMemSize(output_shape, format, data_type, output_mem_size) != GRAPH_SUCCESS) || + (output_mem_size < 0)) { + GELOGE(FAILED, "Calc op[%s:%s] out[%zu] mem size failed, mem_size=%ld, format=%s, data_type=%s.", name.c_str(), + type.c_str(), i, output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str()); + return FAILED; + } + GELOGI("Calc op[%s:%s] out[%zu] mem size is %ld, format=%s, data_type=%s.", name.c_str(), type.c_str(), i, + output_mem_size, TypeUtils::FormatToSerialString(format).c_str(), + TypeUtils::DataTypeToSerialString(data_type).c_str()); + + TensorUtils::SetSize(output_tensor, output_mem_size); + if (op_desc->UpdateOutputDesc(static_cast(i), output_tensor) != GRAPH_SUCCESS) { + GELOGE(FAILED, "Update op[%s:%s] out[%zu] desc failed, format=%s, data_type=%s.", name.c_str(), type.c_str(), i, + TypeUtils::FormatToSerialString(format).c_str(), TypeUtils::DataTypeToSerialString(data_type).c_str()); + return FAILED; + } + } + GELOGD("Calc op[%s:%s] running param success.", name.c_str(), type.c_str()); + return SUCCESS; +} + +void HostAiCpuOpsKernelInfoStore::GetAllOpsKernelInfo(map &infos) const { infos = op_info_map_; } + +Status HostAiCpuOpsKernelInfoStore::GenerateTask(const Node &node, RunContext &context, vector &tasks) { + // no need to generate device task + return SUCCESS; +} + +bool HostAiCpuOpsKernelInfoStore::CheckSupported(const OpDescPtr &op_desc, std::string &) const { + if (op_desc == nullptr) { + return false; + } + return op_info_map_.count(op_desc->GetType()) > 0; +} +} // namespace host_aicpu +} // namespace ge diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.h b/src/ge/host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.h new file mode 100644 index 00000000..a4051b9b --- /dev/null +++ b/src/ge/host_aicpu_engine/ops_kernel_store/host_aicpu_ops_kernel_info.h @@ -0,0 +1,88 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_HOST_AICPU_OPS_KERNEL_INFO_H_ +#define GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_HOST_AICPU_OPS_KERNEL_INFO_H_ + +#include +#include +#include + +#include "common/opskernel/ops_kernel_info_store.h" + +namespace ge { +namespace host_aicpu { +class HostAiCpuOpsKernelInfoStore : public OpsKernelInfoStore { + public: + HostAiCpuOpsKernelInfoStore() {} + ~HostAiCpuOpsKernelInfoStore() override = default; + + /** + * Initialize related resources of the host aicpu kernelinfo store + * @return status whether this operation success + */ + Status Initialize(const std::map &options) override; + + /** + * Release related resources of the host aicpu kernel info store + * @return status whether this operation success + */ + Status Finalize() override; + + /** + * Check to see if an operator is fully supported or partially supported. + * @param op_desc OpDesc information + * @param reason unsupported reason + * @return bool value indicate whether the operator is fully supported + */ + bool CheckSupported(const OpDescPtr &op_desc, std::string &reason) const override; + + /** + * Returns the full operator information. + * @param infos reference of a map, + * contain operator's name and detailed information + */ + void GetAllOpsKernelInfo(std::map &infos) const override; + + /** + * Calc the running size of Operator, + * then GE will alloc the mem size from runtime + * @param ge_node Node information + * @return status whether this operation success + */ + Status CalcOpRunningParam(ge::Node &ge_node) override; + + /** + * call the runtime's interface to generate the task + * @param node Node information + * @param context run context info + * @return status whether this operation success + */ + Status GenerateTask(const ge::Node &ge_node, ge::RunContext &context, std::vector &tasks) override; + + HostAiCpuOpsKernelInfoStore(const HostAiCpuOpsKernelInfoStore &ops_kernel_store) = delete; + HostAiCpuOpsKernelInfoStore(const HostAiCpuOpsKernelInfoStore &&ops_kernel_store) = delete; + HostAiCpuOpsKernelInfoStore &operator=(const HostAiCpuOpsKernelInfoStore &ops_kernel_store) = delete; + HostAiCpuOpsKernelInfoStore &operator=(HostAiCpuOpsKernelInfoStore &&ops_kernel_store) = delete; + + private: + // store op name and OpInfo key-value pair + std::map op_info_map_; +}; +} // namespace host_aicpu +} // namespace ge + +#endif // GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_HOST_AICPU_OPS_KERNEL_INFO_H_ diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/assign_op.cc b/src/ge/host_aicpu_engine/ops_kernel_store/op/assign_op.cc new file mode 100644 index 00000000..32f8ec24 --- /dev/null +++ b/src/ge/host_aicpu_engine/ops_kernel_store/op/assign_op.cc @@ -0,0 +1,51 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "host_aicpu_engine/ops_kernel_store/op/assign_op.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "host_aicpu_engine/ops_kernel_store/op/op_factory.h" + +namespace { +const size_t kAssignInputNum = 2; +const size_t kAssignRefInputIndex = 0; +const size_t kAssignValueInputIndex = 1; +const size_t kAssignRefOutputIndex = 0; +} // namespace + +namespace ge { +namespace host_aicpu { +Status AssignOp::Compute(const ge::OpDescPtr &op_desc_ptr, const std::vector &inputs, + std::vector &outputs) { + GELOGI("AssignOp [%s, %s] compute begin.", node_.GetName().c_str(), node_.GetType().c_str()); + if (inputs.size() != kAssignInputNum) { + GELOGE(PARAM_INVALID, "Number of input for AssignOp must be %zu.", kAssignInputNum); + return PARAM_INVALID; + } + auto &ref_input = inputs[kAssignRefInputIndex]; + const auto &value_input = inputs[kAssignValueInputIndex]; + ref_input->SetData(value_input->GetData().GetData(), value_input->GetData().GetSize()); + GeTensorPtr output_ptr = MakeShared(op_desc_ptr->GetOutputDesc(kAssignRefOutputIndex), + value_input->GetData().GetData(), value_input->GetData().GetSize()); + GE_CHECK_NOTNULL(output_ptr); + outputs.push_back(output_ptr); + GELOGI("AssignOp [%s, %s] compute success.", node_.GetName().c_str(), node_.GetType().c_str()); + return SUCCESS; +} + +REGISTER_OP_CREATOR(Assign, AssignOp); +} // namespace host_aicpu +} // namespace ge diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/assign_op.h b/src/ge/host_aicpu_engine/ops_kernel_store/op/assign_op.h new file mode 100644 index 00000000..caf9d4c9 --- /dev/null +++ b/src/ge/host_aicpu_engine/ops_kernel_store/op/assign_op.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_ASSIGN_OP_H_ +#define GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_ASSIGN_OP_H_ + +#include "host_aicpu_engine/ops_kernel_store/op/op.h" + +namespace ge { +namespace host_aicpu { +class AssignOp : public Op { + public: + AssignOp(const Node &node, RunContext &run_context) : Op(node, run_context) {} + ~AssignOp() override = default; + AssignOp &operator=(const AssignOp &op) = delete; + AssignOp(const AssignOp &op) = delete; + + /** + * @brief compute for node_task. + * @return result + */ + Status Compute(const ge::OpDescPtr &op_desc_ptr, const std::vector &inputs, + std::vector &outputs) override; +}; +} // namespace host_aicpu +} // namespace ge + +#endif // GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_ASSIGN_OP_H_ diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/host_op.cc b/src/ge/host_aicpu_engine/ops_kernel_store/op/host_op.cc new file mode 100644 index 00000000..9dbd80e0 --- /dev/null +++ b/src/ge/host_aicpu_engine/ops_kernel_store/op/host_op.cc @@ -0,0 +1,34 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "host_aicpu_engine/ops_kernel_store/op/host_op.h" +#include "framework/common/util.h" +#include "host_aicpu_engine/ops_kernel_store/op/op_factory.h" + +namespace ge { +namespace host_aicpu { +Status HostOp::Run() { + // no need to generate device task + return SUCCESS; +} + +REGISTER_OP_CREATOR(NoOp, HostOp); +REGISTER_OP_CREATOR(Variable, HostOp); +REGISTER_OP_CREATOR(Constant, HostOp); +REGISTER_OP_CREATOR(Assign, HostOp); +REGISTER_OP_CREATOR(RandomUniform, HostOp); +} // namespace host_aicpu +} // namespace ge diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/host_op.h b/src/ge/host_aicpu_engine/ops_kernel_store/op/host_op.h new file mode 100644 index 00000000..6655e620 --- /dev/null +++ b/src/ge/host_aicpu_engine/ops_kernel_store/op/host_op.h @@ -0,0 +1,36 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ +#define GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ + +#include "host_aicpu_engine/ops_kernel_store/op/op.h" + +namespace ge { +namespace host_aicpu { +class HostOp : public Op { + public: + HostOp(const Node &node, RunContext &run_context) : Op(node, run_context) {} + ~HostOp() override = default; + HostOp &operator=(const HostOp &op) = delete; + HostOp(const HostOp &op) = delete; + + Status Run() override; +}; +} // namespace host_aicpu +} // namespace ge + +#endif // GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_HOST_OP_H_ diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/op.h b/src/ge/host_aicpu_engine/ops_kernel_store/op/op.h new file mode 100644 index 00000000..87c7993e --- /dev/null +++ b/src/ge/host_aicpu_engine/ops_kernel_store/op/op.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ +#define GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ + +#include +#include +#include +#include "common/ge_inner_error_codes.h" +#include "common/opskernel/ops_kernel_info_types.h" +#include "graph/node.h" + +namespace ge { +namespace host_aicpu { +/** + * The base class for all op. + */ +class Op { + public: + Op(const Node &node, RunContext &run_context) : run_context_(run_context), node_(node) {} + virtual ~Op() = default; + virtual Status Run() = 0; + + protected: + const RunContext &run_context_; + const Node &node_; +}; +} // namespace host_aicpu +} // namespace ge + +#endif // GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_OP_H_ diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/op_factory.cc b/src/ge/host_aicpu_engine/ops_kernel_store/op/op_factory.cc new file mode 100644 index 00000000..ec376d8a --- /dev/null +++ b/src/ge/host_aicpu_engine/ops_kernel_store/op/op_factory.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "host_aicpu_engine/ops_kernel_store/op/op_factory.h" +#include "framework/common/debug/ge_log.h" +#include "common/ge_inner_error_codes.h" +#include "graph/op_desc.h" + +namespace ge { +namespace host_aicpu { +OpFactory &OpFactory::Instance() { + static OpFactory instance; + return instance; +} + +std::shared_ptr OpFactory::CreateOp(const Node &node, RunContext &run_context) { + auto iter = op_creator_map_.find(node.GetType()); + if (iter != op_creator_map_.end()) { + return iter->second(node, run_context); + } + + GELOGE(FAILED, "Not supported OP, type = %s, name = %s", node.GetType().c_str(), node.GetName().c_str()); + return nullptr; +} + +void OpFactory::RegisterCreator(const std::string &type, const OP_CREATOR_FUNC &func) { + if (func == nullptr) { + GELOGW("Func is NULL."); + return; + } + + auto iter = op_creator_map_.find(type); + if (iter != op_creator_map_.end()) { + GELOGW("%s creator already exist", type.c_str()); + return; + } + + op_creator_map_[type] = func; + all_ops_.emplace_back(type); +} +} // namespace host_aicpu +} // namespace ge diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/op_factory.h b/src/ge/host_aicpu_engine/ops_kernel_store/op/op_factory.h new file mode 100644 index 00000000..007bceaa --- /dev/null +++ b/src/ge/host_aicpu_engine/ops_kernel_store/op/op_factory.h @@ -0,0 +1,94 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ +#define GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ + +#include +#include +#include +#include +#include +#include "common/ge/ge_util.h" +#include "host_aicpu_engine/ops_kernel_store/op/op.h" + +namespace ge { +namespace host_aicpu { +using OP_CREATOR_FUNC = std::function(const Node &, RunContext &)>; + +/** + * manage all the op, support create op. + */ +class OpFactory { + public: + static OpFactory &Instance(); + + /** + * @brief create Op. + * @param [in] node share ptr of node + * @param [in] run_context run context + * @return not nullptr success + * @return nullptr fail + */ + std::shared_ptr CreateOp(const Node &node, RunContext &run_context); + + /** + * @brief Register Op create function. + * @param [in] type Op type + * @param [in] func Op create func + */ + void RegisterCreator(const std::string &type, const OP_CREATOR_FUNC &func); + + const std::vector &GetAllOps() const { return all_ops_; } + + bool CheckSupported(const std::string &type) { return op_creator_map_.find(type) != op_creator_map_.end(); } + + OpFactory(const OpFactory &) = delete; + OpFactory &operator=(const OpFactory &) = delete; + OpFactory(OpFactory &&) = delete; + OpFactory &operator=(OpFactory &&) = delete; + + private: + OpFactory() = default; + ~OpFactory() = default; + + // the op creator function map + std::map op_creator_map_; + std::vector all_ops_; +}; + +class OpRegistrar { + public: + OpRegistrar(const std::string &type, const OP_CREATOR_FUNC &func) { + OpFactory::Instance().RegisterCreator(type, func); + } + ~OpRegistrar() = default; + + OpRegistrar(const OpRegistrar &) = delete; + OpRegistrar &operator=(const OpRegistrar &) = delete; + OpRegistrar(OpRegistrar &&) = delete; + OpRegistrar &operator=(OpRegistrar &&) = delete; +}; + +#define REGISTER_OP_CREATOR(type, clazz) \ + std::shared_ptr Creator_##type##Op(const Node &node, RunContext &run_context) { \ + return MakeShared(node, run_context); \ + } \ + OpRegistrar g_##type##Op_creator(#type, Creator_##type##Op) +} // namespace host_aicpu +} // namespace ge + +#endif // GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_OP_FACTORY_H_ diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/random_uniform_op.cc b/src/ge/host_aicpu_engine/ops_kernel_store/op/random_uniform_op.cc new file mode 100644 index 00000000..81768f7a --- /dev/null +++ b/src/ge/host_aicpu_engine/ops_kernel_store/op/random_uniform_op.cc @@ -0,0 +1,104 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "host_aicpu_engine/ops_kernel_store/op/random_uniform_op.h" +#include +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/type_utils.h" +#include "host_aicpu_engine/ops_kernel_store/op/op_factory.h" + +namespace ge { +namespace host_aicpu { +Status RandomUniformOp::Compute(const ge::OpDescPtr &op_desc_ptr, const std::vector &inputs, + std::vector &outputs) { + GELOGI("RandomUniformOp [%s, %s] compute begin.", node_.GetName().c_str(), node_.GetType().c_str()); + int64_t seed = 0; + int64_t seed2 = 0; + (void)AttrUtils::GetInt(op_desc_ptr, "seed", seed); + (void)AttrUtils::GetInt(op_desc_ptr, "seed2", seed2); + DataType data_type = DT_UNDEFINED; + if (AttrUtils::GetDataType(op_desc_ptr, VAR_ATTR_DTYPE, data_type) != GRAPH_SUCCESS) { + GELOGE(PARAM_INVALID, "get attr VAR_ATTR_DTYPE failed"); + return PARAM_INVALID; + } + + switch (data_type) { + case DT_FLOAT16: + break; + case DT_FLOAT: + if (Generate(op_desc_ptr, seed, seed2, outputs) != SUCCESS) { + GELOGE(FAILED, "Generate random_distribution for RandomUniformOp failed, data_type=DT_FLOAT"); + return FAILED; + } + break; + case DT_DOUBLE: + if (Generate(op_desc_ptr, seed, seed2, outputs) != SUCCESS) { + GELOGE(FAILED, "Generate random_distribution for RandomUniformOp failed, data_type=DT_DOUBLE"); + return FAILED; + } + break; + default: + GELOGE(UNSUPPORTED, "Supported DataType for RandomUniformOp is DT_FLOAT16 / DT_FLOAT / DT_DOUBLE, but dtype=%s", + TypeUtils::DataTypeToSerialString(data_type).c_str()); + return UNSUPPORTED; + } + + GELOGI("RandomUniformOp [%s, %s] compute success.", node_.GetName().c_str(), node_.GetType().c_str()); + return SUCCESS; +} + +template +Status RandomUniformOp::Generate(const ge::OpDescPtr &op_desc_ptr, int64_t seed, int64_t seed2, + std::vector &outputs) { + GE_CHECK_NOTNULL(op_desc_ptr); + // RandomUniformOp has and only has one output + int64_t data_num = op_desc_ptr->GetOutputDesc(0).GetShape().GetShapeSize(); + std::unique_ptr buf(new (std::nothrow) T[data_num]()); + if (buf == nullptr) { + GELOGE(MEMALLOC_FAILED, "New sizeof(T) * data_num(%zu) memory failed", static_cast(sizeof(T) * data_num)); + return MEMALLOC_FAILED; + } + + int64_t final_seed; + if (seed == 0) { + if (seed2 == 0) { + std::random_device rd; + final_seed = rd(); + } else { + final_seed = seed2; + } + } else { + final_seed = seed; + } + std::mt19937_64 gen(final_seed); + std::uniform_real_distribution distribution(0, 1); + for (int64_t i = 0; i < data_num; i++) { + *(buf.get() + i) = distribution(gen); + } + + GeTensorPtr output = + MakeShared(op_desc_ptr->GetOutputDesc(0), reinterpret_cast(buf.get()), data_num * sizeof(T)); + GE_CHECK_NOTNULL(output); + outputs.emplace_back(output); + + return SUCCESS; +} + +REGISTER_OP_CREATOR(RandomUniform, RandomUniformOp); +} // namespace host_aicpu +} // namespace ge diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/random_uniform_op.h b/src/ge/host_aicpu_engine/ops_kernel_store/op/random_uniform_op.h new file mode 100644 index 00000000..dfb2485f --- /dev/null +++ b/src/ge/host_aicpu_engine/ops_kernel_store/op/random_uniform_op.h @@ -0,0 +1,45 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_RANDOM_UNIFORM_OP_H_ +#define GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_RANDOM_UNIFORM_OP_H_ + +#include "host_aicpu_engine/ops_kernel_store/op/op.h" + +namespace ge { +namespace host_aicpu { +class RandomUniformOp : public Op { + public: + RandomUniformOp(const Node &node, RunContext &run_context) : Op(node, run_context) {} + ~RandomUniformOp() override = default; + RandomUniformOp &operator=(const RandomUniformOp &op) = delete; + RandomUniformOp(const RandomUniformOp &op) = delete; + + /** + * @brief compute for node_task. + * @return result + */ + Status Compute(const ge::OpDescPtr &op_desc_ptr, const std::vector &inputs, + std::vector &outputs) override; + + private: + template + Status Generate(const ge::OpDescPtr &op_desc_ptr, int64_t seed, int64_t seed2, std::vector &outputs); +}; +} // namespace host_aicpu +} // namespace ge + +#endif // GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_RANDOM_UNIFORM_OP_H_ diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/variable_op.cc b/src/ge/host_aicpu_engine/ops_kernel_store/op/variable_op.cc new file mode 100644 index 00000000..effa346b --- /dev/null +++ b/src/ge/host_aicpu_engine/ops_kernel_store/op/variable_op.cc @@ -0,0 +1,46 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "host_aicpu_engine/ops_kernel_store/op/variable_op.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "host_aicpu_engine/ops_kernel_store/op/op_factory.h" + +namespace { +const size_t kInputSize = 1; +} + +namespace ge { +namespace host_aicpu { +Status VariableOp::Compute(const ge::OpDescPtr &op_desc_ptr, const std::vector &inputs, + std::vector &outputs) { + GELOGI("VariableOp [%s, %s] compute begin.", node_.GetName().c_str(), node_.GetType().c_str()); + if (inputs.size() != kInputSize) { + GELOGE(PARAM_INVALID, "Number of input for VariableOp must be %zu.", kInputSize); + return PARAM_INVALID; + } + GeTensorPtr output_ptr = + MakeShared(op_desc_ptr->GetOutputDesc(0), inputs[0]->GetData().GetData(), inputs[0]->GetData().GetSize()); + GE_CHECK_NOTNULL(output_ptr); + outputs.push_back(output_ptr); + GELOGI("VariableOp [%s, %s] compute success.", node_.GetName().c_str(), node_.GetType().c_str()); + return SUCCESS; +} + +REGISTER_OP_CREATOR(Variable, VariableOp); +REGISTER_OP_CREATOR(Constant, VariableOp); +} // namespace host_aicpu +} // namespace ge diff --git a/src/ge/host_aicpu_engine/ops_kernel_store/op/variable_op.h b/src/ge/host_aicpu_engine/ops_kernel_store/op/variable_op.h new file mode 100644 index 00000000..b6570557 --- /dev/null +++ b/src/ge/host_aicpu_engine/ops_kernel_store/op/variable_op.h @@ -0,0 +1,41 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_VARIABLE_OP_H_ +#define GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_VARIABLE_OP_H_ + +#include "host_aicpu_engine/ops_kernel_store/op/op.h" + +namespace ge { +namespace host_aicpu { +class VariableOp : public Op { + public: + VariableOp(const Node &node, RunContext &run_context) : Op(node, run_context) {} + ~VariableOp() override = default; + VariableOp &operator=(const VariableOp &op) = delete; + VariableOp(const VariableOp &op) = delete; + + /** + * @brief compute for node_task. + * @return result + */ + Status Compute(const ge::OpDescPtr &op_desc_ptr, const std::vector &inputs, + std::vector &outputs) override; +}; +} // namespace host_aicpu +} // namespace ge + +#endif // GE_HOST_AICPU_ENGINE_OPS_KERNEL_STORE_OP_VARIABLE_OP_H_ diff --git a/src/ge/host_kernels/add_kernel.cc b/src/ge/host_kernels/add_kernel.cc index 6d6a049c..afef1c37 100644 --- a/src/ge/host_kernels/add_kernel.cc +++ b/src/ge/host_kernels/add_kernel.cc @@ -133,25 +133,24 @@ Status AddKernel::BCastAdd(const OpDescPtr &op_desc_ptr, const std::vector &input) { if (op_desc_ptr == nullptr) { - GELOGE(PARAM_INVALID, "Op_desc_ptr must not be null."); + GELOGW("Op_desc_ptr must not be null."); return PARAM_INVALID; } // check how many inputs if ((input.size() != kAddInputSize) || (op_desc_ptr->GetOutputsSize() != kAddOutputSize)) { - GELOGE(PARAM_INVALID, "The number of input for add must be %zu, output number must be %zu.", kAddInputSize, - kAddOutputSize); + GELOGW("The number of input for add must be %zu, output number must be %zu.", kAddInputSize, kAddOutputSize); return PARAM_INVALID; } // input vector elements must not be null if ((input[kAddFirstInput] == nullptr) || (input[kAddSecondInput] == nullptr)) { - GELOGE(PARAM_INVALID, "Input vector elements must not be null."); + GELOGW("Input vector elements must not be null."); return PARAM_INVALID; } // Inputs must have the same datatype. DataType data_type_0 = input[kAddFirstInput]->GetTensorDesc().GetDataType(); DataType data_type_1 = input[kAddSecondInput]->GetTensorDesc().GetDataType(); if (data_type_0 != data_type_1) { - GELOGE(PARAM_INVALID, "Data type of inputs for add not matched, data_type_0:%s, data_type_1:%s", + GELOGW("Data type of inputs for add not matched, data_type_0:%s, data_type_1:%s", TypeUtils::DataTypeToSerialString(data_type_0).c_str(), TypeUtils::DataTypeToSerialString(data_type_1).c_str()); return PARAM_INVALID; @@ -192,7 +191,7 @@ Status AddKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector x2_dims; const auto &op_in_desc = op_desc_ptr->MutableInputDesc(0); GE_CHECK_NOTNULL(op_in_desc); - ; DataType data_type = op_in_desc->GetDataType(); bool result = (OpUtils::GetShapeDataFromConstTensor(input[0], data_type, x1_dims) == SUCCESS) && (OpUtils::GetShapeDataFromConstTensor(input[1], data_type, x2_dims) == SUCCESS); diff --git a/src/ge/host_kernels/concat_offset_kernel.cc b/src/ge/host_kernels/concat_offset_kernel.cc index 2e609d68..0a870949 100644 --- a/src/ge/host_kernels/concat_offset_kernel.cc +++ b/src/ge/host_kernels/concat_offset_kernel.cc @@ -41,7 +41,7 @@ Status ConcatOffsetKernel::Compute(const OpDescPtr op_desc_ptr, const vector(reinterpret_cast(input_0->GetData().data()))); // validate inputs if (static_cast(input.size()) != (N + kNumOne) || input.size() <= kConcatOffsetInputIndexOne) { - GELOGE(PARAM_INVALID, "The number of input for concat offset must be equal with %d, and must be more than one.", - (N + kNumOne)); + GELOGW("The number of input for concat offset must be equal with %d, and must be more than one.", (N + kNumOne)); return NOT_CHANGED; } @@ -59,7 +58,7 @@ Status ConcatOffsetKernel::Compute(const OpDescPtr op_desc_ptr, const vectorGetTensorDesc().GetShape(); int64_t output_size = output_shape.GetShapeSize(); if (concat_dim >= output_size) { - GELOGE(PARAM_INVALID, "Concat dim is biger than the size of output_shape."); + GELOGW("Concat dim is biger than the size of output_shape."); return NOT_CHANGED; } GELOGI("Output shape size is %ld", output_size); @@ -79,7 +78,7 @@ Status ConcatOffsetKernel::Compute(const OpDescPtr op_desc_ptr, const vectorGetOutputDesc(0); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to fold node %s, out of memeory", op_desc_ptr->GetName().c_str()); + GELOGW("Failed to fold node %s, out of memeory", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } @@ -87,7 +86,7 @@ Status ConcatOffsetKernel::Compute(const OpDescPtr op_desc_ptr, const vectorMutableTensorDesc().SetShape(output_shape); GE_IF_BOOL_EXEC(output_ptr->SetData(reinterpret_cast(buf.get()), static_cast(sizeof(DT_INT32) * output_size)) != GRAPH_SUCCESS, - GELOGE(INTERNAL_ERROR, "set data failed"); + GELOGW("set data failed"); return NOT_CHANGED); v_output.push_back(output_ptr); // caculate offset diff --git a/src/ge/host_kernels/dynamic_stitch_kernel.cc b/src/ge/host_kernels/dynamic_stitch_kernel.cc index c8a19e44..c1245535 100644 --- a/src/ge/host_kernels/dynamic_stitch_kernel.cc +++ b/src/ge/host_kernels/dynamic_stitch_kernel.cc @@ -63,11 +63,11 @@ Status DynamicStitchKernel::Compute(const OpDescPtr op_desc_ptr, const vector &input) { if (op_desc_ptr == nullptr) { - GELOGE(PARAM_INVALID, "Input op_desc is nullptr."); + GELOGW("Input op_desc is nullptr."); return PARAM_INVALID; } if (op_desc_ptr->GetOutputsSize() == 0) { - GELOGE(PARAM_INVALID, "Current output_desc is empty."); + GELOGW("Current output_desc is empty."); return PARAM_INVALID; } // validate input @@ -78,7 +78,7 @@ Status DynamicStitchKernel::ValidateParams(const OpDescPtr &op_desc_ptr, const s } for (const auto &in : input) { if (in == nullptr) { - GELOGE(PARAM_INVALID, "input is nullptr."); + GELOGW("input is nullptr."); return PARAM_INVALID; } } @@ -150,7 +150,7 @@ Status DynamicStitchKernel::GenData(const vector &input, GeTen // 2.allocate memery for output std::unique_ptr buf(new (std::nothrow) uint8_t[allowance]); if (buf == nullptr) { - GELOGE(MEMALLOC_FAILED, "new buffer failed"); + GELOGW("new buffer failed"); return INTERNAL_ERROR; } // 3.copy data from input_data along with the sequence of input_indices @@ -164,7 +164,7 @@ Status DynamicStitchKernel::GenData(const vector &input, GeTen output_ptr->MutableTensorDesc().SetShape(merged_shape); Status ret = output_ptr->SetData(buf.get(), allowance); if (ret != GRAPH_SUCCESS) { - GELOGE(INTERNAL_ERROR, "set data failed"); + GELOGW("set data failed"); return NOT_CHANGED; } return SUCCESS; diff --git a/src/ge/host_kernels/empty_kernel.cc b/src/ge/host_kernels/empty_kernel.cc index 856caf50..a5e5fbcf 100644 --- a/src/ge/host_kernels/empty_kernel.cc +++ b/src/ge/host_kernels/empty_kernel.cc @@ -38,7 +38,7 @@ const size_t kShapeMaxDims = 1; } // namespace Status EmptyKernel::EmptyCheck(const OpDescPtr &op_desc_ptr, const std::vector &input) { if (op_desc_ptr == nullptr) { - GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr is nullptr."); + GELOGW("Parameter's invalid, Input opDescPtr is nullptr."); return PARAM_INVALID; } // check input size @@ -46,20 +46,19 @@ Status EmptyKernel::EmptyCheck(const OpDescPtr &op_desc_ptr, const std::vectorGetAllInputsDesc().size() != kEmptyInputsSize) || (input.size() != kEmptyInputsSize) || (op_desc_ptr->GetAllOutputsDesc().size() != kEmptyOutputsSize)); if (size_check) { - GELOGE(PARAM_INVALID, "Input/Output size error. InDesc size:%zu, OutDesc size:%zu, in size:%zu ", + GELOGW("Input/Output size error. InDesc size:%zu, OutDesc size:%zu, in size:%zu ", op_desc_ptr->GetAllInputsDesc().size(), op_desc_ptr->GetAllOutputsDesc().size(), input.size()); return PARAM_INVALID; } if (input.at(kEmptyFirstInput) == nullptr) { - GELOGE(PARAM_INVALID, "Parameter's invalid, first input is nullptr."); + GELOGW("Parameter's invalid, first input is nullptr."); return PARAM_INVALID; } ConstGeTensorPtr shape = input.at(kEmptyFirstInput); // Check if the dimension is 1-D if (shape->GetTensorDesc().GetShape().GetDimNum() > kShapeMaxDims) { - GELOGE(PARAM_INVALID, "Check if the dimension is 1-D failed, dims:%zu", - shape->GetTensorDesc().GetShape().GetDimNum()); + GELOGW("Check if the dimension is 1-D failed, dims:%zu", shape->GetTensorDesc().GetShape().GetDimNum()); return PARAM_INVALID; } return SUCCESS; @@ -84,7 +83,7 @@ Status EmptyKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector(shape, shape_vec, total_data_size); } else { - GELOGE(PARAM_INVALID, "shape type must be DT_INT32 or DT_INT64."); + GELOGW("shape type must be DT_INT32 or DT_INT64."); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/expanddims_kernel.cc b/src/ge/host_kernels/expanddims_kernel.cc index 1d17ad48..15648573 100644 --- a/src/ge/host_kernels/expanddims_kernel.cc +++ b/src/ge/host_kernels/expanddims_kernel.cc @@ -66,7 +66,7 @@ Status ExpanddimsKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vec auto output_tensor_desc = op_desc_ptr->GetOutputDesc(kExpandDimsIndexZero); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to fold node %s, out of memory", op_desc_ptr->GetName().c_str()); + GELOGW("Failed to fold node %s, out of memory", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/floordiv_kernel.cc b/src/ge/host_kernels/floordiv_kernel.cc index 4175df92..05eded80 100644 --- a/src/ge/host_kernels/floordiv_kernel.cc +++ b/src/ge/host_kernels/floordiv_kernel.cc @@ -260,7 +260,7 @@ Status FloorDivKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetOutputDesc(0); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); + GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/floormod_kernel.cc b/src/ge/host_kernels/floormod_kernel.cc index a8c16c9d..7ad746de 100644 --- a/src/ge/host_kernels/floormod_kernel.cc +++ b/src/ge/host_kernels/floormod_kernel.cc @@ -122,7 +122,7 @@ Status FloorModKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector(op_desc_ptr->GetOutputDesc(kFloorModFirstOutput)); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); + GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/gather_v2_kernel.cc b/src/ge/host_kernels/gather_v2_kernel.cc index c8cc3006..7413395a 100644 --- a/src/ge/host_kernels/gather_v2_kernel.cc +++ b/src/ge/host_kernels/gather_v2_kernel.cc @@ -274,7 +274,7 @@ Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr auto indices_ptr = const_cast(reinterpret_cast(indices_tensor_ptr->GetData().data())); for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) { if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) { - GELOGE(NOT_CHANGED, "indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis)); + GELOGW("indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis)); return NOT_CHANGED; } indicates_.push_back(*(indices_ptr + i)); @@ -284,7 +284,7 @@ Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr auto indices_ptr = const_cast(reinterpret_cast(indices_tensor_ptr->GetData().data())); for (int64_t i = 0; i < indices_shape.GetShapeSize(); i++) { if (*(indices_ptr + i) < 0 || *(indices_ptr + i) >= x_shape.GetDim(axis)) { - GELOGE(NOT_CHANGED, "indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis)); + GELOGW("indices %ld value is not in range [0, %ld)", i, x_shape.GetDim(axis)); return NOT_CHANGED; } indicates_.push_back(*(indices_ptr + i)); @@ -296,19 +296,19 @@ Status GatherV2Kernel::SaveIndicesByDataType(ConstGeTensorPtr indices_tensor_ptr Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vector &input, vector &v_output) const { if (op_desc_ptr == nullptr) { - GELOGE(NOT_CHANGED, "input opdesc is nullptr."); + GELOGW("input opdesc is nullptr."); return NOT_CHANGED; } if (input.size() != kGatherV2InpotNum) { - GELOGE(NOT_CHANGED, "The number of input for GatherV2 must be %zu.", kGatherV2InpotNum); + GELOGW("The number of input for GatherV2 must be %zu.", kGatherV2InpotNum); return NOT_CHANGED; } bool is_null = (input[kGatherV2InputIndexZero] == nullptr || input[kGatherV2InputIndexOne] == nullptr || input[kGatherV2InputIndexTwo] == nullptr); if (is_null) { - GELOGE(NOT_CHANGED, "some input is nullptr."); + GELOGW("some input is nullptr."); return NOT_CHANGED; } ConstGeTensorPtr tensor0 = input.at(kGatherV2InputIndexZero); @@ -318,7 +318,7 @@ Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vectorGetData().size() == 0) || (tensor1->GetData().size() == 0) || (tensor2->GetData().size() == 0)); if (size_is_zero) { - GELOGE(NOT_CHANGED, "some input size is zero."); + GELOGW("some input size is zero."); return NOT_CHANGED; } @@ -326,13 +326,13 @@ Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vectorGetTensorDesc().GetShape(); // axis must be scalar if (axis_shape.GetDimNum() != 0) { - GELOGE(NOT_CHANGED, "axis must be scalar but its shape is %zu", axis_shape.GetDimNum()); + GELOGW("axis must be scalar but its shape is %zu", axis_shape.GetDimNum()); return NOT_CHANGED; } auto axis_data_type = tensor2->GetTensorDesc().GetDataType(); bool is_valid_axis_data_type = axis_data_type == DT_INT32 || axis_data_type == DT_INT64; if (!is_valid_axis_data_type) { - GELOGE(NOT_CHANGED, "axis datatype must be DT_INT32 or DT_INT64"); + GELOGW("axis datatype must be DT_INT32 or DT_INT64"); return NOT_CHANGED; } @@ -340,11 +340,11 @@ Status GatherV2Kernel::Check(const OpDescPtr &op_desc_ptr, const vectorGetTensorDesc().GetDataType(); bool is_valid_indices_data_type = indices_data_type == DT_INT32 || indices_data_type == DT_INT64; if (!is_valid_indices_data_type) { - GELOGE(NOT_CHANGED, "indices datatype must be DT_INT32 or DT_INT64"); + GELOGW("indices datatype must be DT_INT32 or DT_INT64"); return NOT_CHANGED; } if (indices_shape.GetDimNum() > kMaxIndicatesDims) { - GELOGE(NOT_CHANGED, "indices input only support 0 or 1 dims"); + GELOGW("indices input only support 0 or 1 dims"); return NOT_CHANGED; } return SUCCESS; @@ -372,7 +372,7 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vectorGetName().c_str()); @@ -390,13 +390,13 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector= 0 ? axis : axis + x_shape.GetDimNum(); // check axis value if (axis < 0 || (axis + 1) > static_cast(x_shape.GetDimNum())) { - GELOGE(NOT_CHANGED, "axis is invalid"); + GELOGW("axis is invalid"); return NOT_CHANGED; } auto indices_data_type = tensor1->GetTensorDesc().GetDataType(); ret = SaveIndicesByDataType(tensor1, x_shape, indices_shape, indices_data_type, static_cast(axis)); if (ret != SUCCESS) { - GELOGE(NOT_CHANGED, "Save indeices by data type failed!"); + GELOGW("Save indeices by data type failed!"); return ret; } @@ -420,7 +420,7 @@ Status GatherV2Kernel::Compute(const OpDescPtr op_desc_ptr, const vector(op_desc_ptr->GetOutputDesc(0)); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); + GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } output_ptr->MutableTensorDesc().SetShape(GeShape(y_shape)); diff --git a/src/ge/host_kernels/identity_kernel.cc b/src/ge/host_kernels/identity_kernel.cc new file mode 100644 index 00000000..16bd3138 --- /dev/null +++ b/src/ge/host_kernels/identity_kernel.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "identity_kernel.h" +#include "inc/kernel_factory.h" + +namespace { +constexpr uint32_t kInputDescIndex = 0; +constexpr uint32_t kOutputDescIndex = 0; +} // namespace + +namespace ge { +Status IdentityKernel::Compute(const ge::OpDescPtr op_desc, const std::vector &input, + std::vector &v_output) { + if (op_desc == nullptr) { + GELOGE(PARAM_INVALID, "IdentityKernel op_desc is null."); + return NOT_CHANGED; + } + if (input.empty()) { + GELOGE(PARAM_INVALID, "Node [%s] inputs is empty.", op_desc->GetName().c_str()); + return NOT_CHANGED; + } + if (op_desc->GetOutputsSize() < 1) { + GELOGE(PARAM_INVALID, "Node [%s] output is empty.", op_desc->GetName().c_str()); + return NOT_CHANGED; + } + GELOGD("IdentityKernel in: node[%s]", op_desc->GetName().c_str()); + + auto out_tensor_desc = op_desc->GetOutputDesc(kOutputDescIndex); + GeTensorPtr output_ptr = MakeShared(out_tensor_desc); + if (output_ptr == nullptr) { + GELOGE(OUT_OF_MEMORY, "Node [%s] make shared failed.", op_desc->GetName().c_str()); + return OUT_OF_MEMORY; + } + auto input_tensor_ptr = input.at(kInputDescIndex); + if (input_tensor_ptr == nullptr) { + GELOGE(PARAM_INVALID, "Node [%s] get input failed.", op_desc->GetName().c_str()); + return NOT_CHANGED; + } + if (output_ptr->SetData(input_tensor_ptr->GetData()) != GRAPH_SUCCESS) { + GELOGW("Compute: SetData failed"); + return NOT_CHANGED; + } + v_output.emplace_back(output_ptr); + GELOGD("IdentityKernel success: node[%s]", op_desc->GetName().c_str()); + + return SUCCESS; +} +REGISTER_KERNEL(IDENTITY, IdentityKernel); +} // namespace ge diff --git a/src/ge/host_kernels/identity_kernel.h b/src/ge/host_kernels/identity_kernel.h new file mode 100644 index 00000000..2164d880 --- /dev/null +++ b/src/ge/host_kernels/identity_kernel.h @@ -0,0 +1,31 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_GRAPH_PASSES_FOLDING_KERNEL_IDENTITY_KERNEL_H_ +#define GE_GRAPH_PASSES_FOLDING_KERNEL_IDENTITY_KERNEL_H_ + +#include "inc/kernel.h" +#include + +namespace ge { +class IdentityKernel : public Kernel { + public: + Status Compute(const ge::OpDescPtr op_desc_ptr, const std::vector &input, + std::vector &v_output) override; +}; +} // namespace ge + +#endif // GE_GRAPH_PASSES_FOLDING_KERNEL_IDENTITY_KERNEL_H_ diff --git a/src/ge/host_kernels/pack_kernel.cc b/src/ge/host_kernels/pack_kernel.cc index f3f64a6c..9b62a582 100644 --- a/src/ge/host_kernels/pack_kernel.cc +++ b/src/ge/host_kernels/pack_kernel.cc @@ -63,7 +63,7 @@ Status PackKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector &input) { if (op_desc_ptr == nullptr) { - GELOGE(PARAM_INVALID, "input opdesc is nullptr."); + GELOGW("input opdesc is nullptr."); return PARAM_INVALID; } if (!(AttrUtils::GetInt(op_desc_ptr, PACK_ATTR_NAME_NUM, n_))) { @@ -71,16 +71,15 @@ Status PackKernel::ValidateKernelParams(const ge::OpDescPtr &op_desc_ptr, GELOGD("Attr %s is not set, default value %ld is used.", PACK_ATTR_NAME_NUM.c_str(), n_); } if (!(AttrUtils::GetInt(op_desc_ptr, ATTR_NAME_AXIS, axis_))) { - GELOGE(PARAM_INVALID, "Attr %s is not exist.", ATTR_NAME_AXIS.c_str()); + GELOGW("Attr %s is not exist.", ATTR_NAME_AXIS.c_str()); return PARAM_INVALID; } if (input.empty()) { - GELOGE(PARAM_INVALID, "The number of input for Pack should be %ld, in fact it is %zu ", n_, input.size()); + GELOGW("The number of input for Pack should be %ld, in fact it is %zu ", n_, input.size()); return NOT_CHANGED; } if (input.size() != static_cast(n_)) { - GELOGE(PARAM_INVALID, "The number of input for Pack should be %d, in fact it is %ld ", static_cast(n_), - input.size()); + GELOGW("The number of input for Pack should be %d, in fact it is %ld ", static_cast(n_), input.size()); return PARAM_INVALID; } data_type_ = op_desc_ptr->GetInputDesc(0).GetDataType(); diff --git a/src/ge/host_kernels/permute_kernel.cc b/src/ge/host_kernels/permute_kernel.cc index 8263d19f..24bed54d 100644 --- a/src/ge/host_kernels/permute_kernel.cc +++ b/src/ge/host_kernels/permute_kernel.cc @@ -110,14 +110,14 @@ Status PermuteKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetData().data(); formats::TransResult trans_result; auto ret = formats::TransposeWithShapeCheck(src_data, src_shape, data_shape, src_data_type, perm_list, trans_result); if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to Transpose from %s to %s, shape %s to %s, perm_list %s, data type %s", + GELOGW("Failed to Transpose from %s to %s, shape %s to %s, perm_list %s, data type %s", TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(data_shape).c_str(), formats::ShapeToString(perm_list).c_str(), TypeUtils::DataTypeToSerialString(src_data_type).c_str()); diff --git a/src/ge/host_kernels/rank_kernel.cc b/src/ge/host_kernels/rank_kernel.cc index faaf16b8..7fb92039 100644 --- a/src/ge/host_kernels/rank_kernel.cc +++ b/src/ge/host_kernels/rank_kernel.cc @@ -19,6 +19,7 @@ #include #include +#include "graph/types.h" #include "common/ge_inner_error_codes.h" #include "common/op/ge_op_utils.h" #include "framework/common/debug/ge_log.h" @@ -46,10 +47,13 @@ Status RankKernel::Compute(const NodePtr &node, std::vector &v_outp const auto &input_shape = op_desc->MutableInputDesc(kRankDataInputIndex); GE_CHECK_NOTNULL(input_shape); + if (input_shape->GetShape().GetDims() == UNKNOWN_RANK) { + return NOT_CHANGED; + } auto ndims = input_shape->GetShape().GetDimNum(); GeTensorDesc tensor_desc(op_desc->GetOutputDesc(0)); GeTensorPtr output_ptr; - output_ptr = MakeShared(tensor_desc, reinterpret_cast(&ndims), sizeof(ndims)); + output_ptr = MakeShared(tensor_desc, reinterpret_cast(&ndims), GetSizeByDataType(DT_INT32)); if (output_ptr == nullptr) { GELOGE(MEMALLOC_FAILED, "make_shared ge::GeTensor failed"); return MEMALLOC_FAILED; diff --git a/src/ge/host_kernels/reduce_prod_kernel.cc b/src/ge/host_kernels/reduce_prod_kernel.cc index 479b50ab..739d4b9f 100644 --- a/src/ge/host_kernels/reduce_prod_kernel.cc +++ b/src/ge/host_kernels/reduce_prod_kernel.cc @@ -51,7 +51,7 @@ Status ReduceProdKernel::ReduceProdCheck(const ge::OpDescPtr &op_desc_ptr, op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } - GELOGE(PARAM_INVALID, "Unexpected ReduceProd node, node input size: %zu, node name: %s", input.size(), + GELOGW("Unexpected ReduceProd node, node input size: %zu, node name: %s", input.size(), op_desc_ptr->GetName().c_str()); return PARAM_INVALID; } @@ -60,13 +60,13 @@ Status ReduceProdKernel::ReduceProdCheck(const ge::OpDescPtr &op_desc_ptr, GE_CHECK_NOTNULL(data_tensor); GE_CHECK_NOTNULL(axis_tensor); if (axis_tensor->GetTensorDesc().GetShape().GetDimNum() > kReduceProdMaxAxisRank) { - GELOGE(PARAM_INVALID, "Axis must be at most rank 1, node node: %s", op_desc_ptr->GetName().c_str()); + GELOGW("Axis must be at most rank 1, node node: %s", op_desc_ptr->GetName().c_str()); return PARAM_INVALID; } DataType data_type = data_tensor->GetTensorDesc().GetDataType(); if (kReduceProdSupportedType.find(data_type) == kReduceProdSupportedType.end()) { - GELOGE(PARAM_INVALID, "ReduceProdKernel data type %s not support, node name: %s", + GELOGW("ReduceProdKernel data type %s not support, node name: %s", TypeUtils::DataTypeToSerialString(data_type).c_str(), op_desc_ptr->GetName().c_str()); return PARAM_INVALID; } @@ -83,7 +83,7 @@ Status ReduceProdKernel::AxisCal(const std::vector &input) int32_t *axis = const_cast(reinterpret_cast(axis_tensor->GetData().GetData())); GE_CHECK_NOTNULL(axis); if (static_cast(*axis) >= data_dim_size) { - GELOGE(PARAM_INVALID, "axis is out of rank of data_dims, axis is %d.", *axis); + GELOGW("axis is out of rank of data_dims, axis is %d.", *axis); return PARAM_INVALID; } axis_dim_ = data_dims[static_cast(*axis)]; @@ -98,13 +98,13 @@ Status ReduceProdKernel::AxisCal(const std::vector &input) // data_dims is the vector of dims, element in data_dims isn't negative. if (axis_appear) { if (data_dims[i] != 0 && end_dim_ > (INT64_MAX / data_dims[i])) { - GELOGE(INTERNAL_ERROR, "Product is overflow. multiplier 1: %ld. multiplier 2: %ld.", end_dim_, data_dims[i]); + GELOGW("Product is overflow. multiplier 1: %ld. multiplier 2: %ld.", end_dim_, data_dims[i]); return INTERNAL_ERROR; } end_dim_ *= data_dims[i]; } else { if (data_dims[i] != 0 && head_dim_ > (INT64_MAX / data_dims[i])) { - GELOGE(INTERNAL_ERROR, "Product is overflow. multiplier 1: %ld. multiplier 2: %ld.", head_dim_, data_dims[i]); + GELOGW("Product is overflow. multiplier 1: %ld. multiplier 2: %ld.", head_dim_, data_dims[i]); return INTERNAL_ERROR; } head_dim_ *= data_dims[i]; @@ -122,7 +122,7 @@ Status ReduceProdKernel::DataCal(const std::vector &input, size_t data_num = data_tensor->GetData().size() / sizeof(int32_t); unique_ptr buf(new (std::nothrow) int32_t[data_num]()); if (buf == nullptr) { - GELOGE(MEMALLOC_FAILED, "new buf failed"); + GELOGW("new buf failed"); return INTERNAL_ERROR; } @@ -190,12 +190,12 @@ Status ReduceProdKernel::ComputeNoAxis(const ge::OpDescPtr &op_desc_ptr, const s ConstGeTensorPtr data_tensor = input.at(kReduceProdDataIndex); GE_CHECK_NOTNULL(data_tensor); if (data_tensor->GetData().size() == 0) { - GELOGE(PARAM_INVALID, "ReduceProdKernel data size of inputs is 0, node node: %s", op_desc_ptr->GetName().c_str()); + GELOGW("ReduceProdKernel data size of inputs is 0, node node: %s", op_desc_ptr->GetName().c_str()); return PARAM_INVALID; } DataType data_type = data_tensor->GetTensorDesc().GetDataType(); if (kReduceProdSupportedType.find(data_type) == kReduceProdSupportedType.end()) { - GELOGE(PARAM_INVALID, "ReduceProdKernel data type %s not support, node name: %s", + GELOGW("ReduceProdKernel data type %s not support, node name: %s", TypeUtils::DataTypeToSerialString(data_type).c_str(), op_desc_ptr->GetName().c_str()); return PARAM_INVALID; } @@ -206,7 +206,7 @@ Status ReduceProdKernel::ComputeNoAxis(const ge::OpDescPtr &op_desc_ptr, const s size_t data_num = data_tensor->GetData().size() / sizeof(int32_t); unique_ptr buf(new (std::nothrow) int32_t[data_num]()); if (buf == nullptr) { - GELOGE(MEMALLOC_FAILED, "new buf failed"); + GELOGW("new buf failed"); return INTERNAL_ERROR; } @@ -235,7 +235,7 @@ Status ReduceProdKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vec GELOGI("ReduceProdKernel in."); Status ret = ReduceProdCheck(op_desc_ptr, input); if (ret != SUCCESS && ret != NOT_CHANGED) { - GELOGE(PARAM_INVALID, "ReduceProdKernel input is invalid, failed to fold node."); + GELOGW("ReduceProdKernel input is invalid, failed to fold node."); return NOT_CHANGED; } @@ -243,7 +243,7 @@ Status ReduceProdKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vec auto output_tensor_desc = op_desc_ptr->GetOutputDesc(0); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); + GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/reformat_kernel.cc b/src/ge/host_kernels/reformat_kernel.cc index 33a13599..c2dd1e17 100644 --- a/src/ge/host_kernels/reformat_kernel.cc +++ b/src/ge/host_kernels/reformat_kernel.cc @@ -56,7 +56,7 @@ Status ReFormatKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetTensorDesc().GetShape()).c_str()); return NOT_CHANGED; } GeTensorPtr output_ptr = MakeShared(op_desc_ptr->GetOutputDesc(kReformatFirstOutput)); if (output_ptr == nullptr) { - GELOGE(INTERNAL_ERROR, "Create shared ptr for GeTensor failed"); + GELOGW("Create shared ptr for GeTensor failed"); return NOT_CHANGED; } - GE_IF_BOOL_EXEC(output_ptr->SetData(input.at(0)->GetData()) != GRAPH_SUCCESS, - GELOGE(INTERNAL_ERROR, "set data failed"); + GE_IF_BOOL_EXEC(output_ptr->SetData(input.at(0)->GetData()) != GRAPH_SUCCESS, GELOGW("set data failed"); return NOT_CHANGED); v_output.emplace_back(output_ptr); GELOGD("ReFormatKernel success."); diff --git a/src/ge/host_kernels/reshape_kernel.cc b/src/ge/host_kernels/reshape_kernel.cc index 906624d2..dc7e4bb8 100644 --- a/src/ge/host_kernels/reshape_kernel.cc +++ b/src/ge/host_kernels/reshape_kernel.cc @@ -67,7 +67,7 @@ Status ReshapeKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vector auto output_tensor_desc = op_desc_ptr->GetOutputDesc(kOutputDescFirstIndex); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to fold node %s, out of memory", op_desc_ptr->GetName().c_str()); + GELOGW("Failed to fold node %s, out of memory", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/rsqrt_kernel.cc b/src/ge/host_kernels/rsqrt_kernel.cc index 3e14fd5f..56972d23 100644 --- a/src/ge/host_kernels/rsqrt_kernel.cc +++ b/src/ge/host_kernels/rsqrt_kernel.cc @@ -64,7 +64,7 @@ Status RsqrtKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector 0) { unique_ptr buf(new (std::nothrow) float[data_count]()); if (buf == nullptr) { - GELOGE(MEMALLOC_FAILED, "new buf failed"); + GELOGW("new buf failed"); return NOT_CHANGED; } @@ -81,13 +81,13 @@ Status RsqrtKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetOutputDesc(0); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "MakeShared GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); + GELOGW("MakeShared GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } output_ptr->MutableTensorDesc().SetDataType(DT_FLOAT); GE_IF_BOOL_EXEC(output_ptr->SetData(reinterpret_cast(buf.get()), data_size) != GRAPH_SUCCESS, - GELOGE(INTERNAL_ERROR, "set data failed"); + GELOGW("set data failed"); return NOT_CHANGED); output_ptr->MutableTensorDesc().SetShape(x_shape); v_output.push_back(output_ptr); diff --git a/src/ge/host_kernels/slice_d_kernel.cc b/src/ge/host_kernels/slice_d_kernel.cc index ad0a1675..3b8fd0a0 100644 --- a/src/ge/host_kernels/slice_d_kernel.cc +++ b/src/ge/host_kernels/slice_d_kernel.cc @@ -129,7 +129,7 @@ Status SliceDKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetOutputDesc(0); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to fold node %s, out of memory", op_desc_ptr->GetName().c_str()); + GELOGW("Failed to fold node %s, out of memory", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } @@ -143,8 +143,14 @@ Status SliceDKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector(const_cast(x_tensor->GetData().data())); int64_t x_data_size = x_tensor->GetTensorDesc().GetShape().GetShapeSize(); - Status ret = OpUtils::SetOutputSliceData(data, x_data_size, x_data_type, x_dims, begin_list, size_list, - output_ptr.get(), stride_list); + + Status ret = CheckOutputDims(size_list, op_desc_ptr); + if (ret != SUCCESS) { + return ret; + } + + ret = OpUtils::SetOutputSliceData(data, x_data_size, x_data_type, x_dims, begin_list, size_list, output_ptr.get(), + stride_list); if (ret != SUCCESS) { GELOGW("Set output data of SliceD failed."); return NOT_CHANGED; @@ -155,5 +161,16 @@ Status SliceDKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector &output_dims, const OpDescPtr attr) { + // check dim not all less than 0 + for (auto dim : output_dims) { + if (dim > 0) { + return SUCCESS; + } + } + GELOGW("all output dim <=0, can't be processed. op_name : %s", attr->GetName().c_str()); + return NOT_CHANGED; +} + REGISTER_KERNEL(SLICED, SliceDKernel); } // namespace ge diff --git a/src/ge/host_kernels/slice_d_kernel.h b/src/ge/host_kernels/slice_d_kernel.h index 9fe35352..90ef9b8b 100644 --- a/src/ge/host_kernels/slice_d_kernel.h +++ b/src/ge/host_kernels/slice_d_kernel.h @@ -29,6 +29,7 @@ class SliceDKernel : public Kernel { private: Status SliceDCheck(const OpDescPtr &op_desc_ptr, const std::vector &input, std::vector &begin_list, std::vector &size_list); + Status CheckOutputDims(const std::vector &output_dims, const OpDescPtr attr); }; } // namespace ge diff --git a/src/ge/host_kernels/slice_kernel.cc b/src/ge/host_kernels/slice_kernel.cc index 1d7d90c2..5f72fc49 100644 --- a/src/ge/host_kernels/slice_kernel.cc +++ b/src/ge/host_kernels/slice_kernel.cc @@ -21,8 +21,8 @@ #include "common/types.h" #include "common/util.h" #include "framework/common/debug/ge_log.h" -#include "host_kernels/kernel_utils.h" #include "graph/utils/type_utils.h" +#include "host_kernels/kernel_utils.h" #include "inc/kernel_factory.h" namespace ge { diff --git a/src/ge/host_kernels/ssd_prior_box_kernel.cc b/src/ge/host_kernels/ssd_prior_box_kernel.cc index c874d732..9de5a08d 100644 --- a/src/ge/host_kernels/ssd_prior_box_kernel.cc +++ b/src/ge/host_kernels/ssd_prior_box_kernel.cc @@ -365,7 +365,7 @@ Status SsdPriorboxKernel::Compute(const NodePtr &node, std::vector // make TensorDesc GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(INTERNAL_ERROR, "Create shared ptr for GeTensor failed"); + GELOGW("Create shared ptr for GeTensor failed"); return NOT_CHANGED; } GE_IF_BOOL_EXEC(output_ptr->SetData(reinterpret_cast(output_data.get()), diff --git a/src/ge/host_kernels/strided_slice_kernel.cc b/src/ge/host_kernels/strided_slice_kernel.cc index 0d70a36a..6a9a558c 100644 --- a/src/ge/host_kernels/strided_slice_kernel.cc +++ b/src/ge/host_kernels/strided_slice_kernel.cc @@ -46,31 +46,31 @@ Status StridedSliceKernel::CheckAndGetAttr(const OpDescPtr &attr, const std::vec int64_t shrink_axis_mask = 0; if (attr == nullptr) { - GELOGE(PARAM_INVALID, "input opdescptr is nullptr."); + GELOGW("input opdescptr is nullptr."); return PARAM_INVALID; } if (input.size() != kStridedSliceInputSize) { - GELOGE(PARAM_INVALID, "The number of input for strided slice must be %zu.", kStridedSliceInputSize); + GELOGW("The number of input for strided slice must be %zu.", kStridedSliceInputSize); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_BEGIN_MASK, begin_mask)) { - GELOGE(PARAM_INVALID, "get begin_mask attr failed."); + GELOGW("get begin_mask attr failed."); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_END_MASK, end_mask)) { - GELOGE(PARAM_INVALID, "get end_mask attr failed."); + GELOGW("get end_mask attr failed."); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_ELLIPSIS_MASK, ellipsis_mask)) { - GELOGE(PARAM_INVALID, "get ellipsis_mask attr failed."); + GELOGW("get ellipsis_mask attr failed."); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_NEW_AXIS_MASK, new_axis_mask)) { - GELOGE(PARAM_INVALID, "get new_axis_mask attr failed."); + GELOGW("get new_axis_mask attr failed."); return PARAM_INVALID; } if (!AttrUtils::GetInt(attr, STRIDE_SLICE_ATTR_SHRINK_AXIS_MASK, shrink_axis_mask)) { - GELOGE(PARAM_INVALID, "get shrink_axis_mask attr failed."); + GELOGW("get shrink_axis_mask attr failed."); return PARAM_INVALID; } if ((ellipsis_mask != 0) || (new_axis_mask != 0)) { @@ -98,7 +98,7 @@ Status StridedSliceKernel::CheckAndGetAttr(const OpDescPtr &attr, const std::vec ConstGeTensorPtr weight2 = input[kStridedSliceInputIndex2]; ConstGeTensorPtr weight3 = input[kStridedSliceInputIndex3]; if (CheckWeight(weight0, weight1, weight2, weight3) != SUCCESS) { - GELOGE(PARAM_INVALID, "Check And Get Attr failed."); + GELOGW("Check And Get Attr failed."); return PARAM_INVALID; } @@ -168,6 +168,17 @@ void StridedSliceKernel::GetOutputDims(uint32_t dims_size, const std::vector &output_dims, const OpDescPtr attr) { + // check dim not all less than 0 + for (auto dim : output_dims) { + if (dim > 0) { + return SUCCESS; + } + } + GELOGW("all output dim <=0, can't be processed. op_name : %s", attr->GetName().c_str()); + return NOT_CHANGED; +} + Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector &input, vector &v_output) { GELOGI("StridedSliceKernel in."); @@ -191,7 +202,7 @@ Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vector(weight2->GetData().data()); const int32_t *stride = reinterpret_cast(weight3->GetData().data()); if ((begin == nullptr) || (end == nullptr) || (stride == nullptr)) { - GELOGE(PARAM_INVALID, "input weight tensor is nullptr."); + GELOGW("input weight tensor is nullptr."); return NOT_CHANGED; } @@ -237,16 +248,22 @@ Status StridedSliceKernel::Compute(const ge::OpDescPtr attr, const std::vectorGetOutputDesc(0); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "MakeShared GeTensor failed, node name %s.", attr->GetName().c_str()); + GELOGW("MakeShared GeTensor failed, node name %s.", attr->GetName().c_str()); return NOT_CHANGED; } void *data = reinterpret_cast(const_cast(weight0->GetData().data())); GE_CHECK_NOTNULL(data); + + ret = CheckOutputDims(output_dims, attr); + if (ret != SUCCESS) { + return ret; + } + ret = OpUtils::SetOutputSliceData(data, static_cast(data_size), args.data_type, input_dims, begin_vec, output_dims, output_ptr.get(), stride_vec); if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "SetOutputSliceData failed."); + GELOGW("SetOutputSliceData failed."); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/strided_slice_kernel.h b/src/ge/host_kernels/strided_slice_kernel.h index e569b2d0..0ba3afbd 100644 --- a/src/ge/host_kernels/strided_slice_kernel.h +++ b/src/ge/host_kernels/strided_slice_kernel.h @@ -44,6 +44,7 @@ class StridedSliceKernel : public Kernel { int32_t &end_i, int32_t &dim_i) const; void GetOutputDims(uint32_t dims_size, const std::vector &output_dims, const Attr &args, vector &v_dims); + Status CheckOutputDims(const std::vector &output_dims, const OpDescPtr attr); }; } // namespace ge #endif // GE_GRAPH_PASSES_FOLDING_KERNEL_STRIDED_SLICE_KERNEL_H_ diff --git a/src/ge/host_kernels/sub_kernel.cc b/src/ge/host_kernels/sub_kernel.cc index ed1e5808..70a14c9f 100644 --- a/src/ge/host_kernels/sub_kernel.cc +++ b/src/ge/host_kernels/sub_kernel.cc @@ -162,7 +162,7 @@ Status SubKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vectorGetOutputDesc(kSubFirstOutput); GeTensorPtr output_ptr = MakeShared(output_tensor_desc); if (output_ptr == nullptr) { - GELOGE(MEMALLOC_FAILED, "make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); + GELOGW("make_shared ge::GeTensor failed, node name %s.", op_desc_ptr->GetName().c_str()); return NOT_CHANGED; } diff --git a/src/ge/host_kernels/transdata_kernel.cc b/src/ge/host_kernels/transdata_kernel.cc index 5fe44fe4..c5c9da6e 100644 --- a/src/ge/host_kernels/transdata_kernel.cc +++ b/src/ge/host_kernels/transdata_kernel.cc @@ -113,7 +113,7 @@ Status TransdataKernel::Compute(const OpDescPtr op_desc_ptr, const std::vectorGetData().data(); formats::TransResult trans_result; auto ret = formats::TransposeWithShapeCheck(src_data, src_shape, data_shape, src_data_type, perm_list, trans_result); if (ret != SUCCESS) { - GELOGE(INTERNAL_ERROR, "Failed to Transpose from %s to %s, shape %s to %s, perm_list %s, data type %s", + GELOGW("Failed to Transpose from %s to %s, shape %s to %s, perm_list %s, data type %s", TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(data_shape).c_str(), formats::ShapeToString(perm_list).c_str(), TypeUtils::DataTypeToSerialString(src_data_type).c_str()); diff --git a/src/ge/hybrid/common/npu_memory_allocator.cc b/src/ge/hybrid/common/npu_memory_allocator.cc index f432318b..1908725f 100644 --- a/src/ge/hybrid/common/npu_memory_allocator.cc +++ b/src/ge/hybrid/common/npu_memory_allocator.cc @@ -25,6 +25,11 @@ namespace hybrid { std::map> NpuMemoryAllocator::allocators_; std::mutex NpuMemoryAllocator::mu_; +AllocationAttr::AllocationAttr(int padding, void *try_reuse_addr) + : padding_(padding), try_reuse_addr_(try_reuse_addr) {} +AllocationAttr::AllocationAttr(int padding) : AllocationAttr(padding, nullptr) {} +AllocationAttr::AllocationAttr(void *try_reuse_addr) : AllocationAttr(0, try_reuse_addr) {} + NpuMemoryAllocator *NpuMemoryAllocator::GetAllocator() { int32_t device_id = 0; if (rtGetDevice(&device_id) != RT_ERROR_NONE) { @@ -38,15 +43,26 @@ NpuMemoryAllocator *NpuMemoryAllocator::GetAllocator() { NpuMemoryAllocator::NpuMemoryAllocator(uint32_t device_id) : device_id_(device_id) {} -void *NpuMemoryAllocator::Allocate(std::size_t size, void *try_reuse_addr) { - void *buffer = - MemManager::CachingInstance(RT_MEMORY_HBM).Malloc(size, reinterpret_cast(try_reuse_addr), device_id_); +void *NpuMemoryAllocator::Allocate(std::size_t size, AllocationAttr *attr) { + void *try_reuse_addr = nullptr; + size_t allocate_size = size; + if (attr != nullptr) { + try_reuse_addr = attr->try_reuse_addr_; + if (attr->padding_ != 0) { + // padding up to multiple of attr->padding, and add extra attr->padding_ + allocate_size = (size + 2 * attr->padding_ - 1) / attr->padding_ * attr->padding_; + GELOGD("Padding size %ld by %d. final size = %zu.", size, attr->padding_, allocate_size); + } + } + + void *buffer = MemManager::CachingInstance(RT_MEMORY_HBM) + .Malloc(allocate_size, reinterpret_cast(try_reuse_addr), device_id_); if (buffer == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to malloc memory, device_id = %u, size = %zu", device_id_, size); + GELOGE(MEMALLOC_FAILED, "Failed to malloc memory, device_id = %u, size = %zu", device_id_, allocate_size); return nullptr; } - GELOGI("Allocating buffer of size %u successfully. device_id = %u, address = %p", size, device_id_, buffer); + GELOGI("Allocating buffer of size %zu successfully. device_id = %u, address = %p", allocate_size, device_id_, buffer); return buffer; } diff --git a/src/ge/hybrid/common/npu_memory_allocator.h b/src/ge/hybrid/common/npu_memory_allocator.h index 8cfeafa6..a9744540 100644 --- a/src/ge/hybrid/common/npu_memory_allocator.h +++ b/src/ge/hybrid/common/npu_memory_allocator.h @@ -26,16 +26,35 @@ namespace ge { namespace hybrid { +class AllocationAttr { + public: + explicit AllocationAttr(int padding); + explicit AllocationAttr(void *try_reuse_addr); + AllocationAttr(int padding, void *try_reuse_addr); + ~AllocationAttr() = default; + + private: + friend class NpuMemoryAllocator; + int padding_ = 0; + void *try_reuse_addr_ = nullptr; +}; + class NpuMemoryAllocator { public: ~NpuMemoryAllocator() = default; static NpuMemoryAllocator *GetAllocator(uint32_t device_id); static NpuMemoryAllocator *GetAllocator(); static void DestroyAllocator(); + static AllocationAttr *AttrWithDefaultPadding() { + static AllocationAttr attr(kDefaultPadding, nullptr); + return &attr; + } - void *Allocate(std::size_t size, void *try_reuse_addr = nullptr); + void *Allocate(std::size_t size, AllocationAttr *attr = nullptr); void Deallocate(void *data); + static constexpr int kDefaultPadding = 32; + private: explicit NpuMemoryAllocator(uint32_t device_id); uint32_t device_id_; diff --git a/src/ge/hybrid/common/tensor_value.cc b/src/ge/hybrid/common/tensor_value.cc index 9544e03a..929d3c87 100644 --- a/src/ge/hybrid/common/tensor_value.cc +++ b/src/ge/hybrid/common/tensor_value.cc @@ -24,7 +24,7 @@ namespace hybrid { TensorBuffer::TensorBuffer(NpuMemoryAllocator *allocator, void *buffer, size_t size) : allocator_(allocator), buffer_(buffer), size_(size) {} -std::unique_ptr TensorBuffer::Create(NpuMemoryAllocator *allocator, size_t size) { +std::unique_ptr TensorBuffer::Create(NpuMemoryAllocator *allocator, size_t size, AllocationAttr *attr) { void *buffer = nullptr; if (size == 0) { GELOGD("size is 0"); @@ -36,7 +36,7 @@ std::unique_ptr TensorBuffer::Create(NpuMemoryAllocator *allocator return nullptr; } - buffer = allocator->Allocate(size); + buffer = allocator->Allocate(size, attr); if (buffer == nullptr) { GELOGE(MEMALLOC_FAILED, "Failed to allocate memory. size = %zu", size); return nullptr; diff --git a/src/ge/hybrid/common/tensor_value.h b/src/ge/hybrid/common/tensor_value.h index 18e67534..db8df9e5 100644 --- a/src/ge/hybrid/common/tensor_value.h +++ b/src/ge/hybrid/common/tensor_value.h @@ -24,10 +24,12 @@ namespace ge { namespace hybrid { class NpuMemoryAllocator; +class AllocationAttr; class TensorBuffer { public: - static std::unique_ptr Create(NpuMemoryAllocator *allocator, size_t size); + static std::unique_ptr Create(NpuMemoryAllocator *allocator, size_t size, + AllocationAttr *attr = nullptr); static std::unique_ptr Create(void *buffer, size_t size); diff --git a/src/ge/hybrid/executor/hybrid_execution_context.cc b/src/ge/hybrid/executor/hybrid_execution_context.cc index bb8e0195..8144ba52 100644 --- a/src/ge/hybrid/executor/hybrid_execution_context.cc +++ b/src/ge/hybrid/executor/hybrid_execution_context.cc @@ -17,34 +17,5 @@ #include "hybrid_execution_context.h" namespace ge { -namespace hybrid { -NodeStatePtr GraphExecutionContext::GetOrCreateNodeState(const NodePtr &node) { - auto &node_state = node_states[node]; - if (node_state == nullptr) { - const NodeItem *node_item = model->GetNodeItem(node); - if (node_item == nullptr) { - return nullptr; - } - node_state.reset(new (std::nothrow) NodeState(*node_item)); - } - - return node_state; -} - -void GraphExecutionContext::OnError(Status error_code) { - GELOGE(error_code, "Error occurred while executing model"); - { - std::lock_guard lk(mu_); - this->status = error_code; - } - - compile_queue.Stop(); - execution_queue.Stop(); -} - -Status GraphExecutionContext::GetStatus() { - std::lock_guard lk(mu_); - return status; -} -} // namespace hybrid +namespace hybrid {} // namespace hybrid } // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/executor/hybrid_execution_context.h b/src/ge/hybrid/executor/hybrid_execution_context.h index 07a6fabf..96722fa9 100644 --- a/src/ge/hybrid/executor/hybrid_execution_context.h +++ b/src/ge/hybrid/executor/hybrid_execution_context.h @@ -20,6 +20,7 @@ #include #include #include "common/blocking_queue.h" +#include "framework/common/debug/ge_log.h" #include "hybrid/common/npu_memory_allocator.h" #include "hybrid/common/tensor_value.h" #include "hybrid/executor/hybrid_profiler.h" @@ -33,34 +34,26 @@ namespace hybrid { struct GraphExecutionContext { uint64_t session_id = 0; const HybridModel *model = nullptr; - NodeDoneManager cv_manager; - BlockingQueue compile_queue; - BlockingQueue execution_queue; - std::vector all_inputs; - std::vector all_outputs; - std::unordered_map node_states; rtStream_t stream = nullptr; + rtContext_t rt_context = nullptr; + rtContext_t rt_gen_context = nullptr; std::unique_ptr callback_manager; NpuMemoryAllocator *allocator = nullptr; mutable std::unique_ptr profiler; bool trace_enabled = false; - int profiling_level = 0; + long profiling_level = 0; bool dump_enabled = false; - Status status = SUCCESS; - std::mutex mu_; - - NodeStatePtr GetOrCreateNodeState(const NodePtr &node); - void OnError(Status status); - Status GetStatus(); + long iteration = 0; }; -#define RECORD_PROFILING_EVENT(context, event_type, fmt, category, node_name, ...) \ +#define RECORD_PROFILING_EVENT(context, evt_type, fmt, category, node_name, ...) \ do { \ if ((context)->profiler != nullptr) { \ if (node_name != nullptr) { \ - context->profiler->RecordEvent(event_type, "[%s] [%s] " fmt, node_name, category, ##__VA_ARGS__); \ + context->profiler->RecordEvent(evt_type, "tid:%lu [%s] [%s] " fmt, GetTid(), node_name, category, \ + ##__VA_ARGS__); \ } else { \ - context->profiler->RecordEvent(event_type, "[%s] " fmt, category, ##__VA_ARGS__); \ + context->profiler->RecordEvent(evt_type, "tid:%lu [%s] " fmt, GetTid(), category, ##__VA_ARGS__); \ } \ } \ } while (0) @@ -79,7 +72,6 @@ struct GraphExecutionContext { #define RECORD_CALLBACK_EVENT(context, name, fmt, ...) \ RECORD_PROFILING_EVENT((context), HybridProfiler::CALLBACK, fmt, "Callback", name, ##__VA_ARGS__) - } // namespace hybrid } // namespace ge #endif // GE_HYBRID_EXECUTOR_HYBRID_EXECUTION_CONTEXT_H_ diff --git a/src/ge/hybrid/executor/hybrid_model_async_executor.cc b/src/ge/hybrid/executor/hybrid_model_async_executor.cc index bd5d77f7..7f650017 100644 --- a/src/ge/hybrid/executor/hybrid_model_async_executor.cc +++ b/src/ge/hybrid/executor/hybrid_model_async_executor.cc @@ -77,19 +77,18 @@ Status HybridModelAsyncExecutor::Init() { GE_CHECK_NOTNULL(data_inputer_); GE_CHK_RT_RET(rtStreamCreate(&stream_, RT_STREAM_PRIORITY_DEFAULT)); - engine_ = std::unique_ptr(new (std::nothrow) HybridModelExecutor(model_, device_id_, stream_)); - GE_CHECK_NOTNULL(engine_); - GE_CHK_STATUS_RET(engine_->Init(), "Failed to init hybrid engine"); - + executor_ = std::unique_ptr(new (std::nothrow) HybridModelExecutor(model_, device_id_, stream_)); + GE_CHECK_NOTNULL(executor_); + GE_CHK_STATUS_RET(executor_->Init(), "Failed to init hybrid engine"); GE_CHK_STATUS_RET(InitInputTensors(), "Failed to init input tensors"); return SUCCESS; } Status HybridModelAsyncExecutor::PreRun(InputData ¤t_data) { GE_CHK_STATUS_RET(SyncVarData(), "Failed to sync var data"); - RECORD_MODEL_EXECUTION_EVENT(engine_->GetContext(), "[SyncVarData] End"); + RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[SyncVarData] End"); GE_CHK_STATUS_RET(CopyInputData(current_data), "Failed to copy input data to model"); - RECORD_MODEL_EXECUTION_EVENT(engine_->GetContext(), "[CopyInputData] End"); + RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[CopyInputData] End"); return SUCCESS; } @@ -119,21 +118,21 @@ Status HybridModelAsyncExecutor::RunInternal() { args.inputs[it.first] = it.second; } - RECORD_MODEL_EXECUTION_EVENT(engine_->GetContext(), "[RunInternal] [iteration = %d] Start", iterator_count_); + RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[RunInternal] [iteration = %d] Start", iterator_count_); ret = PreRun(current_data); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - ret != SUCCESS, (void)HandleResult(ret, current_data.index, args.outputs, data_wrapper->GetOutput()); + ret != SUCCESS, (void)HandleResult(ret, current_data.index, args, data_wrapper->GetOutput()); CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); continue, "PreRun failed."); // [No need to check value] - ret = engine_->Execute(args); - ret = HandleResult(ret, current_data.index, args.outputs, data_wrapper->GetOutput()); + ret = executor_->Execute(args); + ret = HandleResult(ret, current_data.index, args, data_wrapper->GetOutput()); if (ret != SUCCESS) { CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); continue; } - RECORD_MODEL_EXECUTION_EVENT(engine_->GetContext(), "[RunInternal] [iteration = %d] End", iterator_count_); + RECORD_MODEL_EXECUTION_EVENT(executor_->GetContext(), "[RunInternal] [iteration = %d] End", iterator_count_); iterator_count_++; GELOGI("run iterator count is %lu", iterator_count_); } @@ -143,8 +142,8 @@ Status HybridModelAsyncExecutor::RunInternal() { return SUCCESS; } -Status HybridModelAsyncExecutor::HandleResult(Status exec_ret, uint32_t data_id, - const std::vector &output_tensors, OutputData *output_data) { +Status HybridModelAsyncExecutor::HandleResult(Status exec_ret, uint32_t data_id, HybridModelExecutor::ExecuteArgs &args, + OutputData *output_data) { GELOGD("Start to handle result. model id = %u, data index = %u, execution ret = %u", model_id_, data_id, exec_ret); std::vector output_tensor_info_list; if (exec_ret == END_OF_SEQUENCE) { @@ -158,7 +157,7 @@ Status HybridModelAsyncExecutor::HandleResult(Status exec_ret, uint32_t data_id, } GE_CHECK_NOTNULL(output_data); - auto ret = CopyOutputs(output_tensors, output_data, output_tensor_info_list); + auto ret = CopyOutputs(args, output_data, output_tensor_info_list); if (ret != SUCCESS) { OnComputeDone(data_id, INTERNAL_ERROR, output_tensor_info_list); return INTERNAL_ERROR; @@ -215,9 +214,8 @@ Status HybridModelAsyncExecutor::CopyInputData(const InputData ¤t_data) { Status HybridModelAsyncExecutor::InitInputTensors() { auto allocator = NpuMemoryAllocator::GetAllocator(device_id_); GE_CHECK_NOTNULL(allocator); - for (const auto &it : model_->input_nodes_) { - auto input_index = it.first; - auto input_node = it.second; + int input_index = 0; + for (const auto &input_node : model_->GetRootGraphItem()->GetInputNodes()) { GELOGD("Init input[%u], node = %s", input_index, input_node->NodeName().c_str()); auto output_desc = input_node->op_desc->GetOutputDescPtr(kDataOutputIndex); GE_CHECK_NOTNULL(output_desc); @@ -235,6 +233,7 @@ Status HybridModelAsyncExecutor::InitInputTensors() { TensorValue tensor(shared_ptr(buffer.release())); tensor.SetName("Input_" + input_node->NodeName()); input_tensors_.emplace(input_index, tensor); + input_index += 1; } return SUCCESS; @@ -250,35 +249,33 @@ Status HybridModelAsyncExecutor::OnComputeDone(uint32_t data_index, uint32_t res return result_code; } -Status HybridModelAsyncExecutor::CopyOutputs(const std::vector &output_tensors, OutputData *output_data, +Status HybridModelAsyncExecutor::CopyOutputs(HybridModelExecutor::ExecuteArgs &args, OutputData *output_data, std::vector &outputs) { // copy output data from op to designated position - NodeItem *net_output_node = model_->net_output_node_; - GE_CHECK_NOTNULL(net_output_node); - auto all_input_desc = net_output_node->op_desc->GetAllInputsDescPtr(); - - if (all_input_desc.size() != output_tensors.size()) { + std::vector &output_tensor_desc_list = args.output_desc; + std::vector &output_tensors = args.outputs; + if (output_tensor_desc_list.size() != output_tensors.size()) { GELOGE(INTERNAL_ERROR, "Output sizes mismatch. From op_desc = %zu, and from output tensors = %zu", - all_input_desc.size(), output_tensors.size()); + output_tensor_desc_list.size(), output_tensors.size()); return INTERNAL_ERROR; } - GELOGD("Number of outputs = %zu", all_input_desc.size()); + GELOGD("Number of outputs = %zu", output_tensor_desc_list.size()); for (size_t i = 0; i < output_tensors.size(); ++i) { GELOGD("Start to process output[%zu]", i); auto &output_tensor = output_tensors[i]; - auto &tensor_desc = all_input_desc.at(i); + auto &tensor_desc = output_tensor_desc_list.at(i); GE_CHECK_NOTNULL(tensor_desc); int64_t output_size = -1; - GE_CHK_GRAPH_STATUS_RET(TensorUtils::CalcTensorMemSize(tensor_desc->MutableShape(), tensor_desc->GetFormat(), + GE_CHK_GRAPH_STATUS_RET(TensorUtils::CalcTensorMemSize(tensor_desc->GetShape(), tensor_desc->GetFormat(), tensor_desc->GetDataType(), output_size), "Failed to calc tensor size for output[%zu]. shape = [%s], type = %s, format = %s", i, - tensor_desc->MutableShape().ToString().c_str(), + tensor_desc->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()).c_str(), TypeUtils::FormatToSerialString(tensor_desc->GetFormat()).c_str()); GELOGD("Got tensor size for output[%zu] successfully. shape = [%s], type = %s, format = %s, size = %ld", i, - tensor_desc->MutableShape().ToString().c_str(), + tensor_desc->GetShape().ToString().c_str(), TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()).c_str(), TypeUtils::FormatToSerialString(tensor_desc->GetFormat()).c_str(), output_size); @@ -286,7 +283,7 @@ Status HybridModelAsyncExecutor::CopyOutputs(const std::vector &out GE_CHECK_LE(output_size, UINT32_MAX); if (output_tensor.GetSize() < static_cast(output_size)) { GELOGE(INTERNAL_ERROR, "output[%zu] tensor size(%zu) is not enough for output shape [%s]", i, - output_tensor.GetSize(), tensor_desc->MutableShape().ToString().c_str()); + output_tensor.GetSize(), tensor_desc->GetShape().ToString().c_str()); return INTERNAL_ERROR; } @@ -302,7 +299,7 @@ Status HybridModelAsyncExecutor::CopyOutputs(const std::vector &out output.data = std::move(data_buf); output_data->blobs.emplace_back(data_buf.get(), static_cast(output_size), false); } else { - GELOGW("Output[%zu] is empty. shape = [%s]", i, tensor_desc->MutableShape().ToString().c_str()); + GELOGW("Output[%zu] is empty. shape = [%s]", i, tensor_desc->GetShape().ToString().c_str()); output.data = nullptr; output_data->blobs.emplace_back(nullptr, 0U, false); } @@ -310,7 +307,53 @@ Status HybridModelAsyncExecutor::CopyOutputs(const std::vector &out outputs.emplace_back(std::move(output)); GELOGD("Output[%zu] added, type = %s, shape = [%s], size = %ld", i, TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()).c_str(), - tensor_desc->MutableShape().ToString().c_str(), output_size); + tensor_desc->GetShape().ToString().c_str(), output_size); + } + + return SUCCESS; +} + +Status HybridModelAsyncExecutor::Execute(const vector &inputs, vector &outputs) { + GELOGD("Start to execute model."); + // prepare inputs + InputData input_data; + for (auto &tensor : inputs) { + DataBuffer buffer; + buffer.data = const_cast(tensor.GetData().GetData()); + buffer.length = tensor.GetData().size(); + input_data.blobs.emplace_back(buffer); + } + GE_CHK_STATUS_RET(CopyInputData(input_data), "Failed to copy input data to model"); + GELOGD("Done copying input data successfully."); + + HybridModelExecutor::ExecuteArgs args; + args.inputs.resize(input_tensors_.size()); + args.input_desc.resize(input_tensors_.size()); + for (auto &it : input_tensors_) { + args.inputs[it.first] = it.second; + args.input_desc[it.first] = MakeShared(inputs[it.first].GetTensorDesc()); + } + + GE_CHK_STATUS_RET(executor_->Execute(args), "Failed to execute model."); + + std::vector output_tensor_info_list; + OutputData output_data; + GE_CHK_STATUS_RET(CopyOutputs(args, &output_data, output_tensor_info_list), "Failed to copy outputs."); + GELOGD("Done copying output data successfully. output count = %zu", output_tensor_info_list.size()); + + int out_index = 0; + outputs.resize(output_tensor_info_list.size()); + for (auto &out_tensor_info : output_tensor_info_list) { + auto &ge_tensor = outputs[out_index]; + if (out_tensor_info.length > 0) { + GE_CHK_GRAPH_STATUS_RET(ge_tensor.SetData(out_tensor_info.data.get(), out_tensor_info.length), + "Failed to set output[%d].", out_index); + } + + ge_tensor.MutableTensorDesc() = *args.output_desc[out_index]; + GELOGD("Set output[%d], tensor size = %ld, shape = [%s]", out_index, out_tensor_info.length, + ge_tensor.MutableTensorDesc().MutableShape().ToString().c_str()); + ++out_index; } return SUCCESS; diff --git a/src/ge/hybrid/executor/hybrid_model_async_executor.h b/src/ge/hybrid/executor/hybrid_model_async_executor.h index cb440ba7..195f79a9 100644 --- a/src/ge/hybrid/executor/hybrid_model_async_executor.h +++ b/src/ge/hybrid/executor/hybrid_model_async_executor.h @@ -35,6 +35,8 @@ class HybridModelAsyncExecutor { Status Init(); + Status Execute(const vector &inputs, vector &outputs); + Status Start(const std::shared_ptr &listener); void SetDeviceId(uint32_t device_id); @@ -52,10 +54,10 @@ class HybridModelAsyncExecutor { Status SyncVarData(); - Status HandleResult(Status exec_ret, uint32_t data_id, const std::vector &output_tensors, + Status HandleResult(Status exec_ret, uint32_t data_id, HybridModelExecutor::ExecuteArgs &args, OutputData *output_data); - Status CopyOutputs(const std::vector &output_tensors, OutputData *output_data, + Status CopyOutputs(HybridModelExecutor::ExecuteArgs &args, OutputData *output_data, std::vector &outputs); Status OnComputeDone(uint32_t data_index, uint32_t result_code, std::vector &outputs); @@ -70,7 +72,7 @@ class HybridModelAsyncExecutor { uint32_t model_id_ = 0U; std::atomic_bool run_flag_; std::unique_ptr data_inputer_; - std::unique_ptr engine_; + std::unique_ptr executor_; std::future future_; uint64_t iterator_count_ = 0; diff --git a/src/ge/hybrid/executor/hybrid_model_executor.cc b/src/ge/hybrid/executor/hybrid_model_executor.cc index 856b4483..d62d7be3 100644 --- a/src/ge/hybrid/executor/hybrid_model_executor.cc +++ b/src/ge/hybrid/executor/hybrid_model_executor.cc @@ -26,17 +26,17 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id, Status HybridModelExecutor::Init() { GELOGD("Start to init HybridGraphEngine."); GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); - infer_shape_engine_.reset(new (std::nothrow) ShapeInferenceEngine(&context_)); - compile_engine_.reset(new (std::nothrow) TaskCompileEngine(&context_)); - execute_engine_.reset(new (std::nothrow) ExecutionEngine(&context_, context_.callback_manager.get())); - GE_CHK_STATUS_RET_NOLOG(compile_engine_->Init()); GELOGD("HybridGraphEngine initialized successfully."); return SUCCESS; } Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { GELOGD("Start to execute model."); - auto ret = ExecuteGraphInternal(args); + auto root_graph_item = model_->GetRootGraphItem(); + GE_CHECK_NOTNULL(root_graph_item); + + SubgraphExecutor executor(model_->GetRootGraphItem(), &context_); + auto ret = ExecuteGraphInternal(executor, args); Cleanup(); RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); GE_CHK_STATUS_RET(ret, "Failed to execute model"); @@ -46,24 +46,22 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { context_.profiler->Reset(); } + context_.iteration += 1; return SUCCESS; } -Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArgs &args) { +Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, HybridModelExecutor::ExecuteArgs &args) { RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] Start"); GE_CHK_STATUS_RET_NOLOG(ResetExecutionContext(context_)); RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitContext] End"); - GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs(args, context_)); - RECORD_MODEL_EXECUTION_EVENT(&context_, "[InitInputsAndOutputs] End"); - GE_CHK_STATUS_RET_NOLOG(compile_engine_->Start(pool_)); - RECORD_MODEL_EXECUTION_EVENT(&context_, "[CompileProcess] Started"); - GE_CHK_STATUS_RET_NOLOG(infer_shape_engine_->Start(pool_)); - RECORD_MODEL_EXECUTION_EVENT(&context_, "[InferShapeProcess] Started"); - GE_CHK_STATUS_RET(execute_engine_->Start(), "Run execution engine failed."); - RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecutionProcess] End"); - GE_CHK_STATUS_RET_NOLOG(Synchronize()); + + GE_CHK_STATUS_RET(executor.ExecuteAsync(args.inputs, args.input_desc), "Failed to execute partitioned call."); + RECORD_MODEL_EXECUTION_EVENT(&context_, "[ExecuteAsync] End"); + + GE_CHK_STATUS_RET(executor.Synchronize(), "Failed to sync root graph."); RECORD_MODEL_EXECUTION_EVENT(&context_, "[Synchronize] End"); - GE_CHK_STATUS_RET_NOLOG(GetOutput(args)); + + GE_CHK_STATUS_RET(executor.GetOutputs(args.outputs, args.output_desc), "Failed to get outputs"); RECORD_MODEL_EXECUTION_EVENT(&context_, "[GetOutput] End"); return SUCCESS; } @@ -71,18 +69,16 @@ Status HybridModelExecutor::ExecuteGraphInternal(HybridModelExecutor::ExecuteArg Status HybridModelExecutor::Cleanup() { GELOGD("Start to cleanup."); context_.callback_manager->Destroy(); - context_.cv_manager.Reset(); - context_.node_states.clear(); - context_.all_inputs.clear(); - context_.all_outputs.clear(); - context_.compile_queue.Clear(); - context_.execution_queue.Clear(); RuntimeInferenceContext::DestroyContext(to_string(context_.session_id)); GELOGD("Cleanup successfully."); return SUCCESS; } Status HybridModelExecutor::InitExecutionContext() { + GE_CHK_RT_RET(rtCtxGetCurrent(&context_.rt_context)); + GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); + GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); + context_.stream = stream_; context_.model = model_; context_.session_id = ::ge::GetContext().SessionId(); @@ -94,78 +90,15 @@ Status HybridModelExecutor::InitExecutionContext() { if (IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) { context_.trace_enabled = true; } - return SUCCESS; } Status HybridModelExecutor::ResetExecutionContext(GraphExecutionContext &context) { - auto &model = *context.model; - context.all_inputs.resize(model.TotalInputs()); - context.all_outputs.resize(model.TotalOutputs()); - context.compile_queue.Restart(); - context.execution_queue.Restart(); GE_CHK_STATUS_RET_NOLOG(context.callback_manager->Init()); - - for (auto const_node : model.GetConstNodes()) { - auto weight_tensor = model.GetWeight(const_node); - GE_CHECK_NOTNULL(weight_tensor); - for (auto &dst_aid_and_nid : const_node->outputs[0]) { - auto *dst_node_item = dst_aid_and_nid.second; - auto input_offset = dst_node_item->input_start + dst_aid_and_nid.first; - context.all_inputs[input_offset] = *weight_tensor; - } - } - string ctx_id = std::to_string(context.session_id); RuntimeInferenceContext::DestroyContext(ctx_id); GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::CreateContext(ctx_id), "Failed to Destroy RuntimeInferenceContext"); return SUCCESS; } - -Status HybridModelExecutor::InitInputsAndOutputs(HybridModelExecutor::ExecuteArgs &args, - GraphExecutionContext &context) { - for (const auto &it : model_->GetInputNodes()) { - uint32_t input_index = it.first; - if (input_index >= args.inputs.size()) { - GELOGE(PARAM_INVALID, "Not enough inputs. NumInputs = %zu, but input index = %u", args.inputs.size(), - input_index); - return PARAM_INVALID; - } - - auto node_item = it.second; - auto &input_tensor = args.inputs[input_index]; - GELOGD("Set input tensor[%u] to inputs with index = %d, addr = %p, size = %zu", input_index, node_item->input_start, - input_tensor.GetData(), input_tensor.GetSize()); - context.all_inputs[node_item->input_start] = input_tensor; - } - - for (size_t i = 0; i < model_->GetOutputOffsets().size(); ++i) { - auto offset = model_->GetOutputOffsets()[i]; - if (i < args.outputs.size() && args.outputs[i].GetData() != nullptr) { - GELOGD("Use user allocated output memory. output index = %zu, output offset = %d", i, offset); - context.all_outputs[offset] = args.outputs[i]; - } - } - - return SUCCESS; -} - -Status HybridModelExecutor::Synchronize() { - GE_CHK_RT_RET(rtStreamSynchronize(stream_)); - return SUCCESS; -} - -Status HybridModelExecutor::GetOutput(HybridModelExecutor::ExecuteArgs &args) { - auto &net_output_input_offsets = model_->GetNetOutputInputOffsets(); - auto num_outputs = net_output_input_offsets.size(); - args.outputs.resize(num_outputs); - for (size_t i = 0; i < num_outputs; ++i) { - auto offset = net_output_input_offsets[i]; - GELOGI("Get output[%zu] from offset %d", i, offset); - args.outputs[i] = context_.all_inputs[offset]; - } - - return SUCCESS; -} } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/executor/hybrid_model_executor.h b/src/ge/hybrid/executor/hybrid_model_executor.h index 2bda6331..9996dbe0 100644 --- a/src/ge/hybrid/executor/hybrid_model_executor.h +++ b/src/ge/hybrid/executor/hybrid_model_executor.h @@ -20,9 +20,7 @@ #include "graph/load/new_model_manager/data_inputer.h" #include "hybrid/executor/hybrid_execution_context.h" #include "hybrid/executor/rt_callback_manager.h" -#include "hybrid/executor/worker/execution_engine.h" -#include "hybrid/executor/worker/shape_inference_engine.h" -#include "hybrid/executor/worker/task_compile_engine.h" +#include "hybrid/executor/subgraph_executor.h" namespace ge { namespace hybrid { @@ -30,7 +28,9 @@ class HybridModelExecutor { public: struct ExecuteArgs { std::vector inputs; + std::vector input_desc; std::vector outputs; + std::vector output_desc; }; HybridModelExecutor(HybridModel *model, uint32_t device_id, rtStream_t stream); @@ -44,24 +44,15 @@ class HybridModelExecutor { Status Execute(ExecuteArgs &args); private: - Status ExecuteGraphInternal(ExecuteArgs &args); + Status ExecuteGraphInternal(SubgraphExecutor &executor, ExecuteArgs &args); Status Cleanup(); Status InitExecutionContext(); static Status ResetExecutionContext(GraphExecutionContext &context); - Status InitInputsAndOutputs(ExecuteArgs &args, GraphExecutionContext &context); - Status GetOutput(ExecuteArgs &args); - - Status Synchronize(); - - ThreadPool pool_; HybridModel *model_; uint32_t device_id_; rtStream_t stream_; GraphExecutionContext context_; - std::unique_ptr infer_shape_engine_; - std::unique_ptr compile_engine_; - std::unique_ptr execute_engine_; }; } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/executor/hybrid_profiler.cc b/src/ge/hybrid/executor/hybrid_profiler.cc index 1081a144..4c70e043 100644 --- a/src/ge/hybrid/executor/hybrid_profiler.cc +++ b/src/ge/hybrid/executor/hybrid_profiler.cc @@ -59,11 +59,10 @@ void HybridProfiler::Dump(std::ostream &output_stream) { auto first_evt = events_[0]; auto start = first_evt.timestamp; - output_stream << "Start " << first_evt.desc << std::endl; std::vector prev_timestamps; prev_timestamps.resize(kMaxEventTypes, start); - for (int i = 1; i < counter_; ++i) { + for (int i = 0; i < counter_; ++i) { auto &evt = events_[i]; auto elapsed = std::chrono::duration_cast(evt.timestamp - start).count(); auto &prev_ts = prev_timestamps[evt.event_type]; diff --git a/src/ge/hybrid/executor/node_done_manager.cc b/src/ge/hybrid/executor/node_done_manager.cc index dfeddb5b..3ec45339 100644 --- a/src/ge/hybrid/executor/node_done_manager.cc +++ b/src/ge/hybrid/executor/node_done_manager.cc @@ -15,35 +15,49 @@ */ #include "hybrid/executor/node_done_manager.h" +#include #include "framework/common/debug/ge_log.h" namespace ge { namespace hybrid { +namespace { +constexpr int kDefaultWaitTimeoutInSec = 10; +} bool NodeDoneManager::Cond::Await() { - std::unique_lock lk(mu_); - cv_.wait(lk, [&]() { return is_released_ || is_cancelled_; }); + std::unique_lock lk(cond_mu_); + if (!cv_.wait_for(lk, std::chrono::seconds(kDefaultWaitTimeoutInSec), + [&]() { return is_released_ || is_cancelled_; })) { + GELOGE(INTERNAL_ERROR, "Wait timed out."); + return false; + } + return is_released_; } void NodeDoneManager::Cond::Release() { - std::unique_lock lk(mu_); + std::unique_lock lk(cond_mu_); is_released_ = true; cv_.notify_all(); } void NodeDoneManager::Cond::Cancel() { - std::unique_lock lk(mu_); + std::unique_lock lk(cond_mu_); is_cancelled_ = true; cv_.notify_all(); } bool NodeDoneManager::Cond::IsRelease() { - std::unique_lock lk(mu_); + std::unique_lock lk(cond_mu_); return is_released_; } NodeDoneManager::Cond *NodeDoneManager::GetSubject(const NodePtr &node) { std::lock_guard lk(mu_); + if (destroyed_) { + GELOGD("Already destroyed."); + return nullptr; + } + auto it = subjects_.find(node); if (it == subjects_.end()) { return &subjects_[node]; @@ -52,8 +66,10 @@ NodeDoneManager::Cond *NodeDoneManager::GetSubject(const NodePtr &node) { return &it->second; } -void NodeDoneManager::Reset() { +void NodeDoneManager::Destroy() { + GELOGD("Start to reset NodeDoneManager."); std::lock_guard lk(mu_); + GELOGD("Cond size = %zu.", subjects_.size()); for (auto &sub : subjects_) { if (!sub.second.IsRelease()) { sub.second.Cancel(); @@ -62,15 +78,24 @@ void NodeDoneManager::Reset() { } subjects_.clear(); + destroyed_ = true; + GELOGD("Done resetting NodeDoneManager successfully."); } void NodeDoneManager::NodeDone(const NodePtr &node) { - GetSubject(node)->Release(); - GELOGD("[%s] Node released.", node->GetName().c_str()); + auto sub = GetSubject(node); + if (sub != nullptr) { + sub->Release(); + GELOGD("[%s] Node released.", node->GetName().c_str()); + } } bool NodeDoneManager::Await(const NodePtr &node) { auto sub = GetSubject(node); + if (sub == nullptr) { + return false; + } + GELOGD("[%s] Await start. is_released = %s", node->GetName().c_str(), sub->IsRelease() ? "true" : "false"); bool ret = sub->Await(); GELOGD("[%s] Await ended. is_released = %s", node->GetName().c_str(), sub->IsRelease() ? "true" : "false"); diff --git a/src/ge/hybrid/executor/node_done_manager.h b/src/ge/hybrid/executor/node_done_manager.h index ccf263d1..f1fdfbec 100644 --- a/src/ge/hybrid/executor/node_done_manager.h +++ b/src/ge/hybrid/executor/node_done_manager.h @@ -31,7 +31,7 @@ class NodeDoneManager { bool Await(const NodePtr &node); - void Reset(); + void Destroy(); private: class Cond { @@ -42,7 +42,7 @@ class NodeDoneManager { bool Await(); private: - std::mutex mu_; + std::mutex cond_mu_; std::condition_variable cv_; bool is_released_ = false; bool is_cancelled_ = false; @@ -51,6 +51,7 @@ class NodeDoneManager { Cond *GetSubject(const NodePtr &node); std::mutex mu_; std::unordered_map subjects_; + bool destroyed_ = false; }; } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/executor/node_state.cc b/src/ge/hybrid/executor/node_state.cc index 6895f158..c78dd725 100644 --- a/src/ge/hybrid/executor/node_state.cc +++ b/src/ge/hybrid/executor/node_state.cc @@ -15,13 +15,136 @@ */ #include "hybrid/executor/node_state.h" +#include +#include "framework/common/debug/log.h" #include "graph/compute_graph.h" +#include "hybrid_execution_context.h" +#include "subgraph_context.h" namespace ge { namespace hybrid { -NodeState::NodeState(const NodeItem &node_item) { - this->node_item = &node_item; - this->op_desc = node_item.node->GetOpDesc(); +namespace { +constexpr auto kMaxWaitTimeInSec = 10; +} +ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item(node_item) { + this->num_pending_shapes_ = node_item.num_inputs - node_item.num_static_input_shapes; + GELOGD("[%s] ShapeInferenceState created, pending shape count = %d", node_item.NodeName().c_str(), + this->num_pending_shapes_); +} + +void ShapeInferenceState::UpdateInputShape(uint32_t idx, const GeShape &ori_shape, const GeShape &shape) { + if (!node_item.is_dynamic || node_item.is_input_shape_static[idx]) { + GELOGD("[%s] Trying to update static shape, idx = %u. old shape = [%s], new shape = [%s]", + node_item.NodeName().c_str(), idx, node_item.op_desc->MutableInputDesc(idx)->GetShape().ToString().c_str(), + shape.ToString().c_str()); + return; + } + + GELOGD("[%s] Update input shape [%u] with Shape: [%s] and OriginalShape: [%s]", node_item.NodeName().c_str(), idx, + shape.ToString().c_str(), ori_shape.ToString().c_str()); + + std::lock_guard lk(mu_); + node_item.op_desc->MutableInputDesc(idx)->SetShape(shape); + node_item.op_desc->MutableInputDesc(idx)->SetOriginShape(ori_shape); + if (--num_pending_shapes_ == 0) { + ready_cv_.notify_all(); + } +} + +void ShapeInferenceState::UpdateInputShapeFuture(uint32_t idx, ShapeFuture &&future) { + if (!node_item.is_dynamic || node_item.is_input_shape_static[idx]) { + GELOGD("[%s] Trying to update constant shape, idx = %u", node_item.NodeName().c_str(), idx); + return; + } + + GELOGD("[%s] Update input shape [%u] with ShapeFuture.", node_item.NodeName().c_str(), idx); + std::lock_guard lk(mu_); + shape_futures.emplace_back(idx, std::move(future)); + if (--num_pending_shapes_ == 0) { + ready_cv_.notify_all(); + } +} + +Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &context) { + if (!node_item.is_dynamic) { + return SUCCESS; + } + std::unique_lock lk(mu_); + if (num_pending_shapes_ > 0) { + GELOGD("[%s] Await pending shape or shape future start.", node_item.NodeName().c_str()); + if (!ready_cv_.wait_for(lk, std::chrono::seconds(kMaxWaitTimeInSec), [&]() { return num_pending_shapes_ == 0; })) { + GELOGE(INTERNAL_ERROR, "[%s] Wait for shape timeout.", node_item.NodeName().c_str()); + return INTERNAL_ERROR; + } + GELOGD("[%s] Await pending shape or shape future end.", node_item.NodeName().c_str()); + } + + for (auto &p : shape_futures) { + auto idx = p.first; + auto &future = p.second; + GeShape shape; + GeShape ori_shape; + RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx); + GE_CHK_STATUS_RET(future.Get(ori_shape, shape), "[%s] Get shape failed. index = %u", node_item.NodeName().c_str(), + idx); + RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); + + GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s]", node_item.NodeName().c_str(), idx, + shape.ToString().c_str(), ori_shape.ToString().c_str()); + node_item.op_desc->MutableInputDesc(idx)->SetShape(std::move(shape)); + node_item.op_desc->MutableInputDesc(idx)->SetOriginShape(ori_shape); + } + + return SUCCESS; +} + +ShapeFuture::ShapeFuture(NodePtr src_node, uint32_t src_index, SubgraphContext *subgraph_context) + : src_node_(std::move(src_node)), src_index_(src_index), subgraph_context_(subgraph_context) {} + +NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context) + : node_item_(&node_item), shape_inference_state_(node_item), subgraph_context_(subgraph_context) { + this->op_desc_ = node_item.node->GetOpDesc(); +} + +Status NodeState::AwaitInputTensors(GraphExecutionContext &context) const { + for (auto &src_node : node_item_->dependents_for_execution) { + GELOGI("[%s] Start to wait for data dependent node: [%s]", node_item_->NodeName().c_str(), + src_node->GetName().c_str()); + RECORD_EXECUTION_EVENT(&context, node_item_->NodeName().c_str(), "[AwaitNodeDone] [%s] Start", + src_node->GetName().c_str()); + if (!subgraph_context_->Await(src_node)) { + GELOGE(INTERNAL_ERROR, "[%s] Await node [%s] failed.", GetName().c_str(), src_node->GetName().c_str()); + return INTERNAL_ERROR; + } + + RECORD_EXECUTION_EVENT(&context, node_item_->NodeName().c_str(), "[AwaitNodeDone] [%s] End", + src_node->GetName().c_str()); + GELOGI("[%s] Done waiting node.", src_node->GetName().c_str()); + } + + return SUCCESS; +} + +Status NodeState::WaitForPrepareDone() { + if (prepare_future_.valid()) { + GELOGD("[%s] Start to wait for prepare future.", GetName().c_str()); + GE_CHK_STATUS_RET(prepare_future_.get(), "[%s] PreRun failed.", GetName().c_str()); + } + + return SUCCESS; +} + +Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) { + GELOGI("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); + if (!subgraph_context_->Await(src_node_)) { + GELOGE(INTERNAL_ERROR, "cancelled"); + return INTERNAL_ERROR; + } + + shape = src_node_->GetOpDesc()->MutableOutputDesc(src_index_)->MutableShape(); + ori_shape = src_node_->GetOpDesc()->MutableOutputDesc(src_index_)->GetOriginShape(); + GELOGI("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str()); + return SUCCESS; } } // namespace hybrid -} // namespace ge \ No newline at end of file +} // namespace ge diff --git a/src/ge/hybrid/executor/node_state.h b/src/ge/hybrid/executor/node_state.h index b2811bcb..73e0f75c 100644 --- a/src/ge/hybrid/executor/node_state.h +++ b/src/ge/hybrid/executor/node_state.h @@ -17,38 +17,83 @@ #ifndef GE_HYBRID_EXECUTOR_NODE_STATE_H_ #define GE_HYBRID_EXECUTOR_NODE_STATE_H_ +#include +#include +#include +#include "external/ge/ge_api_error_codes.h" #include "hybrid/model/node_item.h" +#include "node_done_manager.h" namespace ge { namespace hybrid { - class NodeTask; +class GraphExecutionContext; +class SubgraphContext; -// 存放一些会å˜åŒ–çš„ä¿¡æ¯... -class NodeState { +class ShapeFuture { public: - NodeState() = default; - explicit NodeState(const NodeItem &node_item); + ShapeFuture(NodePtr src_node, uint32_t src_index, SubgraphContext *subgraph_context); + ~ShapeFuture() = default; + Status Get(GeShape &ori_shape, GeShape &shape); + + private: + NodePtr src_node_; + uint32_t src_index_; + SubgraphContext *subgraph_context_; +}; + +struct ShapeInferenceState { + explicit ShapeInferenceState(const NodeItem &node_item); + + void UpdateInputShape(uint32_t idx, const GeShape &ori_shape, const GeShape &shape); + + void UpdateInputShapeFuture(uint32_t idx, ShapeFuture &&future); + + Status AwaitShapesReady(const GraphExecutionContext &context); + + const NodeItem &node_item; + + private: + std::vector> shape_futures; + int num_pending_shapes_ = 0; + std::condition_variable ready_cv_; + std::mutex mu_; +}; + +// saving sth. dynamic during execution +struct NodeState { + public: + NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context); ~NodeState() = default; - inline int NodeId() const { return node_item->node_id; } + OpDesc *GetOpDesc() const { return op_desc_.get(); } + + inline const NodeItem *GetNodeItem() const { return node_item_; } + + inline const string &GetName() const { return node_item_->NodeName(); } + + inline const string &GetType() const { return node_item_->NodeType(); } - inline Node *GetNode() const { return node_item->node.get(); } + ShapeInferenceState &GetShapeInferenceState() { return shape_inference_state_; } - OpDesc *GetOpDesc() const { return op_desc.get(); } + const shared_ptr &GetKernelTask() const { return kernel_task_; } - inline const NodeItem *GetNodeItem() const { return node_item; } + void SetKernelTask(const shared_ptr &kernel_task) { kernel_task_ = kernel_task; } - inline const string &GetName() const { return node_item->NodeName(); } + Status WaitForPrepareDone(); - inline const string &GetType() const { return node_item->NodeType(); } + void SetPrepareFuture(std::future &&prepare_future) { this->prepare_future_ = std::move(prepare_future); } - // private: - const NodeItem *node_item = nullptr; - std::shared_ptr kernel_task = nullptr; + Status AwaitInputTensors(GraphExecutionContext &context) const; - bool is_compiled = false; - OpDescPtr op_desc; + private: + const NodeItem *node_item_ = nullptr; + std::shared_ptr kernel_task_ = nullptr; + std::future prepare_future_; + OpDescPtr op_desc_; + ShapeInferenceState shape_inference_state_; + SubgraphContext *subgraph_context_; + std::mutex mu_; }; using NodeStatePtr = std::shared_ptr; diff --git a/src/ge/hybrid/executor/rt_callback_manager.cc b/src/ge/hybrid/executor/rt_callback_manager.cc index 6be8da31..c1c98f73 100644 --- a/src/ge/hybrid/executor/rt_callback_manager.cc +++ b/src/ge/hybrid/executor/rt_callback_manager.cc @@ -42,7 +42,6 @@ Status CallbackManager::Init() { rtContext_t ctx = nullptr; GE_CHK_RT_RET(rtCtxGetCurrent(&ctx)); ret_future_ = std::async([&](rtContext_t context) -> Status { return CallbackProcess(context); }, ctx); - if (!ret_future_.valid()) { GELOGE(INTERNAL_ERROR, "Failed to init callback manager."); return INTERNAL_ERROR; diff --git a/src/ge/hybrid/executor/subgraph_context.cc b/src/ge/hybrid/executor/subgraph_context.cc new file mode 100644 index 00000000..395e75de --- /dev/null +++ b/src/ge/hybrid/executor/subgraph_context.cc @@ -0,0 +1,113 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "subgraph_context.h" + +#include "common/debug/log.h" + +namespace ge { +namespace hybrid { +SubgraphContext::SubgraphContext(const GraphItem *graph_item) : graph_item_(graph_item) {} + +Status SubgraphContext::Init() { + GE_CHECK_NOTNULL(graph_item_); + GELOGD("[%s] Start to init subgraph context. total inputs = %d, total outputs = %d", graph_item_->GetName().c_str(), + graph_item_->TotalInputs(), graph_item_->TotalOutputs()); + all_inputs_.resize(graph_item_->TotalInputs()); + all_outputs_.resize(graph_item_->TotalOutputs()); + + return SUCCESS; +} + +NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { + std::lock_guard lk(mu_); + auto &node_state = node_states_[node_item]; + if (node_state == nullptr) { + node_state.reset(new (std::nothrow) NodeState(*node_item, this)); + } + + return node_state; +} + +Status SubgraphContext::SetInput(int index, const TensorValue &tensor) { + if (static_cast(index) >= all_inputs_.size()) { + GELOGE(INTERNAL_ERROR, "output index output range. all input num = %zu, input index = %d", all_inputs_.size(), + index); + return INTERNAL_ERROR; + } + + all_inputs_[index] = tensor; + return SUCCESS; +} + +Status SubgraphContext::SetInput(const NodeItem &node_item, int input_index, const TensorValue &tensor) { + auto index = node_item.input_start + input_index; + return SetInput(index, tensor); +} + +Status SubgraphContext::SetOutput(const NodeItem &node_item, int output_index, const TensorValue &tensor) { + auto index = node_item.output_start + output_index; + if (output_index >= node_item.num_outputs || static_cast(index) >= all_outputs_.size()) { + GELOGE(INTERNAL_ERROR, "output index output range. all output num = %zu, node_item = %s, output index = %d", + all_outputs_.size(), node_item.DebugString().c_str(), output_index); + return INTERNAL_ERROR; + } + + all_outputs_[index] = tensor; + return SUCCESS; +} + +Status SubgraphContext::GetInput(int index, TensorValue &tensor) { + GE_CHECK_GE(all_inputs_.size(), index + 1U); + tensor = all_inputs_[index]; + return SUCCESS; +} + +Status SubgraphContext::GetOutputs(std::vector &outputs) { + if (graph_item_->IsDynamic()) { + GELOGD("[%s] graph is dynamic, get outputs from net output input tensors", graph_item_->GetName().c_str()); + // get from net output inputs + auto output_node = graph_item_->GetOutputNode(); + if (output_node != nullptr) { + for (int i = 0; i < output_node->num_inputs; ++i) { + TensorValue tensor; + GE_CHK_STATUS_RET_NOLOG(GetInput(output_node->input_start + i, tensor)); + GELOGD("[%s] Adding output tensor by input index [%d], tensor = %s", graph_item_->GetName().c_str(), + output_node->input_start + i, tensor.DebugString().c_str()); + outputs.emplace_back(std::move(tensor)); + } + } + } else { + GELOGD("[%s] graph is non-dynamic, get outputs from subgraph outputs", graph_item_->GetName().c_str()); + for (auto &tensor : all_outputs_) { + GELOGD("[%s] Adding output tensor: %s", graph_item_->GetName().c_str(), tensor.DebugString().c_str()); + outputs.emplace_back(tensor); + } + } + + return SUCCESS; +} + +bool SubgraphContext::Await(const NodePtr &node) { return node_done_manager_.Await(node); } + +void SubgraphContext::OnError(Status error) { + GELOGE(error, "[%s] Error occurred while executing graph.", graph_item_->GetName().c_str()); + node_done_manager_.Destroy(); +} + +void SubgraphContext::NodeDone(const NodePtr &node) { node_done_manager_.NodeDone(node); } +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/executor/subgraph_context.h b/src/ge/hybrid/executor/subgraph_context.h new file mode 100644 index 00000000..fd934d80 --- /dev/null +++ b/src/ge/hybrid/executor/subgraph_context.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_EXECUTOR_ITERATION_CONTEXT_H_ +#define GE_HYBRID_EXECUTOR_ITERATION_CONTEXT_H_ + +#include + +#include "hybrid/common/tensor_value.h" +#include "hybrid/executor/node_state.h" +#include "hybrid/executor/node_done_manager.h" +#include "hybrid/model/graph_item.h" +#include "hybrid/model/node_item.h" + +namespace ge { +namespace hybrid { +class SubgraphContext { + public: + explicit SubgraphContext(const GraphItem *graph_item); + ~SubgraphContext() = default; + + Status Init(); + NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item); + + void OnError(Status error); + + Status SetInput(const NodeItem &node_item, int input_index, const TensorValue &tensor); + Status SetOutput(const NodeItem &node_item, int output_index, const TensorValue &tensor); + Status SetInput(int index, const TensorValue &tensor); + Status GetInput(int index, TensorValue &tensor); + Status GetOutputs(std::vector &outputs); + + bool Await(const NodePtr &node); + void NodeDone(const NodePtr &node); + + private: + friend class TaskContext; + const GraphItem *graph_item_; + std::mutex mu_; + std::vector all_inputs_; + std::vector all_outputs_; + NodeDoneManager node_done_manager_; + std::unordered_map node_states_; +}; +} // namespace hybrid +} // namespace ge + +#endif // GE_HYBRID_EXECUTOR_ITERATION_CONTEXT_H_ diff --git a/src/ge/hybrid/executor/subgraph_executor.cc b/src/ge/hybrid/executor/subgraph_executor.cc new file mode 100644 index 00000000..7664e90d --- /dev/null +++ b/src/ge/hybrid/executor/subgraph_executor.cc @@ -0,0 +1,373 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/executor/subgraph_executor.h" +#include "hybrid/executor/worker/task_compile_engine.h" +#include "hybrid/executor/worker/execution_engine.h" +#include "hybrid/node_executor/node_executor.h" + +namespace ge { +namespace hybrid { +namespace { +constexpr int kDefaultThreadNum = 4; +constexpr int kDataInputIndex = 0; +} // namespace + +SubgraphExecutor::SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape) + : graph_item_(graph_item), + context_(context), + force_infer_shape_(force_infer_shape), + pre_run_pool_(kDefaultThreadNum) {} + +SubgraphExecutor::~SubgraphExecutor() { GELOGD("[%s] SubgraphExecutor destroyed.", graph_item_->GetName().c_str()); } + +Status SubgraphExecutor::Init(const std::vector &inputs, + const std::vector &input_desc) { + subgraph_context_.reset(new (std::nothrow) SubgraphContext(graph_item_)); + GE_CHECK_NOTNULL(subgraph_context_); + GE_CHK_STATUS_RET(subgraph_context_->Init(), "[%s] Failed to init subgraph context.", graph_item_->GetName().c_str()); + + shape_inference_engine_.reset(new (std::nothrow) ShapeInferenceEngine(context_, subgraph_context_.get())); + GE_CHECK_NOTNULL(shape_inference_engine_); + + if (graph_item_->IsDynamic()) { + GE_CHK_STATUS_RET(InitInputsForUnknownShape(inputs, input_desc), "[%s] Failed to set inputs.", + graph_item_->GetName().c_str()); + } else { + GE_CHK_STATUS_RET(InitInputsForKnownShape(inputs), + "[%s] Failed to init subgraph executor for known shape subgraph.", + graph_item_->GetName().c_str()); + } + + return SUCCESS; +} + +Status SubgraphExecutor::InitInputsForUnknownShape(const std::vector &inputs, + const std::vector &input_desc) { + // Number of inputs of parent node should be greater or equal than that of subgraph + auto input_nodes = graph_item_->GetInputNodes(); + if (inputs.size() < input_nodes.size()) { + GELOGE(INTERNAL_ERROR, "[%s] Number of inputs [%zu] is not sufficient for subgraph which needs [%zu] inputs.", + graph_item_->GetName().c_str(), inputs.size(), input_nodes.size()); + return INTERNAL_ERROR; + } + + for (size_t i = 0; i < input_nodes.size(); ++i) { + auto &input_node = input_nodes[i]; + if (input_node == nullptr) { + GELOGD("[%s] Input[%zu] is not needed by subgraph, skip it.", graph_item_->GetName().c_str(), i); + continue; + } + + auto &input_tensor = inputs[i]; + GELOGD("[%s] Set input tensor[%zu] to inputs with index = %d, tensor = %s", graph_item_->GetName().c_str(), i, + input_node->input_start, input_tensor.DebugString().c_str()); + + GE_CHK_STATUS_RET(subgraph_context_->SetInput(*input_node, kDataInputIndex, input_tensor), + "[%s] Failed to set input tensor[%zu]", graph_item_->GetName().c_str(), i); + + if (force_infer_shape_ || input_node->is_dynamic) { + GELOGD("[%s] Start to update input[%zu] for subgraph data node.", graph_item_->GetName().c_str(), i); + GE_CHECK_LE(i + 1, input_desc.size()); + const auto &tensor_desc = input_desc[i]; + auto node_state = subgraph_context_->GetOrCreateNodeState(input_node); + GE_CHECK_NOTNULL(node_state); + node_state->GetShapeInferenceState().UpdateInputShape(0, tensor_desc->GetOriginShape(), tensor_desc->GetShape()); + } + } + + GELOGD("[%s] Done setting inputs.", graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::InitInputsForKnownShape(const std::vector &inputs) { + auto &input_index_mapping = graph_item_->GetInputIndexMapping(); + for (size_t i = 0; i < input_index_mapping.size(); ++i) { + auto &parent_input_index = input_index_mapping[i]; + if (static_cast(parent_input_index) >= inputs.size()) { + GELOGE(INTERNAL_ERROR, + "[%s] Number of inputs [%zu] is not sufficient for subgraph which needs at lease [%d] inputs", + graph_item_->GetName().c_str(), inputs.size(), parent_input_index + 1); + + return INTERNAL_ERROR; + } + + auto &input_tensor = inputs[parent_input_index]; + subgraph_context_->SetInput(i, input_tensor); + GELOGD("[%s] Set input tensor[%zu] with inputs with index = %d, tensor = %s", graph_item_->GetName().c_str(), i, + parent_input_index, input_tensor.DebugString().c_str()); + } + + return SUCCESS; +} + +Status SubgraphExecutor::ExecuteAsync(const std::vector &inputs, + const std::vector &input_desc) { + GELOGD("[%s] is dynamic = %s", graph_item_->GetName().c_str(), graph_item_->IsDynamic() ? "true" : "false"); + GE_CHK_STATUS_RET(Init(inputs, input_desc), "[%s] Failed to init executor.", graph_item_->GetName().c_str()); + + if (!graph_item_->IsDynamic()) { + return ExecuteAsyncForKnownShape(inputs); + } + + GE_CHK_STATUS_RET(ScheduleTasks(), "[%s] Failed to execute tasks.", graph_item_->GetName().c_str()); + GELOGD("[%s] Done executing subgraph successfully.", graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector &inputs) { + GELOGD("[%s] subgraph is not dynamic.", graph_item_->GetName().c_str()); + if (graph_item_->GetAllNodes().size() != 1) { + GELOGE(INTERNAL_ERROR, "[%s] Invalid known shape subgraph. node size = %zu", graph_item_->GetName().c_str(), + graph_item_->GetAllNodes().size()); + return INTERNAL_ERROR; + } + + auto node_item = graph_item_->GetAllNodes()[0]; + GE_CHECK_NOTNULL(node_item); + auto node_state = subgraph_context_->GetOrCreateNodeState(node_item); + GE_CHECK_NOTNULL(node_state); + node_state->SetKernelTask(node_item->kernel_task); + + known_shape_task_context_ = TaskContext::Create(*node_item, context_, subgraph_context_.get()); + GE_CHECK_NOTNULL(known_shape_task_context_); + + GE_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_), + "[%s] Failed to execute node [%s] for known subgraph.", graph_item_->GetName().c_str(), + known_shape_task_context_->GetNodeName()); + + GELOGD("[%s] Done execute non-dynamic subgraph successfully.", graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::ExecuteAsync(TaskContext &task_context) { + std::vector inputs; + std::vector input_desc; + for (int i = 0; i < task_context.NumInputs(); ++i) { + auto tensor = task_context.GetInput(i); + GE_CHECK_NOTNULL(tensor); + inputs.emplace_back(*tensor); + input_desc.emplace_back(task_context.GetInputDesc(i)); + } + + GE_CHK_STATUS_RET(ExecuteAsync(inputs, input_desc), "[%s] Failed to execute subgraph.", + graph_item_->GetName().c_str()); + + GE_CHK_STATUS_RET(SetOutputsToParentNode(task_context), "[%s] Failed to set output shapes to parent node.", + graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::PrepareNodes() { + GELOGD("[%s] Start to prepare nodes. force infer shape = %s.", graph_item_->GetName().c_str(), + force_infer_shape_ ? "true" : "false"); + auto &all_nodes = graph_item_->GetAllNodes(); + for (size_t i = 0; i < all_nodes.size(); ++i) { + auto &node_item = *all_nodes[i]; + // for while op + if (force_infer_shape_ && !node_item.is_dynamic) { + GELOGD("[%s] Force infer shape is set, updating node to dynamic.", node_item.NodeName().c_str()); + auto &mutable_node_item = const_cast(node_item); + mutable_node_item.SetToDynamic(); + } + + GELOGD("[%s] Start to prepare node [%s].", graph_item_->GetName().c_str(), node_item.NodeName().c_str()); + auto node_state = subgraph_context_->GetOrCreateNodeState(&node_item); + GE_CHECK_NOTNULL(node_state); + auto p_node_state = node_state.get(); + + if (node_item.node_type == NETOUTPUT) { + // Wait for all inputs become valid + // after PrepareNodes returned. all output tensors and shapes are valid + GE_CHK_STATUS_RET_NOLOG(p_node_state->GetShapeInferenceState().AwaitShapesReady(*context_)); + GE_CHK_STATUS_RET_NOLOG(p_node_state->AwaitInputTensors(*context_)); + continue; + } + + // only do shape inference and compilation for nodes with dynamic shapes. + if (node_item.is_dynamic) { + auto prepare_future = pre_run_pool_.commit([this, p_node_state]() -> Status { + GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state)); + return PrepareForExecution(context_, *p_node_state); + }); + + p_node_state->SetPrepareFuture(std::move(prepare_future)); + } else { + GELOGD("[%s] Skipping shape inference and compilation for node with static shape.", node_item.NodeName().c_str()); + if (node_item.kernel_task == nullptr) { + GELOGW("[%s] Node of static shape got no task.", node_item.NodeName().c_str()); + GE_CHK_STATUS_RET(TaskCompileEngine::Compile(*p_node_state, context_), "[%s] Failed to create task.", + p_node_state->GetName().c_str()); + } else { + node_state->SetKernelTask(node_item.kernel_task); + } + } + + if (!ready_queue_.Push(p_node_state)) { + GELOGE(INTERNAL_ERROR, "[%s] Error occurs while launching tasks. quit from preparing nodes.", + graph_item_->GetName().c_str()); + return INTERNAL_ERROR; + } + + GELOGD("[%s] Push node [%s] to queue.", graph_item_->GetName().c_str(), node_item.NodeName().c_str()); + } + + return SUCCESS; +} + +Status SubgraphExecutor::InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state) { + const auto &node_item = *node_state.GetNodeItem(); + GE_CHK_STATUS_RET(shape_inference_engine->InferShape(node_state), "[%s] Failed to InferShape.", + node_state.GetName().c_str()); + GE_CHK_STATUS_RET(shape_inference_engine->PropagateOutputShapes(node_item), "[%s] Failed to PropagateOutputShapes.", + node_state.GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeState &node_state) { + auto &node_item = *node_state.GetNodeItem(); + if (node_item.kernel_task == nullptr) { + GE_CHK_STATUS_RET(TaskCompileEngine::Compile(node_state, ctx), "Failed to create task for node[%s]", + node_state.GetName().c_str()); + } else { + node_state.SetKernelTask(node_item.kernel_task); + } + + GELOGD("[%s] Start to invoke CalcOpRunningParam.", node_item.NodeName().c_str()); + RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[CalcOpRunningParam] Start"); + GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().CalcOpRunningParam(*node_item.node), + "[%s] Failed to invoke CalcOpRunningParam.", node_item.NodeName().c_str()); + RECORD_COMPILE_EVENT(ctx, node_item.NodeName().c_str(), "[CalcOpRunningParam] End"); + GELOGD("[%s] Done invoking CalcOpRunningParam successfully.", node_item.NodeName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::LaunchTasks() { + while (true) { + NodeState *node_state = nullptr; + if (!ready_queue_.Pop(node_state)) { + GELOGE(INTERNAL_ERROR, "[%s] Failed to pop node.", graph_item_->GetName().c_str()); + return INTERNAL_ERROR; + } + + if (node_state == nullptr) { + GELOGD("[%s] Got EOF from queue.", graph_item_->GetName().c_str()); + return SUCCESS; + } + + GE_CHK_STATUS_RET_NOLOG(node_state->WaitForPrepareDone()); + + GELOGD("[%s] Start to execute.", node_state->GetName().c_str()); + auto task_context = TaskContext::Create(*node_state->GetNodeItem(), context_, subgraph_context_.get()); + GE_CHECK_NOTNULL(task_context); + task_context->SetForceInferShape(force_infer_shape_); + auto shared_task_context = std::shared_ptr(task_context.release()); + GE_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, shared_task_context, *context_), + "[%s] Execute node failed.", node_state->GetName().c_str()); + + GELOGD("[%s] Done executing node successfully.", node_state->GetName().c_str()); + } +} + +Status SubgraphExecutor::ScheduleTasks() { + GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); + auto prepare_future = std::async([&]() -> Status { + auto ret = PrepareNodes(); + ready_queue_.Push(nullptr); + return ret; + }); + + GELOGD("[%s] Start to execute subgraph.", graph_item_->GetName().c_str()); + auto ret = LaunchTasks(); + if (ret != SUCCESS) { + GELOGE(ret, "[%s] Failed to execute subgraph.", graph_item_->GetName().c_str()); + subgraph_context_->OnError(ret); + ready_queue_.Stop(); + prepare_future.wait(); + return ret; + } + + GE_CHK_STATUS_RET(prepare_future.get(), "[%s] Error occurred in task preparation.", graph_item_->GetName().c_str()); + + GELOGD("[%s] Done launching all tasks successfully.", graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::GetOutputs(vector &outputs) { return subgraph_context_->GetOutputs(outputs); } + +Status SubgraphExecutor::GetOutputs(vector &outputs, std::vector &output_desc) { + GE_CHK_STATUS_RET(GetOutputs(outputs), "[%s] Failed to get output tensors.", graph_item_->GetName().c_str()); + + // copy output data from op to designated position + std::vector output_tensor_desc_list; + GE_CHK_STATUS_RET(graph_item_->GetOutputDescList(output_desc), "[%s] Failed to get output tensor desc.", + graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::Synchronize() { + GELOGD("[%s] Synchronize start.", graph_item_->GetName().c_str()); + GE_CHK_RT_RET(rtStreamSynchronize(context_->stream)); + GELOGD("[%s] Done synchronizing successfully.", graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status SubgraphExecutor::SetOutputsToParentNode(TaskContext &task_context) { + // get output tensors and tensor desc list + std::vector outputs; + std::vector output_desc_list; + GE_CHK_STATUS_RET(subgraph_context_->GetOutputs(outputs), "[%s] Failed to get output tensors.", + graph_item_->GetName().c_str()); + GE_CHK_STATUS_RET(graph_item_->GetOutputDescList(output_desc_list), "[%s] Failed to get output tensor desc.", + graph_item_->GetName().c_str()); + + if (outputs.size() != output_desc_list.size()) { + GELOGE(INTERNAL_ERROR, "[%s] num output tensors = %zu, num output tensor desc = %zu", + graph_item_->GetName().c_str(), outputs.size(), output_desc_list.size()); + return INTERNAL_ERROR; + } + + // mapping to parent task context + for (size_t i = 0; i < outputs.size(); ++i) { + int parent_output_index = graph_item_->GetParentOutputIndex(i); + GE_CHECK_GE(parent_output_index, 0); + // update tensor + GELOGD("[%s] Updating output[%zu] to parent output[%d]", graph_item_->GetName().c_str(), i, parent_output_index); + + GELOGD("[%s] Updating output tensor, index = %d, tensor = %s", graph_item_->GetName().c_str(), parent_output_index, + outputs[i].DebugString().c_str()); + GE_CHK_STATUS_RET(task_context.SetOutput(parent_output_index, outputs[i])); + + // updating shapes. dynamic format/dtype is not supported. + // It should be noted that even the subgraph is of known shape, it is also necessary to update parent output desc, + // for instance, IfOp may have two known-shaped subgraphs of different output shapes + const auto &output_desc = output_desc_list[i]; + auto parent_output_desc = task_context.MutableOutputDesc(parent_output_index); + GE_CHECK_NOTNULL(parent_output_desc); + GELOGD("[%s] Updating output shape[%d] from [%s] to [%s]", graph_item_->GetName().c_str(), parent_output_index, + parent_output_desc->MutableShape().ToString().c_str(), output_desc->GetShape().ToString().c_str()); + parent_output_desc->SetShape(output_desc->GetShape()); + + GELOGD("[%s] Updating output original shape[%d] from [%s] to [%s]", graph_item_->GetName().c_str(), + parent_output_index, parent_output_desc->GetOriginShape().ToString().c_str(), + output_desc->GetOriginShape().ToString().c_str()); + parent_output_desc->SetOriginShape(output_desc->GetOriginShape()); + } + + return SUCCESS; +} +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/executor/subgraph_executor.h b/src/ge/hybrid/executor/subgraph_executor.h new file mode 100644 index 00000000..7cdb2070 --- /dev/null +++ b/src/ge/hybrid/executor/subgraph_executor.h @@ -0,0 +1,101 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_EXECUTOR_EXECUTOR_SUBGRAPH_EXECUTOR_H_ +#define GE_HYBRID_EXECUTOR_EXECUTOR_SUBGRAPH_EXECUTOR_H_ + +#include + +#include "common/blocking_queue.h" +#include "common/thread_pool.h" +#include "hybrid/executor/subgraph_context.h" +#include "hybrid/executor/node_state.h" +#include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/executor/worker/shape_inference_engine.h" +#include "hybrid/model/graph_item.h" +#include "hybrid/node_executor/task_context.h" + +namespace ge { +namespace hybrid { +// Executor for executing a subgraph +class SubgraphExecutor { + public: + SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape = false); + ~SubgraphExecutor(); + + /** + * Execute subgraph async, output tensor address(not data) and output tensor descriptions are + * valid after this method returned + * @param inputs input tensors + * @param input_desc input tensor descriptions + * @return SUCCESS on success, error code otherwise + */ + Status ExecuteAsync(const std::vector &inputs, const std::vector &input_desc); + + /** + * Execute subgraph async, output tensor address(not data) and output tensor descriptions are + * valid after this method returned + * @param task_context instance of TaskContext + * @return SUCCESS on success, error code otherwise + */ + Status ExecuteAsync(TaskContext &task_context); + + /** + * Synchronize all tasks in the subgraph. output tensor data are valid after this method returned + * @return SUCCESS on success, error code otherwise + */ + Status Synchronize(); + + /** + * Get output tensors + * @param outputs output tensors + * @return SUCCESS on success, error code otherwise + */ + Status GetOutputs(std::vector &outputs); + + /** + * Get output tensors and output tensor descriptions + * @param outputs output tensors + * @param output_desc output tensor descriptions + * @return SUCCESS on success, error code otherwise + */ + Status GetOutputs(std::vector &outputs, std::vector &output_desc); + + private: + static Status PrepareForExecution(GraphExecutionContext *ctx, NodeState &node_state); + static Status InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state); + Status Init(const std::vector &inputs, const std::vector &input_desc); + Status InitInputsForUnknownShape(const std::vector &inputs, + const std::vector &input_desc); + Status InitInputsForKnownShape(const std::vector &inputs); + Status ExecuteAsyncForKnownShape(const std::vector &inputs); + Status ScheduleTasks(); + Status PrepareNodes(); + Status LaunchTasks(); + Status SetOutputsToParentNode(TaskContext &task_context); + + const GraphItem *graph_item_; + GraphExecutionContext *context_; + std::unique_ptr subgraph_context_; + bool force_infer_shape_; + ThreadPool pre_run_pool_; + BlockingQueue ready_queue_; + std::unique_ptr shape_inference_engine_; + std::shared_ptr known_shape_task_context_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_EXECUTOR_EXECUTOR_SUBGRAPH_EXECUTOR_H_ diff --git a/src/ge/hybrid/executor/worker/execution_engine.cc b/src/ge/hybrid/executor/worker/execution_engine.cc index 9e656139..20da6378 100644 --- a/src/ge/hybrid/executor/worker/execution_engine.cc +++ b/src/ge/hybrid/executor/worker/execution_engine.cc @@ -15,7 +15,6 @@ */ #include "hybrid/executor/worker/execution_engine.h" -#include #include "graph/runtime_inference_context.h" #include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_adapter.h" @@ -23,9 +22,38 @@ namespace ge { namespace hybrid { +namespace { +constexpr int64_t kMaxPadding = 63; + +Status LogInputs(const NodeItem &node_item, const TaskContext &task_context) { + for (auto i = 0; i < task_context.NumInputs(); ++i) { + const auto &input_tensor = task_context.GetInput(i); + GE_CHECK_NOTNULL(input_tensor); + const auto &tensor_desc = node_item.op_desc->MutableInputDesc(i); + GE_CHECK_NOTNULL(tensor_desc); + GELOGD("[%s] Print task args. input[%d] = %s, shape = [%s]", node_item.NodeName().c_str(), i, + input_tensor->DebugString().c_str(), tensor_desc->MutableShape().ToString().c_str()); + } + + return SUCCESS; +} + +Status LogOutputs(const NodeItem &node_item, const TaskContext &task_context) { + for (auto i = 0; i < task_context.NumOutputs(); ++i) { + const auto &output_tensor = task_context.GetOutput(i); + GE_CHECK_NOTNULL(output_tensor); + const auto &tensor_desc = node_item.op_desc->MutableOutputDesc(i); + GE_CHECK_NOTNULL(tensor_desc); + GELOGD("[%s] Print task args. output[%d] = %s, shape = [%s]", node_item.NodeName().c_str(), i, + output_tensor->DebugString().c_str(), tensor_desc->MutableShape().ToString().c_str()); + } + + return SUCCESS; +} +} // namespace class NodeDoneCallback { public: - NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr &task_context); + NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr task_context); ~NodeDoneCallback() = default; Status OnNodeDone(); @@ -35,8 +63,8 @@ class NodeDoneCallback { std::shared_ptr context_; }; -NodeDoneCallback::NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr &task_context) - : graph_context_(graph_context), context_(task_context) {} +NodeDoneCallback::NodeDoneCallback(GraphExecutionContext *graph_context, std::shared_ptr task_context) + : graph_context_(graph_context), context_(std::move(task_context)) {} Status NodeDoneCallback::PrepareConstInputs(const NodeItem &node_item) { for (auto output_idx : node_item.to_const_output_id_list) { @@ -46,17 +74,28 @@ Status NodeDoneCallback::PrepareConstInputs(const NodeItem &node_item) { auto output_tensor = context_->GetOutput(output_idx); GE_CHECK_NOTNULL(output_tensor); - vector host_buffer(output_tensor->GetSize()); - GELOGD("[%s] To cache output[%d] to host, size = %zu", node_item.NodeName().c_str(), output_idx, - output_tensor->GetSize()); - GE_CHK_RT_RET(rtMemcpy(host_buffer.data(), host_buffer.size(), output_tensor->GetData(), output_tensor->GetSize(), - RT_MEMCPY_HOST_TO_DEVICE)); Tensor tensor; - tensor.SetData(host_buffer); auto ge_tensor_desc = node_item.op_desc->MutableOutputDesc(output_idx); GE_CHECK_NOTNULL(ge_tensor_desc); tensor.SetTensorDesc(TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc)); + int64_t tensor_size; + GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*ge_tensor_desc, tensor_size), + "Failed to invoke GetTensorSizeInBytes"); + + if (output_tensor->GetSize() < static_cast(tensor_size)) { + GELOGE(INTERNAL_ERROR, "[%s] Tensor size is not enough. output index = %d, required size = %zu, tensor = %s", + node_item.NodeName().c_str(), output_idx, tensor_size, output_tensor->DebugString().c_str()); + return INTERNAL_ERROR; + } + + vector host_buffer(tensor_size); + GELOGD("[%s] To cache output[%d] to host, size = %zu", node_item.NodeName().c_str(), output_idx, + output_tensor->GetSize()); + GE_CHK_RT_RET( + rtMemcpy(host_buffer.data(), tensor_size, output_tensor->GetData(), tensor_size, RT_MEMCPY_DEVICE_TO_HOST)); + tensor.SetData(host_buffer); + string session_id = std::to_string(context_->GetSessionId()); RuntimeInferenceContext *runtime_infer_ctx = nullptr; GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx), @@ -87,115 +126,118 @@ Status NodeDoneCallback::OnNodeDone() { GE_CHK_STATUS_RET_NOLOG(PrepareConstInputs(node_item)); // PropagateOutputs for type == DEPEND_COMPUTE if (node_item.shape_inference_type == DEPEND_COMPUTE) { + if (graph_context_->trace_enabled) { + (void)LogOutputs(node_item, *context_); + } + GE_CHK_STATUS_RET(context_->PropagateOutputs(), "[%s] Failed to propagate outputs failed", node_item.NodeName().c_str()); RECORD_CALLBACK_EVENT(graph_context_, context_->GetNodeName(), "[PropagateOutputs] End"); } - // release + // release condition variable if (node_item.has_observer) { GELOGI("[%s] Notify observer. node_id = %d", node_item.NodeName().c_str(), node_item.node_id); - graph_context_->cv_manager.NodeDone(node_item.node); + context_->NodeDone(); } RECORD_CALLBACK_EVENT(graph_context_, context_->GetNodeName(), "[Callback] End"); return SUCCESS; } -ExecutionEngine::ExecutionEngine(GraphExecutionContext *context, CallbackManager *callback_manager) - : context_(context), callback_manager_(callback_manager) {} - -Status ExecutionEngine::Start() { - GE_CHK_STATUS_RET_NOLOG(ExecutionProcess()); - return SUCCESS; -} - -Status ExecutionEngine::ExecutionProcess() { - GELOGI("ExecutorEngine worker started"); - auto &ready_queue = context_->execution_queue; - while (true) { - NodeStatePtr node_state = nullptr; - if (!ready_queue.Pop(node_state)) { - GELOGE(FAILED, "Pop task failed"); - return FAILED; - } - - // EOF - if (node_state == nullptr) { - break; +Status ExecutionEngine::ExecuteAsync(NodeState &node_state, const std::shared_ptr &task_context, + GraphExecutionContext &execution_context) { + GELOGI("[%s] Node is ready for execution", task_context->GetNodeName()); + RECORD_EXECUTION_EVENT(&execution_context, task_context->GetNodeName(), "Start"); + auto cb = std::shared_ptr(new (std::nothrow) NodeDoneCallback(&execution_context, task_context)); + GE_CHECK_NOTNULL(cb); + auto callback = [&, cb]() { + auto ret = cb->OnNodeDone(); + if (ret != SUCCESS) { + task_context->OnError(ret); } + }; - RECORD_EXECUTION_EVENT(context_, node_state->GetName().c_str(), "Start"); - GELOGI("[%s] Node is ready for execution", node_state->GetName().c_str()); - auto *node_item = node_state->node_item; - auto task_context = TaskContext::Create(*node_item, context_); - GE_CHECK_NOTNULL(task_context); - auto shared_task_context = shared_ptr(task_context.release()); - - auto cb = std::shared_ptr(new (std::nothrow) NodeDoneCallback(context_, shared_task_context)); - GE_CHECK_NOTNULL(cb); - auto callback = [&, cb]() { - auto ret = cb->OnNodeDone(); - if (ret != SUCCESS) { - context_->OnError(ret); - } - }; - - GE_CHK_STATUS_RET_NOLOG(ExecuteAsync(*node_state, *shared_task_context, callback)); - GE_CHK_STATUS_RET_NOLOG(PropagateOutputs(*node_item, *shared_task_context)); - } - - GELOGI("ExecutorEngine worker ended."); + GE_CHK_STATUS_RET_NOLOG(DoExecuteAsync(node_state, *task_context, execution_context, callback)); + GE_CHK_STATUS_RET_NOLOG(PropagateOutputs(*node_state.GetNodeItem(), *task_context, execution_context)); return SUCCESS; } -Status ExecutionEngine::ExecuteAsync(NodeState &node_state, TaskContext &task_context, - const std::function &callback) { - const auto &task = node_state.kernel_task; +Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, TaskContext &task_context, GraphExecutionContext &context, + const std::function &callback) { + const auto &task = node_state.GetKernelTask(); if (task == nullptr) { GELOGE(INTERNAL_ERROR, "[%s] NodeTask is null.", node_state.GetName().c_str()); return INTERNAL_ERROR; } - RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[PrepareTask] Start"); - auto executor = node_state.node_item->node_executor; + // Wait for dependent nodes(DEPEND_COMPUTE), so that the input tensors are valid. + RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[AwaitDependents] Start"); + GE_CHK_STATUS_RET(node_state.AwaitInputTensors(context), "[%s] Failed to wait for dependent nodes.", + node_state.GetName().c_str()); + + const auto &node_item = *node_state.GetNodeItem(); + auto executor = node_item.node_executor; + GE_CHECK_NOTNULL(executor); + RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[%s] Failed to prepare task", node_state.GetName().c_str()); - RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[PrepareTask] End"); + RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); GELOGD("[%s] Done task preparation successfully.", node_state.GetName().c_str()); - if (context_->trace_enabled) { - for (auto i = 0; i < task_context.NumInputs(); ++i) { - const auto &input_tensor = task_context.GetInput(i); - GE_CHECK_NOTNULL(input_tensor); - GELOGD("[%s] Tensor of input[%d] = %s", node_state.GetName().c_str(), i, input_tensor->DebugString().c_str()); - } - - for (auto i = 0; i < task_context.NumOutputs(); ++i) { - const auto &output_tensor = task_context.GetOutput(i); - GE_CHECK_NOTNULL(output_tensor); - GELOGD("[%s] Tensor of output[%d] = %s", node_state.GetName().c_str(), i, output_tensor->DebugString().c_str()); + if (context.trace_enabled) { + LogInputs(node_item, task_context); + if (node_item.shape_inference_type != DEPEND_COMPUTE) { + LogOutputs(node_item, task_context); } } - RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[ExecuteTask] Start"); + GE_CHK_STATUS_RET(ValidateInputTensors(node_state, task_context), "Failed to validate input tensors."); + RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[ValidateInputTensors] End"); + GE_CHK_STATUS_RET(executor->ExecuteTask(*task, task_context, callback), "[%s] Failed to execute task", node_state.GetName().c_str()); - RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[ExecuteTask] End"); + RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[ExecuteTask] End"); GELOGD("[%s] Done task launch successfully.", node_state.GetName().c_str()); return SUCCESS; } -Status ExecutionEngine::PropagateOutputs(const NodeItem &node_item, TaskContext &task_context) { +Status ExecutionEngine::ValidateInputTensors(const NodeState &node_state, const TaskContext &task_context) { + for (auto i = 0; i < task_context.NumInputs(); ++i) { + const auto &input_tensor = task_context.GetInput(i); + GE_CHECK_NOTNULL(input_tensor); + const auto &tensor_desc = node_state.GetOpDesc()->MutableInputDesc(i); + GE_CHECK_NOTNULL(tensor_desc); + int64_t expected_size; + GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, expected_size)); + GELOGD("[%s] Input[%d] expects [%ld] bytes.", task_context.GetNodeName(), i, expected_size); + auto size_diff = expected_size - static_cast(input_tensor->GetSize()); + if (size_diff > 0) { + if (size_diff <= kMaxPadding) { + GELOGW("[%s] Input[%d]: tensor size mismatches. expected: %ld, but given %zu", task_context.GetNodeName(), i, + expected_size, input_tensor->GetSize()); + } else { + GELOGE(INTERNAL_ERROR, "[%s] Input[%d]: tensor size mismatches. expected: %ld, but given %zu", + task_context.GetNodeName(), i, expected_size, input_tensor->GetSize()); + return INTERNAL_ERROR; + } + } + } + + return SUCCESS; +} + +Status ExecutionEngine::PropagateOutputs(const NodeItem &node_item, TaskContext &task_context, + GraphExecutionContext &context) { if (node_item.shape_inference_type != DEPEND_COMPUTE) { GE_CHK_STATUS_RET(task_context.PropagateOutputs(), "[%s] Failed to propagate outputs.", node_item.NodeName().c_str()); - RECORD_EXECUTION_EVENT(context_, task_context.GetNodeName(), "[PropagateOutputs] End"); + RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PropagateOutputs] End"); + GELOGD("[%s] Done propagating outputs successfully.", node_item.NodeName().c_str()); } - GELOGD("[%s] Done propagating outputs successfully.", node_item.NodeName().c_str()); return SUCCESS; } } // namespace hybrid diff --git a/src/ge/hybrid/executor/worker/execution_engine.h b/src/ge/hybrid/executor/worker/execution_engine.h index f5f317af..56f1557d 100644 --- a/src/ge/hybrid/executor/worker/execution_engine.h +++ b/src/ge/hybrid/executor/worker/execution_engine.h @@ -17,30 +17,21 @@ #ifndef GE_HYBRID_EXECUTOR_EXECUTOR_EXECUTION_ENGINE_H_ #define GE_HYBRID_EXECUTOR_EXECUTOR_EXECUTION_ENGINE_H_ -#include "common/thread_pool.h" -#include "hybrid/common/npu_memory_allocator.h" #include "hybrid/executor/hybrid_execution_context.h" -#include "hybrid/executor/rt_callback_manager.h" #include "hybrid/node_executor/task_context.h" namespace ge { namespace hybrid { class ExecutionEngine { public: - explicit ExecutionEngine(GraphExecutionContext *context, CallbackManager *callback_manager); - ~ExecutionEngine() = default; - - Status Start(); + static Status ExecuteAsync(NodeState &node_state, const std::shared_ptr &task_context, + GraphExecutionContext &execution_context); private: - Status PropagateOutputs(const NodeItem &node_item, TaskContext &task_context); - - Status ExecutionProcess(); - - Status ExecuteAsync(NodeState &node_state, TaskContext &task_context, const std::function &callback); - - GraphExecutionContext *context_; - CallbackManager *callback_manager_; + static Status ValidateInputTensors(const NodeState &node_state, const TaskContext &task_context); + static Status PropagateOutputs(const NodeItem &node_item, TaskContext &task_context, GraphExecutionContext &context); + static Status DoExecuteAsync(NodeState &node_state, TaskContext &task_context, GraphExecutionContext &context, + const std::function &callback); }; } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/executor/worker/shape_inference_engine.cc b/src/ge/hybrid/executor/worker/shape_inference_engine.cc index 90082fff..650bcc54 100644 --- a/src/ge/hybrid/executor/worker/shape_inference_engine.cc +++ b/src/ge/hybrid/executor/worker/shape_inference_engine.cc @@ -15,117 +15,30 @@ */ #include "hybrid/executor/worker/shape_inference_engine.h" - #include "graph/shape_refiner.h" -#include "graph/runtime_inference_context.h" #include "graph/utils/node_utils.h" #include "hybrid/node_executor/node_executor.h" namespace ge { namespace hybrid { +ShapeInferenceEngine::ShapeInferenceEngine(GraphExecutionContext *execution_context, SubgraphContext *subgraph_context) + : execution_context_(execution_context), subgraph_context_(subgraph_context) {} -ShapeInferenceEngine::ShapeInferenceEngine(GraphExecutionContext *context) : context_(context) {} - -Status ShapeInferenceEngine::Start(ThreadPool &pool) { - GELOGI("RuntimeShapeInferenceEngine start."); - pool.commit([&]() { - auto ret = this->InferShapeProcess(); - InferenceDone(ret); - }); - - return SUCCESS; -} - -Status ShapeInferenceEngine::InferShapeProcess() { - GELOGI("RuntimeShapeInferenceEngine worker start."); - const auto &root_nodes = context_->model->RootNodes(); - auto &complete_queue = context_->compile_queue; - std::queue ready_nodes; - for (auto &node_item : root_nodes) { - auto infer_state = GetOrCreateEntry(*node_item); - GE_CHECK_NOTNULL(infer_state); - ready_nodes.emplace(infer_state); - } - - while (!ready_nodes.empty()) { - InferenceState *infer_state = ready_nodes.front(); - ready_nodes.pop(); - auto node_item = infer_state->node_item; - // even for non-dynamic shape node, it is still necessary to wait for pending shapes if got any. - // which indicates that the parent node is of type 4, in which case the inputs will be valid only - // when computing is done. - GE_CHK_STATUS_RET(infer_state->AwaitShapeFutures(context_), "Await shape failed."); - GELOGI("[%s] Node is ready for shape inference.", node_item.NodeName().c_str()); - if (node_item.is_dynamic) { - // may block - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "Start"); - GELOGI("[%s] Start to invoke InferShape", node_item.NodeName().c_str()); - auto ret = InferShape(*infer_state); - if (ret != SUCCESS) { - return ret; - } - - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] Start"); - GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().CalcOpRunningParam(*node_item.node), - "[%s] Failed to invoke CalcOpRunningParam.", node_item.NodeName().c_str()); - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[CalcOpRunningParam] End"); - } else { - GELOGD("[%s] Skip static shape node", node_item.NodeName().c_str()); - } - - if (node_item.node_type != NETOUTPUT) { - GELOGI("[%s] Push to compile queue", node_item.NodeName().c_str()); - // may block if full - auto node_state = context_->GetOrCreateNodeState(node_item.node); - complete_queue.Push(node_state); - } - - // Propagate - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[PropagateOutputShapes] Start"); - PropagateOutputShapes(*infer_state, ready_nodes); - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[PropagateOutputShapes] End"); - } - - return SUCCESS; -} - -void ShapeInferenceEngine::InferenceDone(Status status) { - if (status != SUCCESS) { - GELOGE(status, "Error occurred while shape inference"); - context_->OnError(status); - } else { - context_->compile_queue.Push(nullptr); - } - inference_states_.clear(); - GELOGI("RuntimeShapeInferenceEngine worker END"); -} - -Status ShapeInferenceEngine::InferShape(InferenceState &entry) { - // input shapes are ready, wait for dependent data if has any - const auto &node_item = entry.node_item; - if (!node_item.dependent_node_list.empty()) { - for (auto &src_node : node_item.dependent_node_list) { - auto *src_node_item = context_->model->GetNodeItem(src_node); - GELOGI("[%s] Start to wait for data dependent node: %s", node_item.NodeName().c_str(), - src_node_item->NodeName().c_str()); - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[AwaitNodeDone] [%s] Start", - src_node->GetName().c_str()); - if (!context_->cv_manager.Await(src_node)) { - GELOGE(INTERNAL_ERROR, "[%s] Await node failed.", src_node_item->NodeName().c_str()); - return INTERNAL_ERROR; - } +Status ShapeInferenceEngine::InferShape(NodeState &node_state) { + // Wait for all input shape become valid + GE_CHK_STATUS_RET_NOLOG(node_state.GetShapeInferenceState().AwaitShapesReady(*execution_context_)); - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[AwaitNodeDone] [%s] End", - src_node->GetName().c_str()); - GELOGI("[%s] Done waiting node.", src_node_item->NodeName().c_str()); - } + auto &node_item = *node_state.GetNodeItem(); + if (node_item.is_output_shape_static) { + return SUCCESS; } - + // Skip shape inference for node of type DEPEND_COMPUTE if (node_item.shape_inference_type == DEPEND_COMPUTE) { - GELOGD("[%s] Skip node with unknown shape type DEPEND_COMPUTE", node_item.NodeName().c_str()); + GELOGD("[%s] Skipping node with unknown shape type DEPEND_COMPUTE", node_item.NodeName().c_str()); return SUCCESS; } + // Clear shape range in case shape inference func forgot to do it if (node_item.shape_inference_type == DEPEND_SHAPE_RANGE) { // in case InferFunc forgot to reset output shape for (auto &output_desc : node_item.op_desc->GetAllOutputsDescPtr()) { @@ -133,13 +46,18 @@ Status ShapeInferenceEngine::InferShape(InferenceState &entry) { } } - // do shape inference - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[InferShape] Start"); - GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); - GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndType(node_item.node), "Invoke InferShapeAndType failed."); - RECORD_SHAPE_INFERENCE_EVENT(context_, node_item.NodeName().c_str(), "[InferShape] End"); + // Wait for "const input nodes" if node's shape inference function requires any. + GE_CHK_STATUS_RET_NOLOG(AwaitDependentNodes(node_state)); - // Check shape + // Do shape inference + GELOGD("[%s] Start to invoke InferShapeAndType", node_item.NodeName().c_str()); + { + std::lock_guard lk(mu_); + RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] Start"); + GE_CHK_STATUS_RET(ShapeRefiner::InferShapeAndType(node_item.node), "Invoke InferShapeAndType failed."); + RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[InferShapeAndType] End"); + } + // Check again to make sure shape is valid after shape inference if (node_item.shape_inference_type != DEPEND_SHAPE_RANGE) { bool is_unknown_shape = false; GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node_item.node, is_unknown_shape), @@ -149,12 +67,37 @@ Status ShapeInferenceEngine::InferShape(InferenceState &entry) { node_item.NodeName().c_str()); } + GELOGD("[%s] [HybridTrace] After shape inference. Node = %s", node_item.NodeName().c_str(), + node_item.DebugString().c_str()); + GELOGD("[%s] InferShapeAndType finished successfully.", node_item.NodeName().c_str()); return SUCCESS; } -void ShapeInferenceEngine::PropagateOutputShapes(InferenceState &entry, std::queue &queue) { - auto &node_item = entry.node_item; +Status ShapeInferenceEngine::AwaitDependentNodes(NodeState &node_state) { + auto &node_item = *node_state.GetNodeItem(); + for (auto &src_node : node_item.dependents_for_shape_inference) { + GELOGI("[%s] Start to wait for data dependent node: %s", node_item.NodeName().c_str(), src_node->GetName().c_str()); + RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[AwaitNodeDone] [%s] Start", + src_node->GetName().c_str()); + if (!subgraph_context_->Await(src_node)) { + GELOGE(INTERNAL_ERROR, "[%s] Await node failed.", src_node->GetName().c_str()); + return INTERNAL_ERROR; + } + + RECORD_SHAPE_INFERENCE_EVENT(execution_context_, node_item.NodeName().c_str(), "[AwaitNodeDone] [%s] End", + src_node->GetName().c_str()); + GELOGI("[%s] Done waiting node.", src_node->GetName().c_str()); + } + + return SUCCESS; +} + +Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) { + if (node_item.is_output_shape_static) { + return SUCCESS; + } + // output shape will not be valid until compute is done. bool shape_is_future = node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE; @@ -171,88 +114,25 @@ void ShapeInferenceEngine::PropagateOutputShapes(InferenceState &entry, std::que // propagate output to all sub-inputs for (auto &dst_input_index_and_node : output_nodes) { auto &dst_node_item = dst_input_index_and_node.second; - auto inference_state = GetOrCreateEntry(*dst_node_item); + auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); + GE_CHECK_NOTNULL(dst_node_state); + GELOGI("[%s] Update dst node [%s], input index = %d", node_item.NodeName().c_str(), dst_node_item->NodeName().c_str(), dst_input_index_and_node.first); - // in case type 3/4, shape will be valid after computing is done + // in case type 3 and 4, shape will be valid after computing is done if (shape_is_future) { - ShapeFuture future(node_item.node, i, &context_->cv_manager); - inference_state->UpdateInputShapeFuture(dst_input_index_and_node.first, std::move(future)); + ShapeFuture future(node_item.node, i, subgraph_context_); + dst_node_state->GetShapeInferenceState().UpdateInputShapeFuture(dst_input_index_and_node.first, + std::move(future)); } else { - inference_state->UpdateInputShape(dst_input_index_and_node.first, ori_shape, shape); - } - - if (inference_state->IsInputShapesReady()) { - GELOGI("[%s] Node input shape is ready, add to queue.", inference_state->node_item.NodeName().c_str()); - queue.emplace(inference_state); + dst_node_state->GetShapeInferenceState().UpdateInputShape(dst_input_index_and_node.first, ori_shape, shape); } } } GELOGD("[%s] Propagating output shapes finished successfully.", node_item.NodeName().c_str()); -} - -ShapeInferenceEngine::InferenceState *ShapeInferenceEngine::GetOrCreateEntry(const NodeItem &node_item) { - auto &node_state = inference_states_[node_item.node_id]; - if (node_state == nullptr) { - node_state.reset(new (std::nothrow) InferenceState(node_item)); - } - - return node_state.get(); -} - -ShapeInferenceEngine::InferenceState::InferenceState(const NodeItem &node_item) : node_item(node_item) { - this->num_pending_shapes = node_item.num_inputs; -} - -void ShapeInferenceEngine::InferenceState::UpdateInputShape(uint32_t idx, const GeShape &ori_shape, - const GeShape &shape) { - if (node_item.const_input_shapes.count(idx) != 0) { - GELOGD("[%s] Trying to update constant shape, idx = %u. old shape = [%s], new shape = [%s]", - node_item.NodeName().c_str(), idx, node_item.op_desc->MutableInputDesc(idx)->GetShape().ToString().c_str(), - shape.ToString().c_str()); - } - - GELOGD("[%s] Update input shape [%u] with Shape: [%s] and OriginalShape: [%s]", node_item.NodeName().c_str(), idx, - shape.ToString().c_str(), ori_shape.ToString().c_str()); - num_pending_shapes -= 1; - node_item.op_desc->MutableInputDesc(idx)->SetShape(shape); - node_item.op_desc->MutableInputDesc(idx)->SetOriginShape(ori_shape); -} - -void ShapeInferenceEngine::InferenceState::UpdateInputShapeFuture(uint32_t idx, ShapeFuture &&future) { - if (node_item.const_input_shapes.count(idx) != 0) { - GELOGE(INTERNAL_ERROR, "[%s] Trying to update constant shape, idx = %u", node_item.NodeName().c_str(), idx); - return; - } - - GELOGD("[%s] Update input shape [%u] with ShapeFuture.", node_item.NodeName().c_str(), idx); - num_pending_shapes -= 1; - shape_futures.emplace_back(idx, std::move(future)); -} - -Status ShapeInferenceEngine::InferenceState::AwaitShapeFutures(GraphExecutionContext *context) { - for (auto &p : shape_futures) { - auto idx = p.first; - auto &future = p.second; - GeShape shape; - GeShape ori_shape; - RECORD_SHAPE_INFERENCE_EVENT(context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx); - GE_CHK_STATUS_RET(future.Get(ori_shape, shape), "[%s] Get shape failed. index = %u", node_item.NodeName().c_str(), - idx); - RECORD_SHAPE_INFERENCE_EVENT(context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); - - GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s]", node_item.NodeName().c_str(), idx, - shape.ToString().c_str(), ori_shape.ToString().c_str()); - node_item.op_desc->MutableInputDesc(idx)->SetShape(std::move(shape)); - node_item.op_desc->MutableInputDesc(idx)->SetOriginShape(ori_shape); - } - return SUCCESS; } - -ShapeInferenceEngine::ShapeFuture::ShapeFuture(NodePtr src_node, uint32_t src_index, NodeDoneManager *node_done_manager) - : src_node_(std::move(src_node)), src_index_(src_index), node_done_manager_(node_done_manager) {} } // namespace hybrid -} // namespace ge \ No newline at end of file +} // namespace ge diff --git a/src/ge/hybrid/executor/worker/shape_inference_engine.h b/src/ge/hybrid/executor/worker/shape_inference_engine.h index b1e1c879..65878818 100644 --- a/src/ge/hybrid/executor/worker/shape_inference_engine.h +++ b/src/ge/hybrid/executor/worker/shape_inference_engine.h @@ -17,75 +17,27 @@ #ifndef GE_HYBRID_EXECUTOR_INFERSHAPE_SHAPE_INFERENCE_ENGINE_H_ #define GE_HYBRID_EXECUTOR_INFERSHAPE_SHAPE_INFERENCE_ENGINE_H_ -#include -#include -#include -#include "common/thread_pool.h" #include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/executor/subgraph_context.h" +#include namespace ge { namespace hybrid { class ShapeInferenceEngine { public: - explicit ShapeInferenceEngine(GraphExecutionContext *context); - + ShapeInferenceEngine(GraphExecutionContext *execution_context, SubgraphContext *subgraph_context); ~ShapeInferenceEngine() = default; - Status Start(ThreadPool &pool); - - private: - class ShapeFuture { - public: - ShapeFuture(NodePtr src_node, uint32_t src_index, NodeDoneManager *node_done_manager); - ~ShapeFuture() = default; - Status Get(GeShape &ori_shape, GeShape &shape) { - GELOGI("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); - if (!node_done_manager_->Await(src_node_)) { - GELOGE(INTERNAL_ERROR, "cancelled"); - return INTERNAL_ERROR; - } - - shape = src_node_->GetOpDesc()->MutableOutputDesc(src_index_)->MutableShape(); - ori_shape = src_node_->GetOpDesc()->MutableOutputDesc(src_index_)->GetOriginShape(); - GELOGI("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str()); - return SUCCESS; - } - - private: - NodePtr src_node_; - uint32_t src_index_; - NodeDoneManager *node_done_manager_; - }; - - struct InferenceState { - explicit InferenceState(const NodeItem &node_item); - inline bool IsInputShapesReady() const { return num_pending_shapes == 0; } - - void UpdateInputShape(uint32_t idx, const GeShape &ori_shape, const GeShape &shape); - - Status AwaitShapeFutures(GraphExecutionContext *context); + Status InferShape(NodeState &node_state); - void UpdateInputShapeFuture(uint32_t idx, ShapeFuture &&future); + Status PropagateOutputShapes(const NodeItem &node_item); - const NodeItem &node_item; - - private: - std::vector> shape_futures; - int num_pending_shapes = 0; - }; - - InferenceState *GetOrCreateEntry(const NodeItem &node_item); - - Status InferShapeProcess(); - - void InferenceDone(Status status); - - Status InferShape(InferenceState &entry); - - void PropagateOutputShapes(InferenceState &entry, std::queue &queue); + private: + Status AwaitDependentNodes(NodeState &node_state); - GraphExecutionContext *context_; - std::unordered_map> inference_states_; + GraphExecutionContext *execution_context_; + SubgraphContext *subgraph_context_; + std::mutex mu_; }; } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/executor/worker/task_compile_engine.cc b/src/ge/hybrid/executor/worker/task_compile_engine.cc index f6434ffa..57b19f5f 100644 --- a/src/ge/hybrid/executor/worker/task_compile_engine.cc +++ b/src/ge/hybrid/executor/worker/task_compile_engine.cc @@ -16,172 +16,22 @@ #include "hybrid/executor/worker/task_compile_engine.h" #include "init/gelib.h" -#include "framework/common/debug/log.h" #include "hybrid/node_executor/node_executor.h" namespace ge { namespace hybrid { -namespace { -uint32_t kDefaultWorkerCnt = 4; -uint32_t kDefaultDeviceId = 0; -} // namespace -TaskCompileEngine::TaskCompileEngine(GraphExecutionContext *context) : context_(context), pool_(kDefaultWorkerCnt) {} - -TaskCompileEngine::~TaskCompileEngine() { - if (rt_context_ != nullptr) { - GELOGD("To destroy compile context: %p.", rt_context_); - GE_CHK_RT(rtCtxDestroy(rt_context_)); - } -} - -Status TaskCompileEngine::Init() { - GELOGD("Start to init CompileEngine"); - rtContext_t current_ctx = nullptr; - GE_CHK_RT(rtCtxGetCurrent(¤t_ctx)); - GE_CHK_RT_RET(rtCtxCreate(&rt_context_, RT_CTX_GEN_MODE, kDefaultDeviceId)); - GELOGD("Context created for compiling. ctx = %p", rt_context_); - GE_CHK_RT_RET(rtCtxSetCurrent(current_ctx)); - return SUCCESS; -} - -void TaskCompileEngine::Reset() { - complete_queue_.Push(nullptr); // ensure iteration can stop - unique_ptr entry; - while (true) { - complete_queue_.Pop(entry); - if (entry == nullptr) { - break; - } - - if (entry->future != nullptr) { - entry->future->wait(); - } - } - - complete_queue_.Clear(); -} - -Status TaskCompileEngine::Start(ThreadPool &pool) { - pool.commit([&]() { (void)this->CompileProcess(); }); - - worker_future_ = pool_.commit([&]() -> Status { return this->DistributeCompiledTasks(); }); - - if (!worker_future_.valid()) { - GELOGE(INTERNAL_ERROR, "Failed to start worker thread"); - return INTERNAL_ERROR; - } - - return SUCCESS; -} - -Status TaskCompileEngine::CompileProcess() { - auto &compile_queue = context_->compile_queue; - while (true) { - NodeStatePtr node_state; - // Stop() will not be invoked, Pop won't failed - (void)compile_queue.Pop(node_state); - - // EOF - if (node_state == nullptr) { - GELOGD("Got EOF"); - complete_queue_.Push(unique_ptr()); - break; - } - - auto entry = unique_ptr(new (std::nothrow) ResultQueueEntry()); - GE_CHECK_NOTNULL(entry); - entry->node_state = node_state; - - auto node_item = *node_state->node_item; - if (node_item.kernel_task != nullptr) { - GELOGD("use precompiled task. node name = %s", node_item.NodeName().c_str()); - node_state->kernel_task = node_item.kernel_task; - complete_queue_.Push(std::move(entry)); - continue; - } - - auto ret = CompileAsync(*node_state->node_item, *entry); - if (ret == SUCCESS) { - complete_queue_.Push(std::move(entry)); - continue; - } - - // On Error - worker_future_.wait(); - Reset(); - return CompileDone(ret); - } - - Status ret = worker_future_.get(); - Reset(); - return CompileDone(ret); -} - -Status TaskCompileEngine::CompileDone(Status status) { - if (status != SUCCESS) { - GELOGE(status, "Error occurred while compiling node."); - context_->OnError(status); - } else { - context_->execution_queue.Push(nullptr); - } - GELOGI("CompileEngine worker END. ret = %u", status); - return status; -} - -Status TaskCompileEngine::DoCompile(const NodeItem &node_item, NodeState &node_state) { - RECORD_COMPILE_EVENT(context_, node_state.GetName().c_str(), "Start"); - GE_CHK_RT_RET(rtCtxSetCurrent(rt_context_)); - auto ret = node_item.node_executor->CompileTask(*context_->model, node_item.node, node_state.kernel_task); - RECORD_COMPILE_EVENT(context_, node_state.GetName().c_str(), "End"); +Status TaskCompileEngine::Compile(NodeState &node_state, GraphExecutionContext *context) { + const auto &node_item = *node_state.GetNodeItem(); + RECORD_COMPILE_EVENT(context, node_item.NodeName().c_str(), "Start"); + GE_CHK_RT_RET(rtCtxSetCurrent(context->rt_gen_context)); + + shared_ptr kernel_task; + auto ret = node_item.node_executor->CompileTask(*context->model, node_item.node, kernel_task); + RECORD_COMPILE_EVENT(context, node_state.GetName().c_str(), "End"); GE_CHK_STATUS_RET(ret, "Failed to create task for node: %s", node_item.NodeName().c_str()); + node_state.SetKernelTask(kernel_task); GELOGI("Compiling node %s successfully", node_state.GetName().c_str()); return SUCCESS; } - -Status TaskCompileEngine::CompileAsync(const NodeItem &node_item, ResultQueueEntry &entry) { - auto node_state = entry.node_state; - auto f = pool_.commit([this, node_item, node_state]() -> Status { return DoCompile(node_item, *node_state); }); - - if (!f.valid()) { - GELOGE(INTERNAL_ERROR, "Failed to commit compile task"); - return INTERNAL_ERROR; - } - - entry.future = unique_ptr>(new (std::nothrow) std::future(std::move(f))); - GE_CHECK_NOTNULL(entry.future); - return SUCCESS; -} - -Status TaskCompileEngine::DistributeCompiledTasks() { - GELOGD("DistributeCompiledTasks start."); - auto &execute_queue = context_->execution_queue; - unique_ptr entry; - bool ret = SUCCESS; - while (true) { - if (!complete_queue_.Pop(entry)) { - GELOGE(INTERNAL_ERROR, "Failed to pop item from queue"); - ret = INTERNAL_ERROR; - break; - } - - // EOF - if (entry == nullptr) { - break; - } - - // if has compile future - if (entry->future != nullptr) { - ret = entry->future->get(); - if (ret != SUCCESS) { - break; - } - } - - execute_queue.Push(entry->node_state); - } - - GELOGD("DistributeCompiledTasks out. ret = %u.", ret); - return ret; -} } // namespace hybrid -} // namespace ge +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/executor/worker/task_compile_engine.h b/src/ge/hybrid/executor/worker/task_compile_engine.h index 828a1d8c..a677cb2e 100644 --- a/src/ge/hybrid/executor/worker/task_compile_engine.h +++ b/src/ge/hybrid/executor/worker/task_compile_engine.h @@ -17,44 +17,13 @@ #ifndef GE_HYBRID_EXECUTOR_COMPILE_TASK_COMPILE_ENGINE_H_ #define GE_HYBRID_EXECUTOR_COMPILE_TASK_COMPILE_ENGINE_H_ -#include -#include -#include "common/thread_pool.h" #include "hybrid/executor/hybrid_execution_context.h" namespace ge { namespace hybrid { class TaskCompileEngine { public: - explicit TaskCompileEngine(GraphExecutionContext *context); - - ~TaskCompileEngine(); - - Status Init(); - - Status Start(ThreadPool &pool); - - private: - struct ResultQueueEntry { - NodeStatePtr node_state; - std::unique_ptr> future; - }; - - Status CompileProcess(); - - Status CompileDone(Status status); - - private: - Status DoCompile(const NodeItem &node_item, NodeState &node_state); - Status CompileAsync(const NodeItem &node_item, ResultQueueEntry &entry); - Status DistributeCompiledTasks(); - void Reset(); - - rtContext_t rt_context_ = nullptr; - GraphExecutionContext *context_; - BlockingQueue> complete_queue_; - ThreadPool pool_; - std::future worker_future_; + static Status Compile(NodeState &node_state, GraphExecutionContext *context); }; } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/hybrid_davinci_model.cc b/src/ge/hybrid/hybrid_davinci_model.cc index 58c7d0e3..0454fa72 100644 --- a/src/ge/hybrid/hybrid_davinci_model.cc +++ b/src/ge/hybrid/hybrid_davinci_model.cc @@ -18,6 +18,7 @@ #include "hybrid_davinci_model.h" #include "hybrid/model/hybrid_model.h" #include "hybrid/executor/hybrid_model_async_executor.h" +#include "hybrid/node_executor/node_executor.h" namespace ge { namespace hybrid { @@ -25,14 +26,19 @@ class HybridDavinciModel::Impl { public: explicit Impl(GeRootModelPtr ge_model) : model_(std::move(ge_model)), executor_(&model_) {} - ~Impl() = default; + ~Impl() { NodeExecutorManager::GetInstance().FinalizeExecutors(); } Status Init() { + GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().EnsureInitialized(), "Failed to initialize executors"); GE_CHK_STATUS_RET(model_.Init(), "Failed to init model.") GE_CHK_STATUS_RET(executor_.Init(), "Failed to init model executor.") return SUCCESS; } + Status Execute(const vector &inputs, vector &outputs) { + return executor_.Execute(inputs, outputs); + } + Status ModelRunStart() { return executor_.Start(listener_); } Status ModelRunStop() { return executor_.Stop(); } @@ -76,6 +82,11 @@ Status HybridDavinciModel::Init() { return impl_->Init(); } +Status HybridDavinciModel::Execute(const vector &inputs, vector &outputs) { + GE_CHECK_NOTNULL(impl_); + return impl_->Execute(inputs, outputs); +} + Status HybridDavinciModel::ModelRunStart() { GE_CHECK_NOTNULL(impl_); return impl_->ModelRunStart(); diff --git a/src/ge/hybrid/hybrid_davinci_model.h b/src/ge/hybrid/hybrid_davinci_model.h index 866b756b..c286a222 100644 --- a/src/ge/hybrid/hybrid_davinci_model.h +++ b/src/ge/hybrid/hybrid_davinci_model.h @@ -37,6 +37,8 @@ class HybridDavinciModel { Status Init(); + Status Execute(const vector &inputs, vector &outputs); + Status ModelRunStart(); Status ModelRunStop(); diff --git a/src/ge/hybrid/hybrid_davinci_model_stub.cc b/src/ge/hybrid/hybrid_davinci_model_stub.cc index bca118f8..7bde98a3 100644 --- a/src/ge/hybrid/hybrid_davinci_model_stub.cc +++ b/src/ge/hybrid/hybrid_davinci_model_stub.cc @@ -26,6 +26,8 @@ std::unique_ptr HybridDavinciModel::Create(const GeRootModel Status HybridDavinciModel::Init() { return UNSUPPORTED; } +Status HybridDavinciModel::Execute(const vector &inputs, vector &outputs) { return UNSUPPORTED; } + Status HybridDavinciModel::ModelRunStart() { return UNSUPPORTED; } Status HybridDavinciModel::ModelRunStop() { return UNSUPPORTED; } diff --git a/src/ge/hybrid/model/graph_item.cc b/src/ge/hybrid/model/graph_item.cc new file mode 100644 index 00000000..120865ce --- /dev/null +++ b/src/ge/hybrid/model/graph_item.cc @@ -0,0 +1,63 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "framework/common/util.h" +#include "graph_item.h" + +namespace ge { +namespace hybrid { +namespace { +constexpr int kInvalidIndex = -1; +} // namespace +GraphItem::~GraphItem() { GELOGD("[%s] GraphItem destroyed.", name_.c_str()); } + +const vector &hybrid::GraphItem::GetAllNodes() const { return node_items_; } + +const vector &GraphItem::GetInputNodes() const { return input_nodes_; } + +Status GraphItem::GetOutputDescList(vector &output_desc_list) const { + if (output_node_ == nullptr) { + return SUCCESS; + } + + if (is_dynamic_) { + for (auto &tensor_desc : output_node_->op_desc->GetAllInputsDescPtr()) { + output_desc_list.emplace_back(tensor_desc); + } + } else { + for (auto &tensor_desc : output_node_->op_desc->GetAllOutputsDescPtr()) { + output_desc_list.emplace_back(tensor_desc); + } + } + + return SUCCESS; +} + +bool GraphItem::IsDynamic() const { return is_dynamic_; } + +const vector &GraphItem::GetInputIndexMapping() const { return input_index_mapping_; } + +int GraphItem::GetParentOutputIndex(size_t index) const { + if (index >= output_index_mapping_.size()) { + return kInvalidIndex; + } + + return output_index_mapping_[index]; +} + +const NodeItem *GraphItem::GetOutputNode() const { return output_node_; } +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/model/graph_item.h b/src/ge/hybrid/model/graph_item.h new file mode 100644 index 00000000..cb0fbbed --- /dev/null +++ b/src/ge/hybrid/model/graph_item.h @@ -0,0 +1,64 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_MODEL_SUBGRAPH_ITEM_H_ +#define GE_HYBRID_MODEL_SUBGRAPH_ITEM_H_ + +#include "external/ge/ge_api_error_codes.h" +#include "hybrid/model/node_item.h" + +namespace ge { +namespace hybrid { +class GraphItem { + public: + GraphItem() = default; + ~GraphItem(); + const vector &GetAllNodes() const; + const vector &GetInputNodes() const; + Status GetOutputDescList(std::vector &output_desc_list) const; + + int TotalInputs() const { return total_inputs_; } + + int TotalOutputs() const { return total_outputs_; } + + const std::string &GetName() const { return name_; } + + void SetName(const string &name) { name_ = name; } + + const NodeItem *GetOutputNode() const; + + bool IsDynamic() const; + int GetParentOutputIndex(size_t index) const; + const vector &GetInputIndexMapping() const; + + private: + friend class HybridModelBuilder; + std::string name_; + std::vector node_items_; + std::vector input_nodes_; + const NodeItem *output_node_ = nullptr; + // + std::vector> output_edges_; + int total_inputs_ = 0; + int total_outputs_ = 0; + + bool is_dynamic_ = true; + std::vector input_index_mapping_; + std::vector output_index_mapping_; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_MODEL_SUBGRAPH_ITEM_H_ diff --git a/src/ge/hybrid/model/hybrid_model.cc b/src/ge/hybrid/model/hybrid_model.cc index e3726aec..0cb81aa3 100644 --- a/src/ge/hybrid/model/hybrid_model.cc +++ b/src/ge/hybrid/model/hybrid_model.cc @@ -29,6 +29,8 @@ namespace ge { namespace hybrid { HybridModel::HybridModel(GeRootModelPtr ge_model) : ge_root_model_(std::move(ge_model)) {} +HybridModel::~HybridModel() { GELOGD("[%s] HybridModel destroyed.", model_name_.c_str()); } + Status HybridModel::Init() { GELOGD("Start to init hybrid model."); GE_CHK_STATUS_RET(HybridModelBuilder(*this).Build(), "Failed to build hybrid model."); @@ -36,22 +38,6 @@ Status HybridModel::Init() { return SUCCESS; } -void HybridModel::Print() const { - for (const auto &node : node_items_) { - GELOGD("%s", node->DebugString().c_str()); - } -} - -TensorValue *HybridModel::GetWeight(const NodeItem *const_node) const { - auto it = weights_.find(const_node->node_id); - if (it == weights_.end() || it->second == nullptr) { - GELOGE(INTERNAL_ERROR, "[%s] Failed to get weight", const_node->NodeName().c_str()); - return nullptr; - } - - return it->second.get(); -} - TensorValue *HybridModel::GetVariable(const string &name) const { auto it = variable_tensors_.find(name); if (it == variable_tensors_.end()) { @@ -83,26 +69,26 @@ const std::vector *HybridModel::GetTaskDefs(const NodePtr &node) } NodeItem *HybridModel::MutableNodeItem(const NodePtr &node) { - auto node_id = node->GetOpDesc()->GetId(); - if (node_id < 0 || static_cast(node_id) > node_items_.size()) { - GELOGE(INTERNAL_ERROR, "index out of range. node_id = %ld, num_nodes = %zu", node_id, node_items_.size()); + auto it = node_items_.find(node); + if (it == node_items_.end()) { return nullptr; } - return node_items_[node_id].get(); + + return it->second.get(); } const NodeItem *HybridModel::GetNodeItem(const NodePtr &node) const { - auto node_id = node->GetOpDesc()->GetId(); - if (node_id < 0 || static_cast(node_id) > node_items_.size()) { - GELOGE(INTERNAL_ERROR, "Index out of range. node_id = %ld, num_nodes = %zu.", node_id, node_items_.size()); + auto it = node_items_.find(node); + if (it == node_items_.end()) { return nullptr; } - return node_items_[node_id].get(); + + return it->second.get(); } GeModelPtr HybridModel::GetGeModel(const NodePtr &node) const { - auto it = known_shape_sub_graphs_.find(node); - if (it == known_shape_sub_graphs_.end()) { + auto it = known_shape_sub_models_.find(node); + if (it == known_shape_sub_models_.end()) { GELOGE(INTERNAL_ERROR, "[%s] Failed to get GeModel for subgraph node.", node->GetName().c_str()); return nullptr; } @@ -110,8 +96,27 @@ GeModelPtr HybridModel::GetGeModel(const NodePtr &node) const { return it->second; } -const vector &HybridModel::GetNetOutputInputOffsets() const { return net_output_input_offsets_; } +const GraphItem *HybridModel::GetRootGraphItem() const { return root_graph_item_.get(); } + +const GraphItem *HybridModel::GetSubgraphItem(const std::string &graph_name) const { + GELOGD("To find subgraph item by name = %s", graph_name.c_str()); + auto it = subgraph_items_.find(graph_name); + if (it == subgraph_items_.end()) { + GELOGD("Subgraph item not found by node = %s", graph_name.c_str()); + return nullptr; + } + + return it->second.get(); +} + +const GraphItem *HybridModel::GetSubgraphItem(const ComputeGraphPtr &subgraph) const { + if (subgraph == nullptr) { + GELOGE(PARAM_INVALID, "subgraph is nullptr"); + return nullptr; + } -void HybridModel::SetDeviceId(uint32_t device_id) { device_id_ = device_id; } + auto subgraph_name = subgraph->GetName(); + return GetSubgraphItem(subgraph_name); +} } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/model/hybrid_model.h b/src/ge/hybrid/model/hybrid_model.h index 007f76c6..f554752e 100644 --- a/src/ge/hybrid/model/hybrid_model.h +++ b/src/ge/hybrid/model/hybrid_model.h @@ -26,39 +26,23 @@ #include "graph/node.h" #include "hybrid/common/tensor_value.h" #include "hybrid/model/node_item.h" +#include "hybrid/model/graph_item.h" #include "model/ge_root_model.h" namespace ge { namespace hybrid { -class HybridModelAsyncExecutor; class HybridModel { public: explicit HybridModel(GeRootModelPtr ge_model); - ~HybridModel() = default; + ~HybridModel(); Status Init(); - const std::vector &RootNodes() const { return root_nodes_; } - const NodeItem *GetNodeItem(const NodePtr &node) const; - size_t NumNodes() const { return node_items_.size(); } - uint64_t GetSessionId() const { return root_runtime_param_.session_id; } - int TotalInputs() const { return total_inputs_; } - - const map &GetInputNodes() const { return input_nodes_; } - - const std::map> &GetInputOffsets() const { return input_offsets_; } - - const vector &GetNetOutputInputOffsets() const; - - const std::vector &GetOutputOffsets() const { return output_offsets_; } - - const std::vector &GetConstNodes() const { return const_nodes_; } - GeModelPtr GetGeModel(const NodePtr &node) const; NodeItem *MutableNodeItem(const NodePtr &node); @@ -67,46 +51,40 @@ class HybridModel { const uint8_t *GetVarMemBase() const { return var_mem_base_; } - void SetDeviceId(uint32_t device_id); + void SetDeviceId(uint32_t device_id) { device_id_ = device_id; } void SetModelId(uint32_t model_id) { model_id_ = model_id; } uint32_t GetModelId() const { return model_id_; } - TensorValue *GetWeight(const NodeItem *const_node) const; - TensorValue *GetVariable(const string &name) const; NodePtr GetVariableNode(const string &name) const; const std::vector *GetTaskDefs(const NodePtr &node) const; - int TotalOutputs() const { return total_outputs_; } + const GraphItem *GetRootGraphItem() const; - GeRootModelPtr GetGeRootModel() const { return ge_root_model_; } - void Print() const; + const GraphItem *GetSubgraphItem(const std::string &graph_name) const; + + const GraphItem *GetSubgraphItem(const ComputeGraphPtr &subgraph) const; private: friend class HybridModelBuilder; friend class HybridModelAsyncExecutor; + std::string model_name_; GeRootModelPtr ge_root_model_; - std::vector root_nodes_; std::map input_nodes_; - std::map> input_offsets_; - std::vector output_offsets_; - std::vector net_output_input_offsets_; - NodeItem *net_output_node_ = nullptr; - std::vector> node_items_; - std::vector const_nodes_; std::map constant_op_nodes_; std::map variable_nodes_; std::map> variable_tensors_; - std::map> weights_; std::map> task_defs_; - std::map known_shape_sub_graphs_; - int total_inputs_ = 0; - int total_outputs_ = 0; + std::map known_shape_sub_models_; + + std::unique_ptr root_graph_item_; + std::map> subgraph_items_; + std::map> node_items_; // runtime fields uint32_t device_id_ = 0; diff --git a/src/ge/hybrid/model/hybrid_model_builder.cc b/src/ge/hybrid/model/hybrid_model_builder.cc index 190890b7..436beada 100644 --- a/src/ge/hybrid/model/hybrid_model_builder.cc +++ b/src/ge/hybrid/model/hybrid_model_builder.cc @@ -23,7 +23,6 @@ #include "graph/manager/trans_var_data_utils.h" #include "graph/utils/graph_utils.h" #include "graph/utils/type_utils.h" -#include "framework/common/debug/log.h" #include "hybrid/common/npu_memory_allocator.h" #include "hybrid/node_executor/node_executor.h" @@ -32,6 +31,7 @@ namespace hybrid { namespace { const uint32_t kSubgraphIndex = 0U; const uint32_t kVarOutputIndex = 0U; +const uint32_t kAlignment = 32; const int kBytes = 8; int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) { @@ -46,6 +46,9 @@ int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) { for (size_t dim_index = 0; dim_index < dim_num; ++dim_index) { var_size *= shape.GetDim(dim_index); } + + // padding up to multiple of kAlignment, and add extra kAlignment + var_size = (var_size + kAlignment * 2 - 1) / kAlignment * kAlignment; return var_size; } } // namespace @@ -56,20 +59,19 @@ HybridModelBuilder::HybridModelBuilder(HybridModel &hybrid_model) Status HybridModelBuilder::Build() { GE_CHK_STATUS_RET(ValidateParams(), "Failed to validate GeRootModel"); - graph_name_ = ge_root_model_->GetRootGraph()->GetName(); + hybrid_model_.model_name_ = ge_root_model_->GetRootGraph()->GetName(); GELOGI("[%s] Start to build hybrid model.", GetGraphName()); GE_CHK_STATUS_RET(InitRuntimeParams(), "[%s] Failed to InitRuntimeParams", GetGraphName()); - GE_CHK_STATUS_RET(NodeExecutorManager::GetInstance().EnsureInitialized(), "Failed to initialize executors"); GE_CHK_STATUS_RET(IndexSpecialNodes(), "[%s] Failed to index nodes", GetGraphName()); GE_CHK_STATUS_RET(IndexTaskDefs(), "[%s] Failed to index task defs", GetGraphName()); GE_CHK_STATUS_RET(LoadGraph(), "[%s] Failed to load graph", GetGraphName()); + GE_CHK_STATUS_RET(AssignUninitializedConstantOps(), "[%s] Failed to assign uninitialized constants", GetGraphName()); GE_CHK_STATUS_RET(TransAllVarData(), "[%s] Failed to trans all var data", GetGraphName()); GE_CHK_STATUS_RET(CopyVarData(), "[%s] Failed to copy var data", GetGraphName()); GE_CHK_STATUS_RET(InitModelMem(), "[%s] Failed to init memory", GetGraphName()); GE_CHK_STATUS_RET(InitWeights(), "[%s] Failed to init weights", GetGraphName()); GE_CHK_STATUS_RET(InitConstantOps(), "[%s] Failed to init constant op", GetGraphName()); GE_CHK_STATUS_RET(InitVariableTensors(), "[%s] Failed to init variables", GetGraphName()); - GE_CHK_STATUS_RET(ResolveRootNodes(), "[%s] Failed to resolve root nodes", GetGraphName()); GE_CHK_STATUS_RET(LoadTasks(), "[%s] Failed to load tasks", GetGraphName()); GELOGI("[%s] Done building hybrid model successfully.", GetGraphName()); return SUCCESS; @@ -81,45 +83,17 @@ Status HybridModelBuilder::ValidateParams() { return SUCCESS; } -Status HybridModelBuilder::ResolveRootNodes() { - for (auto &node : hybrid_model_.node_items_) { - if (node->node->GetInDataNodes().empty()) { - hybrid_model_.root_nodes_.emplace_back(node.get()); - GELOGI("[%s] Root node added. node name = %s", GetGraphName(), node->NodeName().c_str()); - } - } - - if (hybrid_model_.root_nodes_.empty()) { - GELOGE(PARAM_INVALID, "[%s] Root nodes is empty.", GetGraphName()); - return PARAM_INVALID; - } - - return SUCCESS; -} - -Status HybridModelBuilder::BuildNoteItem(const NodePtr &node, NodeItem &node_item) { - GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, node_item.is_dynamic), - "[%s] Failed to get shape status.", node->GetName().c_str()); - +Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { auto op_desc = node->GetOpDesc(); vector dependencies = node->GetOpDesc()->GetOpInferDepends(); GE_CHK_STATUS_RET(ParseDependentInputNodes(node_item, dependencies), "[%s] Failed to parse node dependencies.", node_item.NodeName().c_str()); - auto it = node_ref_inputs_.find(node); - if (it != node_ref_inputs_.end()) { - for (auto &idx_and_node : it->second) { - // var and constant only have one output - node_item.const_input_shapes[idx_and_node.first] = - idx_and_node.second->GetOpDesc()->MutableOutputDesc(kVarOutputIndex); - } - } - node_item.outputs.resize(node_item.num_outputs); for (int i = 0; i < node_item.num_outputs; ++i) { auto out_data_anchor = node->GetOutDataAnchor(i); if (out_data_anchor == nullptr) { - GELOGE(INTERNAL_ERROR, "out anchor[%zu] of node %s is nullptr", i, node->GetName().c_str()); + GELOGE(INTERNAL_ERROR, "out anchor[%d] of node %s is nullptr", i, node->GetName().c_str()); return INTERNAL_ERROR; } @@ -137,27 +111,46 @@ Status HybridModelBuilder::BuildNoteItem(const NodePtr &node, NodeItem &node_ite } } + GE_CHK_STATUS_RET_NOLOG(ResolveRefIo(node_item)); return SUCCESS; } -Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item) { - auto &node_items = hybrid_model_.node_items_; - auto node_id = node->GetOpDesc()->GetId(); - if (node_id < 0 || static_cast(node_id) > node_items.size()) { - GELOGE(INTERNAL_ERROR, "[%s] Index out of range. node_id = %ld, num_nodes = %zu", node->GetName().c_str(), node_id, - node_items.size()); - return INTERNAL_ERROR; +Status HybridModelBuilder::ResolveRefIo(NodeItem &node_item) { + bool is_ref = false; + auto &op_desc = *node_item.op_desc; + (void)AttrUtils::GetBool(op_desc, ATTR_NAME_REFERENCE, is_ref); + if (!is_ref) { + return SUCCESS; } - auto &node_ptr = node_items[node_id]; - if (node_ptr != nullptr) { - *node_item = node_ptr.get(); + auto inputs = op_desc.GetAllInputName(); + auto outputs = op_desc.GetAllOutputName(); + for (auto &output : outputs) { + for (auto &input : inputs) { + if (input.first == output.first) { + auto input_idx = static_cast(input.second); + auto output_idx = static_cast(output.second); + node_item.reuse_inputs[output_idx] = input_idx; + GELOGD("[%s] Output[%d] reuse input[%d]", node_item.NodeName().c_str(), output_idx, input_idx); + } + } + } + + return SUCCESS; +} + +Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item) { + auto &node_items = hybrid_model_.node_items_; + auto it = node_items.find(node); + if (it != node_items.end()) { + *node_item = it->second.get(); return SUCCESS; } auto new_node = std::unique_ptr(new (std::nothrow) NodeItem(node)); GE_CHECK_NOTNULL(new_node); GE_CHECK_NOTNULL(new_node->op_desc); + GE_CHK_STATUS_RET(new_node->Init(), "Failed to init NodeItem [%s] .", node->GetName().c_str()); GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); // we do not need L2 Buffer @@ -166,21 +159,54 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n (void)AttrUtils::SetBool(new_node->op_desc, kIsFirstNode, false); (void)AttrUtils::SetBool(new_node->op_desc, kIsLastNode, false); - int32_t unknown_shape_type_val = 0; - (void)AttrUtils::GetInt(new_node->op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val); - new_node->shape_inference_type = static_cast(unknown_shape_type_val); - if (new_node->shape_inference_type == DEPEND_SHAPE_RANGE || new_node->shape_inference_type == DEPEND_COMPUTE) { - new_node->has_observer = true; + if (new_node->is_dynamic && (new_node->IsControlOp() || new_node->NodeType() == PARTITIONEDCALL)) { + new_node->shape_inference_type = DEPEND_COMPUTE; } + new_node->node_id = node_index; + new_node->op_desc->SetId(node_index); + node_index += 1; + *node_item = new_node.get(); - node_items[node_id] = std::move(new_node); + node_items[node] = std::move(new_node); return SUCCESS; } Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies) { std::set dependent_input_nodes; auto &ge_node = node_item.node; + + // The input tensors become valid after computation is done for parent nodes of type DEPEND_COMPUTE. + // Wait for these parent nodes before execution. + for (const auto &in_anchor : ge_node->GetAllInDataAnchors()) { + const auto &peer_anchor = in_anchor->GetPeerOutAnchor(); + if (peer_anchor == nullptr) { + GELOGD("[%s] Input[%d] do not have peer anchor", node_item.NodeName().c_str(), in_anchor->GetIdx()); + continue; + } + + auto src_node = peer_anchor->GetOwnerNode(); + GE_CHECK_NOTNULL(src_node); + + auto src_node_item = MutableNodeItem(src_node); + GE_CHECK_NOTNULL(src_node_item); + + if (src_node_item->shape_inference_type == DEPEND_COMPUTE) { + GELOGD("[%s] Add input data dependent node [%s] due to inference type = DEPEND_COMPUTE", + node_item.NodeName().c_str(), src_node_item->NodeName().c_str()); + + src_node_item->has_observer = true; + node_item.dependents_for_execution.emplace_back(src_node); + } + + if (src_node_item->shape_inference_type == DEPEND_SHAPE_RANGE) { + GELOGD("[%s] Add input shape dependent node [%s] due to inference type = DEPEND_SHAPE_RANGE", + node_item.NodeName().c_str(), src_node_item->NodeName().c_str()); + src_node_item->has_observer = true; + dependent_input_nodes.emplace(src_node); + } + } + for (const auto &input_name : dependencies) { int input_index = node_item.op_desc->GetInputIndexByName(input_name); if (input_index < 0) { @@ -205,7 +231,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s } for (const auto &dep_node : dependent_input_nodes) { - node_item.dependent_node_list.emplace_back(dep_node); + node_item.dependents_for_shape_inference.emplace_back(dep_node); } return SUCCESS; @@ -262,9 +288,14 @@ Status HybridModelBuilder::DoLinkDataAnchors(OutDataAnchorPtr &out_data_anchor, Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { const auto &wrapped_node = graph.GetParentNode(); + std::set root_nodes; for (const auto &node : graph.GetDirectNode()) { GE_CHECK_NOTNULL(node); if (node->GetType() != DATA_TYPE) { + if (node->GetInDataNodes().empty()) { + root_nodes.emplace(node); + } + continue; } @@ -291,12 +322,28 @@ Status HybridModelBuilder::MergeInputNodes(ComputeGraph &graph) { for (auto &out_data_anchor : node->GetAllOutDataAnchors()) { GE_CHECK_NOTNULL(out_data_anchor); for (auto &peer_in_data_anchor : out_data_anchor->GetPeerInDataAnchors()) { + auto dst_node = peer_in_data_anchor->GetOwnerNode(); + root_nodes.emplace(dst_node); GE_CHK_STATUS_RET_NOLOG(DoUnlinkDataAnchors(out_data_anchor, peer_in_data_anchor)); GE_CHK_STATUS_RET_NOLOG(DoLinkDataAnchors(src_out_anchor, peer_in_data_anchor)); } } } + // transfer in control edges to all root nodes + for (auto &root_node : root_nodes) { + auto in_nodes = root_node->GetInAllNodes(); + std::set in_node_set(in_nodes.begin(), in_nodes.end()); + for (auto &in_control_node : wrapped_node->GetInControlNodes()) { + if (in_node_set.count(in_control_node) == 0) { + GELOGD("[%s] Restore control edge to [%s]", in_control_node->GetName().c_str(), root_node->GetName().c_str()); + GE_CHECK_NOTNULL(in_control_node->GetOutControlAnchor()); + (void)in_control_node->GetOutControlAnchor()->LinkTo(root_node->GetInControlAnchor()); + } + } + } + + wrapped_node->GetInControlAnchor()->UnlinkAll(); return SUCCESS; } @@ -307,6 +354,11 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { const auto &net_output_desc = net_output_node->GetOpDesc(); GE_CHECK_NOTNULL(net_output_desc); + auto all_in_nodes = net_output_node->GetInAllNodes(); + auto all_out_nodes = parent_node->GetOutAllNodes(); + net_output_node->GetInControlAnchor()->UnlinkAll(); + parent_node->GetOutControlAnchor()->UnlinkAll(); + for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { auto src_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(src_out_anchor); @@ -338,10 +390,25 @@ Status HybridModelBuilder::MergeNetOutputNode(ComputeGraph &graph) { } } + // transfer out control edges + std::set in_node_set(all_in_nodes.begin(), all_in_nodes.end()); + std::set out_node_set(all_out_nodes.begin(), all_out_nodes.end()); + for (auto &src_node : in_node_set) { + GELOGD("[%s] process in node.", src_node->GetName().c_str()); + auto out_nodes = src_node->GetOutAllNodes(); + std::set node_set(out_nodes.begin(), out_nodes.end()); + for (auto &dst_node : out_node_set) { + if (node_set.count(dst_node) == 0) { + src_node->GetOutControlAnchor()->LinkTo(dst_node->GetInControlAnchor()); + GELOGD("[%s] Restore control edge to [%s]", src_node->GetName().c_str(), dst_node->GetName().c_str()); + } + } + } + return SUCCESS; } -Status HybridModelBuilder::MergeSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) { +Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph) { merged_graph = MakeShared("MergedGraph"); for (const auto &node : root_graph.GetDirectNode()) { GE_CHECK_NOTNULL(node); @@ -371,32 +438,74 @@ Status HybridModelBuilder::MergeSubgraphs(ComputeGraph &root_graph, ComputeGraph } auto subgraph = NodeUtils::GetSubgraph(*node, kSubgraphIndex); - GE_CHK_STATUS_RET(MergeInputNodes(*subgraph), "Failed to merge data nodes for subgraph: %s", - subgraph->GetName().c_str()); - GE_CHK_STATUS_RET(MergeNetOutputNode(*subgraph), "Failed to merge net output nodes for subgraph: %s", - subgraph->GetName().c_str()); - GELOGD("Merging subgraph %s successfully.", subgraph->GetName().c_str()); - for (auto &sub_node : subgraph->GetAllNodes()) { - auto sub_op_type = sub_node->GetType(); - if (sub_op_type == DATA_TYPE || sub_op_type == NETOUTPUT) { - continue; - } + GE_CHECK_NOTNULL(subgraph); + GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), "[%s] Failed to merge subgraph.", + subgraph->GetName().c_str()); + } - if (sub_op_type == CONSTANT || sub_op_type == CONSTANTOP || sub_op_type == VARIABLE) { - GELOGE(INTERNAL_ERROR, "Unexpected node in unknown subgraph. type = %s, node = %s::%s", sub_op_type.c_str(), - subgraph->GetName().c_str(), sub_node->GetName().c_str()); - return INTERNAL_ERROR; - } + // invoke before adding subgraphs. in case modify node id in known-shaped subgraphs. + GE_CHK_GRAPH_STATUS_RET(merged_graph->TopologicalSorting(), "Failed to invoke TopologicalSorting on merged graph."); + + for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { + GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); + GE_CHK_GRAPH_STATUS_RET(merged_graph->AddSubgraph(remained_subgraph), "Failed to add subgraph [%s]", + remained_subgraph->GetName().c_str()); + } + + return SUCCESS; +} - merged_graph->AddNode(sub_node); - GELOGD("%s::%s added to merged graph.", subgraph->GetName().c_str(), sub_node->GetName().c_str()); +Status HybridModelBuilder::UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, + ComputeGraph &sub_graph) { + auto parent_node = sub_graph.GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + + GE_CHK_STATUS_RET(MergeInputNodes(sub_graph), "[%s] Failed to merge data nodes for subgraph", + sub_graph.GetName().c_str()); + GE_CHK_STATUS_RET(MergeNetOutputNode(sub_graph), "[%s] Failed to merge net output nodes for subgraph", + sub_graph.GetName().c_str()); + GELOGD("[%s] Done merging subgraph inputs and outputs successfully.", sub_graph.GetName().c_str()); + + for (auto &sub_node : sub_graph.GetDirectNode()) { + auto sub_op_type = sub_node->GetType(); + if (sub_op_type == DATA_TYPE || sub_op_type == NETOUTPUT) { + continue; + } + + if (sub_op_type == CONSTANT || sub_op_type == VARIABLE) { + GELOGE(INTERNAL_ERROR, "Unexpected node in unknown subgraph. type = %s, node = %s::%s", sub_op_type.c_str(), + sub_graph.GetName().c_str(), sub_node->GetName().c_str()); + return INTERNAL_ERROR; } + + if (sub_op_type == PARTITIONEDCALL) { + bool is_unknown_shape = false; + GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*sub_node, is_unknown_shape), + "[%s] Failed to invoke GetNodeUnknownShapeStatus.", sub_node->GetName().c_str()); + if (is_unknown_shape) { + auto sub_sub_graph = NodeUtils::GetSubgraph(*sub_node, kSubgraphIndex); + GE_CHECK_NOTNULL(sub_sub_graph); + GE_CHK_STATUS_RET(UnfoldSubgraph(root_graph, parent_graph, *sub_sub_graph), "[%s] Failed to merge subgraph", + sub_sub_graph->GetName().c_str()); + continue; + } + } + + parent_graph.AddNode(sub_node); + GELOGD("[%s::%s] added to parent graph: [%s].", sub_graph.GetName().c_str(), sub_node->GetName().c_str(), + parent_graph.GetName().c_str()); } + GELOGD("[%s] Done merging subgraph. remove it from root graph.", sub_graph.GetName().c_str()); + root_graph.RemoveSubgraph(sub_graph.GetName()); return SUCCESS; } -Status HybridModelBuilder::ParseNetOutput(const NodeItem &node_item) { +Status HybridModelBuilder::BuildOutputMapping(GraphItem &graph_item, const NodeItem &node_item, bool is_root_graph) { + auto output_size = node_item.op_desc->GetAllInputsSize(); + GE_CHECK_LE(output_size, UINT32_MAX); + graph_item.output_edges_.resize(output_size); + for (auto &in_data_anchor : node_item.node->GetAllInDataAnchors()) { auto peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); GE_CHECK_NOTNULL(peer_out_anchor); @@ -408,11 +517,20 @@ Status HybridModelBuilder::ParseNetOutput(const NodeItem &node_item) { auto output_offset = src_node_item->output_start + peer_out_anchor->GetIdx(); GELOGI("Output[%d], node = %s, output_index = %d, output_offset = %d ", in_data_anchor->GetIdx(), src_node_item->NodeName().c_str(), peer_out_anchor->GetIdx(), output_offset); - hybrid_model_.output_offsets_.emplace_back(output_offset); + + graph_item.output_edges_[in_data_anchor->GetIdx()] = {src_node_item, peer_out_anchor->GetIdx()}; } - for (int i = 0; i < node_item.num_inputs; ++i) { - hybrid_model_.net_output_input_offsets_.emplace_back(node_item.input_start + i); + if (!is_root_graph) { + for (uint32_t i = 0; i < static_cast(output_size); ++i) { + uint32_t p_index = i; + // Net output of Subgraph of while do not have parent index + if (AttrUtils::GetInt(node_item.op_desc->GetInputDesc(i), ATTR_NAME_PARENT_NODE_INDEX, p_index)) { + GELOGD("[%s] Parent index not set for input[%u].", node_item.NodeName().c_str(), i); + } + + graph_item.output_index_mapping_.emplace_back(p_index); + } } return SUCCESS; @@ -420,82 +538,46 @@ Status HybridModelBuilder::ParseNetOutput(const NodeItem &node_item) { Status HybridModelBuilder::LoadGraph() { auto root_graph = ge_root_model_->GetRootGraph(); - GELOGI("Before merge subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), + std::shared_ptr merged_graph; + GELOGI("Before merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), root_graph->GetAllNodesSize()); - ComputeGraphPtr merged_graph; - GE_CHK_STATUS_RET_NOLOG(MergeSubgraphs(*root_graph, merged_graph)); - GELOGI("After merge subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", merged_graph->GetDirectNodesSize(), - merged_graph->GetAllNodesSize()); - - merged_graph->SetGraphID(runtime_param_.graph_id); - GE_DUMP(merged_graph, "hybrid_merged_graph"); - int input_start = 0; - int output_start = 0; - uint32_t data_op_index = 0; - hybrid_model_.node_items_.resize(merged_graph->GetDirectNodesSize()); - - int64_t node_index = 0; - for (auto &node : merged_graph->GetDirectNode()) { - OpDescPtr op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - op_desc->SetId(node_index++); - } - - for (const auto &node : merged_graph->GetDirectNode()) { - GE_CHECK_NOTNULL(node); - GE_CHECK_NOTNULL(node->GetOpDesc()); - const auto &op_type = node->GetType(); - - NodeItem *node_item = nullptr; - GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); - GE_CHK_STATUS_RET_NOLOG(BuildNoteItem(node, *node_item)); - GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task - - node_item->input_start = input_start; - node_item->output_start = output_start; - input_start += node_item->num_inputs; - output_start += node_item->num_outputs; - - if (op_type == DATA_TYPE || op_type == AIPP_DATA_TYPE) { - auto data_index = data_op_index; - if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_INDEX, data_index)) { - GELOGI("ge_train: get new index %u, old %u", data_index, data_op_index); - } - hybrid_model_.input_nodes_.emplace(data_index, node_item); - data_op_index++; - } else if (op_type == NETOUTPUT) { - hybrid_model_.net_output_node_ = node_item; - GE_CHK_STATUS_RET_NOLOG(ParseNetOutput(*node_item)); - } else if (op_type == PARTITIONEDCALL) { // known graph - GE_CHK_STATUS_RET_NOLOG(ParsePartitionedCall(*node_item)); + GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraphs(*root_graph, merged_graph), "Failed to unfold subgraphs."); + root_graph = std::move(merged_graph); + GELOGI("After merging subgraphs DirectNodesSize = %zu, GetAllNodesSize = %zu", root_graph->GetDirectNodesSize(), + root_graph->GetAllNodesSize()); + GE_DUMP(root_graph, "hybrid_merged_graph"); + + GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph."); + GELOGD("Done loading root graph successfully."); + + for (auto &sub_graph : root_graph->GetAllSubgraphs()) { + GE_CHECK_NOTNULL(sub_graph); + GELOGD("Start to load subgraph [%s]", sub_graph->GetName().c_str()); + auto parent_node = sub_graph->GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + auto parent_node_item = MutableNodeItem(parent_node); + // parent node is in another known subgraph + if (parent_node_item == nullptr) { + GELOGD("[%s] Subgraph is in another known shaped subgraph, skip it.", sub_graph->GetName().c_str()); + continue; } - GELOGI("NodeItem created: %s", node_item->DebugString().c_str()); - } - - for (auto &it : hybrid_model_.input_nodes_) { - auto input_index = it.first; - auto input_node = it.second; - - if (input_node->outputs.empty()) { - GELOGE(INTERNAL_ERROR, "data output anchor is empty"); - return INTERNAL_ERROR; - } + if (sub_graph->GetGraphUnknownFlag()) { + GE_CHK_STATUS_RET(LoadDynamicSubgraph(*sub_graph, false), "Failed to load subgraph: [%s]", + sub_graph->GetName().c_str()); + } else { + GE_CHK_STATUS_RET(IdentifyVariableOutputs(*parent_node_item), "[%s] Failed to identify ref outputs.", + parent_node_item->NodeName().c_str()); - for (auto &out : input_node->outputs) { - std::vector offsets; - for (auto &dst_anchor_and_node : out) { - auto dst_node_item = dst_anchor_and_node.second; - offsets.emplace_back(dst_node_item->input_start + dst_anchor_and_node.first); + // if parent is function control op. need add a virtual partitioned call + if (parent_node_item->IsControlOp()) { + GE_CHK_STATUS_RET(LoadKnownShapedSubgraph(*sub_graph, parent_node_item), + "Failed to load function control op subgraph [%s]", sub_graph->GetName().c_str()); } - - hybrid_model_.input_offsets_.emplace(input_index, std::move(offsets)); } } - hybrid_model_.total_inputs_ = input_start; - hybrid_model_.total_outputs_ = output_start; - GELOGI("HybridGraph::LoadGraph OUT"); + GELOGI("Done loading all subgraphs successfully."); return SUCCESS; } @@ -507,7 +589,6 @@ Status HybridModelBuilder::VarNodeToTensor(const NodePtr &var_node, std::unique_ string var_name = var_node->GetName(); auto tensor_desc = var_node->GetOpDesc()->MutableOutputDesc(0); uint8_t *var_logic = nullptr; - GE_CHK_STATUS_RET(var_manager_->GetVarAddr(var_name, *tensor_desc, &var_logic), "Failed to get var addr. var_name = %s, session_id = %ld", var_name.c_str(), hybrid_model_.GetSessionId()); @@ -559,10 +640,26 @@ Status HybridModelBuilder::HandleDtString(const GeTensor &tensor, void *var_addr return SUCCESS; } +Status HybridModelBuilder::AssignUninitializedConstantOps() { + for (auto &it : hybrid_model_.constant_op_nodes_) { + const string &var_name = it.first; + const NodePtr &var_node = it.second; + auto tensor_desc = var_node->GetOpDesc()->MutableOutputDesc(0); + if (!var_manager_->IsVarExist(var_name, *tensor_desc)) { + // allocate constant + GELOGD("[%s] Constant not allocated during graph building. now allocate it.", var_name.c_str()); + GE_CHK_STATUS_RET(var_manager_->AssignVarMem(var_name, *tensor_desc, RT_MEMORY_HBM)); + GE_CHK_STATUS_RET(var_manager_->SetAllocatedGraphId(var_name, runtime_param_.graph_id)); + } + } + + return SUCCESS; +} + Status HybridModelBuilder::InitConstantOps() { for (auto &it : hybrid_model_.constant_op_nodes_) { - string var_name = it.first; - NodePtr &var_node = it.second; + const string &var_name = it.first; + const NodePtr &var_node = it.second; std::unique_ptr var_tensor; GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); @@ -578,7 +675,7 @@ Status HybridModelBuilder::InitConstantOps() { if (ge_tensor->GetData().size() > 0) { GE_CHK_STATUS_RET_NOLOG(HandleDtString(*ge_tensor, v_output_addr)); - GELOGI("[IMAS]InitConstant memcpy graph_%u type[V] name[%s] output[%d] memaddr[%p] mem_size[%u] datasize[%zu]", + GELOGI("[IMAS]InitConstant memcpy graph_%u type[V] name[%s] output[%d] memaddr[%p] mem_size[%zu] datasize[%zu]", runtime_param_.graph_id, op_desc->GetName().c_str(), 0, v_output_addr, v_output_size, ge_tensor->GetData().size()); GE_CHK_RT_RET(rtMemcpy(v_output_addr, v_output_size, ge_tensor->GetData().data(), ge_tensor->GetData().size(), @@ -614,7 +711,8 @@ Status HybridModelBuilder::InitWeights() { } Status HybridModelBuilder::LoadTasks() { - for (auto &node_item : hybrid_model_.node_items_) { + for (auto &it : hybrid_model_.node_items_) { + auto &node_item = it.second; auto &node_ptr = node_item->node; if (node_item->node_type == NETOUTPUT) { continue; @@ -622,7 +720,6 @@ Status HybridModelBuilder::LoadTasks() { GELOGD("[%s] Start to build kernel task", node_ptr->GetName().c_str()); auto load_ret = node_item->node_executor->LoadTask(hybrid_model_, node_ptr, node_item->kernel_task); - if (load_ret != UNSUPPORTED && load_ret != SUCCESS) { GELOGE(load_ret, "[%s] Failed to load task", node_ptr->GetName().c_str()); return load_ret; @@ -634,6 +731,23 @@ Status HybridModelBuilder::LoadTasks() { return SUCCESS; } +Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr &ge_model) { + auto parent_node = sub_graph.GetParentNode(); + GE_CHECK_NOTNULL(parent_node); + auto op_type = parent_node->GetType(); + if (op_type == IF || op_type == CASE || op_type == WHILE) { + GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d", sub_graph.GetName().c_str(), + ge_model->GetModelTaskDefPtr()->task_size()); + subgraph_models_.emplace(sub_graph.GetName(), ge_model); + } else { + GELOGD("Set ge_model for subgraph: [%s], task_size = %d", sub_graph.GetName().c_str(), + ge_model->GetModelTaskDefPtr()->task_size()); + hybrid_model_.known_shape_sub_models_.emplace(sub_graph.GetParentNode(), ge_model); + } + + return SUCCESS; +} + Status HybridModelBuilder::IndexTaskDefs() { const auto &root_graph = ge_root_model_->GetRootGraph(); for (auto &it : ge_root_model_->GetSubgraphInstanceNameToModel()) { @@ -646,12 +760,9 @@ Status HybridModelBuilder::IndexTaskDefs() { continue; } - bool is_unknown_shape = false; - GE_CHK_GRAPH_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*sub_graph->GetParentNode(), is_unknown_shape), - "Failed to invoke GetNodeUnknownShapeStatus."); + bool is_unknown_shape = sub_graph->GetGraphUnknownFlag(); if (!is_unknown_shape) { - GELOGD("Set ge_model for subgraph: %s", sub_graph->GetName().c_str()); - hybrid_model_.known_shape_sub_graphs_.emplace(sub_graph->GetParentNode(), ge_model); + GE_CHK_STATUS_RET_NOLOG(LoadGeModel(*sub_graph, ge_model)); continue; } @@ -676,6 +787,8 @@ Status HybridModelBuilder::IndexTaskDefs() { op_index = task_def.kernel().context().op_index(); } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { op_index = task_def.kernel_ex().op_index(); + } else if (task_type == RT_MODEL_TASK_HCCL) { + op_index = task_def.kernel_hccl().op_index(); } else { GELOGD("Skip task type: %d", static_cast(task_type)); continue; @@ -790,12 +903,12 @@ Status HybridModelBuilder::GetPeerNodeAcrossSubGraphs(const NodePtr &data_node, for (uint32_t i = 0; i < static_cast(input_size); ++i) { uint32_t p_index = 0; if (!AttrUtils::GetInt(net_output_desc->GetInputDesc(i), ATTR_NAME_PARENT_NODE_INDEX, p_index)) { - GELOGW("SubGraph: %s input tensor %zu attr %s not found.", src_graph->GetName().c_str(), i, + GELOGW("SubGraph: %s input tensor %u attr %s not found.", src_graph->GetName().c_str(), i, ATTR_NAME_PARENT_NODE_INDEX.c_str()); continue; } - GELOGD("NetOutput's input[%zu], parent_node_index = %u", i, p_index); + GELOGD("NetOutput's input[%u], parent_node_index = %u", i, p_index); if (p_index == out_index) { auto in_anchor = src_net_output_node->GetInDataAnchor(i); GE_CHECK_NOTNULL(in_anchor); @@ -830,7 +943,7 @@ Status HybridModelBuilder::InitRuntimeParams() { ret = ge::AttrUtils::GetInt(first_model, ATTR_MODEL_VAR_SIZE, value); runtime_param_.var_size = ret ? (uint64_t)value : 0; runtime_param_.graph_id = ge_root_model_->GetRootGraph()->GetGraphID(); - GELOGI("InitRuntimeParams(), session_id:%u, var_size:%lu. graph_id = %u", runtime_param_.session_id, + GELOGI("InitRuntimeParams(), session_id:%lu, var_size:%lu. graph_id = %u", runtime_param_.session_id, runtime_param_.var_size, runtime_param_.graph_id); var_manager_ = VarManager::Instance(runtime_param_.session_id); @@ -838,15 +951,19 @@ Status HybridModelBuilder::InitRuntimeParams() { return SUCCESS; } -Status HybridModelBuilder::ParsePartitionedCall(NodeItem &node_item) { +Status HybridModelBuilder::IdentifyVariableOutputs(NodeItem &node_item) { GELOGD("Start to parse outputs of node: %s", node_item.NodeName().c_str()); auto subgraph = NodeUtils::GetSubgraph(*node_item.node, kSubgraphIndex); GE_CHECK_NOTNULL(subgraph); auto net_output_node = subgraph->FindFirstNodeMatchType(NETOUTPUT); - GE_CHECK_NOTNULL(net_output_node); + if (net_output_node == nullptr) { + GELOGD("[%s] Subgraph do not got net output", subgraph->GetName().c_str()); + return SUCCESS; + } auto net_output_desc = net_output_node->GetOpDesc(); GE_CHECK_NOTNULL(net_output_desc); + // constant/variable connected to net output for (const auto &in_data_anchor : net_output_node->GetAllInDataAnchors()) { auto src_node = GetPeerNode(in_data_anchor); GE_CHECK_NOTNULL(src_node); @@ -864,6 +981,8 @@ Status HybridModelBuilder::ParsePartitionedCall(NodeItem &node_item) { node_item.ref_outputs.emplace(parent_index, src_node); } + // Data nodes marked with REF_VAR_SRC_VAR_NAME + // Using variable tensor as data's output for (auto &node : subgraph->GetDirectNode()) { if (node->GetType() != DATA) { continue; @@ -912,6 +1031,11 @@ Status HybridModelBuilder::GetParentNodeOutputIndex(const OpDesc &op_desc, int i Status HybridModelBuilder::InitModelMem() { hybrid_model_.var_mem_base_ = var_manager_->GetVarMemoryBase(RT_MEMORY_HBM); auto total_var_size = hybrid_model_.TotalVarMemSize(); + if (total_var_size == 0 && !hybrid_model_.constant_op_nodes_.empty()) { + total_var_size = var_manager_->GetVarMemSize(RT_MEMORY_HBM) > 0 ? var_manager_->GetVarMemMaxSize() : 0; + GELOGD("Model var size = 0. but got uninitialized constant. set var size to %zu.", total_var_size); + } + if (total_var_size > 0 && hybrid_model_.var_mem_base_ == nullptr) { GE_CHK_STATUS_RET(var_manager_->MallocVarMemory(total_var_size), "Malloc Var Memory Fail."); hybrid_model_.var_mem_base_ = var_manager_->GetVarMemoryBase(RT_MEMORY_HBM); @@ -951,5 +1075,154 @@ Status HybridModelBuilder::CopyVarData() { GELOGI("CopyVarData success."); return SUCCESS; } + +Status HybridModelBuilder::LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem *parent_node_item) { + GELOGD("Start to load known shaped subgraph [%s]", graph.GetName().c_str()); + auto graph_item = std::unique_ptr(new (std::nothrow) GraphItem()); + GE_CHECK_NOTNULL(graph_item); + graph_item->is_dynamic_ = false; + auto subgraph_name = graph.GetName(); + auto wrapper_op_desc = MakeShared(subgraph_name + "_partitioned_call", PARTITIONEDCALL); + GE_CHECK_NOTNULL(wrapper_op_desc); + + for (auto &node : graph.GetDirectNode()) { + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + const auto &op_type = node->GetType(); + + if (op_type == DATA) { + int32_t data_index = 0; + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, data_index)) { + GELOGE(FAILED, "[%s] Failed to get attr [%s]", node->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); + return FAILED; + } + + (void)wrapper_op_desc->AddInputDesc(op_desc->GetInputDesc(0)); + graph_item->input_index_mapping_.emplace_back(data_index); + } else if (op_type == NETOUTPUT) { + int output_index = 0; + for (const auto &output_desc : op_desc->GetAllInputsDescPtr()) { + int32_t data_index = output_index++; + if (!AttrUtils::GetInt(output_desc, ATTR_NAME_PARENT_NODE_INDEX, data_index)) { + GELOGI("[%s] Failed to get attr [%s]", node->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); + } + + GE_CHK_GRAPH_STATUS_RET(wrapper_op_desc->AddOutputDesc(*output_desc), + "[%s] Failed to add output desc. output index = %d", graph.GetName().c_str(), + output_index); + + graph_item->output_index_mapping_.emplace_back(data_index); + } + } + } + + auto temp_graph = MakeShared("temp"); + GE_CHECK_NOTNULL(temp_graph); + auto wrapper_node = temp_graph->AddNode(wrapper_op_desc); + GeModelPtr ge_model = subgraph_models_[subgraph_name]; + GE_CHECK_NOTNULL(ge_model); + hybrid_model_.known_shape_sub_models_.emplace(wrapper_node, ge_model); + + NodeItem *node_item = nullptr; + GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(wrapper_node, &node_item)); + node_item->input_start = 0; + node_item->output_start = 0; + node_item->outputs.resize(node_item->num_outputs); + graph_item->node_items_.emplace_back(node_item); + graph_item->output_node_ = node_item; + graph_item->total_inputs_ = node_item->num_inputs; + graph_item->total_outputs_ = node_item->num_outputs; + + GELOGD("NodeItem create for known shape subgraph [%s], NodeItem = %s", graph.GetName().c_str(), + node_item->DebugString().c_str()); + + GELOGD("Done parse known shape subgraph successfully. graph = [%s]", graph.GetName().c_str()); + graph_item->SetName(graph.GetName()); + GELOGD("Done loading known shape subgraph: [%s]", graph_item->GetName().c_str()); + hybrid_model_.subgraph_items_.emplace(graph.GetName(), std::move(graph_item)); + return SUCCESS; +} + +Status HybridModelBuilder::LoadDynamicSubgraph(ComputeGraph &graph, bool is_root_graph) { + GELOGD("Start to load subgraph [%s]", graph.GetName().c_str()); + // for known partitioned call, load all nodes + auto graph_item = std::unique_ptr(new (std::nothrow) GraphItem()); + GE_CHECK_NOTNULL(graph_item); + + graph_item->is_dynamic_ = true; + graph_item->node_items_.reserve(graph.GetDirectNodesSize()); + int input_start = 0; + int output_start = 0; + std::vector data_nodes; + for (auto &node : graph.GetDirectNode()) { + GE_CHECK_NOTNULL(node); + GE_CHECK_NOTNULL(node->GetOpDesc()); + const auto &op_type = node->GetType(); + + NodeItem *node_item = nullptr; + GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(node, &node_item)); + GE_CHK_STATUS_RET_NOLOG(BuildNodeItem(node, *node_item)); + GE_CHK_STATUS_RET_NOLOG(UpdateAnchorStatus(node)); // needed by FE generate task + + node_item->input_start = input_start; + node_item->output_start = output_start; + input_start += node_item->num_inputs; + output_start += node_item->num_outputs; + + if (op_type == DATA_TYPE || op_type == AIPP_DATA_TYPE) { + data_nodes.emplace_back(node_item); + } else if (op_type == NETOUTPUT) { + graph_item->output_node_ = node_item; + GE_CHK_STATUS_RET_NOLOG(BuildOutputMapping(*graph_item, *node_item, is_root_graph)); + } + + graph_item->node_items_.emplace_back(node_item); + GELOGD("NodeItem created: %s", node_item->DebugString().c_str()); + } + + graph_item->total_inputs_ = input_start; + graph_item->total_outputs_ = output_start; + GE_CHK_STATUS_RET_NOLOG(BuildInputMapping(*graph_item, data_nodes, is_root_graph)); + if (is_root_graph) { + graph_item->SetName("Root-Graph"); + GELOGD("Done loading dynamic subgraph: [%s]", graph_item->GetName().c_str()); + hybrid_model_.root_graph_item_ = std::move(graph_item); + } else { + graph_item->SetName(graph.GetName()); + GELOGD("Done loading dynamic subgraph: [%s]", graph_item->GetName().c_str()); + hybrid_model_.subgraph_items_.emplace(graph.GetName(), std::move(graph_item)); + } + + return SUCCESS; +} + +Status HybridModelBuilder::BuildInputMapping(GraphItem &graph_item, vector &data_nodes, + bool is_root_graph) { + uint32_t data_op_index = 0; + for (auto &node_item : data_nodes) { + auto node = node_item->node; + int data_index = data_op_index; + if (is_root_graph) { + if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_INDEX, data_index)) { + GELOGI("ge_train: get new index %u, old %u", data_index, data_op_index); + } + data_op_index++; + } else { + if (!AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, data_index)) { + GELOGE(FAILED, "[%s] Failed to get attr [%s]", node->GetName().c_str(), ATTR_NAME_PARENT_NODE_INDEX.c_str()); + return FAILED; + } + } + + if (graph_item.input_nodes_.size() <= static_cast(data_index)) { + graph_item.input_nodes_.resize(data_index + 1); + } + + graph_item.input_nodes_[data_index] = node_item; + } + + return SUCCESS; +} } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/model/hybrid_model_builder.h b/src/ge/hybrid/model/hybrid_model_builder.h index 33cd1f03..1103aa1c 100644 --- a/src/ge/hybrid/model/hybrid_model_builder.h +++ b/src/ge/hybrid/model/hybrid_model_builder.h @@ -46,18 +46,20 @@ class HybridModelBuilder { static Status HandleDtString(const GeTensor &tensor, void *var_addr); static Status MergeInputNodes(ComputeGraph &compute_graph); static Status MergeNetOutputNode(ComputeGraph &compute_graph); - static Status MergeSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph); + static Status UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGraphPtr &merged_graph); + static Status UnfoldSubgraph(ComputeGraph &root_graph, ComputeGraph &parent_graph, ComputeGraph &sub_graph); static Status InitWeights(); - + static Status BuildInputMapping(GraphItem &graph_item, std::vector &data_nodes, bool is_root_graph); + static Status ResolveRefIo(NodeItem &node_item); + Status BuildOutputMapping(GraphItem &partitioned_call, const NodeItem &node_item, bool is_root_graph); Status ValidateParams(); Status LoadGraph(); + Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); Status LoadTasks(); - Status ParsePartitionedCall(NodeItem &node_item); - Status ParseNetOutput(const NodeItem &node_item); - Status BuildNoteItem(const NodePtr &node, NodeItem &node_item); + Status IdentifyVariableOutputs(NodeItem &node_item); + Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); Status ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies); - Status ResolveRootNodes(); Status IndexTaskDefs(); Status IndexSpecialNodes(); Status InitRuntimeParams(); @@ -65,19 +67,23 @@ class HybridModelBuilder { Status TransAllVarData(); Status CopyVarData(); Status VarNodeToTensor(const NodePtr &var_node, std::unique_ptr &tensor); + Status AssignUninitializedConstantOps(); Status InitConstantOps(); Status InitVariableTensors(); + Status LoadDynamicSubgraph(ComputeGraph &graph, bool is_root_graph); + Status LoadKnownShapedSubgraph(ComputeGraph &graph, NodeItem *parent_node_item); - const char *GetGraphName() const { return graph_name_.c_str(); } + const char *GetGraphName() const { return hybrid_model_.model_name_.c_str(); } const NodeItem *GetNodeItem(const NodePtr &node) const; NodeItem *MutableNodeItem(const NodePtr &node); GeRootModelPtr ge_root_model_; - std::string graph_name_; std::map> weights_; + std::map subgraph_models_; HybridModel &hybrid_model_; std::map>> node_ref_inputs_; + int node_index = 0; RuntimeParam &runtime_param_; VarManager *var_manager_ = nullptr; diff --git a/src/ge/hybrid/model/node_item.cc b/src/ge/hybrid/model/node_item.cc index b5d4fbda..bfc29c84 100644 --- a/src/ge/hybrid/model/node_item.cc +++ b/src/ge/hybrid/model/node_item.cc @@ -16,6 +16,10 @@ #include "node_item.h" #include +#include "common/debug/log.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/node_utils.h" +#include "hybrid/node_executor/node_executor.h" namespace ge { namespace hybrid { @@ -28,20 +32,61 @@ NodeItem::NodeItem(NodePtr node) : node(std::move(node)) { this->node_type = this->node->GetType(); } +Status NodeItem::Init() { + int32_t unknown_shape_type_val = 0; + (void)AttrUtils::GetInt(op_desc, ::ge::ATTR_NAME_UNKNOWN_SHAPE_TYPE, unknown_shape_type_val); + shape_inference_type = static_cast(unknown_shape_type_val); + + GE_CHK_STATUS_RET(NodeUtils::GetNodeUnknownShapeStatus(*node, is_dynamic), "[%s] Failed to get shape status.", + node->GetName().c_str()); + + if (is_dynamic) { + for (int i = 0; i < num_inputs; ++i) { + const auto &input_desc = op_desc->MutableInputDesc(i); + GE_CHECK_NOTNULL(input_desc); + if (input_desc->MutableShape().IsUnknownShape()) { + is_input_shape_static.push_back(false); + } else { + num_static_input_shapes++; + is_input_shape_static.push_back(true); + GELOGD("[%s] The shape of input[%d] is static. shape = [%s]", NodeName().c_str(), i, + input_desc->MutableShape().ToString().c_str()); + } + } + + for (int i = 0; i < num_outputs; ++i) { + const auto &output_desc = op_desc->MutableOutputDesc(i); + GE_CHECK_NOTNULL(output_desc); + if (output_desc->MutableShape().IsUnknownShape()) { + is_output_shape_static = false; + break; + } + } + } + + return SUCCESS; +} + +bool NodeItem::IsControlOp() const { + auto op_type = op_desc->GetType(); + return op_type == IF || op_type == CASE || op_type == WHILE || op_type == FOR; +} + std::string NodeItem::DebugString() const { std::stringstream ss; ss << "Node: "; ss << "id = " << node_id; - ss << ", name = " << node->GetName(); - ss << ", type = " << node->GetType(); + ss << ", name = [" << node->GetName(); + ss << "], type = " << node->GetType(); ss << ", is_dynamic = " << (is_dynamic ? "True" : "False"); + ss << ", is_output_static = " << (is_output_shape_static ? "True" : "False"); ss << ", unknown_shape_op_type = " << shape_inference_type; ss << ", input_start = " << input_start; ss << ", num_inputs = " << num_inputs; ss << ", output_start = " << output_start; ss << ", num_outputs = " << num_outputs; ss << ", dependent_nodes = ["; - for (const auto &dep_node : dependent_node_list) { + for (const auto &dep_node : dependents_for_shape_inference) { ss << dep_node->GetName() << ", "; } ss << "]"; @@ -55,5 +100,17 @@ std::string NodeItem::DebugString() const { return ss.str(); } + +void NodeItem::SetToDynamic() { + num_static_input_shapes = 0; + is_dynamic = true; + for (size_t i = 0; i < is_input_shape_static.size(); ++i) { + is_input_shape_static[i] = false; + } + if (kernel_task != nullptr && !kernel_task->IsSupportDynamicShape()) { + GELOGD("[%s] Dynamic shape is not supported, clear node task.", node_name.c_str()); + kernel_task = nullptr; + } +} } // namespace hybrid -} // namespace ge \ No newline at end of file +} // namespace ge diff --git a/src/ge/hybrid/model/node_item.h b/src/ge/hybrid/model/node_item.h index b12d100b..ff024b36 100644 --- a/src/ge/hybrid/model/node_item.h +++ b/src/ge/hybrid/model/node_item.h @@ -18,6 +18,7 @@ #define GE_HYBRID_MODEL_NODE_ITEM_H_ #include +#include "external/ge/ge_api_error_codes.h" #include "graph/node.h" #include "graph/op_desc.h" #include "framework/common/types.h" @@ -33,10 +34,18 @@ struct NodeItem { explicit NodeItem(NodePtr node); ~NodeItem() = default; + Status Init(); + const std::string &NodeName() const { return node_name; } const std::string &NodeType() const { return node_type; } + bool IsControlOp() const; + + bool NeedInfershape() const; + + void SetToDynamic(); + std::string DebugString() const; NodePtr node; @@ -52,17 +61,22 @@ struct NodeItem { UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE; std::string node_name; std::string node_type; - std::vector dependent_node_list; + std::vector dependents_for_shape_inference; + std::vector dependents_for_execution; std::set to_const_output_id_list; - // src_output_id, dst_anchor_id, dst_node vector inputs; + // src_output_id, dst_anchor_id, dst_node vector>> outputs; std::shared_ptr kernel_task; const NodeExecutor *node_executor = nullptr; - std::map const_input_shapes; std::map ref_outputs; + std::map reuse_inputs; + + std::vector is_input_shape_static; + bool is_output_shape_static = true; + int num_static_input_shapes = 0; }; } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc index 3f198bba..71280649 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc +++ b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.cc @@ -16,10 +16,8 @@ #include "aicore_node_executor.h" #include "cce/taskdown_common.hpp" -#include "graph/debug/ge_attr_define.h" -#include "hybrid/model/hybrid_model.h" +#include "hybrid/executor/hybrid_execution_context.h" #include "init/gelib.h" -#include "framework/common/debug/log.h" namespace ge { namespace hybrid { @@ -27,16 +25,47 @@ REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICORE, AiCore AiCoreNodeTask::AiCoreNodeTask(std::vector> &&tasks) : tasks_(std::move(tasks)) {} +Status AiCoreNodeExecutor::Initialize() { + auto ge_lib = GELib::GetInstance(); + GE_CHECK_NOTNULL(ge_lib); + if (!ge_lib->InitFlag()) { + GELOGE(GE_CLI_GE_NOT_INITIALIZED, "Ge_lib is uninitialized, failed."); + return GE_CLI_GE_NOT_INITIALIZED; + } + + auto &kernel_manager = ge_lib->OpsKernelManagerObj(); + auto aic_ops_store = kernel_manager.GetOpsKernelInfoStore("AIcoreEngine"); + GE_CHECK_NOTNULL(aic_ops_store); + + compiler_.reset(new (std::nothrow) AiCoreTaskCompiler(aic_ops_store)); + GE_CHECK_NOTNULL(compiler_); + return SUCCESS; +} + Status AiCoreNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { GE_CHECK_NOTNULL(node); - GELOGI("AiCoreNodeExecutor[%s] LoadTask Start.", node->GetName().c_str()); + GELOGI("AiCoreNodeExecutor(%s) LoadTask Start.", node->GetName().c_str()); auto *task_defs = model.GetTaskDefs(node); - Status ret = SUCCESS; - GE_IF_BOOL_EXEC(task_defs != nullptr && !task_defs->empty(), ret = CreateTask(model, *task_defs, node, task)); + if (task_defs == nullptr || task_defs->empty()) { + bool dynamic_flag = false; + if (!AttrUtils::GetBool(node->GetOpDesc(), "support_dynamicshape", dynamic_flag) || !dynamic_flag) { + GELOGD("Skip create task of node (%s) as 'support_dynamicshape' is false and cann't get task_defs.", + node->GetName().c_str()); + return SUCCESS; + } else { + GELOGE(FAILED, "Task_defs is empty for node (%s) which 'support_dynamicshape' is true, failed.", + node->GetName().c_str()); + return FAILED; + } + } - GELOGI("AiCoreNodeExecutor[%s] LoadTask End, ret[%u].", node->GetName().c_str(), ret); - return ret; + AiCoreTaskBuilder builder(node->GetOpDesc(), *task_defs); + std::unique_ptr node_task; + GE_CHK_STATUS_RET(builder.BuildTask(node_task, true), "[%s] Failed to build op tasks.", node->GetName().c_str()); + task = std::move(node_task); + GELOGI("AiCoreNodeExecutor(%s) LoadTask End.", node->GetName().c_str()); + return SUCCESS; } Status AiCoreNodeExecutor::GenNodeKey(const NodePtr &node, std::string &node_key) { @@ -45,18 +74,21 @@ Status AiCoreNodeExecutor::GenNodeKey(const NodePtr &node, std::string &node_key GE_CHECK_NOTNULL(op_desc); // make sure unique, (op_id + input_shape) is unique - node_key = std::to_string(op_desc->GetId()) + "/"; + node_key = std::to_string(op_desc->GetId()) + "-"; node_key.append(std::to_string(op_desc->GetInputsSize())); - auto input_descs = op_desc->GetAllInputsDesc(); - for (auto input_desc : input_descs) { - node_key.push_back('/'); - std::vector dims = input_desc.GetShape().GetDims(); - GE_IF_BOOL_EXEC(dims.size() == 0, continue); // scalar - for (std::size_t i = 0; i < dims.size() - 1; i++) { - node_key.append(std::to_string(dims[i])); - node_key.push_back(','); + auto input_descs = op_desc->GetAllInputsDescPtr(); + for (auto &input_desc : input_descs) { + node_key.push_back('-'); + auto &shape = input_desc->MutableShape(); + auto num_dims = shape.GetDimNum(); + if (num_dims == 0) { + continue; + } // scalar + for (std::size_t i = 0; i < num_dims - 1; i++) { + node_key.append(std::to_string(shape.GetDim(i))); + node_key.push_back('_'); } - node_key.append(std::to_string(dims[dims.size() - 1])); + node_key.append(std::to_string(shape.GetDim(num_dims - 1))); } return SUCCESS; } @@ -65,8 +97,10 @@ bool AiCoreNodeTaskRegistry::AddTask(const std::string &node_key, const std::sha GE_CHECK_NOTNULL(task); std::lock_guard lock(mutex_); auto iter = reg_node_tasks_.find(node_key); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(iter != reg_node_tasks_.end(), return false, - "AiCoreNodeTaskRegistry[%s] AddTask failed, key already exist.", node_key.c_str()); + if (iter != reg_node_tasks_.end()) { + GELOGE(FAILED, "AiCoreNodeTaskRegistry(%s) AddTask failed, key already exist.", node_key.c_str()); + return false; + } auto ret = reg_node_tasks_.emplace(node_key, task); return ret.second; } @@ -80,231 +114,89 @@ std::shared_ptr AiCoreNodeTaskRegistry::GetTask(const std::string &nod Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { GE_CHECK_NOTNULL(node); - GELOGI("AiCoreNodeExecutor[%s] CompileTask Start.", node->GetName().c_str()); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + GELOGI("AiCoreNodeExecutor(%s) CompileTask Start.", node->GetName().c_str()); AiCoreNodeTaskRegistry ®istry = AiCoreNodeTaskRegistry::GetInstance(); - std::string node_key; - GE_CHK_STATUS_RET(GenNodeKey(node, node_key), "GenNodeKey failed. op name = %s", node->GetName().c_str()); + std::string shape_key; + GE_CHK_STATUS_RET(GenNodeKey(node, shape_key), "GenNodeKey failed, op name = %s.", node->GetName().c_str()); + auto node_key = std::to_string(model.GetModelId()) + "/" + shape_key; GELOGD("NodeKey for %s = %s", node->GetName().c_str(), node_key.c_str()); task = registry.GetTask(node_key); - GE_CHK_TRUE_EXEC_INFO(task != nullptr, return SUCCESS, "AiCoreNodeExecutor[%s] CompileTask Skip.", - node->GetName().c_str()); + if (task != nullptr) { + GELOGI("AiCoreNodeExecutor(%s) CompileTask Skip.", node->GetName().c_str()); + return SUCCESS; + } std::vector task_defs; - GE_CHK_STATUS_RET_NOLOG(compiler_->CompileOp(node, task_defs)); + auto ori_node_name = node->GetName(); + op_desc->SetName(ori_node_name + "_" + shape_key); + GE_CHK_STATUS_RET(compiler_->CompileOp(node, task_defs), "Compile op(%s) failed.", ori_node_name.c_str()); + op_desc->SetName(ori_node_name); GELOGD("successfully generated task_defs: %s", node->GetName().c_str()); - GE_CHK_STATUS_RET_NOLOG(CreateTask(model, task_defs, node, task)); + AiCoreTaskBuilder builder(node->GetOpDesc(), task_defs); + std::unique_ptr node_task; + GE_CHK_STATUS_RET(builder.BuildTask(node_task, false), "[%s] Failed to build op tasks.", node->GetName().c_str()); + task = std::move(node_task); GELOGD("successfully created node task: %s", node->GetName().c_str()); - GE_CHK_BOOL_EXEC(registry.AddTask(node_key, task), return INTERNAL_ERROR, "Add NodeTask failed. op name = %s", - node->GetName().c_str()); // should not happen. - GELOGI("AiCoreNodeExecutor[%s] CompileTask End.", node->GetName().c_str()); - return SUCCESS; -} - -Status AiCoreNodeExecutor::BuildAiCoreTask(const domi::KernelDef &kernel_def, const OpDescPtr &op_desc, - AiCoreOpTask **task) { - GE_CHECK_NOTNULL(op_desc); - GE_CHECK_NOTNULL(task); - - const auto &context = kernel_def.context(); - auto kernel_type = static_cast(context.kernel_type()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(kernel_type != cce::ccKernelType::TE, return UNSUPPORTED, - "Only TBE kernel is supported, but [%s] got %u", op_desc->GetName().c_str(), - context.kernel_type()); - - auto *aicore_task = new (std::nothrow) AiCoreOpTask(); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(aicore_task == nullptr, return MEMALLOC_FAILED, "Create AiCore op task failed."); - - auto builder = AiCoreTaskBuilder(op_desc, kernel_def); - auto ret = builder.BuildTask(*aicore_task); - GE_IF_BOOL_EXEC(ret != SUCCESS, delete aicore_task; aicore_task = nullptr; return ret); - - *task = aicore_task; - return SUCCESS; -} - -Status AiCoreNodeExecutor::CreateTask(const HybridModel &model, const std::vector &task_defs, - const NodePtr &node, std::shared_ptr &task) { - GE_CHECK_NOTNULL(node); - GELOGD("To CreateTask, task def size = %zu", task_defs.size()); - std::vector> aicore_op_tasks; - aicore_op_tasks.reserve(task_defs.size()); - for (size_t i = 0; i < task_defs.size(); ++i) { - const domi::TaskDef &task_def = task_defs[i]; - GELOGD("Op[%s] Task[%d], type = %u, DebugString = %s", node->GetName().c_str(), i, task_def.type(), - task_def.DebugString().c_str()); - auto task_type = static_cast(task_def.type()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(task_type == RT_MODEL_TASK_KERNEL_EX, return UNSUPPORTED, - "BuildKernelExTask is not supported"); - GE_CHK_BOOL_TRUE_EXEC_INFO(task_type != RT_MODEL_TASK_KERNEL, continue, "Skip task type %d", - static_cast(task_type)); - - const domi::KernelDef &kernel_def = task_def.kernel(); - AiCoreOpTask *aicore_op_task = nullptr; - // not use hybrid model now - GE_CHK_STATUS_RET_NOLOG(BuildAiCoreTask(kernel_def, node->GetOpDesc(), &aicore_op_task)); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(aicore_op_task == nullptr, return FAILED, "BuildAiCoreTask[%s] failed.", - node->GetName().c_str()); - - aicore_op_tasks.emplace_back(std::unique_ptr(aicore_op_task)); + if (!registry.AddTask(node_key, task)) { + GELOGE(INTERNAL_ERROR, "Add NodeTask failed, op name = %s.", node->GetName().c_str()); + return INTERNAL_ERROR; } - if (!aicore_op_tasks.empty()) { - auto aic_task = std::shared_ptr(new AiCoreNodeTask(std::move(aicore_op_tasks))); - task = std::move(aic_task); - GELOGD("Generate AiCoreOpTask success"); - return SUCCESS; - } - - GELOGE(INTERNAL_ERROR, "Failed to build task. node = %s", node->GetName().c_str()); - return INTERNAL_ERROR; -} - -Status AiCoreNodeExecutor::Initialize() { - std::shared_ptr ge_lib = GELib::GetInstance(); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((ge_lib == nullptr) || !ge_lib->InitFlag(), return GE_CLI_GE_NOT_INITIALIZED, - "Get ge_lib failed."); - - auto &kernel_manager = ge_lib->OpsKernelManagerObj(); - auto aic_ops_store = kernel_manager.GetOpsKernelInfoStore("AIcoreEngine"); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(aic_ops_store == nullptr, return GE_CLI_GE_NOT_INITIALIZED, - "Failed to get kernel info store for AIcoreEngine."); - - compiler_.reset(new (std::nothrow) AiCoreTaskCompiler(aic_ops_store)); - GE_CHECK_NOTNULL(compiler_); + GELOGI("AiCoreNodeExecutor(%s) CompileTask End.", node->GetName().c_str()); return SUCCESS; } -Status AiCoreNodeExecutor::Finalize() { return NodeExecutor::Finalize(); } - Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { auto op_desc = context.GetNodeItem().op_desc; GE_CHECK_NOTNULL(op_desc); - GELOGI("AiCoreNodeTask[%s] ExecuteAsync Start.", op_desc->GetName().c_str()); - for (size_t i = 0; i < tasks_.size(); i++) { - GE_CHECK_NOTNULL(tasks_[i]); - GE_CHK_STATUS_RET_NOLOG(tasks_[i]->LaunchKernel(context.GetStream())); + GELOGI("[%s] ExecuteAsync Start.", op_desc->GetName().c_str()); + for (auto &task : tasks_) { + GE_CHK_STATUS_RET_NOLOG(task->LaunchKernel(context.GetStream())); } if (done_callback != nullptr) { GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback)); } - GELOGI("AiCoreNodeTask[%s] ExecuteAsync End.", op_desc->GetName().c_str()); + GELOGD("[%s] ExecuteAsync End.", op_desc->GetName().c_str()); return SUCCESS; } -Status AiCoreNodeTask::UpdateAtomicArgs(TaskContext &context, std::unique_ptr &task) { - GE_CHECK_NOTNULL(task); +Status AiCoreNodeTask::UpdateArgs(TaskContext &context) { auto op_desc = context.GetNodeItem().op_desc; GE_CHECK_NOTNULL(op_desc); - - // refresh atomic output addr - std::vector atomic_output_indexes; // here atomic just clean output - (void)ge::AttrUtils::GetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indexes); - GE_RETURN_WITH_LOG_IF_TRUE(atomic_output_indexes.size() > static_cast(context.NumOutputs()), - "AtomicAddrClean op's arg_size error."); - auto *arg_off = reinterpret_cast(task->args_.get()) + task->offset_; - auto *arg_base = reinterpret_cast(arg_off); - int index = 0; - for (size_t i = 0; i < atomic_output_indexes.size(); ++i) { - const auto output = context.GetOutput(atomic_output_indexes[i]); - GE_CHECK_NOTNULL(output); - arg_base[index++] = reinterpret_cast(output->GetData()); + GELOGI("[%s] AiCoreNodeTask UpdateArgs Start.", op_desc->GetName().c_str()); + for (auto &task : tasks_) { + GE_CHK_STATUS_RET_NOLOG(task->UpdateArgs(context)); } - - // refresh atomic workspace addr - auto workspace_sizes = op_desc->GetWorkspaceBytes(); - uint64_t ops_workspace_num = static_cast(workspace_sizes.size()); - uint64_t workspace_num = static_cast(context.NumWorkspaces()); - GE_CHK_BOOL_EXEC(ops_workspace_num == workspace_num, return PARAM_INVALID, - "The workspace_num in op_desc %lu is not equal to it %lu in context.", ops_workspace_num, - workspace_num); - GE_IF_BOOL_EXEC(workspace_num == 0, return SUCCESS); - - map> workspace_info; - workspace_info = op_desc->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, workspace_info); - if (!workspace_info.empty()) { - bool is_fusion_node = false; - (void)ge::AttrUtils::GetBool(op_desc, ATOMIC_ATTR_IS_FUSION_NODE, is_fusion_node); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(is_fusion_node, return PARAM_INVALID, - "Atomic desc[%s] shouldn't be fusion_node in AiCoreNodeTask", - op_desc->GetName().c_str()); - - for (auto iter = workspace_info.begin(); iter != workspace_info.end(); ++iter) { - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc->GetName() != iter->first, return PARAM_INVALID, - "The node name %s and the node name %s in workspace info are inconsistent.", - op_desc->GetName().c_str(), iter->first.c_str()); - GE_IF_BOOL_EXEC(iter->second.empty(), continue); - - for (auto &info_iter : iter->second) { - auto workspace_index = static_cast(info_iter.first); - - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(workspace_index >= workspace_num, return PARAM_INVALID, - "The workspace index %lu is more than the size %lu of workspace vector.", - workspace_index, workspace_num); - - const auto workspace = context.MutableWorkspace(workspace_index); - arg_base[index++] = reinterpret_cast(workspace); - } - } - } - + GELOGI("[%s] AiCoreNodeTask UpdateArgs End.", op_desc->GetName().c_str()); return SUCCESS; } -Status AiCoreNodeTask::UpdateAllArgs(TaskContext &context, std::unique_ptr &task) { - GE_CHECK_NOTNULL(task); - auto *arg_off = reinterpret_cast(task->args_.get()) + task->offset_; - auto *arg_base = reinterpret_cast(arg_off); - int index = 0; - for (int i = 0; i < context.NumInputs(); ++i) { - const auto input = context.GetInput(i); - GE_CHECK_NOTNULL(input); - arg_base[index++] = reinterpret_cast(input->GetData()); - } - - for (int i = 0; i < context.NumOutputs(); ++i) { - const auto output = context.GetOutput(i); - GE_CHECK_NOTNULL(output); - arg_base[index++] = reinterpret_cast(output->GetData()); - } - - auto op_desc = context.GetNodeItem().op_desc; - GE_CHECK_NOTNULL(op_desc); - auto workspace_sizes = op_desc->GetWorkspaceBytes(); - int ops_workspace_num = static_cast(workspace_sizes.size()); - int workspace_num = static_cast(context.NumWorkspaces()); - GE_CHK_BOOL_EXEC(ops_workspace_num == workspace_num, return PARAM_INVALID, - "The workspace_num in op_desc %lu is not equal to it %lu in context.", ops_workspace_num, - workspace_num); - for (int i = 0; i < workspace_num; ++i) { - const auto workspace = context.MutableWorkspace(i); - arg_base[index++] = reinterpret_cast(workspace); +Status AiCoreNodeTask::UpdateTilingData(TaskContext &context) { + GELOGD("[%s] PrepareWithShape started", context.GetNodeName()); + for (auto &task : tasks_) { + GE_CHK_STATUS_RET_NOLOG(task->PrepareWithShape(context)); } - + GELOGD("[%s] Done PrepareWithShape successfully.", context.GetNodeName()); return SUCCESS; } -Status AiCoreNodeTask::UpdateArgs(TaskContext &context) { - auto op_desc = context.GetNodeItem().op_desc; - GE_CHECK_NOTNULL(op_desc); - GELOGI("AiCoreNodeTask[%s] UpdateArgs Start.", op_desc->GetName().c_str()); - GE_IF_BOOL_EXEC(tasks_.size() == 1, return UpdateAllArgs(context, tasks_[0])); - - std::vector atomic_output_indexes; // here atomic just clean output - (void)ge::AttrUtils::GetListInt(op_desc, ge::ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indexes); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(atomic_output_indexes.empty(), return FAILED, "ATOMIC_ATTR_OUTPUT_INDEX is empty."); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(tasks_.size() != 2, return FAILED, "AtomicAddrClean op task num != 2."); - - GE_CHK_STATUS_RET_NOLOG(UpdateAtomicArgs(context, tasks_[0])); - GE_CHK_STATUS_RET_NOLOG(UpdateAllArgs(context, tasks_[1])); +bool AiCoreNodeTask::IsSupportDynamicShape() { + for (size_t i = 0; i < tasks_.size(); ++i) { + if (!tasks_[i]->IsDynamicShapeSupported()) { + GELOGD("[%s] Task does not support dynamic shape.", tasks_[i]->GetName().c_str()); + return false; + } + } - GELOGI("AiCoreNodeTask[%s] UpdateArgs End.", op_desc->GetName().c_str()); - return SUCCESS; + return true; } } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/aicore/aicore_node_executor.h b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.h index a8b24e68..506202fa 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_node_executor.h +++ b/src/ge/hybrid/node_executor/aicore/aicore_node_executor.h @@ -25,7 +25,6 @@ namespace ge { namespace hybrid { - class AiCoreNodeTaskRegistry { public: ~AiCoreNodeTaskRegistry() = default; @@ -47,32 +46,27 @@ class AiCoreNodeTaskRegistry { class AiCoreNodeTask : public NodeTask { public: explicit AiCoreNodeTask(std::vector> &&tasks); - ~AiCoreNodeTask() = default; - Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + ~AiCoreNodeTask() override = default; + bool IsSupportDynamicShape() override; + Status UpdateTilingData(TaskContext &context) override; + Status UpdateArgs(TaskContext &context) override; + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; private: - static Status UpdateAllArgs(TaskContext &context, std::unique_ptr &task); - static Status UpdateAtomicArgs(TaskContext &context, std::unique_ptr &task); std::vector> tasks_; }; class AiCoreNodeExecutor : public NodeExecutor { public: Status Initialize() override; - Status Finalize() override; - Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const override; Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const override; private: - static Status CreateTask(const HybridModel &model, const std::vector &task_defs, const NodePtr &node, - std::shared_ptr &task); - static Status BuildAiCoreTask(const domi::KernelDef &kernel_def, const OpDescPtr &op_desc, AiCoreOpTask **task); static Status GenNodeKey(const NodePtr &node, std::string &node_key); std::unique_ptr compiler_; }; - } // namespace hybrid } // namespace ge #endif // GE_HYBRID_KERNEL_AICORE_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/aicore/aicore_op_task.cc b/src/ge/hybrid/node_executor/aicore/aicore_op_task.cc index 27256e9a..9ec0cc22 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_op_task.cc +++ b/src/ge/hybrid/node_executor/aicore/aicore_op_task.cc @@ -14,19 +14,313 @@ * limitations under the License. */ -#include "aicore_op_task.h" +#include "hybrid/node_executor/aicore/aicore_op_task.h" +#include "cce/taskdown_common.hpp" #include "framework/common/debug/log.h" +#include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/node_executor/aicore/aicore_task_builder.h" + +using optiling::OpRunInfo; namespace ge { namespace hybrid { +namespace { +constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; +constexpr char const *kAttrOpParamSize = "op_para_size"; +constexpr char const *kAttrAtomicOpParamSize = "atomic_op_para_size"; +} // namespace -Status AiCoreOpTask::LaunchKernel(rtStream_t stream) { - GELOGI("AiCoreOpTask LaunchKernel Start (task = %s, block_dim = %u).", stub_name_.c_str(), block_dim_); +Status AiCoreOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) { + GE_CHK_STATUS_RET_NOLOG(InitWithTaskDef(op_desc, task_def)); + GE_CHK_STATUS_RET_NOLOG(InitTilingInfo(op_desc)); + return SUCCESS; +} + +Status AiCoreOpTask::InitWithTaskDef(const OpDesc &op_desc, const domi::TaskDef &task_def) { + GE_CHK_STATUS_RET(ValidateTaskDef(task_def), "[%s] Failed to validate task def: [%s]", op_desc.GetName().c_str(), + task_def.DebugString().c_str()); + + const domi::KernelDef &kernel_def = task_def.kernel(); + const domi::KernelContext &context = kernel_def.context(); + stub_name_ = kernel_def.stub_func(); + GE_CHK_RT_RET(rtGetFunctionByName(stub_name_.c_str(), &stub_func_)); + args_size_ = kernel_def.args_size(); + block_dim_ = kernel_def.block_dim(); + + // malloc args memory + args_.reset(new (std::nothrow) uint8_t[args_size_]); + GE_CHECK_NOTNULL(args_); + errno_t err = memcpy_s(args_.get(), args_size_, kernel_def.args().data(), args_size_); + if (err != EOK) { + GELOGE(INTERNAL_ERROR, "AiCoreTask memcpy args failed."); + return INTERNAL_ERROR; + } + + if (context.args_offset().size() < sizeof(uint16_t)) { + GELOGE(INTERNAL_ERROR, "Invalid args_offset, size = %zu.", context.args_offset().size()); + return INTERNAL_ERROR; + } + + const auto *args_offset_buffer = reinterpret_cast(context.args_offset().data()); + uint32_t offset = *args_offset_buffer; + if (offset > args_size_) { + GELOGE(INTERNAL_ERROR, "[%s] Arg offset out of range. offset = %u, arg size = %u", GetName().c_str(), offset, + args_size_); + return INTERNAL_ERROR; + } + + arg_base_ = reinterpret_cast(args_.get() + offset); + max_arg_count_ = (args_size_ - offset) / sizeof(void *); + GELOGD("[%s] Done setting kernel args successfully. stub_func = %s, block_dim = %d, arg base = %p, arg size = %u", + op_desc.GetName().c_str(), stub_name_.c_str(), block_dim_, arg_base_, args_size_); + + return SUCCESS; +} + +Status AiCoreOpTask::ValidateTaskDef(const domi::TaskDef &task_def) { + auto task_type = static_cast(task_def.type()); + if (task_type != RT_MODEL_TASK_KERNEL) { + GELOGE(INTERNAL_ERROR, "Invalid task type (%d) in AiCore CreateTask.", static_cast(task_type)); + return INTERNAL_ERROR; + } + + const domi::KernelDef &kernel_def = task_def.kernel(); + const domi::KernelContext &context = kernel_def.context(); + auto kernel_type = static_cast(context.kernel_type()); + if (kernel_type != cce::ccKernelType::TE) { + GELOGE(INTERNAL_ERROR, "Invalid kernel type(%d) in AiCore TaskDef.", static_cast(kernel_type)); + return INTERNAL_ERROR; + } + + return SUCCESS; +} + +Status AiCoreOpTask::PrepareWithShape(TaskContext &context) { + if (tiling_buffer_ != nullptr) { + return UpdateTilingInfo(context); + } + + return SUCCESS; +} + +Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) { + auto node = context.GetNodeItem().node; + GE_CHECK_NOTNULL(node); + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + + GELOGD("[%s] Start to update tiling info for task: [%s]", node->GetName().c_str(), stub_name_.c_str()); + OpRunInfo tiling_info; + tiling_info.block_dim = -1; // codex: Using uninitialized value + + auto execution_context = context.GetExecutionContext(); + RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] Start"); + GE_CHK_STATUS_RET(CalcTilingInfo(node, tiling_info)); + RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] End"); + + // update op args by tiling info + block_dim_ = static_cast(tiling_info.block_dim); + op_desc->SetWorkspaceBytes(tiling_info.workspaces); + + tiling_data_ = tiling_info.tiling_data.str(); + if (tiling_data_.empty()) { + GELOGE(INTERNAL_ERROR, "[%s] Tiling data is empty.", stub_name_.c_str()); + return INTERNAL_ERROR; + } + + if (tiling_data_.size() > tiling_buffer_->GetSize()) { + GELOGE(INTERNAL_ERROR, "[%s] Tiling data size now (%zu) shouldn't larger than we alloc before (%zu).", + stub_name_.c_str(), tiling_data_.size(), tiling_buffer_->GetSize()); + return INTERNAL_ERROR; + } + + RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CopyTilingInfo] Start"); + GE_CHK_RT_RET(rtMemcpy(tiling_buffer_->GetData(), tiling_buffer_->GetSize(), tiling_data_.c_str(), + tiling_data_.size(), RT_MEMCPY_HOST_TO_DEVICE)); + RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CopyTilingInfo] End"); + + GELOGD("[%s] Done updating tiling info for task: [%s]", node->GetName().c_str(), stub_name_.c_str()); + return SUCCESS; +} + +Status AiCoreOpTask::CalcTilingInfo(const NodePtr &node, OpRunInfo &tiling_info) { + GELOGD("[%s] Start to invoke OpParaCalculate.", node->GetName().c_str()); + GE_CHK_STATUS_RET(OpParaCalculate(*node, tiling_info), "Failed calc tiling data of node %s.", + node->GetName().c_str()); + GELOGD("[%s] Done invoking OpParaCalculate successfully.", node->GetName().c_str()); + return SUCCESS; +} + +Status AiCoreOpTask::UpdateArgs(TaskContext &task_context) { + size_t expected_arg_count = task_context.NumInputs() + task_context.NumOutputs() + task_context.NumWorkspaces(); + if (tiling_buffer_ != nullptr) { + ++expected_arg_count; + } + if (expected_arg_count > max_arg_count_) { + GELOGE(INTERNAL_ERROR, "[%s] Invalid arg memory, max arg count = %u, but expect = %zu", GetName().c_str(), + max_arg_count_, expected_arg_count); + return INTERNAL_ERROR; + } + + int index = 0; + for (int i = 0; i < task_context.NumInputs(); ++i) { + const auto input = task_context.GetInput(i); + GE_CHECK_NOTNULL(input); + arg_base_[index++] = reinterpret_cast(input->GetData()); + } + for (int i = 0; i < task_context.NumOutputs(); ++i) { + const auto output = task_context.GetOutput(i); + GE_CHECK_NOTNULL(output); + arg_base_[index++] = reinterpret_cast(output->GetData()); + } + + int workspace_num = static_cast(task_context.NumWorkspaces()); + for (int i = 0; i < workspace_num; ++i) { + const auto workspace = task_context.MutableWorkspace(i); + GE_CHECK_NOTNULL(workspace); + arg_base_[index++] = reinterpret_cast(workspace); + } + + if (tiling_buffer_ != nullptr) { + arg_base_[index++] = reinterpret_cast(tiling_buffer_->GetData()); + } + + if (task_context.IsTraceEnabled()) { + for (int i = 0; i < index; ++i) { + GELOGD("[%s] Arg[%d] = %lu", stub_name_.c_str(), i, arg_base_[i]); + } + } + + return SUCCESS; +} + +Status AiCoreOpTask::LaunchKernel(rtStream_t stream) { + GELOGD("AiCoreOpTask LaunchKernel Start (task = %s, block_dim = %u).", stub_name_.c_str(), block_dim_); GE_CHK_RT_RET(rtKernelLaunch(stub_func_, block_dim_, args_.get(), args_size_, nullptr, stream)); - GELOGI("AiCoreOpTask LaunchKernel End (task = %s, block_dim = %u).", stub_name_.c_str(), block_dim_); + GELOGD("AiCoreOpTask LaunchKernel End (task = %s, block_dim = %u).", stub_name_.c_str(), block_dim_); return SUCCESS; } +Status AiCoreOpTask::InitTilingInfo(const OpDesc &op_desc) { + bool dynamic_supported = false; + (void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, dynamic_supported); + if (!dynamic_supported) { + GELOGD("[%s] Dynamic shape is not supported.", op_desc.GetName().c_str()); + return SUCCESS; + } + + GELOGD("Start alloc tiling data of node %s.", op_desc.GetName().c_str()); + int64_t max_size = -1; + (void)AttrUtils::GetInt(op_desc, GetKeyForOpParamSize(), max_size); + GELOGD("Got op param size by key: %s, ret = %ld", GetKeyForOpParamSize().c_str(), max_size); + if (max_size <= 0) { + GELOGE(PARAM_INVALID, "[%s] Invalid op_param_size: %ld.", op_desc.GetName().c_str(), max_size); + return PARAM_INVALID; + } + + auto allocator = NpuMemoryAllocator::GetAllocator(); + GE_CHECK_NOTNULL(allocator); + tiling_buffer_ = TensorBuffer::Create(allocator, static_cast(max_size)); + GE_CHECK_NOTNULL(tiling_buffer_); + + GELOGD("[%s] Done allocating tiling buffer, size=%ld.", op_desc.GetName().c_str(), max_size); + return SUCCESS; +} + +bool AiCoreOpTask::IsDynamicShapeSupported() { return tiling_buffer_ != nullptr; } + +const std::string &AiCoreOpTask::GetName() const { return stub_name_; } + +std::string AiCoreOpTask::GetKeyForOpParamSize() const { return kAttrOpParamSize; } + +Status AtomicAddrCleanOpTask::Init(const OpDesc &op_desc, const domi::TaskDef &task_def) { + GE_CHK_STATUS_RET_NOLOG(AiCoreOpTask::Init(op_desc, task_def)); + return InitAtomicAddrCleanIndices(op_desc); +} + +Status AtomicAddrCleanOpTask::InitAtomicAddrCleanIndices(const OpDesc &op_desc) { + GELOGD("[%s] Start to setup AtomicAddrClean task.", op_desc.GetName().c_str()); + std::vector atomic_output_indices; + (void)ge::AttrUtils::GetListInt(op_desc, ATOMIC_ATTR_OUTPUT_INDEX, atomic_output_indices); + map> workspace_info; // op_name, ws_index, ws_offset + workspace_info = op_desc.TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, workspace_info); + if (atomic_output_indices.empty() && workspace_info.empty()) { + GELOGE(INTERNAL_ERROR, "[%s] Neither ATOMIC_ATTR_OUTPUT_INDEX nor EXT_ATTR_ATOMIC_WORKSPACE_INFO is empty.", + op_desc.GetName().c_str()); + return INTERNAL_ERROR; + } + + for (auto output_index : atomic_output_indices) { + GELOGD("[%s] Adding output index [%ld]", op_desc.GetName().c_str(), output_index); + GE_CHECK_GE(output_index, 0); + GE_CHECK_LE(output_index, INT32_MAX); + atomic_output_indices_.emplace_back(static_cast(output_index)); + } + + for (auto &iter : workspace_info) { + for (auto &info_iter : iter.second) { + auto workspace_index = info_iter.first; + GELOGD("[%s] Adding workspace index [%ld]", op_desc.GetName().c_str(), workspace_index); + GE_CHECK_GE(workspace_index, 0); + GE_CHECK_LE(workspace_index, INT32_MAX); + atomic_workspace_indices_.emplace_back(static_cast(workspace_index)); + } + } + + size_t arg_count = atomic_workspace_indices_.size() + atomic_output_indices_.size(); + if (tiling_buffer_ != nullptr) { + arg_count += 1; + } + + if (arg_count > max_arg_count_) { + GELOGE(INTERNAL_ERROR, "[%s] Invalid arg memory, max arg count = %u, but expect = %zu", GetName().c_str(), + max_arg_count_, arg_count); + return INTERNAL_ERROR; + } + + return SUCCESS; +} + +std::string AtomicAddrCleanOpTask::GetKeyForOpParamSize() const { return kAttrAtomicOpParamSize; } + +Status AtomicAddrCleanOpTask::CalcTilingInfo(const NodePtr &node, OpRunInfo &tiling_info) { + GELOGD("[%s] Start to invoke OpAtomicCalculate.", node->GetName().c_str()); + GE_CHK_STATUS_RET(OpAtomicCalculate(*node, tiling_info), "Failed calc tiling data of node %s.", + node->GetName().c_str()); + GELOGD("[%s] Done invoking OpAtomicCalculate successfully.", node->GetName().c_str()); + return SUCCESS; +} + +Status AtomicAddrCleanOpTask::UpdateArgs(TaskContext &task_context) { + // refresh atomic output addr + int index = 0; + for (auto atomic_output_index : atomic_output_indices_) { + const auto output_tensor = task_context.GetOutput(atomic_output_index); + GE_CHECK_NOTNULL(output_tensor); + arg_base_[index++] = reinterpret_cast(output_tensor->GetData()); + } + + // refresh atomic workspace addr + for (auto atomic_ws_index : atomic_workspace_indices_) { + const auto workspace_tensor = task_context.GetOutput(atomic_ws_index); + GE_CHECK_NOTNULL(workspace_tensor); + arg_base_[index++] = reinterpret_cast(workspace_tensor->GetData()); + } + + if (tiling_buffer_ != nullptr) { + arg_base_[index++] = reinterpret_cast(tiling_buffer_->GetData()); + } else { + GELOGD("[%s] Not a dynamic op", GetName().c_str()); + } + + if (task_context.IsTraceEnabled()) { + for (int i = 0; i < index; ++i) { + GELOGD("[%s] Arg[%d] = %lu", GetName().c_str(), i, arg_base_[i]); + } + } + + return SUCCESS; +} } // namespace hybrid -} // namespace ge \ No newline at end of file +} // namespace ge diff --git a/src/ge/hybrid/node_executor/aicore/aicore_op_task.h b/src/ge/hybrid/node_executor/aicore/aicore_op_task.h index d23688a5..41ab0d79 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_op_task.h +++ b/src/ge/hybrid/node_executor/aicore/aicore_op_task.h @@ -18,27 +18,70 @@ #define GE_HYBRID_KERNEL_AICORE_OP_TASK_H_ #include +#include #include "common/ge_inner_error_codes.h" #include "runtime/stream.h" +#include "hybrid/common/tensor_value.h" +#include "hybrid/node_executor/task_context.h" +#include "proto/task.pb.h" +#include "register/op_tiling.h" + namespace ge { namespace hybrid { class AiCoreOpTask { public: AiCoreOpTask() = default; - ~AiCoreOpTask() = default; + virtual ~AiCoreOpTask() = default; + + virtual Status Init(const OpDesc &op_desc, const domi::TaskDef &task_def); + + bool IsDynamicShapeSupported(); + + // do preparation with shape(without actual io memory) + Status PrepareWithShape(TaskContext &context); + + virtual Status UpdateArgs(TaskContext &task_context); + Status LaunchKernel(rtStream_t stream); + const std::string &GetName() const; + + protected: + Status UpdateTilingInfo(TaskContext &context); + virtual std::string GetKeyForOpParamSize() const; + virtual Status CalcTilingInfo(const NodePtr &node, optiling::OpRunInfo &tiling_info); + + std::unique_ptr tiling_buffer_ = nullptr; + std::string tiling_data_; + uintptr_t *arg_base_ = nullptr; + uint32_t max_arg_count_ = 0; + private: - friend class AiCoreTaskBuilder; - friend class AiCoreNodeTask; + static Status ValidateTaskDef(const domi::TaskDef &task_def); + Status InitWithTaskDef(const OpDesc &node, const domi::TaskDef &task_def); + Status InitTilingInfo(const OpDesc &op_desc); + std::string stub_name_; void *stub_func_ = nullptr; std::unique_ptr args_ = nullptr; uint32_t args_size_ = 0; uint32_t block_dim_ = 1; - uint16_t offset_ = 0; }; +class AtomicAddrCleanOpTask : public AiCoreOpTask { + public: + Status Init(const OpDesc &op_desc, const domi::TaskDef &task_def) override; + Status UpdateArgs(TaskContext &task_context) override; + + protected: + std::string GetKeyForOpParamSize() const override; + Status CalcTilingInfo(const NodePtr &node, optiling::OpRunInfo &tiling_info) override; + + private: + Status InitAtomicAddrCleanIndices(const OpDesc &op_desc); + std::vector atomic_output_indices_; + std::vector atomic_workspace_indices_; +}; } // namespace hybrid } // namespace ge #endif // GE_HYBRID_KERNEL_AICORE_OP_TASK_H_ diff --git a/src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc b/src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc index 5b263007..bad91806 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc +++ b/src/ge/hybrid/node_executor/aicore/aicore_task_builder.cc @@ -15,76 +15,78 @@ */ #include "aicore_task_builder.h" -#include -#include "graph/op_desc.h" -#include "cce/taskdown_common.hpp" -#include "framework/common/debug/log.h" -#include "graph/debug/ge_attr_define.h" +#include "common/debug/log.h" +#include "aicore_node_executor.h" namespace ge { namespace hybrid { -std::mutex g_reg_mutex; - -AiCoreTaskBuilder::AiCoreTaskBuilder(const OpDescPtr &op_desc, const domi::KernelDef &kernel_def) - : op_desc_(op_desc), kernel_def_(kernel_def) { - std::string session_graph_id; - GE_IF_BOOL_EXEC(AttrUtils::GetStr(*op_desc_, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id), - GELOGD("Get original type of session_graph_id.")); - // get bin_file_key - stub_name_ = (session_graph_id.empty()) ? op_desc_->GetName() : session_graph_id + "_" + op_desc_->GetName(); -} - -Status AiCoreTaskBuilder::SetKernelArgs(AiCoreOpTask &task) { - const domi::KernelContext &context = kernel_def_.context(); - // get kernel_type - auto kernel_type = static_cast(context.kernel_type()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(kernel_type != cce::ccKernelType::TE, return UNSUPPORTED, - "Invalid kernel type[%d] in AiCore TaskDef.", static_cast(kernel_type)); - - task.args_size_ = kernel_def_.args_size(); - task.block_dim_ = kernel_def_.block_dim(); - - // malloc args memory - task.args_.reset(new (std::nothrow) uint8_t[task.args_size_]); - // task.args_ = std::make_unique(task.args_size_); - GE_CHECK_NOTNULL(task.args_); - errno_t err = memcpy_s(task.args_.get(), task.args_size_, kernel_def_.args().data(), task.args_size_); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(err != EOK, return INTERNAL_ERROR, "AiCoreTask memcpy failed."); - - const auto *args_offset_tmp = reinterpret_cast(const_cast(context.args_offset().data())); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(context.args_offset().size() / sizeof(uint16_t) < 1, return FAILED, - "context.args_offset().size() / sizeof(uint16_t) less than 1"); - task.offset_ = *args_offset_tmp; - return SUCCESS; +namespace { +const size_t kNumTaskWithAtomicAddrCleanTask = 2; } - const char *AiCoreKernelRegistry::GetUnique(const string &stub_key) { std::lock_guard lock(mutex_); auto it = unique_stubs_.find(stub_key); - GE_IF_BOOL_EXEC(it != unique_stubs_.end(), return it->c_str()); + if (it != unique_stubs_.end()) { + return it->c_str(); + } it = unique_stubs_.insert(unique_stubs_.end(), stub_key); return it->c_str(); } -Status AiCoreTaskBuilder::SetStub(AiCoreOpTask &task) { - AiCoreKernelRegistry ®istry = AiCoreKernelRegistry::GetInstance(); - std::lock_guard lock(g_reg_mutex); - const char *unique_key = registry.GetUnique(stub_name_); +AiCoreTaskBuilder::AiCoreTaskBuilder(const OpDescPtr &op_desc, const std::vector &task_defs) + : op_desc_(op_desc), task_defs_(task_defs) {} - GE_CHK_RT_RET(rtGetFunctionByName(unique_key, &(task.stub_func_))); - task.stub_name_ = stub_name_; +Status AiCoreTaskBuilder::BuildTask(std::unique_ptr &node_task, bool ignore_failure_on_atomic) { + GE_CHECK_NOTNULL(op_desc_); + if (task_defs_.size() > kNumTaskWithAtomicAddrCleanTask) { + GELOGE(INTERNAL_ERROR, "[%s] At most 2 task was supported, but got %zu", op_desc_->GetName().c_str(), + task_defs_.size()); + return INTERNAL_ERROR; + } - return SUCCESS; -} + std::vector> op_tasks; + if (ExpectAtomicAddrCleanTask()) { + if (task_defs_.size() != kNumTaskWithAtomicAddrCleanTask) { + if (ignore_failure_on_atomic) { + GELOGI("[%s] AtomicAddrClean task was expected, but got %zu task_defs", op_desc_->GetName().c_str(), + task_defs_.size()); + return SUCCESS; + } else { + GELOGE(INTERNAL_ERROR, "[%s] AtomicAddrClean task was expected, but got %zu task_defs", + op_desc_->GetName().c_str(), task_defs_.size()); + return INTERNAL_ERROR; + } + } -Status AiCoreTaskBuilder::BuildTask(AiCoreOpTask &task) { - GE_CHECK_NOTNULL(op_desc_); - GELOGI("AiCoreTaskBuilder[%s] BuildTask Start.", op_desc_->GetName().c_str()); - GE_CHK_STATUS_RET_NOLOG(SetKernelArgs(task)); - GE_CHK_STATUS_RET_NOLOG(SetStub(task)); - GELOGI("AiCoreTaskBuilder[%s] BuildTask End.", op_desc_->GetName().c_str()); + GELOGD("[%s] Build AtomicAddrClean task.", op_desc_->GetName().c_str()); + auto atomic_task = std::unique_ptr(new (std::nothrow) AtomicAddrCleanOpTask()); + GE_CHECK_NOTNULL(atomic_task); + GE_CHK_STATUS_RET(atomic_task->Init(*op_desc_, task_defs_.front()), "[%s] Failed to init task for AtomicAddrClean", + op_desc_->GetName().c_str()); + op_tasks.emplace_back(std::move(atomic_task)); + } + + // build aicore task + auto aicore_task = std::unique_ptr(new (std::nothrow) AiCoreOpTask()); + GE_CHECK_NOTNULL(aicore_task); + GE_CHK_STATUS_RET(aicore_task->Init(*op_desc_, task_defs_.back()), "[%s] Failed to init task for AtomicAddrClean", + op_desc_->GetName().c_str()); + op_tasks.emplace_back(std::move(aicore_task)); + + node_task.reset(new (std::nothrow) AiCoreNodeTask(std::move(op_tasks))); + GE_CHECK_NOTNULL(node_task); return SUCCESS; } +bool AiCoreTaskBuilder::ExpectAtomicAddrCleanTask() { + if (op_desc_->HasAttr(ATOMIC_ATTR_OUTPUT_INDEX)) { + GELOGD("[%s] Node has ATOMIC_ATTR_OUTPUT_INDEX", op_desc_->GetName().c_str()); + return true; + } + map> workspace_info; + workspace_info = op_desc_->TryGetExtAttr(EXT_ATTR_ATOMIC_WORKSPACE_INFO, workspace_info); + + return !workspace_info.empty(); +} } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/aicore/aicore_task_builder.h b/src/ge/hybrid/node_executor/aicore/aicore_task_builder.h index 18cb309c..4610e57a 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_task_builder.h +++ b/src/ge/hybrid/node_executor/aicore/aicore_task_builder.h @@ -17,14 +17,13 @@ #ifndef GE_HYBRID_KERNEL_AICORE_TASK_BUILDER_H_ #define GE_HYBRID_KERNEL_AICORE_TASK_BUILDER_H_ -#include +#include #include -#include -#include #include "aicore_op_task.h" -#include "proto/task.pb.h" +#include "framework/common/debug/ge_log.h" #include "graph/utils/attr_utils.h" #include "graph/op_kernel_bin.h" +#include "proto/task.pb.h" namespace ge { namespace hybrid { @@ -45,16 +44,16 @@ class AiCoreKernelRegistry { class AiCoreTaskBuilder { public: - AiCoreTaskBuilder(const OpDescPtr &op_desc, const domi::KernelDef &kernel_def); + AiCoreTaskBuilder(const OpDescPtr &op_desc, const std::vector &task_defs); ~AiCoreTaskBuilder() = default; - Status BuildTask(AiCoreOpTask &task); + + Status BuildTask(std::unique_ptr &node_task, bool ignore_failure_on_atomic); private: - Status SetKernelArgs(AiCoreOpTask &task); - Status SetStub(AiCoreOpTask &task); - const OpDescPtr &op_desc_; - const domi::KernelDef &kernel_def_; - std::string stub_name_; + bool ExpectAtomicAddrCleanTask(); + + OpDescPtr op_desc_; + const std::vector &task_defs_; }; } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc b/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc index ac89afbd..588f179d 100644 --- a/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc +++ b/src/ge/hybrid/node_executor/aicore/aicore_task_compiler.cc @@ -34,7 +34,6 @@ Status AiCoreTaskCompiler::DoCompileOp(OpsKernelInfoStore &ops_store, const Node GE_CHECK_NOTNULL(node); vector node_vec; node_vec.emplace_back(node); - std::lock_guard lk(mu_); GE_CHK_STATUS_RET(ops_store.CompileOpRun(node_vec), "Failed to execute CompileOp, node = %s", node->GetName().c_str()); GE_CHK_STATUS_RET(ops_store.CalcOpRunningParam(*node), "Failed to execute CalcOpRunningParam, node = %s", @@ -44,9 +43,8 @@ Status AiCoreTaskCompiler::DoCompileOp(OpsKernelInfoStore &ops_store, const Node Status AiCoreTaskCompiler::CompileOp(const NodePtr &node, std::vector &tasks) const { GE_CHECK_NOTNULL(node); - GELOGI("AiCoreTaskCompiler[%s] CompileOp Start.", node->GetName().c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(aic_kernel_store_ == nullptr, return FAILED, - "Failed to get AiCore kernel store, node = %s", node->GetName().c_str()); + GELOGI("AiCoreTaskCompiler(%s) CompileOp Start.", node->GetName().c_str()); + GE_CHECK_NOTNULL(aic_kernel_store_); GE_CHK_STATUS_RET_NOLOG(DoCompileOp(*aic_kernel_store_, node)); GELOGD("successfully compiled op: %s", node->GetName().c_str()); @@ -56,9 +54,11 @@ Status AiCoreTaskCompiler::CompileOp(const NodePtr &node, std::vector output_offsets(op_desc->GetOutputsSize(), kMemBase); op_desc->SetInputOffset(input_offsets); op_desc->SetOutputOffset(output_offsets); + std::vector workspaces(op_desc->GetWorkspaceBytes().size(), kMemBase); + op_desc->SetWorkspace(std::move(workspaces)); GE_CHK_STATUS_RET_NOLOG(DoGenerateTask(*aic_kernel_store_, *node, tasks)); GELOGD("successfully generated task: %s", node->GetName().c_str()); - GELOGI("AiCoreTaskCompiler[%s] CompileOp End.", node->GetName().c_str()); + GELOGI("AiCoreTaskCompiler(%s) CompileOp End.", node->GetName().c_str()); return SUCCESS; } @@ -91,6 +91,5 @@ Status AiCoreTaskCompiler::DoGenerateTask(OpsKernelInfoStore &store, const Node GE_CHK_RT(rtModelDestroy(rt_model_)); return ret; } - } // namespace hybrid -} // namespace ge \ No newline at end of file +} // namespace ge diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc b/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc index d5c3c03c..332675bf 100644 --- a/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc +++ b/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc @@ -199,6 +199,5 @@ void AicpuExtInfoHandler::GetShapeAndType(const AicpuShapeAndType *shape_and_typ data_type = static_cast(shape_and_type->type); shape = std::move(GeShape(dims)); } - } // namespace hybrid } // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h b/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h index e96d794c..a42678b1 100644 --- a/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h +++ b/src/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h @@ -24,7 +24,6 @@ namespace ge { namespace hybrid { - using AicpuShapeAndType = aicpu::FWKAdapter::ShapeAndType; using AicpuExtInfo = aicpu::FWKAdapter::ExtInfo; diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc index 372f35f5..46d9a0aa 100644 --- a/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc +++ b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc @@ -40,19 +40,28 @@ Status AicpuNodeTaskBase::AllocTensorBuffer(size_t size, std::unique_ptris_dynamic) { + // dynamic node must have ext info + GE_CHK_STATUS_RET(aicpu_ext_handle_.Parse(kernel_ext_info), + "Node[%s] parse kernel ext info failed, kernel_ext_info_size=%zu.", node_name_.c_str(), + kernel_ext_info.size()); + } + + // if no ext info no need copy to device. + if (kernel_ext_info.empty()) { + GELOGI("Node[%s] kernel_ext_info is empty, no need copy to device, is_dynamic=%s.", node_name_.c_str(), + node_item_->is_dynamic ? "true" : "false"); + return SUCCESS; + } // copy task args buf GE_CHK_STATUS_RET(AllocTensorBuffer(kernel_ext_info.size(), ext_info_addr_dev_), "Node[%s] alloc kernel_ext_info buf failed, size=%zu", node_name_.c_str(), kernel_ext_info.size()); - // if no input and no output(DEPEND_COMPUTE equal no output), copy once, or else copy when update args. - if (node_item_->num_inputs == 0 && ((unknown_type_ == DEPEND_COMPUTE) || (node_item_->num_outputs == 0))) { - GE_CHK_RT_RET(rtMemcpy(ext_info_addr_dev_->GetData(), ext_info_addr_dev_->GetSize(), kernel_ext_info.data(), - kernel_ext_info.size(), RT_MEMCPY_HOST_TO_DEVICE)); - } + // copy default ext info to device + GE_CHK_RT_RET(rtMemcpy(ext_info_addr_dev_->GetData(), ext_info_addr_dev_->GetSize(), kernel_ext_info.data(), + kernel_ext_info.size(), RT_MEMCPY_HOST_TO_DEVICE)); + return SUCCESS; } @@ -139,16 +148,18 @@ Status AicpuNodeTaskBase::UpdateExtInfo() { } Status AicpuNodeTaskBase::UpdateArgs(TaskContext &context) { - GELOGI("Node[%s] update args begin. unknown_type=%d", node_name_.c_str(), unknown_type_); + GELOGI("Node[%s] update args begin. is_dynamic=%s, unknown_type=%d", node_name_.c_str(), + node_item_->is_dynamic ? "true" : "false", unknown_type_); if (node_item_->num_inputs == 0 && node_item_->num_outputs == 0) { GELOGI("Node[%s] has no input and output, no need update args.", node_name_.c_str()); return SUCCESS; } GE_CHK_STATUS_RET(UpdateIoAddr(context), "Node[%s] update io addr failed.", node_name_.c_str()); - - GE_CHK_STATUS_RET(UpdateExtInfo(), "Node[%s] update ext info failed.", node_name_.c_str()); - + if (node_item_->is_dynamic) { + // dynamic node need update ext info. + GE_CHK_STATUS_RET(UpdateExtInfo(), "Node[%s] update ext info failed.", node_name_.c_str()); + } GELOGI("Node[%s] update args end.", node_name_.c_str()); return SUCCESS; } @@ -275,9 +286,12 @@ Status AicpuTfNodeTask::Init(const HybridModel &model) { fwk_op_kernel.fwkKernelBase.fwk_kernel.workspaceBaseAddr = reinterpret_cast(kernel_workspace_->GetData()); fwk_op_kernel.fwkKernelBase.fwk_kernel.inputOutputAddr = reinterpret_cast(input_output_addr_->GetData()); - // set ext info addr and ext info num - fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = reinterpret_cast(ext_info_addr_dev_->GetData()); - fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoLen = ext_info_addr_dev_->GetSize(); + + if (ext_info_addr_dev_ != nullptr) { + // set ext info addr and ext info num + fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = reinterpret_cast(ext_info_addr_dev_->GetData()); + fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoLen = ext_info_addr_dev_->GetSize(); + } fwk_op_kernel.fwkKernelBase.fwk_kernel.stepIDAddr = GetStepIdAddr(model); @@ -506,7 +520,8 @@ Status AicpuTfNodeTask::UpdateIoAddr(TaskContext &context) { io_addrs.emplace_back(reinterpret_cast(inputData->GetData())); } - if (unknown_type_ != DEPEND_COMPUTE) { + // known shape or not depend compute + if (!node_item_->is_dynamic || unknown_type_ != DEPEND_COMPUTE) { // unknown type 4 do this in call back. GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); for (auto j = 0; j < node_item_->num_outputs; ++j) { @@ -548,14 +563,17 @@ Status AicpuTfNodeTask::LaunchTask(TaskContext &context) { } Status AicpuTfNodeTask::TaskCallback(TaskContext &context) { - GELOGI("Node[%s] task callback start. unknown_type=%d.", node_name_.c_str(), unknown_type_); + GELOGI("Node[%s] task callback start. is_dynamic=%s, unknown_type=%d.", node_name_.c_str(), + node_item_->is_dynamic ? "true" : "false", unknown_type_); Status callback_ret = SUCCESS; - // check need update shape, call update shape. - if (unknown_type_ == DEPEND_SHAPE_RANGE) { - // check result - callback_ret = UpdateOutputShapeFromExtInfo(); - } else if (unknown_type_ == DEPEND_COMPUTE) { - callback_ret = UpdateShapeAndDataByResultSummary(context); + if (node_item_->is_dynamic) { + // check need update shape, call update shape. + if (unknown_type_ == DEPEND_SHAPE_RANGE) { + // check result + callback_ret = UpdateOutputShapeFromExtInfo(); + } else if (unknown_type_ == DEPEND_COMPUTE) { + callback_ret = UpdateShapeAndDataByResultSummary(context); + } } GELOGI("Node[%s] task callback end.", node_name_.c_str()); return callback_ret; @@ -612,8 +630,13 @@ Status AicpuNodeTask::Init(const HybridModel &model) { GE_CHK_STATUS_RET(InitExtInfo(kernel_ext_info), "Node[%s] init ext info failed.", node_name.c_str()); - aicpu_param_head->extInfoLength = ext_info_addr_dev_->GetSize(); - aicpu_param_head->extInfoAddr = reinterpret_cast(ext_info_addr_dev_->GetData()); + if (ext_info_addr_dev_ == nullptr) { + aicpu_param_head->extInfoLength = 0; + aicpu_param_head->extInfoAddr = 0; + } else { + aicpu_param_head->extInfoLength = ext_info_addr_dev_->GetSize(); + aicpu_param_head->extInfoAddr = reinterpret_cast(ext_info_addr_dev_->GetData()); + } GELOGI("Node[%s] init end.", node_name.c_str()); return SUCCESS; @@ -664,10 +687,12 @@ Status AicpuNodeTask::LaunchTask(TaskContext &context) { } Status AicpuNodeTask::TaskCallback(TaskContext &context) { - GELOGI("Node[%s] task callback start, unknown_type=%d.", node_name_.c_str(), unknown_type_); + GELOGI("Node[%s] task callback start, is_dynamic = %s, unknown_type=%d.", node_name_.c_str(), + node_item_->is_dynamic ? "true" : "false", unknown_type_); Status callback_ret = SUCCESS; + // check need update shape, call update shape. - if (unknown_type_ == DEPEND_SHAPE_RANGE) { + if (node_item_->is_dynamic && unknown_type_ == DEPEND_SHAPE_RANGE) { // check result callback_ret = UpdateOutputShapeFromExtInfo(); } else { diff --git a/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h index ce3f9707..8aca6ff7 100644 --- a/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h +++ b/src/ge/hybrid/node_executor/aicpu/aicpu_node_executor.h @@ -24,7 +24,6 @@ namespace ge { namespace hybrid { - class AicpuNodeTaskBase : public NodeTask { public: AicpuNodeTaskBase(const NodeItem *node_item, const domi::TaskDef &task_def) @@ -70,8 +69,10 @@ class AicpuNodeTaskBase : public NodeTask { const std::string node_type_; + // valid when node_item_->is_dynamic is true UnknowShapeOpType unknown_type_ = DEPEND_IN_SHAPE; + // valid when node_item_->is_dynamic is true AicpuExtInfoHandler aicpu_ext_handle_; // ext info addr, device mem @@ -169,7 +170,6 @@ class AiCpuNodeExecutor : public NodeExecutor { Status PrepareTask(NodeTask &task, TaskContext &context) const override; }; - } // namespace hybrid } // namespace ge #endif // GE_HYBRID_KERNEL_AICPU_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc b/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc index 81960c48..2e1893f2 100644 --- a/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc +++ b/src/ge/hybrid/node_executor/compiledsubgraph/known_node_executor.cc @@ -26,7 +26,6 @@ namespace ge { namespace hybrid { - REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::COMPILED_SUBGRAPH, KnownNodeExecutor); Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { @@ -50,19 +49,12 @@ Status KnownNodeTask::ExecuteAsync(TaskContext &context, std::function d rtError_t rt_ret; GELOGI("rtModelExecute start."); - rt_ret = rtModelExecute(davinci_model_->GetRtModelHandle(), davinci_model_->GetRtModelStream(), 0); + rt_ret = rtModelExecute(davinci_model_->GetRtModelHandle(), context.GetStream(), 0); GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtModelExecute error, ret: Ox%X", rt_ret); return FAILED;); GELOGI("rtModelExecute end"); - GELOGI("rtStreamSynchronize start."); - rt_ret = rtStreamSynchronize(davinci_model_->GetRtModelStream()); - GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtStreamSynchronize error, ret: Ox%X", rt_ret); - return FAILED;); - GELOGI("rtStreamSynchronize end."); - context.RegisterCallback(done_callback); GELOGI("[%s] KnownNodeTask::ExecuteAsync success.", context.GetNodeName()); - return SUCCESS; } @@ -89,7 +81,8 @@ Status KnownNodeTask::UpdateArgs(TaskContext &context) { GE_CHK_STATUS_RET(davinci_model_->UpdateKnownNodeArgs(inputs, outputs), "known node task update known node args failed."); - GELOGI("[%s] KnownNodeExecutor::UpdateArgs success.", context.GetNodeName()); + GELOGI("[%s] KnownNodeExecutor::UpdateArgs success, task_size = %d:", context.GetNodeName(), + davinci_model_->GetTaskList().size()); return SUCCESS; } @@ -98,13 +91,22 @@ Status KnownNodeTask::Init(TaskContext &context) { GE_CHK_STATUS_RET(context.AllocateOutputs(), "known node task allocate output failed."); // init davinicmodel - davinci_model_->InitRuntimeParams(); - GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed."); + if (!load_flag_) { + davinci_model_->InitRuntimeParams(); + GE_CHK_STATUS_RET(davinci_model_->InitVariableMem(), "init variable mem failed."); + } + // allocate mem base void *buffer = nullptr; if (davinci_model_->TotalMemSize() != 0) { - GE_CHK_STATUS_RET(context.AllocateWorkspace(davinci_model_->TotalMemSize(), &buffer), - "known node task allocate workspace failed."); + GE_CHK_STATUS_RET( + context.AllocateWorkspace(davinci_model_->TotalMemSize(), &buffer, davinci_model_->GetRuntimeParam().mem_base), + "known node task allocate workspace failed."); + bool addr_not_changed = false; + if (davinci_model_->GetRuntimeParam().mem_base == buffer) { + addr_not_changed = true; + } + davinci_model_->SetKnownNodeAddrNotChanged(addr_not_changed); // update mem base davinci_model_->UpdateMemBase(static_cast(buffer)); GELOGI("KnownNodeTask::Init mem base is %p, size %u.", davinci_model_->GetRuntimeParam().mem_base, @@ -124,7 +126,6 @@ Status KnownNodeTask::Init(TaskContext &context) { Status KnownNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { GELOGI("[%s] KnownNodeExecutor::PrepareTask in.", context.GetNodeName()); - GE_CHK_STATUS_RET(task.Init(context), "known node init davinci model failed."); GE_CHK_STATUS_RET(task.UpdateArgs(context), "known node task update args failed."); @@ -161,6 +162,5 @@ Status KnownNodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context, context.GetNodeItem().NodeName().c_str()); return SUCCESS; } - } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/controlop/control_op_executor.cc b/src/ge/hybrid/node_executor/controlop/control_op_executor.cc new file mode 100644 index 00000000..aee7fb77 --- /dev/null +++ b/src/ge/hybrid/node_executor/controlop/control_op_executor.cc @@ -0,0 +1,344 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "control_op_executor.h" +#include "graph/utils/node_utils.h" +#include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/executor/subgraph_executor.h" + +namespace ge { +namespace hybrid { +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::CONTROL_OP, ControlOpNodeExecutor); + +Status ControlOpNodeTask::ExecuteSubgraph(const GraphItem *subgraph, TaskContext &task_context, + const std::function &done_callback) { + GELOGD("[%s] Start to execute subgraph.", subgraph->GetName().c_str()); + auto execution_context = const_cast(task_context.GetExecutionContext()); + auto executor = MakeShared(subgraph, execution_context); + GE_CHECK_NOTNULL(executor); + GE_CHK_STATUS_RET(executor->ExecuteAsync(task_context), "[%s] Failed to execute partitioned call.", + subgraph->GetName().c_str()); + + auto callback = [executor, done_callback]() mutable { + if (done_callback != nullptr) { + done_callback(); + } + // executor must outlive task context + executor.reset(); + }; + + GE_CHK_STATUS_RET_NOLOG(task_context.RegisterCallback(callback)); + GELOGD("[%s] Done executing subgraph successfully.", subgraph->GetName().c_str()); + return SUCCESS; +} + +Status ControlOpNodeTask::CopyTensorValueToHost(const TensorValue &tensor, int32_t &value) { + GE_CHECK_NOTNULL(tensor.GetData()); + GE_CHECK_GE(tensor.GetSize(), sizeof(value)); + GE_CHK_RT_RET(rtMemcpy(&value, sizeof(value), tensor.GetData(), sizeof(value), RT_MEMCPY_DEVICE_TO_HOST)); + return SUCCESS; +} + +Status ControlOpNodeTask::UpdateArgs(TaskContext &context) { + // do nothing + return SUCCESS; +} + +Status ControlOpNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { + auto ret = DoExecuteAsync(task_context, done_callback); + task_context.SetStatus(ret); + + if (done_callback) { + done_callback(); + } + + return ret; +} + +Status IfOpNodeTask::Init(const NodePtr &node, const HybridModel &model) { + GELOGD("[%s] Start to init IfOpNodeTask.", node->GetName().c_str()); + auto then_subgraph = NodeUtils::GetSubgraph(*node, kThenBranchIndex); + GE_CHECK_NOTNULL(then_subgraph); + GELOGD("[%s] Adding subgraph [%s] to then-subgraph.", node->GetName().c_str(), then_subgraph->GetName().c_str()); + then_ = model.GetSubgraphItem(then_subgraph); + GE_CHECK_NOTNULL(then_); + + auto else_subgraph = NodeUtils::GetSubgraph(*node, kElseBranchIndex); + GE_CHECK_NOTNULL(else_subgraph); + GELOGD("[%s] Adding subgraph [%s] to else-subgraph.", node->GetName().c_str(), else_subgraph->GetName().c_str()); + else_ = model.GetSubgraphItem(else_subgraph); + GE_CHECK_NOTNULL(else_); + + GELOGD("[%s] Done initialization successfully.", node->GetName().c_str()); + return SUCCESS; +} + +const GraphItem *IfOpNodeTask::SelectBranch(int32_t cond) const { return cond != 0 ? then_ : else_; } + +Status IfOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const { + auto cond_tensor = task_context.GetInput(kIfCondIndex); + GE_CHECK_NOTNULL(cond_tensor); + int32_t cond_val = 0; + GE_CHK_STATUS_RET(CopyTensorValueToHost(*cond_tensor, cond_val), "[%s] Failed to get cond value.", + task_context.GetNodeName()); + + auto subgraph = SelectBranch(cond_val); + GELOGD("[%s] Taking subgraph [%s] by cond = [%d]", task_context.GetNodeName(), subgraph->GetName().c_str(), cond_val); + GE_CHK_STATUS_RET(ExecuteSubgraph(subgraph, task_context, done_callback), + "[%s] Failed to execute subgraph. cond = %d", task_context.GetNodeName(), cond_val); + + GELOGD("[%s] Done executing with cond = %d successfully.", task_context.GetNodeName(), cond_val); + return SUCCESS; +} + +Status CaseOpNodeTask::Init(const NodePtr &node, const HybridModel &model) { + size_t num_subgraphs = node->GetOpDesc()->GetSubgraphInstanceNames().size(); + GE_CHECK_LE(num_subgraphs, kMaxBranchNum); + GE_CHECK_GE(num_subgraphs, kMinBranchNum); + auto num_branches = static_cast(num_subgraphs); + GELOGD("[%s] Start to init CaseOpNodeTask with %u branches.", node->GetName().c_str(), num_branches); + + for (uint32_t i = 0; i < num_branches; ++i) { + auto sub_graph = NodeUtils::GetSubgraph(*node, i); + GE_CHECK_NOTNULL(sub_graph); + auto graph_item = model.GetSubgraphItem(sub_graph); + GE_CHECK_NOTNULL(graph_item); + GELOGD("[%s] Adding subgraph [%s] to branch %u.", node->GetName().c_str(), sub_graph->GetName().c_str(), i); + subgraphs_.emplace_back(graph_item); + } + + GELOGD("[%s] Done initialization successfully.", node->GetName().c_str()); + return SUCCESS; +} + +const GraphItem *CaseOpNodeTask::SelectBranch(int32_t branch_index) const { + // subgraphs_ is non-empty. checked int Init + if (branch_index < 0 || static_cast(branch_index) >= subgraphs_.size()) { + GELOGI("Branch index out of range. index = %d, num_subgraphs = %zu, will taking last branch.", branch_index, + subgraphs_.size()); + branch_index = subgraphs_.size() - 1; + } + + return subgraphs_[branch_index]; +} + +Status CaseOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const { + auto branch_tensor = task_context.GetInput(kCaseBranchIndex); + GE_CHECK_NOTNULL(branch_tensor); + int32_t branch_index = 0; + GE_CHK_STATUS_RET(CopyTensorValueToHost(*branch_tensor, branch_index), "[%s] Failed to get branch index.", + task_context.GetNodeName()); + + const GraphItem *subgraph = SelectBranch(branch_index); + GELOGI("[%s] Taking subgraph [%s] by branch = [%d]", task_context.GetNodeName(), subgraph->GetName().c_str(), + branch_index); + + std::vector inputs; + std::vector outputs; + for (int i = 0; i < task_context.NumInputs(); ++i) { + auto input_tensor = task_context.GetInput(i); + GE_CHECK_NOTNULL(input_tensor); + inputs.emplace_back(*input_tensor); + } + + GE_CHK_STATUS_RET(ExecuteSubgraph(subgraph, task_context, done_callback), "[%s] Failed to execute else-subgraph.", + task_context.GetNodeName()); + + GELOGD("[%s] Done executing subgraph[%d] successfully.", task_context.GetNodeName(), branch_index); + return SUCCESS; +} + +Status WhileOpNodeTask::Init(const NodePtr &node, const HybridModel &model) { + GELOGD("[%s] Start to init WhileOpNodeTask.", node->GetName().c_str()); + auto cond_subgraph = NodeUtils::GetSubgraph(*node, kCondBranchIndex); + GE_CHECK_NOTNULL(cond_subgraph); + GELOGD("[%s] Adding subgraph [%s] to cond-subgraph.", node->GetName().c_str(), cond_subgraph->GetName().c_str()); + cond_ = model.GetSubgraphItem(cond_subgraph); + GE_CHECK_NOTNULL(cond_); + + auto body_subgraph = NodeUtils::GetSubgraph(*node, kBodyBranchIndex); + GE_CHECK_NOTNULL(body_subgraph); + GELOGD("[%s] Adding subgraph [%s] to body-subgraph.", node->GetName().c_str(), body_subgraph->GetName().c_str()); + body_ = model.GetSubgraphItem(body_subgraph); + GE_CHECK_NOTNULL(body_); + + GELOGD("[%s] Done initialization successfully.", node->GetName().c_str()); + return SUCCESS; +} + +Status WhileOpNodeTask::DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const { + if (task_context.NumInputs() != task_context.NumOutputs()) { + GELOGE(INTERNAL_ERROR, "[%s] Invalid while args. num_inputs = %d, num_outputs = %d", task_context.GetNodeName(), + task_context.NumInputs(), task_context.NumOutputs()); + return INTERNAL_ERROR; + } + + bool is_continue = false; + GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), "[%s] Failed to execute iteration 0.", + task_context.GetNodeName()); + if (!is_continue) { + for (int i = 0; i < task_context.NumInputs(); ++i) { + auto input_tensor = task_context.GetInput(i); + auto input_tensor_desc = task_context.GetInputDesc(i); + auto output_tensor_desc = task_context.MutableOutputDesc(i); + GE_CHECK_NOTNULL(input_tensor); + GE_CHECK_NOTNULL(input_tensor_desc); + GE_CHECK_NOTNULL(output_tensor_desc); + GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_tensor)); + *output_tensor_desc = *input_tensor_desc; + } + + return SUCCESS; + } + + // backup original input tensor desc + std::vector ori_input_desc; + for (int i = 0; i < task_context.NumInputs(); ++i) { + auto tensor_desc = task_context.GetInputDesc(i); + GE_CHECK_NOTNULL(tensor_desc); + ori_input_desc.emplace_back(*tensor_desc); + } + + int iteration = 1; + while (true) { + GELOGD("[%s] Start to execute, iteration = %d", task_context.GetNodeName(), iteration); + GE_CHK_STATUS_RET(ExecuteOneLoop(task_context, is_continue), "[%s] Failed to execute iteration %d.", + task_context.GetNodeName(), iteration); + + if (!is_continue) { + GELOGD("[%s] Quit from loop. current iteration = %d", task_context.GetNodeName(), iteration); + break; + } + + ++iteration; + } + + for (int i = 0; i < task_context.NumInputs(); ++i) { + auto input_tensor = task_context.GetInput(i); + auto tensor_desc = task_context.MutableInputDesc(i); + GE_CHECK_NOTNULL(input_tensor); + GE_CHECK_NOTNULL(tensor_desc); + // restore original input tensor desc + *tensor_desc = std::move(ori_input_desc[i]); + GE_CHK_STATUS_RET_NOLOG(task_context.SetOutput(i, *input_tensor)); + } + + return SUCCESS; +} + +Status WhileOpNodeTask::ExecuteCond(TaskContext &task_context, bool &is_continue) const { + std::vector inputs; + std::vector input_desc; + std::vector output_desc; + for (int i = 0; i < task_context.NumInputs(); ++i) { + auto input_tensor = task_context.GetInput(i); + GE_CHECK_NOTNULL(input_tensor); + inputs.emplace_back(*input_tensor); + input_desc.emplace_back(task_context.GetInputDesc(i)); + } + + auto execution_context = const_cast(task_context.GetExecutionContext()); + auto executor = MakeShared(cond_, execution_context, task_context.IsForceInferShape()); + GE_CHECK_NOTNULL(executor); + GELOGD("[%s] Start to execute cond-subgraph.", task_context.GetNodeName()); + GE_CHK_STATUS_RET(executor->ExecuteAsync(inputs, input_desc), "Failed to execute partitioned call."); + GELOGD("[%s] Done executing cond-subgraph successfully.", cond_->GetName().c_str()); + GE_CHK_STATUS_RET_NOLOG(task_context.RegisterCallback([executor]() mutable { executor.reset(); })); + + // get cond output + GE_CHK_STATUS_RET(executor->Synchronize(), "[%s] Failed to sync cond-subgraph result.", cond_->GetName().c_str()); + std::vector cond_outputs; + GE_CHK_STATUS_RET(executor->GetOutputs(cond_outputs), "[%s] Failed to get cond-output.", cond_->GetName().c_str()); + if (cond_outputs.empty()) { + GELOGE(INTERNAL_ERROR, "[%s] Cond output is empty.", task_context.GetNodeName()); + return INTERNAL_ERROR; + } + + int cond_val = 0; + GE_CHK_STATUS_RET(CopyTensorValueToHost(cond_outputs[0], cond_val), "[%s] Failed to get cond result.", + task_context.GetNodeName()); + is_continue = cond_val != 0; + return SUCCESS; +} + +Status WhileOpNodeTask::MoveOutputs2Inputs(TaskContext &task_context) { + // set outputs to inputs for next iteration + for (int i = 0; i < task_context.NumInputs(); ++i) { + auto input_tensor = task_context.MutableInput(i); + auto output_tensor = task_context.MutableOutput(i); + GE_CHECK_NOTNULL(input_tensor); + GE_CHECK_NOTNULL(output_tensor); + *input_tensor = *output_tensor; + output_tensor->Destroy(); + + auto output_tensor_desc = task_context.MutableOutputDesc(i); + GE_CHECK_NOTNULL(output_tensor_desc); + GELOGD("[%s] To update input shape[%d] by output shape. from [%s] to [%s]", task_context.GetNodeName(), i, + task_context.MutableInputDesc(i)->GetShape().ToString().c_str(), + output_tensor_desc->GetShape().ToString().c_str()); + *task_context.MutableInputDesc(i) = *output_tensor_desc; + } + + return SUCCESS; +} + +Status WhileOpNodeTask::ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const { + GE_CHK_STATUS_RET(ExecuteCond(task_context, is_continue), "[%s] Failed to execute cond-subgraph", + task_context.GetNodeName()); + if (!is_continue) { + return SUCCESS; + } + + GELOGD("[%s] Start to execute body-subgraph.", task_context.GetNodeName()); + GE_CHK_STATUS_RET(ExecuteSubgraph(body_, task_context, nullptr), "[%s] Failed to execute cond-subgraph", + task_context.GetNodeName()); + GELOGD("[%s] Done executing body-subgraph successfully.", task_context.GetNodeName()); + + // set outputs to inputs for next iteration + GE_CHK_STATUS_RET(MoveOutputs2Inputs(task_context), "[%s] Failed to move outputs to inputs", + task_context.GetNodeName()); + + return SUCCESS; +} + +Status ControlOpNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, + shared_ptr &task) const { + auto node_item = model.GetNodeItem(node); + GE_CHECK_NOTNULL(node_item); + + unique_ptr node_task; + auto node_type = node->GetType(); + if (node_type == IF) { + node_task.reset(new (std::nothrow) IfOpNodeTask()); + } else if (node_type == CASE) { + node_task.reset(new (std::nothrow) CaseOpNodeTask()); + } else if (node_type == WHILE) { + node_task.reset(new (std::nothrow) WhileOpNodeTask()); + } else { + GELOGE(PARAM_INVALID, "[%s] Unsupported type: %s", node->GetName().c_str(), node_type.c_str()); + return PARAM_INVALID; + } + + GE_CHECK_NOTNULL(node_task); + GE_CHK_STATUS_RET(node_task->Init(node, model), "[%s] Failed to init ControlOpNodeTask.", node->GetName().c_str()); + + task = std::move(node_task); + return SUCCESS; +} + +Status ControlOpNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { return SUCCESS; } +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/node_executor/controlop/control_op_executor.h b/src/ge/hybrid/node_executor/controlop/control_op_executor.h new file mode 100644 index 00000000..0619c6a0 --- /dev/null +++ b/src/ge/hybrid/node_executor/controlop/control_op_executor.h @@ -0,0 +1,100 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_CONTROLOP_CONTROL_OP_EXECUTOR_H_ +#define GE_HYBRID_CONTROLOP_CONTROL_OP_EXECUTOR_H_ + +#include +#include "hybrid/node_executor/node_executor.h" +#include "hybrid/model/graph_item.h" + +namespace ge { +namespace hybrid { +class ControlOpNodeTask : public NodeTask { + public: + virtual Status Init(const NodePtr &node, const HybridModel &model) = 0; + Status UpdateArgs(TaskContext &context) override; + + Status ExecuteAsync(TaskContext &task_context, std::function done_callback) override; + + protected: + virtual Status DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const = 0; + static Status CopyTensorValueToHost(const TensorValue &tensor_value, int32_t &value); + static Status ExecuteSubgraph(const GraphItem *subgraph, TaskContext &task_context, + const std::function &done_callback); +}; + +class IfOpNodeTask : public ControlOpNodeTask { + public: + Status Init(const NodePtr &node, const HybridModel &model) override; + + protected: + const GraphItem *SelectBranch(int32_t cond) const; + Status DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const override; + + private: + static constexpr int kIfCondIndex = 0; + static constexpr int kThenBranchIndex = 0; + static constexpr int kElseBranchIndex = 1; + + const GraphItem *then_ = nullptr; + const GraphItem *else_ = nullptr; +}; + +class CaseOpNodeTask : public ControlOpNodeTask { + public: + Status Init(const NodePtr &node, const HybridModel &model) override; + + protected: + const GraphItem *SelectBranch(int32_t branch_index) const; + Status DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const override; + + private: + static constexpr int kCaseBranchIndex = 0; + static constexpr size_t kMaxBranchNum = INT32_MAX; + static constexpr size_t kMinBranchNum = 1; + + std::vector subgraphs_; +}; + +class WhileOpNodeTask : public ControlOpNodeTask { + public: + Status Init(const NodePtr &node, const HybridModel &model) override; + + protected: + Status DoExecuteAsync(TaskContext &task_context, const std::function &done_callback) const override; + Status ExecuteCond(TaskContext &task_context, bool &is_continue) const; + + static Status MoveOutputs2Inputs(TaskContext &task_context); + + Status ExecuteOneLoop(TaskContext &task_context, bool &is_continue) const; + + private: + static constexpr int kCondBranchIndex = 0; + static constexpr int kBodyBranchIndex = 1; + + const GraphItem *cond_ = nullptr; + const GraphItem *body_ = nullptr; +}; + +class ControlOpNodeExecutor : public NodeExecutor { + public: + Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const override; + Status PrepareTask(NodeTask &task, TaskContext &context) const override; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_CONTROLOP_CONTROL_OP_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/hccl/hccl_node_executor.cc b/src/ge/hybrid/node_executor/hccl/hccl_node_executor.cc new file mode 100644 index 00000000..f4fb7530 --- /dev/null +++ b/src/ge/hybrid/node_executor/hccl/hccl_node_executor.cc @@ -0,0 +1,207 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/node_executor/hccl/hccl_node_executor.h" +#include "graph/manager/util/hcom_util.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/fmk_error_codes.h" +#include "common/ge/ge_util.h" +#include "common/ge/plugin_manager.h" +#include "graph/attr_value.h" +#include "graph/debug/ge_attr_define.h" +#include "hccl/hcom.h" + +namespace ge { +namespace hybrid { + +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::HCCL, HcclNodeExecutor); + +Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + GELOGI("[%s] HcclNodeTask::ExecuteAsync in.", context.GetNodeName()); + if (context.handle_ == nullptr) { + GELOGE(FAILED, "hccl handle is nullptr! "); + return FAILED; + } + auto EnqueueHcomOpertion = (hcclResult_t(*)(HcomOpertion, std::function))dlsym( + context.handle_, "EnqueueHcomOpertion"); + if (EnqueueHcomOpertion == nullptr) { + GELOGE(FAILED, "Failed to invoke EnqueueHcomOpertion hcom unknown node function."); + if (dlclose(context.handle_) != 0) { + GELOGW("Failed to close handle %s", dlerror()); + } + return FAILED; + } + + vector inputs; + for (int i = 0; i < context.NumInputs(); ++i) { + TensorValue *tv = context.MutableInput(i); + GE_CHECK_NOTNULL(tv); + inputs.emplace_back(tv->MutableData()); + } + + vector outputs; + for (int i = 0; i < context.NumOutputs(); ++i) { + TensorValue *tv = context.MutableOutput(i); + GE_CHECK_NOTNULL(tv); + outputs.emplace_back(tv->MutableData()); + } + + const NodeItem &node_item = context.GetNodeItem(); + const OpDescPtr op_desc = MakeShared(*(node_item.op_desc)); + GE_CHECK_NOTNULL(op_desc); + + HcomOpertion op_info; + op_info.hcclType = op_desc->GetType(); + op_info.inputPtr = inputs.empty() ? nullptr : inputs[0]; + op_info.outputPtr = outputs.empty() ? nullptr : outputs[0]; + ge::DataType src_data_type = op_desc->GetInputDescPtr(0)->GetDataType(); + auto iter = kConstOpHcclDataType.find(static_cast(src_data_type)); + if (iter == kConstOpHcclDataType.end()) { + GELOGE(PARAM_INVALID, "kConstOpHcclDataType find failed."); + return PARAM_INVALID; + } + op_info.dataType = iter->second; + hcclRedOp_t op_type = HCCL_REP_OP_SUM; + if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HCOMREDUCESCATTER || + op_desc->GetType() == HVDCALLBACKALLREDUCE) { + GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclOperationType(op_desc, op_type), "GetHcclOperationType failed"); + op_info.opType = op_type; + } + int64_t root_id = 0; + if (op_desc->GetType() == HCOMBROADCAST) { + GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclRootId(op_desc, root_id), "GetHcclRootId failed"); + } + op_info.root = root_id; + auto callback = [this](hcclResult_t status) { + if (status != HCCL_SUCCESS) { + GELOGE(HCCL_E_INTERNAL, "Call HcomExcutorInitialize failed, ret: 0x%X", status); + } + std::lock_guard lock(this->hccl_mutex_); + this->cond_.notify_all(); + GELOGI("hccl callback success."); + }; + int32_t count = 0; + GE_CHK_STATUS_RET(HcomOmeUtil::GetHcomCount(op_desc, static_cast(op_info.dataType), false, count), + "GetHcomCount failed"); + GELOGI("[%s] HcclNodeTask::ExecuteAsync hccl_type %s, count %d, data_type %d, op_type %d, root %d.", + context.GetNodeName(), op_info.hcclType.c_str(), count, op_info.dataType, op_info.opType, op_info.root); + op_info.count = count; + + hcclResult_t hccl_ret = EnqueueHcomOpertion(op_info, callback); + if (hccl_ret != HCCL_SUCCESS) { + GELOGE(HCCL_E_INTERNAL, "Call HcomExcutorInitialize failed, ret: 0x%X", hccl_ret); + return HCCL_E_INTERNAL; + } + + // pending until hccl finished + std::unique_lock ulock(hccl_mutex_); + cond_.wait(ulock); + + context.RegisterCallback(done_callback); + GELOGI("[%s] HcclNodeTask::ExecuteAsync success.", context.GetNodeName()); + return SUCCESS; +} + +Status HcclNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; } + +Status HcclNodeTask::Init(TaskContext &context) { + GELOGI("[%s] HcclNodeExecutor::Init success.", context.GetNodeName()); + return SUCCESS; +} + +Status HcclNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { + GELOGI("[%s] HcclNodeExecutor::PrepareTask in.", context.GetNodeName()); + + GE_CHK_STATUS_RET(task.Init(context), "hccl node load hccl so failed."); + // allocate output mem + GE_CHK_STATUS_RET(context.AllocateOutputs(), "hccl node task allocate output failed."); + + GE_CHK_STATUS_RET(task.UpdateArgs(context), "hccl node task update args failed."); + GELOGI("[%s] HcclNodeExecutor::PrepareTask success.", context.GetNodeName()); + return SUCCESS; +} + +Status HcclNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const { + GELOGI("[%s] HcclNodeExecutor::LoadTask in.", node->GetName().c_str()); + GE_CHECK_NOTNULL(node); + + task = MakeShared(); + GE_CHECK_NOTNULL(task); + GELOGI("[%s] HcclNodeExecutor::LoadTask success.", node->GetName().c_str()); + return SUCCESS; +} + +Status HcclNodeExecutor::ExecuteTask(NodeTask &task, TaskContext &context, + const std::function &callback) const { + context.handle_ = handle_; + GE_CHK_STATUS_RET(task.ExecuteAsync(context, callback), "Failed to execute task. node = %s", + context.GetNodeItem().NodeName().c_str()); + return SUCCESS; +} + +Status HcclNodeExecutor::Initialize() { + std::string file_name = "libhcom_graph_adaptor.so"; + std::string path = PluginManager::GetPath(); + path.append(file_name); + string canonical_path = RealPath(path.c_str()); + if (canonical_path.empty()) { + GELOGW("failed to get realpath of %s", path.c_str()); + return FAILED; + } + + GELOGI("FileName:%s, Path:%s.", file_name.c_str(), canonical_path.c_str()); + handle_ = dlopen(canonical_path.c_str(), RTLD_NOW | RTLD_GLOBAL); + if (handle_ == nullptr) { + GELOGE(GE_PLGMGR_SO_NOT_EXIST, "Failed in dlopen %s! ", dlerror()); + return FAILED; + } + auto HcomExcutorInitialize = (hcclResult_t(*)())dlsym(handle_, "HcomExcutorInitialize"); + if (HcomExcutorInitialize == nullptr) { + GELOGE(FAILED, "Failed to invoke HcomExcutorInitialize hcom unknown node function."); + return FAILED; + } + hcclResult_t hccl_ret = HcomExcutorInitialize(); + if (hccl_ret == HCCL_E_PTR) { + GELOGI("Hccl comm is null, hcom executor initialize is not required."); + } else if (hccl_ret == HCCL_SUCCESS) { + GELOGI("Hcom executor initialize success."); + } else { + GELOGE(FAILED, "Call HcomExcutorInitialize failed, ret: 0x%X", hccl_ret); + return FAILED; + } + return SUCCESS; +} + +Status HcclNodeExecutor::Finalize() { + auto HcomExcutorFinalize = (hcclResult_t(*)())dlsym(handle_, "HcomExcutorFinalize"); + if (HcomExcutorFinalize == nullptr) { + GELOGE(FAILED, "Failed to invoke HcomExcutorFinalize hcom unknown node function."); + return FAILED; + } + hcclResult_t hccl_ret = HcomExcutorFinalize(); + if (hccl_ret != HCCL_SUCCESS) { + GELOGE(FAILED, "Call HcomExcutorFinalize failed, ret: 0x%X", hccl_ret); + return FAILED; + } + // dlclose file handle + if (dlclose(handle_) != 0) { + GELOGW("Failed to close handle %s", dlerror()); + } + GELOGI("Hcom executor finalize success."); + return SUCCESS; +} +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/node_executor/hccl/hccl_node_executor.h b/src/ge/hybrid/node_executor/hccl/hccl_node_executor.h new file mode 100644 index 00000000..8791c4e3 --- /dev/null +++ b/src/ge/hybrid/node_executor/hccl/hccl_node_executor.h @@ -0,0 +1,59 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef HYBRID_HCCL_NODE_EXECUTOR_H_ +#define HYBRID_HCCL_NODE_EXECUTOR_H_ +#include "hybrid/node_executor/node_executor.h" +#include "hybrid/model/hybrid_model.h" +#include "graph/op_desc.h" + +namespace ge { +namespace hybrid { +class HybridModel; + +class HcclNodeTask : public NodeTask { + public: + HcclNodeTask() {} + + ~HcclNodeTask() {} + + Status UpdateArgs(TaskContext &context) override; + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + Status Init(TaskContext &context) override; + + private: + std::shared_ptr davinci_model_ = nullptr; + bool load_flag_ = false; + std::mutex hccl_mutex_; + std::condition_variable cond_; +}; + +class HcclNodeExecutor : public NodeExecutor { + public: + Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const; + Status PrepareTask(NodeTask &task, TaskContext &context) const; + Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function &callback) const; + Status Initialize() override; + Status Finalize() override; + ~HcclNodeExecutor() {} + + private: + void *handle_; +}; +} // namespace hybrid +} // namespace ge + +#endif // HYBRID_HCCL_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/hostaicpu/host_aicpu_node_executor.cc b/src/ge/hybrid/node_executor/hostaicpu/host_aicpu_node_executor.cc new file mode 100644 index 00000000..4798b87e --- /dev/null +++ b/src/ge/hybrid/node_executor/hostaicpu/host_aicpu_node_executor.cc @@ -0,0 +1,191 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/node_executor/hostaicpu/host_aicpu_node_executor.h" +#include "hybrid/node_executor/hostaicpu/kernel_factory.h" +#include "graph/passes/folding_pass.h" +#include "hybrid/model/hybrid_model.h" +#include "inc/kernel_factory.h" +#include "ge_local_engine/engine/host_cpu_engine.h" + +namespace ge { +namespace hybrid { +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::HOST_AICPU, HostAiCpuNodeExecutor); + +Status HostCpuNodeTaskBase::UpdateArgs(TaskContext &) { + // no need update args + return SUCCESS; +} + +Status HostCpuNodeTaskBase::ExecuteAsync(TaskContext &context, std::function done_callback) { + GELOGD("[%s] Start execute.", context.GetNodeName()); + + std::vector inputs; + std::vector outputs; + GE_CHK_STATUS_RET(ProcessInputs(context, inputs), "node:%s type:%s, process inputs failed.", node_->GetName().c_str(), + node_->GetType().c_str()); + GE_CHK_STATUS_RET(Execute(context, inputs, outputs), "node:%s type:%s, task execute failed.", + node_->GetName().c_str(), node_->GetType().c_str()); + GE_CHK_STATUS_RET(ProcessOutputs(context, outputs), "node:%s type:%s, process outputs failed.", + node_->GetName().c_str(), node_->GetType().c_str()); + + if (done_callback) { + GELOGD("[%s] Start invoke callback.", context.GetNodeName()); + done_callback(); + } + GELOGD("[%s] Done execute successfully.", context.GetNodeName()); + return SUCCESS; +} + +Status HostCpuNodeTaskBase::ProcessInputs(TaskContext &context, std::vector &inputs) { + int32_t input_num = context.NumInputs(); + for (auto i = 0; i < input_num; ++i) { + auto tensor_value = context.GetInput(i); + GE_CHECK_NOTNULL(tensor_value); + GeTensorPtr input_ptr = + MakeShared(node_->GetOpDesc()->GetInputDesc(i), + reinterpret_cast(tensor_value->GetData()), tensor_value->GetSize()); + if (input_ptr == nullptr) { + GELOGE(MEMALLOC_FAILED, "Make shared failed"); + return MEMALLOC_FAILED; + } + inputs.push_back(input_ptr); + } + return SUCCESS; +} + +Status HostCpuNodeTaskBase::ProcessOutputs(TaskContext &context, std::vector &outputs) { + int32_t output_num = context.NumOutputs(); + if (static_cast(output_num) != outputs.size()) { + GELOGE(INTERNAL_ERROR, "node %s type %s has %d output, but kernel compute only has %zu output.", + node_->GetName().c_str(), node_->GetType().c_str(), output_num, outputs.size()); + return INTERNAL_ERROR; + } + + // alloc output + GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); + + // copy data to output + for (auto i = 0; i < output_num; ++i) { + GeTensorPtr &tensor = outputs[i]; + GE_CHECK_NOTNULL(tensor); + auto tensor_data = tensor->GetData(); + auto tensor_value = context.MutableOutput(i); + GE_CHECK_NOTNULL(tensor_value); + if (tensor_data.GetSize() > tensor_value->GetSize()) { + GELOGE(INTERNAL_ERROR, "node:%s type:%s [%d]th compute data size=%zu, but context data size=%zu.", + node_->GetName().c_str(), node_->GetType().c_str(), i, tensor_data.GetSize(), tensor_value->GetSize()); + return INTERNAL_ERROR; + } + + GELOGI("node:%s type:%s [%d]th output data=%p, out size=%zu, data size=%zu.", node_->GetName().c_str(), + node_->GetType().c_str(), i, tensor_value->GetData(), tensor_value->GetSize(), tensor_data.GetSize()); + if (tensor_data.GetSize() > 0) { + GE_CHK_RT_RET(rtMemcpy(tensor_value->MutableData(), tensor_value->GetSize(), tensor_data.GetData(), + tensor_data.GetSize(), RT_MEMCPY_HOST_TO_HOST)); + } + GELOGI("node:%s type:%s [%d]th set data success, data size=%zu.", node_->GetName().c_str(), + node_->GetType().c_str(), i, tensor_data.GetSize()); + } + + return SUCCESS; +} + +Status CpuKernelNodeTask::Execute(TaskContext &context, const std::vector &inputs, + std::vector &outputs) { + std::vector const_inputs; + for (const auto &input : inputs) { + const_inputs.emplace_back(input); + } + return FoldingPass::RunOpKernel(node_, const_inputs, outputs); +} + +Status HostKernelNodeTask::Execute(TaskContext &context, const std::vector &inputs, + std::vector &outputs) { + auto kernel = KernelFactory::Instance().Create(node_->GetType()); + if (kernel == nullptr) { + GELOGE(UNSUPPORTED, "node %s type %s is not supported by host kernel.", node_->GetName().c_str(), + node_->GetType().c_str()); + return UNSUPPORTED; + } + + std::vector const_inputs; + for (const auto &input : inputs) { + const_inputs.emplace_back(input); + } + Status compute_ret = kernel->Compute(node_->GetOpDesc(), const_inputs, outputs); + if (compute_ret != SUCCESS) { + GELOGE(compute_ret, "node %s type %s compute failed or not imply.", node_->GetName().c_str(), + node_->GetType().c_str()); + return compute_ret; + } + + return SUCCESS; +} + +Status HostAiCpuNodeTask::ProcessInputs(TaskContext &context, std::vector &inputs) { return SUCCESS; } + +Status HostAiCpuNodeTask::ProcessOutputs(TaskContext &context, std::vector &outputs) { return SUCCESS; } + +Status HostAiCpuNodeTask::Execute(TaskContext &context, const std::vector &inputs, + std::vector &outputs) { + RunContext run_context; + auto host_kernel = hybrid::host_aicpu::KernelFactory::Instance().CreateKernel(node_); + if (host_kernel == nullptr) { + GELOGE(UNSUPPORTED, "node %s type %s is not supported by host kernel.", node_->GetName().c_str(), + node_->GetType().c_str()); + return UNSUPPORTED; + } + + Status compute_ret = host_kernel->Compute(context); + if (compute_ret != SUCCESS) { + GELOGE(compute_ret, "node %s type %s compute failed or not imply.", node_->GetName().c_str(), + node_->GetType().c_str()); + return compute_ret; + } + + return SUCCESS; +} + +Status HostAiCpuNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { + return task.UpdateArgs(context); +} + +Status HostAiCpuNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, + std::shared_ptr &task) const { + GE_CHECK_NOTNULL(node); + const std::string &name = node->GetName(); + const std::string &type = node->GetType(); + if (HostCpuEngine::GetInstance().CheckSupported(type)) { + GELOGI("create CpuKernelNodeTask for node %s, type %s.", name.c_str(), type.c_str()); + task = MakeShared(node); + GE_CHECK_NOTNULL(task); + } else if (KernelFactory::Instance().Create(type) != nullptr) { + GELOGI("create HostKernelNodeTask for node %s, type %s.", name.c_str(), type.c_str()); + task = MakeShared(node); + GE_CHECK_NOTNULL(task); + } else if (hybrid::host_aicpu::KernelFactory::Instance().CreateKernel(node) != nullptr) { + GELOGI("create HostAiCpuNodeTask for node %s, type %s.", name.c_str(), type.c_str()); + task = MakeShared(node); + GE_CHECK_NOTNULL(task); + } else { + GELOGE(UNSUPPORTED, "node %s type %s is not support in HostAiCpuNodeExecutor now.", name.c_str(), type.c_str()); + return UNSUPPORTED; + } + return SUCCESS; +} +} // namespace hybrid +} // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/node_executor/hostaicpu/host_aicpu_node_executor.h b/src/ge/hybrid/node_executor/hostaicpu/host_aicpu_node_executor.h new file mode 100644 index 00000000..406d1597 --- /dev/null +++ b/src/ge/hybrid/node_executor/hostaicpu/host_aicpu_node_executor.h @@ -0,0 +1,83 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_KERNEL_HOST_AICPU_NODE_EXECUTOR_H_ +#define GE_HYBRID_KERNEL_HOST_AICPU_NODE_EXECUTOR_H_ + +#include "inc/kernel.h" +#include "hybrid/node_executor/node_executor.h" + +namespace ge { +namespace hybrid { +class HostCpuNodeTaskBase : public NodeTask { + public: + explicit HostCpuNodeTaskBase(const NodePtr &node) : node_(node) {} + ~HostCpuNodeTaskBase() = default; + virtual Status UpdateArgs(TaskContext &context); + virtual Status ExecuteAsync(TaskContext &context, std::function done_callback); + + protected: + NodePtr node_; + + private: + virtual Status Execute(TaskContext &context, const std::vector &inputs, + std::vector &outputs) = 0; + virtual Status ProcessInputs(TaskContext &context, std::vector &inputs); + virtual Status ProcessOutputs(TaskContext &context, std::vector &outputs); +}; + +class CpuKernelNodeTask : public HostCpuNodeTaskBase { + public: + explicit CpuKernelNodeTask(const NodePtr &node) : HostCpuNodeTaskBase(node) {} + ~CpuKernelNodeTask() = default; + + private: + Status Execute(TaskContext &context, const std::vector &inputs, + std::vector &outputs) override; +}; + +class HostKernelNodeTask : public HostCpuNodeTaskBase { + public: + explicit HostKernelNodeTask(const NodePtr &node) : HostCpuNodeTaskBase(node) {} + ~HostKernelNodeTask() = default; + + private: + Status Execute(TaskContext &context, const std::vector &inputs, + std::vector &outputs) override; +}; + +class HostAiCpuNodeTask : public HostCpuNodeTaskBase { + public: + explicit HostAiCpuNodeTask(const NodePtr &node) : HostCpuNodeTaskBase(node) {} + ~HostAiCpuNodeTask() = default; + + private: + Status Execute(TaskContext &context, const std::vector &inputs, + std::vector &outputs) override; + Status ProcessInputs(TaskContext &context, std::vector &inputs) override; + Status ProcessOutputs(TaskContext &context, std::vector &outputs) override; +}; + +class HostAiCpuNodeExecutor : public NodeExecutor { + public: + Status PrepareTask(NodeTask &task, TaskContext &context) const override; + + virtual Status LoadTask(const HybridModel &model, const NodePtr &node, + std::shared_ptr &task) const override; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_KERNEL_HOST_AICPU_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/assign_kernel.cc b/src/ge/hybrid/node_executor/hostaicpu/kernel/assign_kernel.cc new file mode 100644 index 00000000..02ce40e2 --- /dev/null +++ b/src/ge/hybrid/node_executor/hostaicpu/kernel/assign_kernel.cc @@ -0,0 +1,61 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/node_executor/hostaicpu/kernel/assign_kernel.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "hybrid/node_executor/hostaicpu/kernel_factory.h" + +namespace { +const size_t kAssignInputNum = 2; +const size_t kAssignRefInputIndex = 0; +const size_t kAssignValueInputIndex = 1; +const size_t kAssignRefOutputIndex = 0; +} // namespace + +namespace ge { +namespace hybrid { +namespace host_aicpu { +Status AssignKernel::Compute(TaskContext& context) { + GELOGI("AssignKernel [%s, %s] compute begin.", node_->GetName().c_str(), node_->GetType().c_str()); + + auto ref_tensor = context.MutableInput(kAssignRefInputIndex); + GE_CHECK_NOTNULL(ref_tensor); + const auto value_tensor = context.GetInput(kAssignValueInputIndex); + GE_CHECK_NOTNULL(value_tensor); + if (value_tensor->GetSize() > ref_tensor->GetSize()) { + GELOGE(INTERNAL_ERROR, "[%s] value_input_size=%zu, but ref_input_size=%zu.", node_->GetName().c_str(), + value_tensor->GetSize(), ref_tensor->GetSize()); + return INTERNAL_ERROR; + } + + GELOGI("[%s] value_input_data=%p, ref_input_size=%zu, value_input_size=%zu.", node_->GetName().c_str(), + ref_tensor->GetSize(), ref_tensor->GetData(), value_tensor->GetSize()); + if (value_tensor->GetSize() > 0) { + GE_CHK_RT_RET(rtMemcpy(ref_tensor->MutableData(), ref_tensor->GetSize(), value_tensor->GetData(), + value_tensor->GetSize(), RT_MEMCPY_HOST_TO_HOST)); + } + GE_CHK_STATUS_RET(context.SetOutput(kAssignRefOutputIndex, *ref_tensor), "[%s] Failed to set output.", + context.GetNodeName()); + + GELOGI("AssignKernel [%s, %s] compute success.", node_->GetName().c_str(), node_->GetType().c_str()); + return SUCCESS; +} + +REGISTER_KERNEL_CREATOR(Assign, AssignKernel); +} // namespace host_aicpu +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/assign_kernel.h b/src/ge/hybrid/node_executor/hostaicpu/kernel/assign_kernel.h new file mode 100644 index 00000000..6af30926 --- /dev/null +++ b/src/ge/hybrid/node_executor/hostaicpu/kernel/assign_kernel.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_HOST_AICPU_KERNEL_ASSIGN_KERNEL_H_ +#define GE_HYBRID_HOST_AICPU_KERNEL_ASSIGN_KERNEL_H_ + +#include "hybrid/node_executor/hostaicpu/kernel/kernel.h" + +namespace ge { +namespace hybrid { +namespace host_aicpu { +class AssignKernel : public Kernel { + public: + AssignKernel(const NodePtr &node) : Kernel(node) {} + ~AssignKernel() override = default; + AssignKernel &operator=(const AssignKernel &op) = delete; + AssignKernel(const AssignKernel &op) = delete; + + /** + * @brief compute for node_task. + * @return result + */ + Status Compute(TaskContext &context) override; +}; +} // namespace host_aicpu +} // namespace hybrid +} // namespace ge + +#endif // GE_HYBRID_HOST_AICPU_KERNEL_ASSIGN_KERNEL_H_ diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/kernel.h b/src/ge/hybrid/node_executor/hostaicpu/kernel/kernel.h new file mode 100644 index 00000000..0e22f62a --- /dev/null +++ b/src/ge/hybrid/node_executor/hostaicpu/kernel/kernel.h @@ -0,0 +1,43 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_HOST_AICPU_KERNEL_KERNEL_H_ +#define GE_HYBRID_HOST_AICPU_KERNEL_KERNEL_H_ + +#include "common/ge_inner_error_codes.h" +#include "graph/node.h" +#include "hybrid/node_executor/task_context.h" + +namespace ge { +namespace hybrid { +namespace host_aicpu { +/** + * The base class for all host_kernel. + */ +class Kernel { + public: + Kernel(const NodePtr &node) : node_(node) {} + virtual ~Kernel() = default; + virtual Status Compute(TaskContext &context) = 0; + + protected: + const NodePtr &node_; +}; +} // namespace host_aicpu +} // namespace hybrid +} // namespace ge + +#endif // GE_HYBRID_HOST_AICPU_KERNEL_KERNEL_H_ diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/no_op_kernel.cc b/src/ge/hybrid/node_executor/hostaicpu/kernel/no_op_kernel.cc new file mode 100644 index 00000000..433f8d2f --- /dev/null +++ b/src/ge/hybrid/node_executor/hostaicpu/kernel/no_op_kernel.cc @@ -0,0 +1,33 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/node_executor/hostaicpu/kernel/no_op_kernel.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "hybrid/node_executor/hostaicpu/kernel_factory.h" + +namespace ge { +namespace hybrid { +namespace host_aicpu { +Status NoOpKernel::Compute(TaskContext& context) { + GELOGI("NoOpKernel [%s, %s] no need to compute.", node_->GetName().c_str(), node_->GetType().c_str()); + return SUCCESS; +} + +REGISTER_KERNEL_CREATOR(NoOp, NoOpKernel); +} // namespace host_aicpu +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/no_op_kernel.h b/src/ge/hybrid/node_executor/hostaicpu/kernel/no_op_kernel.h new file mode 100644 index 00000000..3c05c754 --- /dev/null +++ b/src/ge/hybrid/node_executor/hostaicpu/kernel/no_op_kernel.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_HOST_AICPU_KERNEL_NO_OP_KERNEL_H_ +#define GE_HYBRID_HOST_AICPU_KERNEL_NO_OP_KERNEL_H_ + +#include "hybrid/node_executor/hostaicpu/kernel/kernel.h" + +namespace ge { +namespace hybrid { +namespace host_aicpu { +class NoOpKernel : public Kernel { + public: + NoOpKernel(const NodePtr &node) : Kernel(node) {} + ~NoOpKernel() override = default; + NoOpKernel &operator=(const NoOpKernel &op) = delete; + NoOpKernel(const NoOpKernel &op) = delete; + + /** + * @brief compute for node_task. + * @return result + */ + Status Compute(TaskContext &context) override; +}; +} // namespace host_aicpu +} // namespace hybrid +} // namespace ge + +#endif // GE_HYBRID_HOST_AICPU_KERNEL_NO_OP_KERNEL_H_ diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/random_uniform_kernel.cc b/src/ge/hybrid/node_executor/hostaicpu/kernel/random_uniform_kernel.cc new file mode 100644 index 00000000..dfd8f1fe --- /dev/null +++ b/src/ge/hybrid/node_executor/hostaicpu/kernel/random_uniform_kernel.cc @@ -0,0 +1,145 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/node_executor/hostaicpu/kernel/random_uniform_kernel.h" +#include +#include "common/fp16_t.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/utils/type_utils.h" +#include "hybrid/node_executor/hostaicpu/kernel_factory.h" + +namespace ge { +namespace hybrid { +namespace host_aicpu { +Status RandomUniformKernel::Compute(TaskContext& context) { + GELOGI("RandomUniformKernel [%s, %s] compute begin.", node_->GetName().c_str(), node_->GetType().c_str()); + int64_t seed = 0; + int64_t seed2 = 0; + (void)AttrUtils::GetInt(node_->GetOpDesc(), "seed", seed); + (void)AttrUtils::GetInt(node_->GetOpDesc(), "seed2", seed2); + DataType data_type = DT_UNDEFINED; + if (AttrUtils::GetDataType(node_->GetOpDesc(), VAR_ATTR_DTYPE, data_type) != GRAPH_SUCCESS) { + GELOGE(PARAM_INVALID, "get attr VAR_ATTR_DTYPE failed"); + return PARAM_INVALID; + } + + switch (data_type) { + case DT_FLOAT16: + if (GenerateFP16(node_->GetOpDesc(), seed, seed2, context) != SUCCESS) { + GELOGE(FAILED, "Generate random_distribution for RandomUniformOp failed, data_type=DT_FLOAT"); + return FAILED; + } + break; + case DT_FLOAT: + if (Generate(node_->GetOpDesc(), seed, seed2, context) != SUCCESS) { + GELOGE(FAILED, "Generate random_distribution for RandomUniformOp failed, data_type=DT_FLOAT"); + return FAILED; + } + break; + case DT_DOUBLE: + if (Generate(node_->GetOpDesc(), seed, seed2, context) != SUCCESS) { + GELOGE(FAILED, "Generate random_distribution for RandomUniformOp failed, data_type=DT_DOUBLE"); + return FAILED; + } + break; + default: + GELOGE(UNSUPPORTED, "Supported DataType for RandomUniformOp is DT_FLOAT16 / DT_FLOAT / DT_DOUBLE, but dtype=%s", + TypeUtils::DataTypeToSerialString(data_type).c_str()); + return UNSUPPORTED; + } + + GELOGI("RandomUniformKernel [%s, %s] compute success.", node_->GetName().c_str(), node_->GetType().c_str()); + return SUCCESS; +} + +template +Status RandomUniformKernel::Generate(const ge::OpDescPtr& op_desc_ptr, int64_t seed, int64_t seed2, + TaskContext& context) { + GE_CHECK_NOTNULL(op_desc_ptr); + // RandomUniformOp has and only has one output + int64_t data_num = op_desc_ptr->GetOutputDesc(0).GetShape().GetShapeSize(); + std::unique_ptr buf(new (std::nothrow) T[data_num]()); + if (buf == nullptr) { + GELOGE(MEMALLOC_FAILED, "New sizeof(T) * data_num(%zu) memory failed", static_cast(sizeof(T) * data_num)); + return MEMALLOC_FAILED; + } + + int64_t final_seed; + if (seed == 0) { + if (seed2 == 0) { + std::random_device rd; + final_seed = rd(); + } else { + final_seed = seed2; + } + } else { + final_seed = seed; + } + std::mt19937_64 gen(final_seed); + std::uniform_real_distribution distribution(0, 1); + for (int64_t i = 0; i < data_num; i++) { + *(buf.get() + i) = distribution(gen); + } + + std::shared_ptr output = MakeShared(buf.get(), data_num * sizeof(T)); + GE_CHECK_NOTNULL(output); + GE_CHK_STATUS_RET(context.SetOutput(0, *output), "[%s] Failed to set output.", context.GetNodeName()); + + return SUCCESS; +} + +Status RandomUniformKernel::GenerateFP16(const ge::OpDescPtr& op_desc_ptr, int64_t seed, int64_t seed2, + TaskContext& context) { + GE_CHECK_NOTNULL(op_desc_ptr); + // RandomUniformOp has and only has one output + int64_t data_num = op_desc_ptr->GetOutputDesc(0).GetShape().GetShapeSize(); + std::unique_ptr buf(new (std::nothrow) fp16_t[data_num]()); + if (buf == nullptr) { + GELOGE(MEMALLOC_FAILED, "New sizeof(fp16_t) * data_num(%zu) memory failed", + static_cast(sizeof(fp16_t) * data_num)); + return MEMALLOC_FAILED; + } + + int64_t final_seed; + if (seed == 0) { + if (seed2 == 0) { + std::random_device rd; + final_seed = rd(); + } else { + final_seed = seed2; + } + } else { + final_seed = seed; + } + std::mt19937_64 gen(final_seed); + std::uniform_real_distribution distribution(0, 1); + for (int64_t i = 0; i < data_num; i++) { + *(buf.get() + i) = static_cast(distribution(gen)); + } + + std::shared_ptr output = MakeShared(buf.get(), data_num * sizeof(fp16_t)); + GE_CHECK_NOTNULL(output); + GE_CHK_STATUS_RET(context.SetOutput(0, *output), "[%s] Failed to set output.", context.GetNodeName()); + + return SUCCESS; +} + +REGISTER_KERNEL_CREATOR(RandomUniform, RandomUniformKernel); +} // namespace host_aicpu +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/random_uniform_kernel.h b/src/ge/hybrid/node_executor/hostaicpu/kernel/random_uniform_kernel.h new file mode 100644 index 00000000..343c6d08 --- /dev/null +++ b/src/ge/hybrid/node_executor/hostaicpu/kernel/random_uniform_kernel.h @@ -0,0 +1,48 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_HOST_AICPU_KERNEL_RANDOM_UNIFORM_KERNEL_H_ +#define GE_HYBRID_HOST_AICPU_KERNEL_RANDOM_UNIFORM_KERNEL_H_ + +#include "hybrid/node_executor/hostaicpu/kernel/kernel.h" + +namespace ge { +namespace hybrid { +namespace host_aicpu { +class RandomUniformKernel : public Kernel { + public: + RandomUniformKernel(const NodePtr &node) : Kernel(node) {} + ~RandomUniformKernel() override = default; + RandomUniformKernel &operator=(const RandomUniformKernel &op) = delete; + RandomUniformKernel(const RandomUniformKernel &op) = delete; + + /** + * @brief compute for node_task. + * @return result + */ + Status Compute(TaskContext &context) override; + + private: + template + Status Generate(const ge::OpDescPtr &op_desc_ptr, int64_t seed, int64_t seed2, TaskContext &context); + + static Status GenerateFP16(const ge::OpDescPtr &op_desc_ptr, int64_t seed, int64_t seed2, TaskContext &context); +}; +} // namespace host_aicpu +} // namespace hybrid +} // namespace ge + +#endif // GE_HYBRID_HOST_AICPU_KERNEL_RANDOM_UNIFORM_KERNEL_H_ diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/variable_kernel.cc b/src/ge/hybrid/node_executor/hostaicpu/kernel/variable_kernel.cc new file mode 100644 index 00000000..a8259500 --- /dev/null +++ b/src/ge/hybrid/node_executor/hostaicpu/kernel/variable_kernel.cc @@ -0,0 +1,43 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/node_executor/hostaicpu/kernel/variable_kernel.h" +#include "framework/common/debug/ge_log.h" +#include "framework/common/util.h" +#include "hybrid/node_executor/hostaicpu/kernel_factory.h" + +namespace ge { +namespace hybrid { +namespace host_aicpu { +Status VariableKernel::Compute(TaskContext& context) { + GELOGI("VariableKernel [%s, %s] compute begin.", node_->GetName().c_str(), node_->GetType().c_str()); + + auto tensor = context.GetVariable(node_->GetName()); + if (tensor == nullptr) { + GELOGE(PARAM_INVALID, "tensor is NULL."); + return PARAM_INVALID; + } + // Constant & Variable Op has and only has one output + GE_CHK_STATUS_RET(context.SetOutput(0, *tensor), "[%s] Failed to set output.", context.GetNodeName()); + GELOGI("VariableKernel [%s, %s] compute success.", node_->GetName().c_str(), node_->GetType().c_str()); + return SUCCESS; +} + +REGISTER_KERNEL_CREATOR(Variable, VariableKernel); +REGISTER_KERNEL_CREATOR(Constant, VariableKernel); +} // namespace host_aicpu +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel/variable_kernel.h b/src/ge/hybrid/node_executor/hostaicpu/kernel/variable_kernel.h new file mode 100644 index 00000000..cb0a6834 --- /dev/null +++ b/src/ge/hybrid/node_executor/hostaicpu/kernel/variable_kernel.h @@ -0,0 +1,42 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_HOST_AICPU_KERNEL_VARIABLE_KERNEL_H_ +#define GE_HYBRID_HOST_AICPU_KERNEL_VARIABLE_KERNEL_H_ + +#include "hybrid/node_executor/hostaicpu/kernel/kernel.h" + +namespace ge { +namespace hybrid { +namespace host_aicpu { +class VariableKernel : public Kernel { + public: + VariableKernel(const NodePtr &node) : Kernel(node) {} + ~VariableKernel() override = default; + VariableKernel &operator=(const VariableKernel &op) = delete; + VariableKernel(const VariableKernel &op) = delete; + + /** + * @brief compute for node_task. + * @return result + */ + Status Compute(TaskContext &context) override; +}; +} // namespace host_aicpu +} // namespace hybrid +} // namespace ge + +#endif // GE_HYBRID_HOST_AICPU_KERNEL_VARIABLE_KERNEL_H_ diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel_factory.cc b/src/ge/hybrid/node_executor/hostaicpu/kernel_factory.cc new file mode 100644 index 00000000..ca398500 --- /dev/null +++ b/src/ge/hybrid/node_executor/hostaicpu/kernel_factory.cc @@ -0,0 +1,55 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "hybrid/node_executor/hostaicpu/kernel_factory.h" +#include "framework/common/debug/ge_log.h" + +namespace ge { +namespace hybrid { +namespace host_aicpu { +KernelFactory &KernelFactory::Instance() { + static KernelFactory instance; + return instance; +} + +std::shared_ptr KernelFactory::CreateKernel(const NodePtr &node) { + if (node == nullptr) { + GELOGW("node is NULL."); + return nullptr; + } + auto iter = kernel_creator_map_.find(node->GetType()); + if (iter != kernel_creator_map_.end()) { + return iter->second(node); + } + GELOGE(FAILED, "Not supported, type = %s, name = %s", node->GetType().c_str(), node->GetName().c_str()); + return nullptr; +} + +void KernelFactory::RegisterCreator(const std::string &type, const KERNEL_CREATOR_FUNC &func) { + if (func == nullptr) { + GELOGW("Func is NULL."); + return; + } + auto iter = kernel_creator_map_.find(type); + if (iter != kernel_creator_map_.end()) { + GELOGW("%s creator already exist", type.c_str()); + return; + } + kernel_creator_map_[type] = func; +} +} // namespace host_aicpu +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/node_executor/hostaicpu/kernel_factory.h b/src/ge/hybrid/node_executor/hostaicpu/kernel_factory.h new file mode 100644 index 00000000..9ead2005 --- /dev/null +++ b/src/ge/hybrid/node_executor/hostaicpu/kernel_factory.h @@ -0,0 +1,86 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_NODE_EXECUTOR_HOST_AICPU_KERNEL_FACTORY_H_ +#define GE_HYBRID_NODE_EXECUTOR_HOST_AICPU_KERNEL_FACTORY_H_ + +#include +#include +#include +#include "common/ge/ge_util.h" +#include "hybrid/node_executor/hostaicpu/kernel/kernel.h" + +namespace ge { +namespace hybrid { +namespace host_aicpu { +using KERNEL_CREATOR_FUNC = std::function(const NodePtr &)>; + +/** + * manage all the host_aicpu_kernel, support create kernel. + */ +class KernelFactory { + public: + static KernelFactory &Instance(); + + /** + * @brief create Kernel. + * @param [in] node + * @return not nullptr success + * @return nullptr fail + */ + std::shared_ptr CreateKernel(const NodePtr &node); + + /** + * @brief Register Kernel create function. + * @param [in] type: Kernel type + * @param [in] func: Kernel create func + */ + void RegisterCreator(const std::string &type, const KERNEL_CREATOR_FUNC &func); + + KernelFactory(const KernelFactory &) = delete; + KernelFactory &operator=(const KernelFactory &) = delete; + KernelFactory(KernelFactory &&) = delete; + KernelFactory &operator=(KernelFactory &&) = delete; + + private: + KernelFactory() = default; + ~KernelFactory() = default; + + // the kernel creator function map + std::map kernel_creator_map_; +}; + +class KernelRegistrar { + public: + KernelRegistrar(const std::string &type, const KERNEL_CREATOR_FUNC &func) { + KernelFactory::Instance().RegisterCreator(type, func); + } + ~KernelRegistrar() = default; + + KernelRegistrar(const KernelRegistrar &) = delete; + KernelRegistrar &operator=(const KernelRegistrar &) = delete; + KernelRegistrar(KernelRegistrar &&) = delete; + KernelRegistrar &operator=(KernelRegistrar &&) = delete; +}; + +#define REGISTER_KERNEL_CREATOR(type, clazz) \ + std::shared_ptr Creator_##type##Kernel(const NodePtr &node) { return MakeShared(node); } \ + KernelRegistrar g_##type##Kernel_creator(#type, Creator_##type##Kernel) +} // namespace host_aicpu +} // namespace hybrid +} // namespace ge + +#endif // GE_HYBRID_NODE_EXECUTOR_HOST_AICPU_KERNEL_FACTORY_H_ diff --git a/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.cc b/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.cc index c3bc9a41..7cd10a83 100644 --- a/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.cc +++ b/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.cc @@ -17,14 +17,12 @@ #include "hybrid/node_executor/hostcpu/ge_local_node_executor.h" #include "graph/debug/ge_attr_define.h" #include "framework/common/util.h" -#include "framework/common/types.h" +#include "hybrid/model/hybrid_model.h" #include "inc/kernel.h" #include "inc/kernel_factory.h" -#include "common/ge/ge_util.h" namespace ge { namespace hybrid { - REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::GE_LOCAL, GeLocalNodeExecutor); const std::unordered_map> RefInputTask::out_ref_input_index_ = { @@ -64,7 +62,7 @@ Status RefInputTask::RefOneByOne(TaskContext &context) { for (uint32_t out_index = 0; out_index < output_num; ++out_index) { auto input = context.GetInput(out_index); GE_CHECK_NOTNULL(input); - context.SetOutput(out_index, *input); + GE_CHK_STATUS_RET(context.SetOutput(out_index, *input)); GELOGD("node %s type %s output[%u] ref input[%u] addr=%p.", node_name_.c_str(), node_type_.c_str(), out_index, out_index, input->GetData()); } @@ -84,7 +82,7 @@ Status RefInputTask::RefByOrder(const std::vector &ref_order, TaskCont auto ref_input_index = ref_order[out_index]; auto input = context.GetInput(ref_input_index); GE_CHECK_NOTNULL(input); - context.SetOutput(out_index, *input); + GE_CHK_STATUS_RET(context.SetOutput(out_index, *input)); GELOGD("node %s type %s output[%d] ref input[%u] addr=%p.", node_name_.c_str(), node_type_.c_str(), out_index, ref_input_index, input->GetData()); } @@ -132,7 +130,7 @@ Status DependInputShapeTask::Execute(TaskContext &context) { } // alloc output - GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); + GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs(NpuMemoryAllocator::AttrWithDefaultPadding())); // copy data to output for (auto i = 0; i < output_num; ++i) { @@ -194,6 +192,16 @@ Status GeLocalNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &no node_type.c_str()); return MEMALLOC_FAILED; } + } else if (node_type == CONSTANTOP || node_type == VARIABLE) { + GELOGI("node %s type %s, use ConstantNodeTask.", node->GetName().c_str(), node_type.c_str()); + auto tensor = model.GetVariable(node->GetName()); + if (tensor == nullptr) { + GELOGE(INTERNAL_ERROR, "Failed to get tensor by name: %s", node->GetName().c_str()); + return INTERNAL_ERROR; + } + + task = MakeShared(tensor); + GE_CHECK_NOTNULL(task); } else { GELOGE(UNSUPPORTED, "node %s type %s is not support in GeLocalNodeExecutor now.", node->GetName().c_str(), node_type.c_str()); @@ -202,5 +210,20 @@ Status GeLocalNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &no return SUCCESS; } +ConstantNodeTask::ConstantNodeTask(const TensorValue *tensor) : tensor_(tensor) {} + +Status ConstantNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; } + +Status ConstantNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + GELOGD("[%s] Start execute.", context.GetNodeName()); + GE_CHK_STATUS_RET(context.SetOutput(0, *tensor_), "[%s] Failed to set output.", context.GetNodeName()); + if (done_callback) { + GELOGD("[%s] Start invoke callback.", context.GetNodeName()); + done_callback(); + } + + GELOGD("[%s] Done execute successfully.", context.GetNodeName()); + return SUCCESS; +} } // namespace hybrid } // namespace ge \ No newline at end of file diff --git a/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.h b/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.h index beb1f50d..0195e76c 100644 --- a/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.h +++ b/src/ge/hybrid/node_executor/hostcpu/ge_local_node_executor.h @@ -23,7 +23,6 @@ namespace ge { namespace hybrid { - class RefInputTask : public NodeTask { public: explicit RefInputTask(const NodePtr &node) : node_name_(node->GetName()), node_type_(node->GetType()) {} @@ -68,6 +67,18 @@ class DependInputShapeTask : public NodeTask { static const std::unordered_set depend_input_shape_ops_; }; +class ConstantNodeTask : public NodeTask { + public: + explicit ConstantNodeTask(const TensorValue *tensor); + ~ConstantNodeTask() = default; + Status UpdateArgs(TaskContext &context) override; + + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + + private: + const TensorValue *tensor_; +}; + class GeLocalNodeExecutor : public NodeExecutor { public: Status PrepareTask(NodeTask &task, TaskContext &context) const override; diff --git a/src/ge/hybrid/node_executor/node_executor.cc b/src/ge/hybrid/node_executor/node_executor.cc index f3b86948..0f4c5494 100644 --- a/src/ge/hybrid/node_executor/node_executor.cc +++ b/src/ge/hybrid/node_executor/node_executor.cc @@ -16,6 +16,7 @@ #include "hybrid/node_executor/node_executor.h" #include "framework/common/debug/log.h" +#include "graph/utils/node_utils.h" #include "init/gelib.h" #include "hybrid/model/hybrid_model.h" @@ -25,9 +26,11 @@ namespace { const char *const kEngineNameAiCore = "AIcoreEngine"; const char *const kEngineNameGeLocal = "DNN_VM_GE_LOCAL_OP_STORE"; const char *const kEngineNameAiCpu = "aicpu_kernel"; +const char *const kEngineNameHccl = "ops_kernel_info_hccl"; } // namespace Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); + GE_CHK_STATUS_RET_NOLOG(task.UpdateTilingData(context)); // update op_desc before alloc ws GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces()); GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context)); return SUCCESS; @@ -48,6 +51,7 @@ Status NodeExecutor::CompileTask(const HybridModel &model, const NodePtr &node, } Status NodeExecutorManager::EnsureInitialized() { + GE_CHK_STATUS_RET(InitializeExecutors()); std::lock_guard lk(mu_); if (initialized_) { return SUCCESS; @@ -56,6 +60,7 @@ Status NodeExecutorManager::EnsureInitialized() { engine_mapping_.emplace(kEngineNameAiCore, NodeExecutorManager::ExecutorType::AICORE); engine_mapping_.emplace(kEngineNameGeLocal, NodeExecutorManager::ExecutorType::GE_LOCAL); engine_mapping_.emplace(kEngineNameAiCpu, NodeExecutorManager::ExecutorType::AICPU_TF); + engine_mapping_.emplace(kEngineNameHccl, NodeExecutorManager::ExecutorType::HCCL); std::shared_ptr instance_ptr = GELib::GetInstance(); if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { @@ -66,23 +71,7 @@ Status NodeExecutorManager::EnsureInitialized() { OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj(); for (auto &it : ops_kernel_manager.GetAllOpsKernelInfoStores()) { GELOGD("add kernel store: %s", it.first.c_str()); - kernel_stores_.emplace(it.first, it.second); - } - - GELOGI("Start to Initialize NodeExecutors"); - for (auto &it : builders_) { - auto engine_type = it.first; - auto build_fn = it.second; - GE_CHECK_NOTNULL(build_fn); - auto executor = std::unique_ptr(build_fn()); - if (executor == nullptr) { - GELOGE(INTERNAL_ERROR, "Failed to create executor for engine type = %d", engine_type); - return INTERNAL_ERROR; - } - - GELOGD("Executor of engine type = %d was created successfully", engine_type); - GE_CHK_STATUS_RET(executor->Initialize(), "Failed to initialize NodeExecutor of type = %d", engine_type); - executors_.emplace(engine_type, std::move(executor)); + kernel_stores_.emplace(it.first, it.second.get()); } initialized_ = true; @@ -93,6 +82,11 @@ Status NodeExecutorManager::EnsureInitialized() { NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node &node) const { auto op_type = node.GetType(); if (op_type == PARTITIONEDCALL) { + bool is_dynamic = false; + (void)NodeUtils::GetNodeUnknownShapeStatus(node, is_dynamic); + if (is_dynamic) { + return ExecutorType::DYNAMIC_SUBGRAPH; + } return ExecutorType::COMPILED_SUBGRAPH; } @@ -101,6 +95,10 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node return ExecutorType::GE_LOCAL; } + if (op_type == IF || op_type == CASE || op_type == WHILE) { + return ExecutorType::CONTROL_OP; + } + auto op_desc = node.GetOpDesc(); // checked before const auto &lib_name = op_desc->GetOpKernelLibName(); auto it = engine_mapping_.find(lib_name); @@ -116,10 +114,11 @@ Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executo auto executor_type = ResolveExecutorType(node); const auto it = executors_.find(executor_type); if (it == executors_.end()) { - GELOGE(INTERNAL_ERROR, "Failed to get executor by type: %d", executor_type); + GELOGE(INTERNAL_ERROR, "Failed to get executor by type: %d.", executor_type); return INTERNAL_ERROR; } + GELOGD("[%s] Set node executor by type: %d.", node.GetName().c_str(), executor_type); *executor = it->second.get(); return SUCCESS; } @@ -132,6 +131,11 @@ void NodeExecutorManager::RegisterExecutorBuilder(NodeExecutorManager::ExecutorT Status NodeExecutorManager::CalcOpRunningParam(Node &node) const { auto op_desc = node.GetOpDesc(); GE_CHECK_NOTNULL(op_desc); + if (op_desc->GetType() == PARTITIONEDCALL) { + GELOGD("[%s] Skipping CalcOpRunningParam for PartitionedCall.", node.GetName().c_str()); + return SUCCESS; + } + auto it = kernel_stores_.find(op_desc->GetOpKernelLibName()); if (it == kernel_stores_.end()) { GELOGE(INTERNAL_ERROR, "Failed to get OpKernelStore. libName = %s, node = %s", @@ -139,9 +143,91 @@ Status NodeExecutorManager::CalcOpRunningParam(Node &node) const { return INTERNAL_ERROR; } + // calc hccl output size independent, hccl ops kernel manager should GetSize for + // input which is the output size of input-op, but sometimes return error + // when multi-thread + if (op_desc->GetOpKernelLibName() == kEngineNameHccl) { + for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { + GeTensorDesc output_tensor = op_desc->GetOutputDesc(static_cast(i)); + Format format = output_tensor.GetFormat(); + DataType data_type = output_tensor.GetDataType(); + GeShape output_shape = output_tensor.GetShape(); + int64_t output_mem_size = 0; + GE_CHK_STATUS_RET(TensorUtils::CalcTensorMemSize(output_shape, format, data_type, output_mem_size), + "hccl calc tensor mem size failed."); + output_mem_size = + ((output_mem_size + MEMORY_ALIGN_RATIO * MEMORY_ALIGN_SIZE - 1) / MEMORY_ALIGN_SIZE) * MEMORY_ALIGN_SIZE; + TensorUtils::SetSize(output_tensor, output_mem_size); + GE_CHK_STATUS_RET(op_desc->UpdateOutputDesc(static_cast(i), output_tensor), + "hccl update output size failed."); + GELOGD("%s output desc[%u], dim_size: %zu, mem_size: %ld.", node.GetName().c_str(), i, + output_tensor.GetShape().GetDimNum(), output_mem_size); + } + return SUCCESS; + } return it->second->CalcOpRunningParam(node); } +Status NodeExecutorManager::InitializeExecutors() { + std::lock_guard lk(mu_); + if (executor_initialized_) { + ++ref_count_; + GELOGI("Executor is already initialized. add ref count to [%d]", ref_count_); + return SUCCESS; + } + + GELOGI("Start to Initialize NodeExecutors"); + for (auto &it : builders_) { + auto engine_type = it.first; + auto build_fn = it.second; + GE_CHECK_NOTNULL(build_fn); + auto executor = std::unique_ptr(build_fn()); + if (executor == nullptr) { + GELOGE(INTERNAL_ERROR, "Failed to create executor for engine type = %d", engine_type); + return INTERNAL_ERROR; + } + + GELOGD("Executor of engine type = %d was created successfully", engine_type); + auto ret = executor->Initialize(); + if (ret != SUCCESS) { + GELOGE(ret, "Failed to initialize NodeExecutor of type = %d, clear executors", engine_type); + for (auto &executor_it : executors_) { + executor_it.second->Finalize(); + } + executors_.clear(); + return ret; + } + + executors_.emplace(engine_type, std::move(executor)); + } + + ++ref_count_; + executor_initialized_ = true; + GELOGI("Initializing NodeExecutors successfully."); + return SUCCESS; +} + +void NodeExecutorManager::FinalizeExecutors() { + std::lock_guard lk(mu_); + if (!executor_initialized_) { + GELOGD("No need for finalizing for not initialized."); + return; + } + + if (--ref_count_ > 0) { + GELOGD("Ref count = %d, do not finalize executors.", ref_count_); + return; + } + + GELOGD("Start to invoke Finalize on executors."); + for (auto &it : executors_) { + it.second->Finalize(); + } + executors_.clear(); + executor_initialized_ = false; + GELOGD("Done invoking Finalize successfully."); +} + NodeExecutorRegistrar::NodeExecutorRegistrar(NodeExecutorManager::ExecutorType executor_type, NodeExecutor *(*builder)()) { NodeExecutorManager::GetInstance().RegisterExecutorBuilder(executor_type, builder); diff --git a/src/ge/hybrid/node_executor/node_executor.h b/src/ge/hybrid/node_executor/node_executor.h index 613c0bb1..23e52428 100644 --- a/src/ge/hybrid/node_executor/node_executor.h +++ b/src/ge/hybrid/node_executor/node_executor.h @@ -14,70 +14,182 @@ * limitations under the License. */ -#ifndef GE_HYBRID_KERNEL_NODE_EXECUTOR_H_ -#define GE_HYBRID_KERNEL_NODE_EXECUTOR_H_ +#ifndef GE_HYBRID_NODE_EXECUTOR_NODE_EXECUTOR_H_ +#define GE_HYBRID_NODE_EXECUTOR_NODE_EXECUTOR_H_ #include "external/ge/ge_api_error_codes.h" #include "common/opskernel/ops_kernel_info_store.h" #include "graph/node.h" -#include "proto/task.pb.h" #include "task_context.h" namespace ge { +const uint32_t MEMORY_ALIGN_RATIO = 2; +const uint32_t MEMORY_ALIGN_SIZE = 32; namespace hybrid { class HybridModel; - +// Base class of Node Task class NodeTask { public: NodeTask() = default; virtual ~NodeTask() = default; + + /** + * Update tiling data + * @param context instance of TaskContext + * @return SUCCESS on success, error code otherwise + */ + virtual Status UpdateTilingData(TaskContext &context) { return SUCCESS; } + + /** + * Init + * @param context instance of TaskContext + * @return SUCCESS on success, error code otherwise + */ + virtual Status Init(TaskContext &context) { return SUCCESS; } + + /** + * Whether this task supports dynamic shape + * @return true if this task supports dynamic shape, false otherwise + */ + virtual bool IsSupportDynamicShape() { return true; } + + /** + * Update args for execution + * @param context instance of TaskContext + * @return SUCCESS on success, error code otherwise + */ virtual Status UpdateArgs(TaskContext &context) = 0; + + /** + * Execute task async + * @param context instance of TaskContext + * @param done_callback callback function, will be invoked after task is done + * @return SUCCESS on success, error code otherwise + */ virtual Status ExecuteAsync(TaskContext &context, std::function done_callback) = 0; - virtual Status Init(TaskContext &context) { return SUCCESS; } }; +// Node executor class NodeExecutor { public: NodeExecutor() = default; virtual ~NodeExecutor() = default; + /** + * Initialize node executor + * @return SUCCESS on success, error code otherwise + */ virtual Status Initialize() { return SUCCESS; } + /** + * Finalize node executor + * @return SUCCESS on success, error code otherwise + */ virtual Status Finalize() { return SUCCESS; } + /** + * Load task in load stage + * @param model instance of HybridModel + * @param node node + * @param task generated node task + * @return SUCCESS on success, error code otherwise + */ virtual Status LoadTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const; + /** + * Compile task in run stage + * @param model instance of HybridModel + * @param node node + * @param task generated node task + * @return SUCCESS on success, error code otherwise + */ virtual Status CompileTask(const HybridModel &model, const NodePtr &node, std::shared_ptr &task) const; + /** + * Preparation actions before execution + * @param task instance of NodeTask + * @param context instance of TaskContext + * @return SUCCESS on success, error code otherwise + */ virtual Status PrepareTask(NodeTask &task, TaskContext &context) const; + + /** + * Execute task + * @param task instance of NodeTask + * @param context instance of TaskContext + * @param callback callback function which will be invoked after computation is done + * @return SUCCESS on success, error code otherwise + */ virtual Status ExecuteTask(NodeTask &task, TaskContext &context, const std::function &callback) const; }; class NodeExecutorManager { public: - enum class ExecutorType { AICORE, GE_LOCAL, AICPU_TF, AICPU_CUSTOM, COMPILED_SUBGRAPH, HCCL, RESERVED }; + enum class ExecutorType { + AICORE, + AICPU_TF, + AICPU_CUSTOM, + COMPILED_SUBGRAPH, + DYNAMIC_SUBGRAPH, + GE_LOCAL, + CONTROL_OP, + HCCL, + RESERVED + }; static NodeExecutorManager &GetInstance() { static NodeExecutorManager instance; return instance; } - Status CalcOpRunningParam(Node &node) const; - + /** + * Register build of executor + * @param executor_type type of executor + * @param builder build function + */ void RegisterExecutorBuilder(ExecutorType executor_type, const std::function &builder); + /** + * Initialize executor if needed + * @return SUCCESS on success, error code otherwise + */ Status EnsureInitialized(); + Status InitializeExecutors(); + + void FinalizeExecutors(); + + /** + * CalcOpRunningParam + * @param node node + * @return SUCCESS on success, error code otherwise + */ + Status CalcOpRunningParam(Node &node) const; + + /** + * Get executor by node + * @param node node + * @param executor executor + * @return SUCCESS on success, error code otherwise + */ Status GetExecutor(Node &node, const NodeExecutor **executor) const; + /** + * Resolve executor type by node + * @param node node + * @return executor type + */ ExecutorType ResolveExecutorType(Node &node) const; + private: std::map> executors_; std::map> builders_; - std::map> kernel_stores_; + std::map kernel_stores_; std::map engine_mapping_; std::mutex mu_; bool initialized_ = false; + bool executor_initialized_ = false; + int ref_count_ = 0; }; class NodeExecutorRegistrar { @@ -99,4 +211,4 @@ class NodeExecutorRegistrar { ::ge::hybrid::NodeExecutorRegistrar( \ engine_type, []() -> ::ge::hybrid::NodeExecutor * { return new (std::nothrow) executor(); }) -#endif // GE_HYBRID_KERNEL_NODE_EXECUTOR_H_ +#endif // GE_HYBRID_NODE_EXECUTOR_NODE_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc b/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc new file mode 100644 index 00000000..cda9a275 --- /dev/null +++ b/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.cc @@ -0,0 +1,81 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "partitioned_call_node_executor.h" +#include "graph/utils/node_utils.h" + +namespace ge { +namespace hybrid { +REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::DYNAMIC_SUBGRAPH, PartitionedCallNodeExecutor); + +PartitionedCallNodeTask::PartitionedCallNodeTask(const GraphItem *graph_item) : graph_item_(graph_item) {} + +PartitionedCallNodeTask::~PartitionedCallNodeTask() { + GELOGD("[%s] PartitionedCallNodeTask destroyed.", graph_item_->GetName().c_str()); +} + +Status PartitionedCallNodeTask::Init(TaskContext &context) { + auto execution_context = const_cast(context.GetExecutionContext()); + subgraph_executor_.reset(new (std::nothrow) SubgraphExecutor(graph_item_, execution_context)); + GE_CHECK_NOTNULL(subgraph_executor_); + return SUCCESS; +} + +Status PartitionedCallNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + GE_CHK_STATUS_RET(subgraph_executor_->ExecuteAsync(context), "[%s] Failed to set inputs", + graph_item_->GetName().c_str()); + + auto callback = [=]() { Callback(done_callback); }; + + GE_CHK_STATUS_RET(context.RegisterCallback(callback), "[%s] Failed to register callback", + graph_item_->GetName().c_str()); + GELOGD("[%s] Done executing subgraph successfully.", graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status PartitionedCallNodeTask::Callback(const std::function &done_callback) { + GELOGD("[%s] On subgraph callback", graph_item_->GetName().c_str()); + if (done_callback != nullptr) { + done_callback(); + } + + GELOGD("[%s] To release sub graph tensors.", graph_item_->GetName().c_str()); + subgraph_executor_.reset(); + GELOGD("[%s] Done releasing sub graph tensors.", graph_item_->GetName().c_str()); + return SUCCESS; +} + +Status PartitionedCallNodeTask::UpdateArgs(TaskContext &context) { return SUCCESS; } + +Status PartitionedCallNodeExecutor::LoadTask(const ge::hybrid::HybridModel &model, const ge::NodePtr &node, + std::shared_ptr &task) const { + GELOGD("Load dynamic partitioned call: [%s]", node->GetName().c_str()); + auto subgraph = NodeUtils::GetSubgraph(*node, 0); + GE_CHECK_NOTNULL(subgraph); + auto partitioned_call = model.GetSubgraphItem(subgraph); + GE_CHECK_NOTNULL(partitioned_call); + task.reset(new (std::nothrow) PartitionedCallNodeTask(partitioned_call)); + GE_CHECK_NOTNULL(task); + GELOGD("Done loading dynamic partitioned call: [%s]", node->GetName().c_str()); + return SUCCESS; +} + +Status PartitionedCallNodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { + GE_CHK_STATUS_RET(task.Init(context), "[%s] Failed to init task.", context.GetNodeName()); + return SUCCESS; +} +} // namespace hybrid +} // namespace ge diff --git a/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h b/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h new file mode 100644 index 00000000..fd87d6c1 --- /dev/null +++ b/src/ge/hybrid/node_executor/partitioned_call/partitioned_call_node_executor.h @@ -0,0 +1,54 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_HYBRID_NODE_EXECUTOR_SUBGRAPH_SUBGRAPH_EXECUTOR_H_ +#define GE_HYBRID_NODE_EXECUTOR_SUBGRAPH_SUBGRAPH_EXECUTOR_H_ + +#include "hybrid/node_executor/node_executor.h" +#include "hybrid/model/hybrid_model.h" +#include "hybrid/executor/node_state.h" +#include "hybrid/executor/subgraph_executor.h" +#include "common/thread_pool.h" + +namespace ge { +namespace hybrid { +class PartitionedCallNodeTask : public NodeTask { + public: + explicit PartitionedCallNodeTask(const GraphItem *graph_item); + ~PartitionedCallNodeTask() override; + + Status Init(TaskContext &context) override; + + Status UpdateArgs(TaskContext &context) override; + + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; + + private: + Status Callback(const std::function &done_callback); + + const GraphItem *graph_item_; + std::unique_ptr subgraph_executor_; + GraphExecutionContext *context_ = nullptr; +}; + +class PartitionedCallNodeExecutor : public NodeExecutor { + public: + Status LoadTask(const HybridModel &model, const NodePtr &node, shared_ptr &task) const override; + Status PrepareTask(NodeTask &task, TaskContext &context) const override; +}; +} // namespace hybrid +} // namespace ge +#endif // GE_HYBRID_NODE_EXECUTOR_SUBGRAPH_SUBGRAPH_EXECUTOR_H_ diff --git a/src/ge/hybrid/node_executor/task_context.cc b/src/ge/hybrid/node_executor/task_context.cc index 42c653be..ee35bffa 100644 --- a/src/ge/hybrid/node_executor/task_context.cc +++ b/src/ge/hybrid/node_executor/task_context.cc @@ -19,12 +19,16 @@ #include "framework/common/debug/log.h" #include "graph/utils/tensor_utils.h" #include "hybrid/executor/hybrid_execution_context.h" +#include "hybrid/executor/subgraph_executor.h" namespace ge { namespace hybrid { -TaskContext::TaskContext(GraphExecutionContext *execution_context) : execution_context_(execution_context) {} +TaskContext::TaskContext(GraphExecutionContext *execution_context, const NodeItem *node_item, + SubgraphContext *subgraph_context) + : node_item_(node_item), execution_context_(execution_context), subgraph_context_(subgraph_context) {} + TaskContext::~TaskContext() { - GELOGD("To execute ~TaskContext(). node = %s", node_item_->NodeName().c_str()); + GELOGD("[%s] TaskContext destroyed.", node_item_->NodeName().c_str()); for (auto ws_addr : workspaces_) { execution_context_->allocator->Deallocate(ws_addr); } @@ -38,19 +42,28 @@ TaskContext::~TaskContext() { } } -std::unique_ptr TaskContext::Create(const NodeItem &node_item, GraphExecutionContext *graph_context) { - GELOGI("To create task context for node %s, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d", +std::unique_ptr TaskContext::Create(const NodeItem &node_item, GraphExecutionContext *execution_context, + SubgraphContext *subgraph_context) { + GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.", node_item.NodeName().c_str(), node_item.input_start, node_item.num_inputs, node_item.output_start, node_item.num_outputs); - auto task_context = std::unique_ptr(new (std::nothrow) TaskContext(graph_context)); + if (node_item.input_start < 0 || node_item.output_start < 0) { + GELOGE(INTERNAL_ERROR, "NodeItem not property initialized. input_start = %d, output_start = %d", + node_item.input_start, node_item.output_start); + return nullptr; + } + + auto task_context = + std::unique_ptr(new (std::nothrow) TaskContext(execution_context, &node_item, subgraph_context)); if (task_context == nullptr) { - GELOGE(MEMALLOC_FAILED, "Failed to create instance of TaskContext. node = %s", node_item.NodeName().c_str()); + GELOGE(MEMALLOC_FAILED, "[%s] Failed to create instance of TaskContext.", node_item.NodeName().c_str()); return nullptr; } task_context->node_item_ = &node_item; - task_context->inputs_start_ = graph_context->all_inputs.data() + node_item.input_start; - task_context->outputs_start_ = graph_context->all_outputs.data() + node_item.output_start; + task_context->inputs_start_ = subgraph_context->all_inputs_.data() + node_item.input_start; + task_context->outputs_start_ = subgraph_context->all_outputs_.data() + node_item.output_start; + task_context->iteration_ = execution_context->iteration; return task_context; } @@ -59,7 +72,7 @@ int TaskContext::NumInputs() const { return node_item_->num_inputs; } int TaskContext::NumOutputs() const { return node_item_->num_outputs; } TensorValue *TaskContext::MutableInput(int index) { - if (index < 0 || index > node_item_->num_inputs) { + if (index < 0 || index >= node_item_->num_inputs) { GELOGE(PARAM_INVALID, "Index out of range. index = %d, num_inputs = %d", index, node_item_->num_inputs); return nullptr; } @@ -68,7 +81,7 @@ TensorValue *TaskContext::MutableInput(int index) { } const TensorValue *TaskContext::GetOutput(int index) const { - if (index < 0 || index > node_item_->num_outputs) { + if (index < 0 || index >= node_item_->num_outputs) { GELOGE(PARAM_INVALID, "Index out of range. index = %d, num_outputs = %d", index, node_item_->num_outputs); return nullptr; } @@ -77,7 +90,7 @@ const TensorValue *TaskContext::GetOutput(int index) const { } TensorValue *TaskContext::MutableOutput(int index) { - if (index < 0 || index > node_item_->num_outputs) { + if (index < 0 || index >= node_item_->num_outputs) { GELOGE(PARAM_INVALID, "Index out of range. index = %d, num_outputs = %d", index, node_item_->num_outputs); return nullptr; } @@ -97,7 +110,7 @@ void *TaskContext::MutableWorkspace(int index) { } const TensorValue *TaskContext::GetInput(int index) const { - if (index < 0 || index > node_item_->num_inputs) { + if (index < 0 || index >= node_item_->num_inputs) { GELOGE(PARAM_INVALID, "Index out of range. index = %d, num_inputs = %d", index, node_item_->num_inputs); return nullptr; } @@ -120,7 +133,14 @@ Status TaskContext::AllocateWorkspaces() { } Status TaskContext::RegisterCallback(const std::function &callback_fun) const { - return execution_context_->callback_manager->RegisterCallback(callback_fun); + auto ret = execution_context_->callback_manager->RegisterCallback(callback_fun); + if (ret != SUCCESS) { + GELOGE(ret, "[%s] Failed to register callback", GetNodeName()); + execution_context_->callback_manager->Destroy(); + return ret; + } + + return SUCCESS; } string TaskContext::TensorDesc2String(const GeTensorDesc &desc) { @@ -137,7 +157,7 @@ string TaskContext::TensorDesc2String(const GeTensorDesc &desc) { return ss.str(); } -Status TaskContext::AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue &tensor) { +Status TaskContext::AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue &tensor, AllocationAttr *attr) { int64_t size = 0; if (ge::TensorUtils::GetSize(tensor_desc, size) != GRAPH_SUCCESS) { GELOGE(INTERNAL_ERROR, "Failed to get tensor size"); @@ -148,13 +168,14 @@ Status TaskContext::AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue GELOGW("size from tensor_desc == 0"); } - auto buffer = TensorBuffer::Create(execution_context_->allocator, size); + auto buffer = TensorBuffer::Create(execution_context_->allocator, size, attr); GE_CHECK_NOTNULL(buffer); tensor = TensorValue(shared_ptr(buffer.release())); return SUCCESS; } -Status TaskContext::AllocateOutput(int index, const GeTensorDesc &tensor_desc, TensorValue **tensor) { +Status TaskContext::AllocateOutput(int index, const GeTensorDesc &tensor_desc, TensorValue **tensor, + AllocationAttr *attr) { GELOGI("To allocate output for node: %s. index = %d, tensor desc = %s", node_item_->NodeName().c_str(), index, TensorDesc2String(tensor_desc).c_str()); @@ -178,9 +199,29 @@ Status TaskContext::AllocateOutput(int index, const GeTensorDesc &tensor_desc, T GE_CHECK_NOTNULL(ref_tensor); outputs_start_[index] = *ref_tensor; } else { - GE_CHK_STATUS_RET_NOLOG(AllocateTensor(tensor_desc, outputs_start_[index])); - GELOGD("Allocating output successfully. node: %s. index = %d, size = %zu", node_item_->NodeName().c_str(), index, - outputs_start_[index].GetSize()); + auto reuse_input = node_item_->reuse_inputs.find(index); + if (reuse_input != node_item_->reuse_inputs.end()) { + GELOGD("[%s] Output[%d] is referenced to input[%d]", GetNodeName(), index, reuse_input->second); + outputs_start_[index] = inputs_start_[reuse_input->second]; + } else { + GE_CHK_STATUS_RET_NOLOG(AllocateTensor(tensor_desc, outputs_start_[index], attr)); + GELOGD("Allocating output successfully. node: %s. index = %d, size = %zu", node_item_->NodeName().c_str(), index, + outputs_start_[index].GetSize()); + } + } + + // Temp modification + if (node_item_->node_type == "UnsortedSegmentSum" || node_item_->node_type == "UnsortedSegmentSumD" || + node_item_->node_type == "ScatterNd") { + auto &out_tensor = outputs_start_[index]; + GELOGD("[%s] clear output tensor: %s", GetNodeName(), out_tensor.DebugString().c_str()); + auto *ctx = GetExecutionContext(); + string name = "rtMemsetAsync" + node_item_->node_name; + RegisterCallback([ctx, name]() { RECORD_CALLBACK_EVENT(ctx, name.c_str(), "[Compute] Start"); }); + RECORD_EXECUTION_EVENT(GetExecutionContext(), node_item_->node_name.c_str(), "[rtMemsetAsync] Start"); + GE_CHK_RT_RET(rtMemsetAsync(out_tensor.MutableData(), out_tensor.GetSize(), 0, out_tensor.GetSize(), GetStream())); + RECORD_EXECUTION_EVENT(GetExecutionContext(), node_item_->node_name.c_str(), "[rtMemsetAsync] End"); + RegisterCallback([ctx, name]() { RECORD_CALLBACK_EVENT(ctx, name.c_str(), "[Compute] End"); }); } if (execution_context_->trace_enabled) { @@ -194,11 +235,11 @@ Status TaskContext::AllocateOutput(int index, const GeTensorDesc &tensor_desc, T return SUCCESS; } -Status TaskContext::AllocateOutputs() { +Status TaskContext::AllocateOutputs(AllocationAttr *attr) { for (int i = 0; i < node_item_->num_outputs; ++i) { const auto &output_desc = node_item_->op_desc->MutableOutputDesc(i); GE_CHECK_NOTNULL(output_desc); - GE_CHK_STATUS_RET_NOLOG(AllocateOutput(i, *output_desc, nullptr)); + GE_CHK_STATUS_RET_NOLOG(AllocateOutput(i, *output_desc, nullptr, attr)); } return SUCCESS; @@ -230,7 +271,7 @@ Status TaskContext::SetOutput(int index, const TensorValue &tensor) { rtStream_t TaskContext::GetStream() { return execution_context_->stream; } -int64_t TaskContext::GetSessionId() { return execution_context_->session_id; } +int64_t TaskContext::GetSessionId() const { return execution_context_->session_id; } Status TaskContext::GetStatus() const { return status_; } @@ -238,7 +279,13 @@ void TaskContext::SetStatus(Status status) { status_ = status; } Status TaskContext::AllocateWorkspace(size_t size, void **buffer, void *ori_addr) { GE_CHECK_NOTNULL(buffer); - *buffer = execution_context_->allocator->Allocate(size, ori_addr); + if (ori_addr == nullptr) { + *buffer = execution_context_->allocator->Allocate(size, nullptr); + } else { + AllocationAttr attr(ori_addr); + *buffer = execution_context_->allocator->Allocate(size, &attr); + } + if (*buffer == nullptr) { GELOGE(MEMALLOC_FAILED, "Failed to allocate workspace of size = %zu", size); return MEMALLOC_FAILED; @@ -261,16 +308,21 @@ Status TaskContext::PropagateOutputs() { for (auto &dst_input_index_and_node : output_nodes) { auto dst_input_idx = dst_input_index_and_node.first; auto dst_node_item = dst_input_index_and_node.second; + auto input_offset = dst_node_item->input_start + dst_input_idx; GELOGI( "Propagate output of node %s, output index = %d, dst node = %s, " - "dst_input_index = %d, dst_input_offset = %d, addr = %p", - node_item_->NodeName().c_str(), i, dst_node_item->NodeName().c_str(), dst_input_idx, - dst_node_item->input_start + dst_input_idx, - execution_context_->all_inputs.data() + dst_node_item->input_start + dst_input_idx); - execution_context_->all_inputs[dst_node_item->input_start + dst_input_idx] = *tensor; + "dst_input_index = %d, dst_input_offset = %d.", + node_item_->NodeName().c_str(), i, dst_node_item->NodeName().c_str(), dst_input_idx, input_offset); + + if (subgraph_context_->all_inputs_.size() <= static_cast(input_offset)) { + GELOGE(INTERNAL_ERROR, "[%s] input index out of range. index = %d, total input num = %zu", GetNodeName(), + input_offset, subgraph_context_->all_inputs_.size()); + return INTERNAL_ERROR; + } + + subgraph_context_->all_inputs_[input_offset] = *tensor; if (execution_context_->trace_enabled) { - execution_context_->all_inputs[dst_node_item->input_start + dst_input_idx].SetName(node_item_->NodeName() + - "_in_" + std::to_string(i)); + subgraph_context_->all_inputs_[input_offset].SetName(node_item_->NodeName() + "_in_" + std::to_string(i)); } } } @@ -289,5 +341,37 @@ void TaskContext::ReleaseInput(int index) { GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); } } + +ConstGeTensorDescPtr TaskContext::GetOutputDesc(int index) { + return node_item_->op_desc->MutableOutputDesc(static_cast(index)); +} + +ConstGeTensorDescPtr TaskContext::GetInputDesc(int index) { + return node_item_->op_desc->MutableInputDesc(static_cast(index)); +} + +GeTensorDescPtr TaskContext::MutableInputDesc(int index) { + return node_item_->op_desc->MutableInputDesc(static_cast(index)); +} + +GeTensorDescPtr TaskContext::MutableOutputDesc(int index) { + return node_item_->op_desc->MutableOutputDesc(static_cast(index)); +} + +bool TaskContext::IsForceInferShape() const { return force_infer_shape_; } + +void TaskContext::SetForceInferShape(bool force_infer_shape) { force_infer_shape_ = force_infer_shape; } + +void TaskContext::NodeDone() { subgraph_context_->NodeDone(node_item_->node); } + +void TaskContext::OnError(Status error) { subgraph_context_->OnError(error); } + +bool TaskContext::IsTraceEnabled() const { return execution_context_->trace_enabled; } + +TensorValue *TaskContext::GetVariable(const std::string &name) { return execution_context_->model->GetVariable(name); } + +uint64_t TaskContext::GetIterationNumber() const { return iteration_; } + +bool TaskContext::IsDumpEnabled() const { return execution_context_->dump_enabled; } } // namespace hybrid } // namespace ge diff --git a/src/ge/hybrid/node_executor/task_context.h b/src/ge/hybrid/node_executor/task_context.h index 841dcb17..5c42a347 100644 --- a/src/ge/hybrid/node_executor/task_context.h +++ b/src/ge/hybrid/node_executor/task_context.h @@ -22,16 +22,19 @@ #include #include "external/ge/ge_api_error_codes.h" #include "hybrid/common/tensor_value.h" +#include "hybrid/common/npu_memory_allocator.h" #include "hybrid/executor/rt_callback_manager.h" #include "hybrid/model/node_item.h" namespace ge { namespace hybrid { class GraphExecutionContext; +class SubgraphContext; class TaskContext { public: - static std::unique_ptr Create(const NodeItem &node_item, GraphExecutionContext *graph_context); + static std::unique_ptr Create(const NodeItem &node_item, GraphExecutionContext *execution_context, + SubgraphContext *subgraph_context); ~TaskContext(); @@ -41,19 +44,33 @@ class TaskContext { const NodeItem &GetNodeItem() const; const char *GetNodeName() const; TensorValue *MutableInput(int index); + ConstGeTensorDescPtr GetInputDesc(int index); + ConstGeTensorDescPtr GetOutputDesc(int index); + GeTensorDescPtr MutableInputDesc(int index); + GeTensorDescPtr MutableOutputDesc(int index); void ReleaseInput(int index); const TensorValue *GetInput(int index) const; const TensorValue *GetOutput(int index) const; TensorValue *MutableOutput(int index); + TensorValue *GetVariable(const std::string &name); rtStream_t GetStream(); - int64_t GetSessionId(); + int64_t GetSessionId() const; + uint64_t GetIterationNumber() const; + + void NodeDone(); + void OnError(Status error); Status SetOutput(int index, const TensorValue &tensor); - Status AllocateOutput(int index, const GeTensorDesc &tensor_desc, TensorValue **tensor); - Status AllocateOutputs(); + Status AllocateOutput(int index, const GeTensorDesc &tensor_desc, TensorValue **tensor, + AllocationAttr *attr = nullptr); + Status AllocateOutputs(AllocationAttr *attr = nullptr); Status AllocateWorkspaces(); Status AllocateWorkspace(size_t size, void **buffer, void *ori_addr = nullptr); + bool IsTraceEnabled() const; + + bool IsDumpEnabled() const; + const GraphExecutionContext *GetExecutionContext() { return execution_context_; } Status AllocateTemp(size_t size, TensorValue &tensor); @@ -68,17 +85,25 @@ class TaskContext { void SetStatus(Status status); + bool IsForceInferShape() const; + void SetForceInferShape(bool force_infer_shape); + void *handle_ = nullptr; + private: - explicit TaskContext(GraphExecutionContext *execution_context); - TensorValue *inputs_start_ = nullptr; - TensorValue *outputs_start_ = nullptr; + TaskContext(GraphExecutionContext *execution_context, const NodeItem *node_item, SubgraphContext *subgraph_context); + static string TensorDesc2String(const GeTensorDesc &desc); - Status AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue &tensor); + Status AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue &tensor, AllocationAttr *attr); - GraphExecutionContext *execution_context_; const NodeItem *node_item_ = nullptr; + bool force_infer_shape_ = false; + GraphExecutionContext *execution_context_; + SubgraphContext *subgraph_context_; + TensorValue *inputs_start_ = nullptr; + TensorValue *outputs_start_ = nullptr; Status status_ = SUCCESS; std::vector workspaces_; + uint64_t iteration_ = 0; }; } // namespace hybrid } // namespace ge diff --git a/src/ge/inc/kernel_factory.h b/src/ge/inc/kernel_factory.h index c0624e14..61455836 100644 --- a/src/ge/inc/kernel_factory.h +++ b/src/ge/inc/kernel_factory.h @@ -103,5 +103,5 @@ class KernelFactory { return ptr; \ } \ KernelFactory::Registerar g_##type##_Kernel_Creator(type, Creator_##type##_Kernel) -}; // end namespace ge +} // namespace ge #endif // GE_INC_KERNEL_FACTORY_H_ diff --git a/src/ge/init/gelib.cc b/src/ge/init/gelib.cc index 5fcb0cd7..f7740a3c 100644 --- a/src/ge/init/gelib.cc +++ b/src/ge/init/gelib.cc @@ -37,6 +37,7 @@ #include "graph/load/new_model_manager/model_manager.h" #include "graph/manager/graph_mem_allocator.h" #include "graph/manager/graph_var_manager.h" +#include "graph/common/ge_call_wrapper.h" #include "omm/csa_interact.h" #include "runtime/kernel.h" @@ -46,6 +47,9 @@ namespace ge { namespace { const int kDecimal = 10; const int kSocVersionLen = 50; +const uint32_t kAicoreOverflow = (0x1 << 0); +const uint32_t kAtomicOverflow = (0x1 << 1); +const uint32_t kAllOverflow = (kAicoreOverflow | kAtomicOverflow); } // namespace static std::shared_ptr instancePtr_ = nullptr; @@ -75,7 +79,7 @@ Status GELib::Initialize(const map &options) { instancePtr_ = nullptr; return ret; } - GE_TIMESTAMP_END(Init, "GELib::Initialize"); + GE_TIMESTAMP_EVENT_END(Init, "GELib::Initialize"); return SUCCESS; } @@ -126,16 +130,6 @@ Status GELib::InnerInitialize(const map &options) { return initSmStatus; } - GELOGI("memoryMallocSize initial."); - GE_TIMESTAMP_START(SetMemoryMallocSize); - Status initMemStatus = VarManager::Instance(0)->SetMemoryMallocSize(options); - GE_TIMESTAMP_END(SetMemoryMallocSize, "InnerInitialize::SetMemoryMallocSize"); - if (initMemStatus != SUCCESS) { - GELOGE(initMemStatus, "failed to set malloc size"); - RollbackInit(); - return initMemStatus; - } - GELOGI("Start to initialize HostCpuEngine"); GE_TIMESTAMP_START(HostCpuEngineInitialize); Status initHostCpuEngineStatus = HostCpuEngine::GetInstance().Initialize(); @@ -160,37 +154,6 @@ Status GELib::SystemInitialize(const map &options) { } } - iter = options.find(HEAD_STREAM); - head_stream_ = (iter != options.end()) ? std::strtol(iter->second.c_str(), nullptr, kDecimal) : false; - - iter = options.find(OPTION_EXEC_ENABLE_DUMP); - if (iter != options.end()) { - int32_t enable_dump_flag = 1; - auto path_iter = options.find(OPTION_EXEC_DUMP_PATH); - if (iter->second == std::to_string(enable_dump_flag) && path_iter != options.end()) { - std::string dump_path = path_iter->second; - if (!dump_path.empty() && dump_path[dump_path.size() - 1] != '/') { - dump_path = dump_path + "/" + CurrentTimeInStr() + "/"; - } - - PropertiesManager::Instance().AddDumpPropertyValue(DUMP_ALL_MODEL, {}); - GELOGD("Get dump path %s successfully", dump_path.c_str()); - PropertiesManager::Instance().SetDumpOutputPath(dump_path); - } - auto step_iter = options.find(OPTION_EXEC_DUMP_STEP); - if (step_iter != options.end()) { - std::string dump_step = step_iter->second; - GELOGD("Get dump step %s successfully", dump_step.c_str()); - PropertiesManager::Instance().SetDumpStep(dump_step); - } - auto mode_iter = options.find(OPTION_EXEC_DUMP_MODE); - if (mode_iter != options.end()) { - std::string dump_mode = mode_iter->second; - GELOGD("Get dump mode %s successfully", dump_mode.c_str()); - PropertiesManager::Instance().SetDumpMode(dump_mode); - } - } - // In train and infer, profiling is always needed. InitOptions(options); InitProfiling(this->options_); diff --git a/src/ge/init/gelib.h b/src/ge/init/gelib.h index 0dfec391..b5621dfd 100644 --- a/src/ge/init/gelib.h +++ b/src/ge/init/gelib.h @@ -62,9 +62,6 @@ class GELib { // get TrainMode flag bool isTrainMode() { return is_train_mode_; } - // add head stream to model - bool HeadStream() const { return head_stream_; } - // get incre build flag bool IsIncreBuild() const { return is_incre_build_; } @@ -86,6 +83,8 @@ class GELib { Status SetRTSocVersion(const map &options, map &new_options); void RollbackInit(); void InitOptions(const map &options); + void SetDumpModelOptions(const map &options); + void SetOpDebugOptions(const map &options); DNNEngineManager engineManager_; OpsKernelManager opsManager_; @@ -98,7 +97,6 @@ class GELib { bool is_shutdown = false; bool is_use_hcom = false; bool is_incre_build_ = false; - bool head_stream_ = false; std::string incre_build_cache_path_; }; } // namespace ge diff --git a/src/ge/ir_build/atc_ir_common.cc b/src/ge/ir_build/atc_ir_common.cc index 12c85bc0..91fa17d4 100644 --- a/src/ge/ir_build/atc_ir_common.cc +++ b/src/ge/ir_build/atc_ir_common.cc @@ -16,6 +16,7 @@ #include "atc_ir_common.h" #include "common/util/error_manager/error_manager.h" +#include "common/model_parser/graph_parser_util.h" #include "external/ge/ge_api_types.h" #include "framework/common/string_util.h" #include "framework/common/types.h" @@ -29,11 +30,23 @@ namespace ge { namespace { const int64_t kDynamicInputDim = -1; const int64_t kDynamicImageSizeNum = 2; +const size_t kMaxDynamicDimNum = 100; +const size_t kMaxNDDimNum = 4; +const size_t kMinNDDimNum = 1; // datatype/formats from user to GE, Unified to util interface file later const std::map kOutputTypeSupportDatatype = { {"FP32", ge::DT_FLOAT}, {"FP16", ge::DT_FLOAT16}, {"UINT8", ge::DT_UINT8}}; -const std::set kBufferOptimizeSupportOption = {"l2_optimize", "off_optimize"}; -const std::string IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance"; +const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; +const std::set kBufferOptimizeSupportOption = {"l1_optimize", "l2_optimize", "off_optimize", + "l1_and_l2_optimize"}; +// The function is incomplete. Currently, only l2_optimize, off_optimize is supported. +const char *const kBufferOptimizeSupport = "only support l2_optimize, off_optimize"; +const char *const IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT = "high_performance"; +const char *const IR_OPTION_OP_SELECT_IMPLMODE_PRECISON = "high_precision"; +const char *const kCompressWeightError = "it must be appointed when appoint parameter[--optypelist_for_implmode]"; +const char *const kSelectImplmodeError = "only support high_performance, high_precision"; +const char *const kDynamicBatchSizeError = "It can only contains digit, \",\", \" \""; + } // namespace bool CheckDynamicBatchSizeInputShapeValid(unordered_map> shape_map, @@ -42,7 +55,7 @@ bool CheckDynamicBatchSizeInputShapeValid(unordered_map> for (auto iter = shape_map.begin(); iter != shape_map.end(); ++iter) { vector shape = iter->second; if (shape.size() < 1) { - ErrorManager::GetInstance().ATCReportErrMessage("E10017"); + ErrorManager::GetInstance().ATCReportErrMessage("E10012"); GELOGE(ge::PARAM_INVALID, "--input_shape's shape size can not be less than 1 when set --dynamic_batch_size."); return false; } @@ -61,16 +74,17 @@ bool CheckDynamicBatchSizeInputShapeValid(unordered_map> } if (size == 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10043"); + ErrorManager::GetInstance().ATCReportErrMessage("E10031"); GELOGE(ge::PARAM_INVALID, "At least one batch n must be equal to -1 when set --dynamic_batch_size."); return false; } for (char c : dynamic_batch_size) { if (!isdigit(c) && (c != ',') && (c != ' ')) { - ErrorManager::GetInstance().ATCReportErrMessage("E10047", {"value"}, {dynamic_batch_size}); - GELOGE(ge::PARAM_INVALID, "Input parameter[--dynamic_batch_size]'s value[%s] is invalid.", - dynamic_batch_size.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"value", "reason"}, + {dynamic_batch_size, kDynamicBatchSizeError}); + GELOGE(ge::PARAM_INVALID, "Input parameter[--dynamic_batch_size]'s value[%s] is invalid. reason: %s", + dynamic_batch_size.c_str(), kDynamicBatchSizeError); return false; } } @@ -90,7 +104,7 @@ bool CheckDynamicImagesizeInputShapeValid(unordered_map> if (std::count(shape.begin(), shape.end(), kDynamicInputDim) > 0) { ErrorManager::GetInstance().ATCReportErrMessage("E10019"); GELOGE(ge::PARAM_INVALID, - "--input_shape's shape is invalid, only height or width can be -1 when set --dynamic_image_size."); + "--input_shape's shape is invalid, only height and width can be -1 when set --dynamic_image_size."); return false; } continue; @@ -116,21 +130,18 @@ bool CheckDynamicImagesizeInputShapeValid(unordered_map> } else { ErrorManager::GetInstance().ATCReportErrMessage("E10019"); GELOGE(ge::PARAM_INVALID, - "--input_shape's shape is invalid, only height or width can be -1 when set --dynamic_image_size."); + "--input_shape's shape is invalid, only height and width can be -1 when set --dynamic_image_size."); return false; } } if (size == 0) { ErrorManager::GetInstance().ATCReportErrMessage("E10019"); GELOGE(ge::PARAM_INVALID, - "--input_shape's shape is invalid, only height or width can be -1 when set --dynamic_image_size."); + "--input_shape's shape is invalid, only height and width can be -1 when set --dynamic_image_size."); return false; } - if (dynamic_image_size.back() == ';') { - dynamic_image_size.erase(dynamic_image_size.end() - 1); - } - + EraseEndSemicolon(dynamic_image_size); // Different parameter sets are split string by ';' std::vector split_set = StringUtils::Split(dynamic_image_size, ';'); // Different dimensions are split by ',' @@ -151,17 +162,106 @@ bool CheckDynamicImagesizeInputShapeValid(unordered_map> return true; } -Status CheckDynamicBatchSizeOrImageSizeParamValid(std::string &dynamic_batch_size, std::string &dynamic_image_size, - const std::string input_shape, const std::string input_format, - bool &is_dynamic_input) { - if (!dynamic_batch_size.empty() && !dynamic_image_size.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10009", {"parameter0", "parameter1"}, - {"dynamic_batch_size", "dynamic_image_size"}); - GELOGE(ge::PARAM_INVALID, "dynamic_batch_size and dynamic_image_size can not both exist"); +bool CheckDynamicDimsInputShapeValid(const unordered_map> &shape_map, string input_format, + string &dynamic_dims) { + if (input_format != "ND") { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--input_format", input_format.c_str(), "input_format must be ND when set dynamic_dims"}); + GELOGE(ge::PARAM_INVALID, "input_format must be ND when set dynamic_dims."); + return false; + } + + int32_t dynamic_dim = 0; + for (auto &info_shapes : shape_map) { + auto &shapes = info_shapes.second; + if (shapes.size() > kMaxNDDimNum || shapes.size() < kMinNDDimNum) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--input_shape's dim", std::to_string(shapes.size()), "Dim num must within [1, 4] when set dynamic_dims"}); + GELOGE(ge::PARAM_INVALID, "Dim num must within [%zu, %zu] when set dynamic_dims.", kMinNDDimNum, kMaxNDDimNum); + return false; + } + int tmp = std::count(shapes.begin(), shapes.end(), kDynamicInputDim); + if (dynamic_dim != 0 && dynamic_dim != tmp) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--input_shape's -1 num", std::to_string(tmp), "Every set's num of -1 must be same"}); + GELOGE(ge::PARAM_INVALID, "input_shape's shape is invalid, every set's num of -1 must be same."); + return false; + } + dynamic_dim = tmp; + } + if (dynamic_dim == 0) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--input_shape's dynamic dim num", "0", "at least one dim should be -1 when set dynamic_dims"}); + GELOGE(ge::PARAM_INVALID, "input_shape's shape is invalid, at least one dim should be -1 when set dynamic_dims."); + return false; + } + + if (!CheckAndParseDynamicDims(dynamic_dim, dynamic_dims)) { + GELOGE(ge::PARAM_INVALID, "Check and parse dynamic dims: %s failed.", dynamic_dims.c_str()); + return false; + } + + return true; +} + +bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims) { + EraseEndSemicolon(dynamic_dims); + if (dynamic_dims.empty()) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--dynamic_dims", dynamic_dims.c_str(), "dynamic_dims can not be empty"}); + GELOGE(ge::PARAM_INVALID, "dynamic_dims can not be empty."); + return false; + } + // Different parameter sets are split by ';' + vector split_set = StringUtils::Split(dynamic_dims, ';'); + if (split_set.size() > kMaxDynamicDimNum) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10042", {"parameter", "reason"}, {"dynamic_dims", "dynamic_dims's num of parameter set can not exceed 100"}); + GELOGE(ge::PARAM_INVALID, "dynamic_dims's num of parameter set can not exceed %zu.", kMaxDynamicDimNum); + return false; + } + for (auto split_dim : split_set) { + vector one_set = StringUtils::Split(split_dim, ','); + if (one_set.size() != static_cast(dynamic_dim_num)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--dynamic_dims's parameter num of each set", std::to_string(one_set.size()), + "must be same as input_shape's num of -1"}); + GELOGE(ge::PARAM_INVALID, "dynamic_dims's parameter num of each set must be same as input_shape's num of -1."); + return false; + } + for (auto dim : one_set) { + for (auto c : dim) { + if (!isdigit(c)) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--dynamic_dims's parameter", dim.c_str(), "must be positive integer"}); + GELOGE(ge::PARAM_INVALID, "dynamic_dims's parameter must be positive integer."); + return false; + } + } + } + } + return true; +} + +Status CheckDynamicInputParamValid(string &dynamic_batch_size, string &dynamic_image_size, string &dynamic_dims, + const string input_shape, const string input_format, bool &is_dynamic_input) { + int32_t param_size = static_cast(!dynamic_batch_size.empty()) + + static_cast(!dynamic_image_size.empty()) + static_cast(!dynamic_dims.empty()); + if (param_size > 1) { + ErrorManager::GetInstance().ATCReportErrMessage("E10009", {"parameter0", "parameter1", "parameter2"}, + {"dynamic_batch_size", "dynamic_image_size", "dynamic_dims"}); + GELOGE(ge::PARAM_INVALID, "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one"); return ge::PARAM_INVALID; } - if (dynamic_batch_size.empty() && dynamic_image_size.empty()) { + if (param_size == 0) { return ge::SUCCESS; } @@ -169,8 +269,8 @@ Status CheckDynamicBatchSizeOrImageSizeParamValid(std::string &dynamic_batch_siz vector>> user_shape_map; is_dynamic_input = true; if (input_shape.empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"input_shape"}); - GELOGE(ge::PARAM_INVALID, "The input_shape can not be empty in dynamic batchsize scenario."); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"input_shape"}); + GELOGE(ge::PARAM_INVALID, "The input_shape can not be empty in dynamic input size scenario."); return ge::PARAM_INVALID; } @@ -192,86 +292,22 @@ Status CheckDynamicBatchSizeOrImageSizeParamValid(std::string &dynamic_batch_siz return ge::PARAM_INVALID; } } - return ge::SUCCESS; -} - -bool ParseInputShape(const string &input_shape, unordered_map> &shape_map, - vector>> &user_shape_map, bool is_dynamic_input) { - vector shape_vec = StringUtils::Split(input_shape, ';'); - const int DEFAULT_SHAPE_PAIR_SIZE = 2; - for (const auto &shape : shape_vec) { - vector shape_pair_vec = StringUtils::Split(shape, ':'); - if (shape_pair_vec.size() != DEFAULT_SHAPE_PAIR_SIZE) { - ErrorManager::GetInstance().ATCReportErrMessage("E10010", {"shape"}, {shape}); - GELOGW( - "Input parameter[--input_shape]’s shape is [%s], " - "correct sample is input_name1:n1,c1,h1,w1", - shape.c_str()); - return false; - } - if (shape_pair_vec[1].empty()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10011", {"shape"}, {shape}); - GELOGW( - "Input parameter[--input_shape]’s shape is [%s], can not empty, " - "correct sample is input_name1:n1,c1,h1,w1", - shape.c_str()); - return false; - } - - vector shape_value_strs = StringUtils::Split(shape_pair_vec[1], ','); - vector shape_values; - for (auto &shape_value_str : shape_value_strs) { - // stoul: The method may throw an exception: invalid_argument/out_of_range - if (std::string::npos != shape_value_str.find('.')) { - ErrorManager::GetInstance().ATCReportErrMessage("E10012", {"shape"}, {shape_value_str}); - GELOGW("--input_shape's shape value[%s] exist float number the correct sample is \"input_name1:1,3,224,224\"", - shape_value_str.c_str()); - return false; - } - long left_result = 0; - try { - left_result = stol(StringUtils::Trim(shape_value_str)); - } catch (const std::out_of_range &) { - ErrorManager::GetInstance().ATCReportErrMessage("E10013", {"parameter", "shape"}, {"input_shape", shape}); - GELOGW("--input_shape’s shape_value_str[%s] cause out of range execption!", shape_value_str.c_str()); - return false; - } catch (const std::invalid_argument &) { - ErrorManager::GetInstance().ATCReportErrMessage("E10014", {"parameter", "shape"}, - {"input_shape", shape_value_str}); - GELOGW("--input_shape’s shape_value_str[%s] cause invalid argument!", shape_value_str.c_str()); - return false; - } catch (...) { - ErrorManager::GetInstance().ATCReportErrMessage("E10015", {"parameter", "shape"}, - {"input_shape", shape_value_str}); - GELOGW("--input_shape’s shape_value_str[%s] stol fail!", shape_value_str.c_str()); - return false; - } - int64_t result = left_result; - // - 1 is not currently supported - if (!is_dynamic_input && result <= 0) { - ErrorManager::GetInstance().ATCReportErrMessage("E10057", {"shape", "result"}, {shape, std::to_string(result)}); - GELOGW( - "Input parameter[--input_shape]’s shape value[%s] is invalid, " - "expect positive integer, but value is %ld.", - shape.c_str(), result); - return false; - } - shape_values.push_back(result); + if (!dynamic_dims.empty()) { + if (!CheckDynamicDimsInputShapeValid(shape_map, input_format, dynamic_dims)) { + GELOGE(ge::PARAM_INVALID, "Check dynamic dims: %s of input shape: %s failed.", dynamic_dims.c_str(), + input_shape.c_str()); + return ge::PARAM_INVALID; } - - shape_map.emplace(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); - user_shape_map.push_back(make_pair(StringUtils::Trim(shape_pair_vec[0]), shape_values)); } - - return true; + return ge::SUCCESS; } Status CheckOutputTypeParamValid(const std::string output_type) { if ((!output_type.empty()) && (kOutputTypeSupportDatatype.find(output_type) == kOutputTypeSupportDatatype.end())) { - ErrorManager::GetInstance().ATCReportErrMessage("E10042", {"value"}, {output_type}); - GELOGE(ge::PARAM_INVALID, "Invalid value for --output_type[%s], only support DT_FLOAT, DT_FLOAT16, DT_UINT8!!", - output_type.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--output_type", output_type, kOutputTypeSupport}); + GELOGE(ge::PARAM_INVALID, "Invalid value for --output_type[%s], %s.", output_type.c_str(), kOutputTypeSupport); return ge::PARAM_INVALID; } return ge::SUCCESS; @@ -280,23 +316,23 @@ Status CheckOutputTypeParamValid(const std::string output_type) { Status CheckBufferOptimizeParamValid(const std::string buffer_optimize) { if ((!buffer_optimize.empty()) && (kBufferOptimizeSupportOption.find(buffer_optimize) == kBufferOptimizeSupportOption.end())) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E10068", {"parameter", "value", "reason"}, - {"buffer_optimize", buffer_optimize, "only support l2_optimize or off_optimize"}); - GELOGE(ge::PARAM_INVALID, "buffer_optimize flag %s is invalid, only support l2_optimize or off_optimize", - buffer_optimize.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--buffer_optimize", buffer_optimize, kBufferOptimizeSupport}); + GELOGE(ge::PARAM_INVALID, "Invalid value for --buffer_optimize[%s], %s.", buffer_optimize.c_str(), + kBufferOptimizeSupport); return ge::PARAM_INVALID; } return ge::SUCCESS; } Status CheckCompressWeightParamValid(const std::string enable_compress_weight, const std::string compress_weight_conf) { - if ((!compress_weight_conf.empty()) && - (!CheckInputPathValid(compress_weight_conf, ge::ir_option::COMPRESS_WEIGHT_CONF))) { + if ((!compress_weight_conf.empty()) && (!CheckInputPathValid(compress_weight_conf, "--compress_weight_conf"))) { GELOGE(ge::PARAM_INVALID, "compress weight config file not found, file_name:%s", compress_weight_conf.c_str()); return ge::PARAM_INVALID; } if ((enable_compress_weight != "") && (enable_compress_weight != "true") && (enable_compress_weight != "false")) { + ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, + {"enable_compress_weight", enable_compress_weight}); GELOGE(ge::PARAM_INVALID, "Input parameter[--enable_compress_weight]'s value[%s] must be true or false.", enable_compress_weight.c_str()); return ge::PARAM_INVALID; @@ -336,7 +372,7 @@ int CheckLogParamValidAndSetLogLevel(const std::string log) { } Status CheckInsertOpConfParamValid(const std::string insert_op_conf) { - if ((!insert_op_conf.empty()) && (!CheckInputPathValid(insert_op_conf, ge::ir_option::INSERT_OP_FILE))) { + if ((!insert_op_conf.empty()) && (!CheckInputPathValid(insert_op_conf, "--insert_op_conf"))) { GELOGE(ge::PARAM_INVALID, "insert op config file not found: %s", insert_op_conf.c_str()); return ge::PARAM_INVALID; } @@ -361,7 +397,7 @@ Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory) Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream) { if ((enable_single_stream != "") && (enable_single_stream != "true") && (enable_single_stream != "false")) { - ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"parameter", "value"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, {"enable_single_stream", enable_single_stream}); GELOGE(ge::PARAM_INVALID, "Input parameter[--enable_single_stream]'s value[%s] must be true or false.", enable_single_stream.c_str()); @@ -373,16 +409,28 @@ Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream) Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode) { // only appointed op_select_implmode, can user appoint optypelist_for_implmode if (optypelist_for_implmode != "" && op_select_implmode == "") { - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"op_select_implmode"}); - GELOGE( - ge::FAILED, - "Input parameter[--op_select_implmode] must be appointed when appoint parameter[--optypelist_for_implmode]."); + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--op_select_implmode", op_select_implmode.c_str(), kCompressWeightError}); + GELOGE(ge::PARAM_INVALID, "Invalid value for --op_select_implmode[%s], %s.", op_select_implmode.c_str(), + kCompressWeightError); return ge::PARAM_INVALID; } // op_select_implmode default value is high_performance if (op_select_implmode == "") { op_select_implmode = IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT; + } else { + if (op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_DEFAULT && + op_select_implmode != IR_OPTION_OP_SELECT_IMPLMODE_PRECISON) { + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--op_select_implmode", op_select_implmode.c_str(), kSelectImplmodeError}); + GELOGE(ge::PARAM_INVALID, "Invalid value for --op_select_implmode[%s], %s.", op_select_implmode.c_str(), + kSelectImplmodeError); + return ge::PARAM_INVALID; + } } + return ge::SUCCESS; } @@ -393,4 +441,13 @@ void PrintOptionMap(std::map &options, std::string tip GELOGI("%s set successfully, key=%s, value=%s", tips.c_str(), key.c_str(), option_name.c_str()); } } + +void EraseEndSemicolon(string ¶m) { + if (param.empty()) { + return; + } + if (param.back() == ';') { + param.erase(param.end() - 1); + } +} } // namespace ge diff --git a/src/ge/ir_build/atc_ir_common.h b/src/ge/ir_build/atc_ir_common.h index b0a2b08b..e4d3103b 100644 --- a/src/ge/ir_build/atc_ir_common.h +++ b/src/ge/ir_build/atc_ir_common.h @@ -29,10 +29,12 @@ #include "framework/omg/omg_inner_types.h" namespace ge { - static std::set caffe_support_input_format = {"NCHW", "ND"}; static std::set tf_support_input_format = {"NCHW", "NHWC", "ND", "NCDHW", "NDHWC"}; static std::set onnx_support_input_format = {"NCHW", "ND"}; +static const char *const kCaffeFormatSupport = "only support NCHW, ND in Caffe model"; +static const char *const kTFFormatSupport = "only support NCHW, NHWC, ND, NCDHW, NDHWC in TF model"; +static const char *const kONNXFormatSupport = "only support NCHW, ND in ONNX model"; static std::map input_format_str_to_geformat = { {"ND", domi::DOMI_TENSOR_ND}, {"NCHW", domi::DOMI_TENSOR_NCHW}, {"NHWC", domi::DOMI_TENSOR_NHWC}, @@ -47,12 +49,14 @@ bool CheckDynamicBatchSizeInputShapeValid(unordered_map> bool CheckDynamicImagesizeInputShapeValid(unordered_map> shape_map, const std::string input_format, std::string &dynamic_image_size); -Status CheckDynamicBatchSizeOrImageSizeParamValid(std::string &dynamic_batch_size, std::string &dynamic_image_size, - const std::string input_shape, const std::string input_format, - bool &is_dynamic_input); +bool CheckDynamicDimsInputShapeValid(const std::unordered_map> &shape_map, + std::string input_format, std::string &dynamic_dims); + +bool CheckAndParseDynamicDims(int32_t dynamic_dim_num, std::string &dynamic_dims); -bool ParseInputShape(const std::string &input_shape, std::unordered_map> &shape_map, - std::vector>> &user_shape_map, bool is_dynamic_input = false); +Status CheckDynamicInputParamValid(std::string &dynamic_batch_size, std::string &dynamic_image_size, + std::string &dynamic_dims, const std::string input_shape, + const std::string input_format, bool &is_dynamic_input); Status CheckOutputTypeParamValid(const std::string output_type); Status CheckBufferOptimizeParamValid(const std::string buffer_optimize); @@ -63,5 +67,6 @@ Status CheckDisableReuseMemoryParamValid(const std::string disable_reuse_memory) Status CheckEnableSingleStreamParamValid(const std::string enable_single_stream); Status CheckImplmodeParamValid(const std::string &optypelist_for_implmode, std::string &op_select_implmode); void PrintOptionMap(std::map &options, std::string tips); +void EraseEndSemicolon(std::string ¶m); } // namespace ge #endif // FRAMEWORK_DOMI_ATC_IR_COMMON_H_ diff --git a/src/ge/ir_build/ge_ir_build.cc b/src/ge/ir_build/ge_ir_build.cc index 0be75b51..a9ff1ab5 100644 --- a/src/ge/ir_build/ge_ir_build.cc +++ b/src/ge/ir_build/ge_ir_build.cc @@ -26,6 +26,7 @@ #include "framework/common/util.h" #include "framework/omg/omg_inner_types.h" #include "framework/omg/omg_inner_types.h" +#include "common/model_parser/graph_parser_util.h" #include "ge/ge_api_types.h" #include "generator/ge_generator.h" #include "graph/compute_graph.h" @@ -151,6 +152,7 @@ class Impl { GetContext().is_dynamic_input = false; GetContext().dynamic_batch_size.clear(); GetContext().dynamic_image_size.clear(); + GetContext().dynamic_dims.clear(); }; ~Impl() { (void)generator_.Finalize(); }; graphStatus CheckOptions(const std::map &options); @@ -200,17 +202,20 @@ graphStatus Impl::Init(const std::map &options) { string dynamic_image_size = options_.find(ge::ir_option::DYNAMIC_IMAGE_SIZE) == options_.end() ? "" : options_[ge::ir_option::DYNAMIC_IMAGE_SIZE]; + string dynamic_dims = + options_.find(ge::ir_option::DYNAMIC_DIMS) == options_.end() ? "" : options_[ge::ir_option::DYNAMIC_DIMS]; - auto status = CheckDynamicBatchSizeOrImageSizeParamValid(dynamic_batch_size, dynamic_image_size, input_shape, - input_format, is_dynamic_input_); + auto status = CheckDynamicInputParamValid(dynamic_batch_size, dynamic_image_size, dynamic_dims, input_shape, + input_format, is_dynamic_input_); if (status != ge::SUCCESS) { - GELOGE(GRAPH_PARAM_INVALID, "check dynamic batch size or image size failed!"); + GELOGE(GRAPH_PARAM_INVALID, "Check dynamic input size failed!"); return GRAPH_PARAM_INVALID; } - GELOGD("user input dynamic_batch_size:%s,dynamic_image_size:%s", dynamic_batch_size.c_str(), - dynamic_image_size.c_str()); + GELOGD("User input dynamic_batch_size:%s, dynamic_image_size:%s, dynamic_dims:%s.", dynamic_batch_size.c_str(), + dynamic_image_size.c_str(), dynamic_dims.c_str()); GetContext().dynamic_batch_size = dynamic_batch_size; GetContext().dynamic_image_size = dynamic_image_size; + GetContext().dynamic_dims = dynamic_dims; // check output_type std::string output_type = options_.find(ge::ir_option::OUTPUT_TYPE) == options_.end() ? "" : options_[ge::ir_option::OUTPUT_TYPE]; @@ -243,11 +248,13 @@ graphStatus Impl::Init(const std::map &options) { graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector &inputs) { auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); GE_CHECK_NOTNULL(compute_graph); + int64_t index = 0; for (ge::NodePtr &input_node : compute_graph->GetDirectNode()) { GE_CHECK_NOTNULL(input_node); ge::OpDescPtr op = input_node->GetOpDesc(); GE_CHECK_NOTNULL(op); if (op->GetType() == DATA) { + (void)AttrUtils::SetInt(op, ATTR_NAME_INDEX, index++); GELOGI("Data op inputDesc size is: %zu", op->GetAllInputsDesc().size()); ge::GeTensorDesc tensor = op->GetInputDesc(0); string data_op_name = op->GetName(); @@ -259,7 +266,7 @@ graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector(model.data.get()), static_cast(model.length)); } - } // namespace ge diff --git a/src/ge/model/ge_model.h b/src/ge/model/ge_model.h index 6305211a..be4b65bc 100644 --- a/src/ge/model/ge_model.h +++ b/src/ge/model/ge_model.h @@ -87,6 +87,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeModel : public AttrHolder uint8_t platform_type_ = {0}; uint32_t model_id_ = INVALID_MODEL_ID; }; -}; // namespace ge +} // namespace ge using GeModelPtr = std::shared_ptr; #endif // GE_MODEL_GE_MODEL_H_ diff --git a/src/ge/offline/main.cc b/src/ge/offline/main.cc index 61a843c3..214e495a 100644 --- a/src/ge/offline/main.cc +++ b/src/ge/offline/main.cc @@ -26,6 +26,7 @@ #include "common/gflags_util.h" #include "common/util.h" #include "common/util/error_manager/error_manager.h" +#include "common/model_parser/graph_parser_util.h" #include "framework/common/debug/ge_log.h" #include "ge/ge_api.h" #include "generator/ge_generator.h" @@ -66,6 +67,10 @@ static bool is_dynamic_input = false; // 310 limited 8G size const char *const kGraphMemoryManagerMallocMaxSize = "8*1024*1024*1024"; +const char *const kModeSupport = + "only support 0(model to framework model), " + "1(framework model to json), 3(only pre-check), 5(pbtxt to json)"; +const char *const kModelToJsonSupport = "only support 0(Caffe) 3(TensorFlow)"; DEFINE_string(model, "", "The model file."); DEFINE_string(output, "", "The output file path&name."); @@ -138,10 +143,6 @@ DEFINE_string(optypelist_for_implmode, "", "Optional; Nodes need use implmode selected in op_select_implmode " "Format:\"node_name1,node_name2\""); -DEFINE_string(head_stream, "0", - "Optional; Is need head stream, default is not need." - "Format: \"0: no head stream; 1: add head stream;\""); - DEFINE_string(singleop, "", "Optional; If set, generate single op model with the given json file."); DEFINE_int32(disable_reuse_memory, 0, "Optional; If set to 1, disable reuse memory when generating if."); @@ -163,26 +164,36 @@ DEFINE_string(save_original_model, "", "Optional; enable output original offline DEFINE_string(dynamic_batch_size, "", "Optional; If set, generate dynamic multi batch model. " "Different batch sizes are split by ','." - "dynamic_batch_size and dynamic_imagesize can only be set one."); + "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); DEFINE_string(dynamic_image_size, "", "Optional; If set, generate dynamic multi image size model." "Different groups of image size are split by ';'," "while different dimensions of each group are split by ','." - "dynamic_batch_size and dynamic_imagesize can only be set one."); + "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); + +DEFINE_string(dynamic_dims, "", + "Optional; If set, generate dynamic input size model. " + "Different groups of size are split by ';', while different dimensions of each group are split by ','." + "dynamic_batch_size, dynamic_image_size and dynamic_dims can only be set one."); DEFINE_string(enable_small_channel, "0", "Optional; If set to 1, small channel is enabled."); -DEFINE_bool(enable_compress_weight, false, "Optional; enable compress weight. true: enable; false(default): disable"); +DEFINE_string(enable_compress_weight, "false", + "Optional; enable compress weight. true: enable; false(default): disable"); DEFINE_string(compress_weight_conf, "", "Optional; the config file to compress weight"); DEFINE_string(enable_single_stream, "", "Optional; enable single stream. true: enable; false(default): disable"); -DEFINE_string(log, "default", "Optional; generate atc log. Support debug, info, warning, error, null"); +DEFINE_string(log, "null", "Optional; generate atc log. Support debug, info, warning, error, null"); DEFINE_string(dump_mode, "0", "Optional; generate infershape json,only support 1 , 0."); +DEFINE_int32(op_debug_level, 0, + "Optional; configure debug level of compiler. 0(default): close debug;" + "1: open TBE compiler, export ccec file and TBE instruction mapping file; 2: open ccec compiler"); + class GFlagUtils { public: /** @@ -232,7 +243,7 @@ class GFlagUtils { "\"check_result.json\"\n" " --disable_reuse_memory The switch of reuse memory. Default value is : 0." "0 means reuse memory, 1 means do not reuse memory.\n" - " --input_fp16_nodes Input node datatype is fp16 and format is NCHW. Separate multiple nodes with semicolons " + " --input_fp16_nodes Input node datatype is fp16. Separate multiple nodes with semicolons " "(;)." "Use double quotation marks (\") to enclose each argument." "E.g.: \"node_name1;node_name2\"\n" @@ -252,8 +263,7 @@ class GFlagUtils { " --optypelist_for_implmode Appoint which op to use op_select_implmode, used with op_select_implmode ." "Separate multiple nodes with commas (,). Use double quotation marks (\") to enclose each argument." "E.g.: \"node_name1,node_name2\"\n" - " --head_stream Add head stream. 0(default): disable; 1: enable\n" - " --soc_version The soc version. E.g.: \"Ascend310\"\n" + " --soc_version The soc version.\n" " --core_type Set core type AiCore or VectorCore. VectorCore: use vector core. " "Default value is: AiCore\n" " --enable_compress_weight Enable compress weight. true: enable; false(default): disable\n" @@ -280,7 +290,7 @@ class GFlagUtils { static Status CheckDumpInfershapeJsonFlags() { Status ret = CheckFrameWorkValid(FLAGS_framework, FLAGS_weight); GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "check custom aicpu run so failed!"); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "weight"), + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"), return domi::FAILED, "Input parameter[--weight]'s value[%s] is invalid!", FLAGS_weight.c_str()); return domi::SUCCESS; @@ -289,7 +299,7 @@ class GFlagUtils { static Status CheckFlags() { // No model file information passed in GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_model == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"model"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"model"}); return domi::PARAM_INVALID, "Input parameter[--model]'s value is empty!"); // check param disable_reuse_memory GE_CHK_BOOL_EXEC(ge::CheckDisableReuseMemoryParamValid(to_string(FLAGS_disable_reuse_memory)) == ge::SUCCESS, @@ -301,16 +311,16 @@ class GFlagUtils { return ge::FAILED, "check optypelist_for_implmode and op_select_implmode failed!"); // No output file information passed in GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_mode == GEN_OM_MODEL && FLAGS_output == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"output"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"output"}); return domi::PARAM_INVALID, "Input parameter[--output]'s value is empty!"); Status ret = CheckFrameWorkValid(FLAGS_framework, FLAGS_weight); GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "CheckFrameWorkValid failed"); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ge::CheckDynamicBatchSizeOrImageSizeParamValid( - FLAGS_dynamic_batch_size, FLAGS_dynamic_image_size, FLAGS_input_shape, - FLAGS_input_format, is_dynamic_input) != ge::SUCCESS, - return ge::FAILED, "check dynamic batch size or image size failed!"); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( + ge::CheckDynamicInputParamValid(FLAGS_dynamic_batch_size, FLAGS_dynamic_image_size, FLAGS_dynamic_dims, + FLAGS_input_shape, FLAGS_input_format, is_dynamic_input) != ge::SUCCESS, + return ge::FAILED, "check dynamic size(batch size, image size or dims) failed!"); #if !defined(__ANDROID__) && !defined(ANDROID) GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!CheckEncryptModeValid(FLAGS_encrypt_mode), return domi::FAILED, @@ -320,16 +330,16 @@ class GFlagUtils { GELOGI("domi will run with encrypt!"); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_encrypt_key), return domi::FAILED, - "encrypt_key file %s not found!!", FLAGS_encrypt_key.c_str()); + "encrypt_key file not found!!"); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_certificate), return domi::FAILED, - "certificate file %s not found!!", FLAGS_certificate.c_str()); + "certificate file not found!!"); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_hardware_key), return domi::FAILED, - "hardware_key file %s not found!!", FLAGS_hardware_key.c_str()); + "hardware_key file not found!!"); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_private_key), return domi::FAILED, - "private_key file %s not found!!", FLAGS_private_key.c_str()); + "private_key file not found!!"); } else { // No encryption GELOGI("domi will run without encrypt!"); } @@ -338,43 +348,37 @@ class GFlagUtils { /** * Check the validity of the I / O file path */ - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_model, "model"), return domi::FAILED, + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_model, "--model"), return domi::FAILED, "model file %s not found!!", FLAGS_model.c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "weight"), + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_weight != "" && !ge::CheckInputPathValid(FLAGS_weight, "--weight"), return domi::FAILED, "weight file %s not found!!", FLAGS_weight.c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_cal_conf != "" && !ge::CheckInputPathValid(FLAGS_cal_conf, "cal_conf"), + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_cal_conf != "" && !ge::CheckInputPathValid(FLAGS_cal_conf, "--cal_conf"), return domi::FAILED, "calibration config file %s not found!!", FLAGS_cal_conf.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_op_name_map != "" && !ge::CheckInputPathValid(FLAGS_op_name_map, "op_name_map"), return domi::FAILED, + FLAGS_op_name_map != "" && !ge::CheckInputPathValid(FLAGS_op_name_map, "--op_name_map"), return domi::FAILED, "op config file %s not found!!", FLAGS_op_name_map.c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_head_stream != "" && FLAGS_head_stream != "0" && FLAGS_head_stream != "1", - ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"head_stream"}); - return domi::FAILED, "Input parameter[--head_stream] must be 0 or 1!!"); - GE_CHK_BOOL_EXEC(ge::CheckInsertOpConfParamValid(std::string(FLAGS_insert_op_conf)) == ge::SUCCESS, return ge::FAILED, "check insert op conf failed!"); GE_CHK_BOOL_EXEC( - ge::CheckCompressWeightParamValid(FLAGS_enable_compress_weight ? std::string("true") : std::string("false"), - FLAGS_compress_weight_conf) == ge::SUCCESS, + ge::CheckCompressWeightParamValid(FLAGS_enable_compress_weight, FLAGS_compress_weight_conf) == ge::SUCCESS, return ge::FAILED, "check compress weight failed!"); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckOutputPathValid(FLAGS_check_report, "check_report"), return domi::FAILED, + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckOutputPathValid(FLAGS_check_report, "--check_report"), return domi::FAILED, "check_report file %s not found!!", FLAGS_check_report.c_str()); - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( - FLAGS_mode == GEN_OM_MODEL && (!ge::CheckOutputPathValid(FLAGS_output) || !CheckPathWithName(FLAGS_output)), - return domi::FAILED, "output path %s is not valid!!", FLAGS_output.c_str()); + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_mode == GEN_OM_MODEL && (!ge::CheckOutputPathValid(FLAGS_output, "--output") || + !CheckPathWithName(FLAGS_output)), + return domi::FAILED, "output path %s is not valid!!", FLAGS_output.c_str()); GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( FLAGS_save_original_model != "" && FLAGS_save_original_model != "true" && FLAGS_save_original_model != "false", - ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"parameter", "value"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, {"save_original_model", FLAGS_save_original_model}); return domi::FAILED, "Input parameter[--save_original_model]'s value[%s] must be true or false.", FLAGS_save_original_model.c_str()); @@ -395,18 +399,18 @@ class GFlagUtils { static Status CheckConverJsonParamFlags() { // No model path passed in GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_om == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"om"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"om"}); return domi::PARAM_INVALID, "Input parameter[--om]'s value is empty!!"); // JSON path not passed in GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(FLAGS_json == "", - ErrorManager::GetInstance().ATCReportErrMessage("E10000", {"parameter"}, {"json"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"parameter"}, {"json"}); return domi::PARAM_INVALID, "Input parameter[--json]'s value is empty!!"); // Check if the model path is valid - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_om, "om"), return domi::PARAM_INVALID, + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckInputPathValid(FLAGS_om, "--om"), return domi::PARAM_INVALID, "model file path is invalid: %s.", FLAGS_om.c_str()); // Check whether the JSON path is valid - GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckOutputPathValid(FLAGS_json, "om"), return domi::PARAM_INVALID, + GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(!ge::CheckOutputPathValid(FLAGS_json, "--json"), return domi::PARAM_INVALID, "json file path is invalid: %s.", FLAGS_json.c_str()); return domi::SUCCESS; @@ -443,7 +447,8 @@ class GFlagUtils { if (framework != (int32_t)domi::CAFFE && framework != (int32_t)domi::TENSORFLOW && framework != (int32_t)domi::MINDSPORE && framework != (int32_t)domi::ONNX) { // No framework information was passed in or the entered framework is illegal - ErrorManager::GetInstance().ATCReportErrMessage("E10007", {"parameter"}, {"framework"}); + ErrorManager::GetInstance().ATCReportErrMessage("E10007", {"parameter", "support"}, + {"framework", "0(Caffe) or 1(MindSpore) or 3(TensorFlow)"}); DOMI_LOGE( "Input parameter[--framework] is mandatory and it's value must be: " "0(Caffe) or 1(MindSpore) or 3(TensorFlow)."); @@ -494,13 +499,16 @@ class GFlagUtils { } }; -void SetDynamicBatchSizeOrImagesizeOptions() { +void SetDynamicInputSizeOptions() { if (!FLAGS_dynamic_batch_size.empty()) { domi::GetContext().dynamic_batch_size = FLAGS_dynamic_batch_size; } if (!FLAGS_dynamic_image_size.empty()) { domi::GetContext().dynamic_image_size = FLAGS_dynamic_image_size; } + if (!FLAGS_dynamic_dims.empty()) { + domi::GetContext().dynamic_dims = FLAGS_dynamic_dims; + } } static bool CheckInputFormat() { @@ -516,31 +524,29 @@ static bool CheckInputFormat() { if (ge::caffe_support_input_format.find(FLAGS_input_format) != ge::caffe_support_input_format.end()) { return true; } - ErrorManager::GetInstance().ATCReportErrMessage("E10031", {"value"}, {FLAGS_input_format}); // only support NCHW ND - GELOGE(ge::FAILED, - "Input parameter[--input_format]'s value[%s] is wrong, " - "only support NCHW, ND in Caffe model.", - FLAGS_input_format.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--input_format", FLAGS_input_format, ge::kCaffeFormatSupport}); + GELOGE(ge::FAILED, "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), + ge::kCaffeFormatSupport); return false; } else if ((FLAGS_framework == static_cast(domi::TENSORFLOW))) { // tf if (ge::tf_support_input_format.find(FLAGS_input_format) != ge::tf_support_input_format.end()) { return true; } - ErrorManager::GetInstance().ATCReportErrMessage("E10032", {"value"}, {FLAGS_input_format}); // only support NCHW NHWC ND NCDHW NDHWC - GELOGE(ge::FAILED, - "Input parameter[--input_format]'s value[%s] is wrong, " - "only support NCHW, NHWC, ND, NCDHW, NDHWC in tf model", - FLAGS_input_format.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--input_format", FLAGS_input_format, ge::kTFFormatSupport}); + GELOGE(ge::FAILED, "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kTFFormatSupport); return false; } else if (FLAGS_framework == static_cast(domi::ONNX)) { if (ge::onnx_support_input_format.find(FLAGS_input_format) != ge::onnx_support_input_format.end()) { return true; } // only support NCHW ND - GELOGE(ge::FAILED, "Input parameter[--input_format]'s value[%s] is error, Only support NCHW, ND in onnx model", - FLAGS_input_format.c_str()); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--input_format", FLAGS_input_format, ge::kONNXFormatSupport}); + GELOGE(ge::FAILED, "Invalid value for --input_format[%s], %s.", FLAGS_input_format.c_str(), ge::kONNXFormatSupport); return false; } return true; @@ -579,12 +585,12 @@ void GetPluginSoFileList(const string &path, vector &fileList, string &c void LoadModelParserLib(std::string caffe_parser_path) { if (FLAGS_framework == static_cast(domi::TENSORFLOW)) { - void *tf_handle = dlopen("libfmk_tensorflow_parser.so", RTLD_NOW | RTLD_GLOBAL); + void *tf_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL); if (tf_handle == nullptr) { - GELOGW("dlopen fmk library [libfmk_tensorflow_parser.so] failed."); + GELOGW("dlopen fmk library [libfmk_parser.so] failed."); return; } - GELOGI("plugin load libfmk_tensorflow_parser.so success."); + GELOGI("plugin load libfmk_parser.so success."); } else if (FLAGS_framework == static_cast(domi::CAFFE)) { // What we are dealing with here is that the user modifies the caffe.proto scenario. // If no lib_Caffe_Parser.so is found under the plugin path, use the default lib_Caffe_Parser.so path. @@ -596,17 +602,17 @@ void LoadModelParserLib(std::string caffe_parser_path) { return; } GELOGI("plugin load %s success.", caffe_parser_path.c_str()); - // According to the dependency, the Caffe parsing module of the framework is loaded here( libfmk_caffe_parser.so). + // According to the dependency, the Caffe parsing module of the framework is loaded here( libfmk_parser.so). // (depend on the lib_caffe_parser.so) - void *fmk_handle = dlopen("libfmk_caffe_parser.so", RTLD_NOW | RTLD_GLOBAL); + void *fmk_handle = dlopen("libfmk_parser.so", RTLD_NOW | RTLD_GLOBAL); if (fmk_handle == nullptr) { - GELOGW("dlopen fmk library [libfmk_caffe_parser.so] failed."); + GELOGW("dlopen fmk library [libfmk_parser.so] failed."); if (dlclose(handle) != 0) { GELOGW("dlclose lib_caffe_parser.so failed."); } return; } - GELOGI("plugin load libfmk_caffe_parser.so success."); + GELOGI("plugin load libfmk_parser.so success."); } else if (FLAGS_framework == static_cast(domi::ONNX)) { void *handle = dlopen("libfmk_onnx_parser.so", RTLD_NOW | RTLD_GLOBAL); if (handle == nullptr) { @@ -622,8 +628,7 @@ void LoadModelParserLib(std::string caffe_parser_path) { return; } -void LoadCustomOpLib() { - OpRegistry::Instance()->registrationDatas.clear(); +void LoadCustomOpLib(bool need_load_ops_plugin) { std::string plugin_path; GetCustomOpPath(plugin_path); @@ -639,7 +644,11 @@ void LoadCustomOpLib() { } LoadModelParserLib(caffe_parser_path); - + if (!need_load_ops_plugin) { + GELOGI("No need to load ops plugin so."); + return; + } + OpRegistry::Instance()->registrationDatas.clear(); // load other so files except lib_caffe_parser.so in the plugin so path for (auto elem : fileList) { ge::StringUtils::Trim(elem); @@ -654,17 +663,23 @@ void LoadCustomOpLib() { std::vector registrationDatas = OpRegistry::Instance()->registrationDatas; for (OpRegistrationData reg_data : registrationDatas) { - bool ret = ge::OpRegistrationTbe::Instance()->Finalize(reg_data); - if (ret) { - OpRegistry::Instance()->Register(reg_data); + if (reg_data.GetFrameworkType() == static_cast(FLAGS_framework)) { + (void)ge::OpRegistrationTbe::Instance()->Finalize(reg_data); + (void)OpRegistry::Instance()->Register(reg_data); } } } void SaveCustomCaffeProtoPath() { GELOGI("Enter save custom caffe proto path."); - string customop_path; + std::string path_base = ge::GELib::GetPath(); + GELOGI("path_base is %s", path_base.c_str()); + path_base = path_base.substr(0, path_base.rfind('/')); + path_base = path_base.substr(0, path_base.rfind('/') + 1); + ge::GetParserContext().caffe_proto_path = path_base + "include/proto/"; + + string customop_path; const char *path_env = std::getenv("ASCEND_OPP_PATH"); if (path_env != nullptr) { std::string path = path_env; @@ -673,10 +688,6 @@ void SaveCustomCaffeProtoPath() { ge::GetParserContext().custom_proto_path = customop_path; return; } - std::string path_base = ge::GELib::GetPath(); - GELOGI("path_base is %s", path_base.c_str()); - path_base = path_base.substr(0, path_base.rfind('/')); - path_base = path_base.substr(0, path_base.rfind('/') + 1); customop_path = path_base + "ops/framework/custom/caffe/"; ge::GetParserContext().custom_proto_path = customop_path; return; @@ -720,15 +731,6 @@ Status CreateInputsForInference(const ge::Graph &graph, vector &in return ge::SUCCESS; } -void ChangeStringToBool(std::string &arg_s, bool arg_b) { - if (arg_s == "true") { - arg_b = true; - } else { - arg_b = false; - } - return; -} - domi::Status GenerateInfershapeJson() { if (!CheckInputFormat()) { GELOGE(ge::FAILED, "Check input_format failed"); @@ -737,8 +739,6 @@ domi::Status GenerateInfershapeJson() { Status ret = GFlagUtils::CheckDumpInfershapeJsonFlags(); GE_CHK_BOOL_EXEC(ret == domi::SUCCESS, return domi::FAILED, "Check flags failed!"); - // Load custom operator Library - LoadCustomOpLib(); ge::GeGenerator ge_generator; std::map options; ge::Status geRet = ge_generator.Initialize(options); @@ -780,24 +780,25 @@ static Status ConvertModelToJson(int fwk_type, const string &model_file, const s return ret; } - if ((fwk_type != domi::TENSORFLOW) && (fwk_type != domi::CAFFE)) { - ErrorManager::GetInstance().ATCReportErrMessage( - "E10068", {"param", "value", "supports"}, - {"framework", std::to_string(fwk_type), "only support 0(Caffe) 3(TensorFlow)"}); - GELOGE(ge::FAILED, "Input parameter[--framework] is mandatory and it's value must be: 0(Caffe) 3(TensorFlow)."); + if ((fwk_type != domi::TENSORFLOW) && (fwk_type != domi::CAFFE) && (fwk_type != domi::ONNX)) { + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--framework", std::to_string(fwk_type), kModelToJsonSupport}); + GELOGE(ge::FAILED, "Invalid value for --framework[%d], %s.", fwk_type, kModelToJsonSupport); return ge::FAILED; } - // Since the Caffe model's conversion to JSON file depends on lib_caffe_parser.so, loadcustomoplib is called here. - LoadCustomOpLib(); - if (FLAGS_dump_mode == "0") { + // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so. + LoadCustomOpLib(false); ret = ge::ConvertFwkModelToJson((domi::FrameworkType)fwk_type, model_file.c_str(), json_file.c_str()); return ret; } else if (FLAGS_dump_mode == "1") { + // Caffe or tf model to json depend on lib_caffe_parser.so or libfmk_parser.so and ops plugin so. + LoadCustomOpLib(true); ret = GenerateInfershapeJson(); return ret; } else { + ErrorManager::GetInstance().ATCReportErrMessage("E10006", {"parameter"}, {"dump_mode"}); GELOGE(ge::FAILED, "Input parameter[--dump_mode]'s value must be 1 or 0."); return ge::FAILED; } @@ -828,7 +829,7 @@ domi::Status GenerateModel(std::map &options, std::string output ge::Model load_model = ge::Model("loadmodel", "version2"); auto ret1 = load_model.LoadFromFile(FLAGS_model); if (ret1 != ge::GRAPH_SUCCESS) { - ErrorManager::GetInstance().ATCReportErrMessage("E10056", {"parameter"}, {FLAGS_model}); + ErrorManager::GetInstance().ATCReportErrMessage("E10041", {"parameter"}, {FLAGS_model}); DOMI_LOGE( "Load model from %s failed, please check model file or " "input parameter[--framework] is correct", @@ -893,7 +894,7 @@ domi::Status GenerateModel(std::map &options, std::string output (void)ge::GELib::GetInstance()->Finalize(); return domi::FAILED; } - if (SetOutputNodeInfo(graph, FLAGS_output_type, "") != domi::SUCCESS) { + if (ge::SetOutputNodeInfo(graph, FLAGS_output_type, "") != domi::SUCCESS) { DOMI_LOGE("Set output node info fail."); (void)ge_generator.Finalize(); (void)ge::GELib::GetInstance()->Finalize(); @@ -931,10 +932,11 @@ static void SetEnvForSingleOp(std::map &options) { options.emplace(ge::OPTYPELIST_FOR_IMPLMODE, FLAGS_optypelist_for_implmode); options.emplace(ge::AUTO_TUNE_MODE, FLAGS_auto_tune_mode); options.emplace(ge::GRAPH_MEMORY_MAX_SIZE, kGraphMemoryManagerMallocMaxSize); + options.emplace(ge::OP_DEBUG_LEVEL, to_string(FLAGS_op_debug_level)); } domi::Status GenerateSingleOp(const std::string &json_file_path) { - if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output)) { + if (!FLAGS_output.empty() && !ge::CheckOutputPathValid(FLAGS_output, "--output")) { DOMI_LOGE("output path %s is not valid!", FLAGS_output.c_str()); return domi::FAILED; } @@ -947,12 +949,6 @@ domi::Status GenerateSingleOp(const std::string &json_file_path) { // need to be changed when ge.ini plan is done SetEnvForSingleOp(options); - vector build_params; - if (ge::SingleOpParser::ParseSingleOpList(json_file_path, build_params) != ge::SUCCESS) { - DOMI_LOGE("parse single op json file failed"); - return domi::FAILED; - } - auto ret = ge::GELib::Initialize(options); if (ret != ge::SUCCESS) { DOMI_LOGE("GE initialize failed!"); @@ -967,6 +963,14 @@ domi::Status GenerateSingleOp(const std::string &json_file_path) { return domi::FAILED; } + vector build_params; + if (ge::SingleOpParser::ParseSingleOpList(json_file_path, build_params) != ge::SUCCESS) { + DOMI_LOGE("parse single op json file failed"); + (void)generator.Finalize(); + (void)ge::GELib::GetInstance()->Finalize(); + return domi::FAILED; + } + int index = 0; for (auto ¶m : build_params) { string output_path; @@ -1000,7 +1004,7 @@ domi::Status GenerateOmModel() { "quotation marks (\") to enclose each argument such as out_nodes, input_shape, dynamic_image_size"); #if !defined(__ANDROID__) && !defined(ANDROID) // Load custom operator Library - LoadCustomOpLib(); + LoadCustomOpLib(true); SaveCustomCaffeProtoPath(); @@ -1038,8 +1042,6 @@ domi::Status GenerateOmModel() { options.insert(std::pair(ge::INPUT_FP16_NODES, FLAGS_input_fp16_nodes)); } - options.insert(std::pair(string(ge::HEAD_STREAM), FLAGS_head_stream)); - options.insert(std::pair(string(ge::AUTO_TUNE_MODE), FLAGS_auto_tune_mode)); options.insert( @@ -1057,7 +1059,7 @@ domi::Status GenerateOmModel() { options.insert(std::pair(string(ge::FUSION_SWITCH_FILE), FLAGS_fusion_switch_file)); - options.insert(std::pair(string(ge::ENABLE_COMPRESS_WEIGHT), FLAGS_enable_compress_weight + options.insert(std::pair(string(ge::ENABLE_COMPRESS_WEIGHT), (FLAGS_enable_compress_weight == "true") ? ge::kEnableCompressWeightTrue : ge::kEnableCompressWeightFalse)); @@ -1065,13 +1067,15 @@ domi::Status GenerateOmModel() { options.insert(std::pair(string(ge::ENABLE_SINGLE_STREAM), FLAGS_enable_single_stream)); - SetDynamicBatchSizeOrImagesizeOptions(); + SetDynamicInputSizeOptions(); if (!FLAGS_save_original_model.empty()) { options.insert(std::pair(string(ge::SAVE_ORIGINAL_MODEL), FLAGS_save_original_model)); options.insert(std::pair(string(ge::ORIGINAL_MODEL_FILE), FLAGS_output + "_original.om")); } + options.insert(std::pair(string(ge::OP_DEBUG_LEVEL), to_string(FLAGS_op_debug_level))); + // print atc option map ge::PrintOptionMap(options, "atc option"); @@ -1095,8 +1099,8 @@ domi::Status ConvertModelToJson() { return domi::SUCCESS; } -bool CheckRet(domi::Status ret, ge::Status geRet) { - if (ret != domi::SUCCESS || geRet != ge::SUCCESS) { +bool CheckRet(domi::Status ret) { + if (ret != domi::SUCCESS) { if (FLAGS_mode == ONLY_PRE_CHECK) { GELOGW("ATC precheck failed."); } else if (FLAGS_mode == GEN_OM_MODEL) { @@ -1143,9 +1147,9 @@ int init(int argc, char *argv[]) { GFlagUtils::InitGFlag(argc, argv); // set log level int ret = -1; - const std::set log_level = {"default", "null", "debug", "info", "warning", "error"}; + const std::set log_level = {"null", "debug", "info", "warning", "error"}; if (log_level.count(FLAGS_log) == 0) { - std::cout << "E10016: invalid value for --log:" << FLAGS_log << ", only support debug, info, warning, error, null" + std::cout << "E10010: invalid value for --log:" << FLAGS_log << ", only support debug, info, warning, error, null" << std::endl; return ret; } @@ -1155,12 +1159,18 @@ int init(int argc, char *argv[]) { return ret; } + std::string path_base = ge::GELib::GetPath(); + ret = ErrorManager::GetInstance().Init(path_base); + if (ret != 0) { + DOMI_LOGE("ErrorManager init fail !"); + return ret; + } + return 0; } int main(int argc, char *argv[]) { Status ret = domi::SUCCESS; - ge::Status geRet = ge::SUCCESS; std::cout << "ATC start working now, please wait for a moment." << std::endl; try { // Initialize @@ -1185,12 +1195,9 @@ int main(int argc, char *argv[]) { GE_CHK_BOOL_EXEC(ConvertPbtxtToJson() == domi::SUCCESS, ret = domi::FAILED; break, "ATC convert pbtxt to json execute failed!!"); } else { - ErrorManager::GetInstance().ATCReportErrMessage("E10048", {"value"}, {std::to_string(FLAGS_mode)}); - DOMI_LOGE( - "Invalid value for --mode[%d], only support " - "0(model to framework model), 1(framework model to json), 3(only pre-check), " - "5(pbtxt to json)!", - FLAGS_mode); + ErrorManager::GetInstance().ATCReportErrMessage("E10001", {"parameter", "value", "reason"}, + {"--mode", std::to_string(FLAGS_mode), kModeSupport}); + GELOGE(ge::PARAM_INVALID, "Invalid value for --mode[%d], %s.", FLAGS_mode, kModeSupport); ret = domi::FAILED; break; } @@ -1205,11 +1212,16 @@ int main(int argc, char *argv[]) { std::cout << "ATC run failed, some exceptions occur !" << std::endl; } - if (!CheckRet(ret, geRet)) { + if (!CheckRet(ret)) { std::cout << "ATC run failed, Please check the detail log, Try \'atc --help\' for more information" << std::endl; + int result = ErrorManager::GetInstance().OutputErrMessage(STDOUT_FILENO); + if (result != 0) { + DOMI_LOGE("ErrorManager outputErrMessage fail !"); + } return ret; } else { std::cout << "ATC run success, welcome to the next use." << std::endl; + (void)ErrorManager::GetInstance().OutputMessage(STDOUT_FILENO); return 0; } -} +} /*lint +e530*/ diff --git a/src/ge/offline/module.mk b/src/ge/offline/module.mk index c97e7813..a347362a 100644 --- a/src/ge/offline/module.mk +++ b/src/ge/offline/module.mk @@ -42,7 +42,7 @@ LOCAL_SHARED_LIBRARIES := \ libge_compiler \ libruntime_compile \ libparser_common \ - libfmk_tensorflow_parser \ + libfmk_parser \ liberror_manager \ LOCAL_STATIC_LIBRARIES := libgflags diff --git a/src/ge/offline/single_op_parser.cc b/src/ge/offline/single_op_parser.cc index 4d589565..54b6df69 100644 --- a/src/ge/offline/single_op_parser.cc +++ b/src/ge/offline/single_op_parser.cc @@ -28,6 +28,8 @@ #include "common/ge_inner_error_codes.h" #include "framework/common/util.h" #include "graph/utils/tensor_utils.h" +#include "graph/utils/op_desc_utils.h" +#include "graph/operator_factory_impl.h" using Json = nlohmann::json; using std::map; @@ -43,10 +45,14 @@ constexpr char const *kKeyAttr = "attr"; constexpr char const *kKeyName = "name"; constexpr char const *kKeyType = "type"; constexpr char const *kKeyShape = "shape"; +constexpr char const *kKeyShapeRange = "shape_range"; constexpr char const *kKeyValue = "value"; constexpr char const *kKeyFormat = "format"; constexpr char const *kFileSuffix = ".om"; constexpr int kDumpJsonIndent = 2; +constexpr int kShapeRangePairSize = 2; +constexpr int kShapeRangeLow = 0; +constexpr int kShapeRangeHigh = 1; map kAttrTypeDict = { {"bool", GeAttrValue::VT_BOOL}, @@ -90,6 +96,10 @@ T GetValue(const map &dict, string &key, T default_val) { void from_json(const Json &j, SingleOpTensorDesc &desc) { desc.dims = j.at(kKeyShape).get>(); + auto it = j.find(kKeyShapeRange); + if (it != j.end()) { + desc.dim_ranges = j.at(kKeyShapeRange).get>>(); + } string format_str = j.at(kKeyFormat).get(); string type_str = j.at(kKeyType).get(); desc.format = GetValue(kFormatDict, format_str, FORMAT_RESERVED); @@ -200,13 +210,13 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { for (auto &tensor_desc : op_desc.input_desc) { if (tensor_desc.type == DT_UNDEFINED) { ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(false, "Input index[%d]'s dataType is invalid", index); + GELOGE(false, "Input's dataType is invalid when the index is %d", index); return false; } if (tensor_desc.format == FORMAT_RESERVED) { ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"input", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Input index[%d]'s format is invalid", index); + GELOGE(PARAM_INVALID, "Input's format is invalid when the index is %d", index); return false; } ++index; @@ -216,13 +226,13 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { for (auto &tensor_desc : op_desc.output_desc) { if (tensor_desc.type == DT_UNDEFINED) { ErrorManager::GetInstance().ATCReportErrMessage("E10027", {"input", "index"}, {"output", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Output[%d] dataType is invalid", index); + GELOGE(PARAM_INVALID, "Output's dataType is invalid when the index is %d", index); return false; } if (tensor_desc.format == FORMAT_RESERVED) { ErrorManager::GetInstance().ATCReportErrMessage("E10028", {"input", "index"}, {"output", std::to_string(index)}); - GELOGE(PARAM_INVALID, "Output[%d] format is invalid", index); + GELOGE(PARAM_INVALID, "Output's format is invalid when the index is %d", index); return false; } ++index; @@ -245,11 +255,13 @@ bool SingleOpParser::Validate(const SingleOpDesc &op_desc) { return true; } -OpDesc *SingleOpParser::CreateOpDesc(const string &op_type) { return new (std::nothrow) OpDesc(op_type, op_type); } +std::unique_ptr SingleOpParser::CreateOpDesc(const string &op_type) { + return std::unique_ptr(new (std::nothrow) OpDesc(op_type, op_type)); +} Status SingleOpParser::ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param) { - auto *op_desc = CreateOpDesc(single_op_desc.op); + auto op_desc = CreateOpDesc(single_op_desc.op); if (op_desc == nullptr) { GELOGE(MEMALLOC_FAILED, "Failed to create instance of opDesc"); return MEMALLOC_FAILED; @@ -265,6 +277,7 @@ Status SingleOpParser::ConvertToBuildParam(int index, const SingleOpDesc &single } GeTensorDesc ge_tensor_desc(GeShape(desc.dims), desc.format, desc.type); ge_tensor_desc.SetOriginFormat(desc.format); + GE_CHK_STATUS_RET_NOLOG(SetShapeRange(desc, ge_tensor_desc)); TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size()); TensorUtils::SetInputTensor(ge_tensor_desc, true); TensorUtils::SetOutputTensor(ge_tensor_desc, false); @@ -284,6 +297,7 @@ Status SingleOpParser::ConvertToBuildParam(int index, const SingleOpDesc &single GeTensorDesc ge_tensor_desc(GeShape(desc.dims), desc.format, desc.type); ge_tensor_desc.SetOriginFormat(desc.format); + GE_CHK_STATUS_RET_NOLOG(SetShapeRange(desc, ge_tensor_desc)); TensorUtils::SetRealDimCnt(ge_tensor_desc, desc.dims.size()); TensorUtils::SetInputTensor(ge_tensor_desc, false); TensorUtils::SetOutputTensor(ge_tensor_desc, true); @@ -295,10 +309,78 @@ Status SingleOpParser::ConvertToBuildParam(int index, const SingleOpDesc &single op_desc->SetAttr(attr.name, attr.value); } + if (VerifyOpInputOutputSizeByIr(*op_desc) != SUCCESS) { + GELOGE(PARAM_INVALID, "Verify op [%s] input or output size failed.", op_desc->GetType().c_str()); + return PARAM_INVALID; + } + file_name << kFileSuffix; build_param.file_name = file_name.str(); + build_param.op_desc.reset(op_desc.release()); + return SUCCESS; +} + +Status SingleOpParser::VerifyOpInputOutputSizeByIr(const OpDesc ¤t_op_desc) { + ge::Operator operator_ir = ge::OperatorFactory::CreateOperator("tmp_operator", current_op_desc.GetType()); + if (!operator_ir.IsEmpty()) { + auto opdesc_ir = ge::OpDescUtils::GetOpDescFromOperator(operator_ir); + GE_CHECK_NOTNULL(opdesc_ir); + size_t current_opdesc_inputs_num = current_op_desc.GetInputsSize(); + size_t ir_opdesc_inputs_num = opdesc_ir->GetInputsSize(); + if (current_opdesc_inputs_num < ir_opdesc_inputs_num) { + string reason = "is smaller than the ir needed input size " + std::to_string(ir_opdesc_inputs_num); + ErrorManager::GetInstance().ATCReportErrMessage( + "E19014", {"opname", "value", "reason"}, + {current_op_desc.GetName(), "input size " + std::to_string(current_opdesc_inputs_num), reason}); + GELOGE(PARAM_INVALID, "This op [%s] input size %zu is smaller than the ir needed input size %zu", + current_op_desc.GetName().c_str(), current_opdesc_inputs_num, ir_opdesc_inputs_num); + return PARAM_INVALID; + } + size_t current_opdesc_outputs_num = current_op_desc.GetOutputsSize(); + size_t ir_opdesc_outputs_num = opdesc_ir->GetOutputsSize(); + if (current_opdesc_outputs_num < ir_opdesc_outputs_num) { + string reason = "is smaller than the ir needed output size " + std::to_string(ir_opdesc_outputs_num); + ErrorManager::GetInstance().ATCReportErrMessage( + "E19014", {"opname", "value", "reason"}, + {current_op_desc.GetName(), "output size " + std::to_string(current_opdesc_outputs_num), reason}); + GELOGE(PARAM_INVALID, "This op [%s] output size %zu is smaller than the ir needed output size %zu", + current_op_desc.GetName().c_str(), current_opdesc_outputs_num, ir_opdesc_outputs_num); + return PARAM_INVALID; + } + } + return SUCCESS; +} + +Status SingleOpParser::SetShapeRange(const SingleOpTensorDesc &tensor_desc, GeTensorDesc &ge_tensor_desc) { + if (tensor_desc.dim_ranges.empty()) { + return SUCCESS; + } + + std::vector> shape_range; + size_t range_index = 0; + for (auto dim : tensor_desc.dims) { + if (dim >= 0) { + shape_range.emplace_back(dim, dim); + GELOGD("Adding shape range: [%ld, %ld]", dim, dim); + } else { + if (range_index >= tensor_desc.dim_ranges.size()) { + GELOGE(PARAM_INVALID, "The number of shape_range mismatches that of unknown dims."); + return PARAM_INVALID; + } + + auto &range = tensor_desc.dim_ranges[range_index]; + if (range.size() != kShapeRangePairSize) { + GELOGE(PARAM_INVALID, "Invalid shape range entry. index = %zu, size = %zu", range_index, range.size()); + return PARAM_INVALID; + } + + shape_range.emplace_back(range[kShapeRangeLow], range[kShapeRangeHigh]); + GELOGD("Adding shape range: [%ld, %ld]", range[kShapeRangeLow], range[kShapeRangeHigh]); + ++range_index; + } + } - build_param.op_desc.reset(op_desc); + ge_tensor_desc.SetShapeRange(shape_range); return SUCCESS; } @@ -316,17 +398,15 @@ Status SingleOpParser::ParseSingleOpList(const std::string &file, std::vector dims; + std::vector> dim_ranges; ge::Format format = ge::FORMAT_RESERVED; ge::DataType type = ge::DT_UNDEFINED; }; @@ -68,8 +69,10 @@ class SingleOpParser { private: static Status ReadJsonFile(const std::string &file, nlohmann::json &json_obj); static bool Validate(const SingleOpDesc &op_desc); - static OpDesc *CreateOpDesc(const std::string &op_type); + static std::unique_ptr CreateOpDesc(const std::string &op_type); static Status ConvertToBuildParam(int index, const SingleOpDesc &single_op_desc, SingleOpBuildParam &build_param); + static Status VerifyOpInputOutputSizeByIr(const OpDesc ¤t_op_desc); + static Status SetShapeRange(const SingleOpTensorDesc &tensor_desc, GeTensorDesc &ge_tensor_desc); }; } // namespace ge diff --git a/src/ge/opskernel_manager/ops_kernel_manager.cc b/src/ge/opskernel_manager/ops_kernel_manager.cc index a8a1be88..0d6f1e07 100644 --- a/src/ge/opskernel_manager/ops_kernel_manager.cc +++ b/src/ge/opskernel_manager/ops_kernel_manager.cc @@ -38,7 +38,7 @@ const char *const kFinalize = "Finalize"; namespace ge { OpsKernelManager::OpsKernelManager() - : plugin_manager_(), init_flag_(false), enable_fe_flag_(false), enable_aicpu_flag_(false) {} + : plugin_manager_(), op_tiling_manager_(), init_flag_(false), enable_fe_flag_(false), enable_aicpu_flag_(false) {} OpsKernelManager::~OpsKernelManager() { graph_optimizers_.clear(); @@ -76,6 +76,8 @@ Status OpsKernelManager::Initialize(const map &options_const) { GetExternalEnginePath(extern_engine_path); GELOGI("OPTION_EXEC_EXTERN_PLUGIN_PATH=%s.", extern_engine_path.c_str()); + op_tiling_manager_.LoadSo(); + ret = plugin_manager_.LoadSo(extern_engine_path, func_check_list); if (ret == SUCCESS) { initialize_ = options; @@ -134,7 +136,7 @@ void OpsKernelManager::GetExternalEnginePath(std::string &extern_engine_path) { std::string path = path_base + so_path; extern_engine_path = (path + "libfe.so" + ":") + (path + "libge_local_engine.so" + ":") + (path + "librts_engine.so" + ":") + (path + "libaicpu_engine.so" + ":") + - (path_base + "libhccl.so"); + (path_base + "libhcom_graph_adaptor.so"); } Status OpsKernelManager::InitPluginOptions(const map &options) { diff --git a/src/ge/opskernel_manager/ops_kernel_manager.h b/src/ge/opskernel_manager/ops_kernel_manager.h index 8d98ad3f..1d464201 100644 --- a/src/ge/opskernel_manager/ops_kernel_manager.h +++ b/src/ge/opskernel_manager/ops_kernel_manager.h @@ -24,6 +24,7 @@ #include "common/debug/log.h" #include "common/ge/plugin_manager.h" +#include "common/ge/op_tiling_manager.h" #include "common/ge_inner_error_codes.h" #include "common/opskernel/ops_kernel_info_store.h" #include "common/optimizer/graph_optimizer.h" @@ -105,6 +106,7 @@ class OpsKernelManager { Status InitGraphOptimizerPriority(); PluginManager plugin_manager_; + OpTilingManager op_tiling_manager_; // opsKernelInfoStore map ops_kernel_store_{}; // graph_optimizer diff --git a/src/ge/session/inner_session.cc b/src/ge/session/inner_session.cc index 74495e82..b97862e1 100644 --- a/src/ge/session/inner_session.cc +++ b/src/ge/session/inner_session.cc @@ -29,6 +29,34 @@ #include "runtime/mem.h" namespace ge { +namespace { +Status CheckReuseMemoryOption(const std::map &options) { + const int kDecimal = 10; + auto dump_op_env = std::getenv("DUMP_OP"); + int dump_op_flag = (dump_op_env != nullptr) ? std::strtol(dump_op_env, nullptr, kDecimal) : 0; + auto iter = options.find(OPTION_EXEC_DISABLE_REUSED_MEMORY); + if (iter != options.end()) { + if (iter->second == "0") { + GELOGD("%s=0, reuse memory is open", OPTION_EXEC_DISABLE_REUSED_MEMORY); + if (dump_op_flag) { + GELOGW("Will dump incorrect op data with ge option %s=0", OPTION_EXEC_DISABLE_REUSED_MEMORY); + } + } else if (iter->second == "1") { + GELOGD("%s=1, reuse memory is close", OPTION_EXEC_DISABLE_REUSED_MEMORY); + } else { + GELOGE(PARAM_INVALID, "option %s=%s is invalid", OPTION_EXEC_DISABLE_REUSED_MEMORY, iter->second.c_str()); + return FAILED; + } + } else { + if (dump_op_flag) { + GELOGW("Will dump incorrect op data with default reuse memory"); + } + } + + return SUCCESS; +} +} // namespace + static std::mutex mutex_; // BuildGraph and RunGraph use InnerSession::InnerSession(uint64_t session_id, const std::map &options) @@ -39,13 +67,36 @@ Status InnerSession::Initialize() { GELOGW("[InnerSession:%lu] session already initialize.", session_id_); return SUCCESS; } + + // If the global options and the session options are duplicated, the session options is preferred. + auto all_options = options_; + all_options.insert(GetMutableGlobalOptions().begin(), GetMutableGlobalOptions().end()); + + Status ret = CheckReuseMemoryOption(all_options); + if (ret != SUCCESS) { + GELOGE(ret, "[InnerSession:%lu] check reuse memory option failed.", session_id_); + return ret; + } + UpdateThreadContext(std::map{}); GE_CHK_RT_RET(rtSetDevice(GetContext().DeviceId())); - Status ret = graph_manager_.Initialize(options_); + PropertiesManager::Instance().GetDumpProperties(session_id_).InitByOptions(); + + ret = graph_manager_.Initialize(options_); if (ret != SUCCESS) { GELOGE(ret, "[InnerSession:%lu] initialize failed.", session_id_); + PropertiesManager::Instance().RemoveDumpProperties(session_id_); + return ret; + } + + ret = VarManager::Instance(session_id_)->SetMemoryMallocSize(all_options); + if (ret != SUCCESS) { + GELOGE(ret, "failed to set malloc size"); + (void)graph_manager_.Finalize(); + PropertiesManager::Instance().RemoveDumpProperties(session_id_); + GE_CHK_RT(rtDeviceReset(static_cast(GetContext().DeviceId()))); return ret; } @@ -55,6 +106,7 @@ Status InnerSession::Initialize() { ret = VarManager::Instance(session_id_)->Init(version, session_id_, DEFAULT_DEVICE_ID, DEFAULT_JOB_ID); if (ret != SUCCESS) { GELOGE(ret, "failed to init session instance"); + PropertiesManager::Instance().RemoveDumpProperties(session_id_); } init_flag_ = true; return SUCCESS; @@ -78,6 +130,9 @@ Status InnerSession::Finalize() { // release var memory GELOGI("VarManager free var memory."); (void)VarManager::Instance(session_id_)->FreeVarMemory(); + + PropertiesManager::Instance().RemoveDumpProperties(session_id_); + GE_CHK_RT(rtDeviceReset(static_cast(GetContext().DeviceId()))); return ret; @@ -223,6 +278,7 @@ void InnerSession::UpdateThreadContext(const std::map GetThreadLocalContext().SetGlobalOption(GetMutableGlobalOptions()); GetThreadLocalContext().SetSessionOption(options_); GetThreadLocalContext().SetGraphOption(options); + GetContext().SetSessionId(session_id_); } void InnerSession::UpdateThreadContext(uint32_t graph_id) { diff --git a/src/ge/session/omg.cc b/src/ge/session/omg.cc index 71dd631e..55075d6a 100644 --- a/src/ge/session/omg.cc +++ b/src/ge/session/omg.cc @@ -22,15 +22,16 @@ #include "common/convert/pb2json.h" #include "common/debug/log.h" #include "common/debug/memory_dumper.h" +#include "common/ge/ge_util.h" +#include "common/helper/model_helper.h" #include "common/model_parser/base.h" +#include "common/model_parser/graph_parser_util.h" #include "common/model_saver.h" #include "common/properties_manager.h" #include "common/string_util.h" #include "common/types.h" #include "common/util.h" #include "common/util/error_manager/error_manager.h" -#include "common/helper/model_helper.h" -#include "common/ge/ge_util.h" #include "framework/common/debug/ge_log.h" #include "framework/omg/parser/parser_inner_ctx.h" #include "google/protobuf/io/zero_copy_stream_impl.h" @@ -65,6 +66,9 @@ namespace ge { namespace { const std::string kGraphDefaultName = "domi_default"; const std::string kScopeIdAttr = "fusion_scope"; +const char *const kOutputTypeSample = "correct sample is \"opname:index:dtype\""; +const char *const kOutputTypeSupport = "only support FP32, FP16, UINT8"; +const char *const kOutputTypeError = "The multiple out nodes set in output_type must be found in out_nodes."; } // namespace // When the model is converted to a JSON file, the following operator attributes in the blacklist will be ignored @@ -78,7 +82,7 @@ static bool CheckInputTrueOrFalse(const std::string &s, const std::string &atc_p if ((s == "true") || (s == "false")) { return true; } else { - ErrorManager::GetInstance().ATCReportErrMessage("E10033", {"parameter", "value"}, {atc_param, s}); + ErrorManager::GetInstance().ATCReportErrMessage("E10005", {"parameter", "value"}, {atc_param, s}); GELOGE(PARAM_INVALID, "Input parameter[--%s]'s value[%s] must be true or false.", atc_param.c_str(), s.c_str()); return false; } @@ -97,12 +101,12 @@ static Status CheckInputShapeNode(const ComputeGraphPtr &graph) { std::string node_name = it.first; ge::NodePtr node = graph->FindNode(node_name); if (node == nullptr) { - ErrorManager::GetInstance().ATCReportErrMessage("E10034", {"parameter", "opname"}, {"input_shape", node_name}); + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"input_shape", node_name}); GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not exist in model", node_name.c_str()); return PARAM_INVALID; } if (node->GetType() != DATA) { - ErrorManager::GetInstance().ATCReportErrMessage("E10035", {"parameter", "opname"}, {"input_shape", node_name}); + ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, {"input_shape", node_name}); GELOGE(PARAM_INVALID, "Input parameter[--input_shape]'s opname[%s] is not a input opname", node_name.c_str()); return PARAM_INVALID; } @@ -110,6 +114,22 @@ static Status CheckInputShapeNode(const ComputeGraphPtr &graph) { return SUCCESS; } +void AddAttrsForInputNodes(const vector &adjust_fp16_format_vec, const string &fp16_nodes_name, uint32_t index, + OpDescPtr &op_desc) { + if (AttrUtils::SetBool(op_desc, "input_fp16", true) && + AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_DATATYPE, TypeUtils::DataTypeToSerialString(DT_FLOAT16))) { + if ((index < adjust_fp16_format_vec.size()) && (adjust_fp16_format_vec[index] == "true")) { + GELOGI("This node [%s] should be set NC1HWC0", fp16_nodes_name.c_str()); + if (!AttrUtils::SetBool(op_desc, "input_set_nc1hwc0", true)) { + GELOGW("This node [%s] set NC1HWC0 failed", fp16_nodes_name.c_str()); + } + if (!AttrUtils::SetStr(op_desc, ATTR_ATC_USER_DEFINE_FORMAT, TypeUtils::FormatToSerialString(FORMAT_NC1HWC0))) { + GELOGW("This node [%s] set NC1HWC0 failed", fp16_nodes_name.c_str()); + } + } + } +} + static Status CheckInputFp16Nodes(const ComputeGraphPtr &graph, const string &input_fp16_nodes, const string &is_input_adjust_hw_layout) { GE_CHECK_NOTNULL(graph); @@ -133,28 +153,22 @@ static Status CheckInputFp16Nodes(const ComputeGraphPtr &graph, const string &in for (uint32_t i = 0; i < input_fp16_nodes_vec.size(); ++i) { ge::NodePtr node = graph->FindNode(input_fp16_nodes_vec[i]); if (node == nullptr) { - ErrorManager::GetInstance().ATCReportErrMessage("E10034", {"parameter", "opname"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10016", {"parameter", "opname"}, {"input_fp16_nodes", input_fp16_nodes_vec[i]}); - GELOGE(PARAM_INVALID, "Can not find node [%s] in graph, please check input_fp16_nodes param", + GELOGE(PARAM_INVALID, "Input parameter[--input_fp16_nodes]'s opname[%s] is not exist in model", input_fp16_nodes_vec[i].c_str()); return PARAM_INVALID; } auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); if (op_desc->GetType() != DATA) { - ErrorManager::GetInstance().ATCReportErrMessage("E10035", {"parameter", "opname"}, + ErrorManager::GetInstance().ATCReportErrMessage("E10017", {"parameter", "opname"}, {"input_fp16_nodes", input_fp16_nodes_vec[i]}); - GELOGE(PARAM_INVALID, "input_fp16_nodes: %s is not a input node name", input_fp16_nodes_vec[i].c_str()); + GELOGE(PARAM_INVALID, "Input parameter[--input_fp16_nodes]'s opname[%s] is not a input opname", + input_fp16_nodes_vec[i].c_str()); return PARAM_INVALID; } - if (ge::AttrUtils::SetBool(op_desc, "input_fp16", true)) { - if ((i < adjust_fp16_format_vec.size()) && (adjust_fp16_format_vec[i] == "true")) { - GELOGI("This node [%s] should be set NC1HWC0", input_fp16_nodes_vec[i].c_str()); - if (!ge::AttrUtils::SetBool(op_desc, "input_set_nc1hwc0", true)) { - GELOGW("This node [%s] set NC1HWC0 failed", input_fp16_nodes_vec[i].c_str()); - } - } - } + AddAttrsForInputNodes(adjust_fp16_format_vec, input_fp16_nodes_vec[i], i, op_desc); } return SUCCESS; } @@ -197,30 +211,6 @@ static Status SetWeightCompressNodes(const ComputeGraphPtr &graph, const string return SUCCESS; } -static Status ParseOutputFp16NodesFormat(const string &is_output_fp16) { - if (is_output_fp16.empty()) { - return SUCCESS; - } - - vector &output_formats = domi::GetContext().output_formats; - output_formats.clear(); - vector node_format_vec = StringUtils::Split(is_output_fp16, ','); - for (auto &is_fp16 : node_format_vec) { - StringUtils::Trim(is_fp16); - if (!CheckInputTrueOrFalse(is_fp16, "is_output_adjust_hw_layout")) { - GELOGE(PARAM_INVALID, "Invalid Param, is_output_adjust_hw_layout only support true/false: but is [%s]", - is_output_fp16.c_str()); - return PARAM_INVALID; - } - if (is_fp16 == "false") { - output_formats.push_back(DOMI_TENSOR_ND); - } else if (is_fp16 == "true") { - output_formats.push_back(domi::DOMI_TENSOR_NC1HWC0); - } - } - return SUCCESS; -} - void FindParserSo(const string &path, vector &file_list, string &caffe_parser_path) { // path, Change to absolute path string real_path = RealPath(path.c_str()); @@ -302,160 +292,6 @@ Status SetOutFormatAndDataTypeAttr(ge::OpDescPtr op_desc, const ge::Format forma return domi::SUCCESS; } -Status StringToInt(std::string &str, int32_t &value) { - try { - value = stoi(str); - } catch (std::invalid_argument &) { - GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", str.c_str()); - return PARAM_INVALID; - } catch (std::out_of_range &) { - GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", str.c_str()); - return PARAM_INVALID; - } - return SUCCESS; -} - -Status VerifyOutputTypeAndOutNodes(std::vector &out_type_vec) { - std::vector> user_out_nodes = domi::GetContext().user_out_nodes; - std::set out_nodes_info; - for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { - // out_nodes set should include output_type and output_format - std::string tmp = user_out_nodes[i].first + ":" + to_string(user_out_nodes[i].second); - out_nodes_info.emplace(tmp); - } - for (uint32_t i = 0; i < out_type_vec.size(); ++i) { - if (out_nodes_info.find(out_type_vec[i]) == out_nodes_info.end()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10059", {"value"}, {out_type_vec[i]}); - GELOGE(domi::FAILED, "Can not find this node (%s) in out_nodes.", out_type_vec[i].c_str()); - return domi::FAILED; - } - } - return domi::SUCCESS; -} - -Status ParseOutputType(const std::string &output_type, std::map> &out_type_index_map, - std::map> &out_type_dt_map) { - if (output_type.find(':') == std::string::npos) { - GELOGI("output_type is not multiple nodes, means all out nodes"); - auto it = output_type_str_to_datatype.find(output_type); - if (it == output_type_str_to_datatype.end()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10042", {"value"}, {output_type}); - GELOGE(ge::PARAM_INVALID, "Invalid value for --output_type[%s], only support DT_FLOAT, DT_FLOAT16, DT_UINT8!!", - output_type.c_str()); - return domi::FAILED; - } - return domi::SUCCESS; - } - std::vector out_type_vec; - vector nodes_v = StringUtils::Split(output_type, ';'); - for (const string &node : nodes_v) { - vector node_index_type_v = StringUtils::Split(node, ':'); - if (node_index_type_v.size() != 3) { // The size must be 3. - ErrorManager::GetInstance().ATCReportErrMessage("E10058", {"value"}, {node}); - GELOGE(PARAM_INVALID, - "The param of output_type is invalid, the correct format is [opname:index:dtype]," - "while the actual input is %s.", - node.c_str()); - return domi::FAILED; - } - ge::DataType tmp_dt; - std::string node_name = StringUtils::Trim(node_index_type_v[0]); - std::string index_str = StringUtils::Trim(node_index_type_v[1]); - int32_t index; - if (StringToInt(index_str, index) != SUCCESS) { - return domi::FAILED; - } - std::string dt_value = StringUtils::Trim(node_index_type_v[2]); - auto it = output_type_str_to_datatype.find(dt_value); - if (it == output_type_str_to_datatype.end()) { - ErrorManager::GetInstance().ATCReportErrMessage("E10042", {"value"}, {dt_value}); - GELOGE(ge::PARAM_INVALID, "output_type [%s] is invalid.", dt_value.c_str()); - return domi::FAILED; - } else { - tmp_dt = it->second; - } - out_type_vec.push_back(node_name + ":" + index_str); - auto it_index = out_type_index_map.find(node_name); - if (it_index == out_type_index_map.end()) { - vector tmp_vec; - tmp_vec.push_back(index); - out_type_index_map.emplace(node_name, tmp_vec); - } else { - it_index->second.push_back(index); - } - - auto it_dt = out_type_dt_map.find(node_name); - if (it_dt == out_type_dt_map.end()) { - vector tmp_vec; - tmp_vec.push_back(tmp_dt); - out_type_dt_map.emplace(node_name, tmp_vec); - } else { - it_dt->second.push_back(tmp_dt); - } - } - return VerifyOutputTypeAndOutNodes(out_type_vec); -} - -Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output) { - ge::ComputeGraphPtr compute_graph = ge::GraphUtils::GetComputeGraph(graph); - GE_CHECK_NOTNULL(compute_graph); - - std::vector> user_out_nodes = domi::GetContext().user_out_nodes; - std::vector output_formats = domi::GetContext().output_formats; - std::vector> output_nodes_info; - std::vector output_nodes_name; - - std::map> out_type_index_map; - std::map> out_type_dt_map; - if (!output_type.empty()) { - if (ParseOutputType(output_type, out_type_index_map, out_type_dt_map) != SUCCESS) { - GELOGE(domi::FAILED, "Parse output_type failed."); - return domi::FAILED; - } - } - - // User declared outputs - for (uint32_t i = 0; i < user_out_nodes.size(); ++i) { - ge::NodePtr out_node = compute_graph->FindNode(user_out_nodes[i].first); - if (out_node == nullptr) { - GELOGE(domi::FAILED, "Can not find src node (%s) in graph.", user_out_nodes[i].first.c_str()); - return domi::FAILED; - } - auto op_desc = out_node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); - if (i < output_formats.size()) { - if (output_formats[i] == domi::DOMI_TENSOR_NC1HWC0) { - GELOGI("The output node [%s] should be set NC1HWC0", user_out_nodes[i].first.c_str()); - if (!ge::AttrUtils::SetBool(op_desc, "output_set_fp16_nc1hwc0", true)) { - GELOGW("The output node [%s] set NC1HWC0 failed", user_out_nodes[i].first.c_str()); - } - } - } - auto it_index = out_type_index_map.find(user_out_nodes[i].first); - auto it_dt = out_type_dt_map.find(user_out_nodes[i].first); - if ((it_index != out_type_index_map.end()) && (it_dt != out_type_dt_map.end())) { - GELOGI("The output node [%s] need to be set output_type", user_out_nodes[i].first.c_str()); - (void)ge::AttrUtils::SetListDataType(op_desc, "_output_dt_list", it_dt->second); - (void)ge::AttrUtils::SetListInt(op_desc, "_output_dt_index", it_index->second); - } - output_nodes_info.push_back(std::make_pair(out_node, user_out_nodes[i].second)); - output_nodes_name.push_back(out_node->GetName() + ":" + std::to_string(user_out_nodes[i].second)); - } - // default output node (leaf) - if (user_out_nodes.empty()) { - for (ge::NodePtr node : compute_graph->GetDirectNode()) { - if (!node->GetInDataNodes().empty() && node->GetOutDataNodes().empty()) { - Status ret = GetOutputLeaf(node, output_nodes_info); - GE_CHK_BOOL_RET_STATUS(ret == SUCCESS, ret, "find leaf fail."); - } - } - } - GetOutputNodesNameAndIndex(output_nodes_info, output_nodes_name); - compute_graph->SetGraphOutNodesInfo(output_nodes_info); - domi::GetContext().net_out_nodes = output_nodes_name; - return domi::SUCCESS; -} - void GetOutputNodesNameAndIndex(std::vector> &output_nodes_info, std::vector &output_nodes_name) { output_nodes_name.clear(); @@ -481,32 +317,6 @@ void GetOutputNodesNameAndIndex(std::vector> &ou } } -Status GetOutputLeaf(NodePtr node, std::vector> &output_nodes_info) { - ge::OpDescPtr tmpDescPtr = node->GetOpDesc(); - if (tmpDescPtr == nullptr) { - GELOGE(domi::FAILED, "Get outnode op desc fail."); - return domi::FAILED; - } - size_t size = tmpDescPtr->GetOutputsSize(); - if (node->GetType() != NETOUTPUT) { - for (size_t index = 0; index < size; ++index) { - output_nodes_info.push_back(std::make_pair(node, index)); - } - } else { - const auto in_anchors = node->GetAllInDataAnchors(); - for (auto in_anchor : in_anchors) { - auto out_anchor = in_anchor->GetPeerOutAnchor(); - if (out_anchor == nullptr) { - GELOGE(domi::FAILED, "Get leaf node op desc fail."); - return domi::FAILED; - } - auto out_node = out_anchor->GetOwnerNode(); - output_nodes_info.push_back(std::make_pair(out_node, out_anchor->GetIdx())); - } - } - return SUCCESS; -} - /// /// @ingroup domi_common /// @brief Initialize omgcontext based on command line input @@ -550,55 +360,12 @@ Status InitDomiOmgContext(const string &input_shape, const string &input_format, return SUCCESS; } -Status ParseOutNodes(const string &out_nodes) { - try { - // parse output node - if (!out_nodes.empty()) { - domi::GetContext().out_nodes_map.clear(); - domi::GetContext().user_out_nodes.clear(); - - vector nodes_v = StringUtils::Split(out_nodes, ';'); - for (const string &node : nodes_v) { - vector key_value_v = StringUtils::Split(node, ':'); - if (key_value_v.size() != 2) { // The size must be 2. - ErrorManager::GetInstance().ATCReportErrMessage("E10069", {"param", "value", "supports"}, - {"out_nodes", node, "opname:index"}); - GELOGE(PARAM_INVALID, - "The input format of --out_nodes is invalid, the correct format is " - "\"node_name1:0;node_name1:1;node_name2:0\", while the actual input is %s.", - node.c_str()); - return PARAM_INVALID; - } - auto iter = domi::GetContext().out_nodes_map.find(key_value_v[0]); - // stoi: The method may throw an exception: invalid_argument/out_of_range - int32_t index = stoi(StringUtils::Trim(key_value_v[1])); - if (iter != domi::GetContext().out_nodes_map.end()) { - iter->second.emplace_back(index); - } else { - std::vector index_v; - index_v.emplace_back(index); - domi::GetContext().out_nodes_map.emplace(key_value_v[0], index_v); - } - domi::GetContext().user_out_nodes.push_back(std::make_pair(key_value_v[0], index)); - } - } - } catch (std::invalid_argument &) { - GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); - return PARAM_INVALID; - } catch (std::out_of_range &) { - GELOGE(PARAM_INVALID, "Invalid of out_nodes: %s ", out_nodes.c_str()); - return PARAM_INVALID; - } - - return SUCCESS; -} - /// @ingroup domi_common /// @brief Judge whether the op_Name_Map parameter matches the network /// @param [in] graph Input network graph /// @return SUCCESS: Input parameters are correct; PARAM_INVALID: Input parameters are incorrect /// -static Status CheckOpNameMap(const ComputeGraphPtr &graph) { +static Status CheckOpNameMap(const ComputeGraphPtr &graph, const std::string &op_conf) { GE_CHECK_NOTNULL(graph); unordered_map graphNodeTypes; for (const NodePtr &node : graph->GetAllNodes()) { @@ -613,7 +380,9 @@ static Status CheckOpNameMap(const ComputeGraphPtr &graph) { GE_RT_PARAM_INVALID_WITH_LOG_IF_TRUE(propertiesMap.empty(), "op_name_map file is empty, please check file!"); for (auto iter = propertiesMap.begin(); iter != propertiesMap.end(); iter++) { GE_IF_BOOL_EXEC(graphNodeTypes.find(iter->second) == graphNodeTypes.end(), - ErrorManager::GetInstance().ATCReportErrMessage("E10060", {"parameter"}, {"op_name_map"}); + ErrorManager::GetInstance().ATCReportErrMessage( + "E10003", {"parameter", "value", "reason"}, + {"op_name_map", op_conf, "type[" + iter->second + "] is not found in model"}); GELOGE(PARAM_INVALID, "Invalid parameter for op_name_map."); return PARAM_INVALID;); } return SUCCESS; @@ -659,7 +428,7 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::mapCreateModelParser(framework); GE_CHK_BOOL_RET_STATUS(model_parser != nullptr, FAILED, "ATC create model parser ret fail, framework:%d.", framework); return model_parser->ToJson(model_file, json_file); } - ErrorManager::GetInstance().ATCReportErrMessage("E10045", {"parameter"}, {"model"}); + ErrorManager::GetInstance().ATCReportErrMessage( + "E10001", {"parameter", "value", "reason"}, + {"--framework", std::to_string(framework), "only support 0(Caffe) 3(TensorFlow)"}); GELOGE(PARAM_INVALID, "Input parameter[--framework] is mandatory and it's value must be: 0(Caffe) 3(TensorFlow)."); return PARAM_INVALID; } diff --git a/src/ge/session/session_manager.cc b/src/ge/session/session_manager.cc index c3439b0b..68a8aa70 100644 --- a/src/ge/session/session_manager.cc +++ b/src/ge/session/session_manager.cc @@ -51,11 +51,11 @@ Status SessionManager::Finalize() { return SUCCESS; } -Status SessionManager::SetrtContext(rtContext_t rt_context) { +Status SessionManager::SetRtContext(SessionId session_id, rtContext_t rt_context) { GELOGI("set rt_context RT_CTX_NORMAL_MODE, device id:%u.", GetContext().DeviceId()); GE_CHK_RT_RET(rtCtxCreate(&rt_context, RT_CTX_NORMAL_MODE, static_cast(GetContext().DeviceId()))); GE_CHK_RT_RET(rtCtxSetCurrent(rt_context)); - RtContextUtil::GetInstance().AddrtContext(rt_context); + RtContextUtil::GetInstance().AddRtContext(session_id, rt_context); return SUCCESS; } @@ -85,7 +85,7 @@ Status SessionManager::CreateSession(const std::map &o session_id = next_session_id; // create a context - ret = SetrtContext(rtContext_t()); + ret = SetRtContext(session_id, rtContext_t()); return ret; } @@ -106,7 +106,7 @@ Status SessionManager::DestroySession(SessionId session_id) { } // Unified destruct rt_context - RtContextUtil::GetInstance().DestroyrtContexts(); + RtContextUtil::GetInstance().DestroyRtContexts(session_id); SessionPtr innerSession = it->second; Status ret = innerSession->Finalize(); @@ -300,4 +300,4 @@ bool SessionManager::IsGraphNeedRebuild(SessionId session_id, uint32_t graph_id) } return innerSession->IsGraphNeedRebuild(graph_id); } -}; // namespace ge +} // namespace ge diff --git a/src/ge/session/session_manager.h b/src/ge/session/session_manager.h index 5cce5214..5cdb849f 100644 --- a/src/ge/session/session_manager.h +++ b/src/ge/session/session_manager.h @@ -33,7 +33,6 @@ class SessionManager { friend class GELib; public: - Status SetrtContext(rtContext_t rtContext); /// /// @ingroup ge_session /// @brief create session @@ -163,10 +162,12 @@ class SessionManager { Status GetNextSessionId(SessionId &next_session_id); + Status SetRtContext(SessionId session_id, rtContext_t rtContext); + std::map session_manager_map_; std::mutex mutex_; bool init_flag_ = false; }; -}; // namespace ge +} // namespace ge #endif // GE_SESSION_SESSION_MANAGER_H_ diff --git a/src/ge/single_op/single_op.cc b/src/ge/single_op/single_op.cc index 9578471a..1a63c964 100644 --- a/src/ge/single_op/single_op.cc +++ b/src/ge/single_op/single_op.cc @@ -17,11 +17,13 @@ #include "single_op/single_op.h" #include "common/fmk_types.h" +#include "common/math/math_util.h" #include "common/profiling/profiling_manager.h" #include "framework/common/debug/ge_log.h" #include "framework/common/util.h" #include "graph/load/new_model_manager/model_utils.h" #include "runtime/mem.h" +#include "single_op/single_op_manager.h" namespace ge { namespace { @@ -50,9 +52,13 @@ Status SingleOp::ValidateArgs(const std::vector &inputs, const std:: for (size_t i = 0; i < num_inputs; ++i) { // preventing from read out of bound size_t aligned_size = GetAlignedSize(inputs[i].length); + GELOGI("Input [%zu], aligned_size:%zu, inputs.length:%u, input_sizes_:%u", i, aligned_size, inputs[i].length, + input_sizes_[i]); if (aligned_size < input_sizes_[i]) { - GELOGE(PARAM_INVALID, "Input size mismatch. index = %zu, model expect %zu, but given %zu(after align)", i, - input_sizes_[i], aligned_size); + GELOGE(PARAM_INVALID, + "Input size mismatch. index = %zu, model expect %zu," + " but given %zu(after align)", + i, input_sizes_[i], aligned_size); return PARAM_INVALID; } } @@ -66,9 +72,13 @@ Status SingleOp::ValidateArgs(const std::vector &inputs, const std:: for (size_t i = 0; i < num_outputs; ++i) { // preventing from write out of bound size_t aligned_size = GetAlignedSize(outputs[i].length); + GELOGI("Output [%zu], aligned_size:%zu, outputs.length:%u, output_sizes_:%u", i, aligned_size, outputs[i].length, + output_sizes_[i]); if (aligned_size < output_sizes_[i]) { - GELOGE(PARAM_INVALID, "Output size mismatch. index = %zu, model expect %zu, but given %zu(after align)", i, - output_sizes_[i], aligned_size); + GELOGE(PARAM_INVALID, + "Output size mismatch. index = %zu, model expect %zu," + "but given %zu(after align)", + i, output_sizes_[i], aligned_size); return PARAM_INVALID; } } @@ -81,23 +91,11 @@ Status SingleOp::GetArgs(const std::vector &inputs, const std::vecto if (use_physical_addr_) { for (auto &input : inputs) { auto *addr = reinterpret_cast(input.data); - size_t aligned_size = GetAlignedSize(input.length); - auto ret = ModelUtils::ConvertVirtualAddressToPhysical(addr, aligned_size, addr); - if (ret != SUCCESS) { - GELOGE(ret, "ConvertVirtualAddressToPhysical failed. Arg index = %zu", arg_index); - return ret; - } args_[arg_index++] = reinterpret_cast(addr); } for (auto &output : outputs) { auto *addr = reinterpret_cast(output.data); - size_t aligned_size = GetAlignedSize(output.length); - auto ret = ModelUtils::ConvertVirtualAddressToPhysical(addr, aligned_size, addr); - if (ret != SUCCESS) { - GELOGE(ret, "ConvertVirtualAddressToPhysical failed. Arg index = %zu", arg_index); - return ret; - } args_[arg_index++] = reinterpret_cast(addr); } } else { @@ -117,6 +115,7 @@ Status SingleOp::UpdateArgs(const std::vector &inputs, const std::ve if (ret != SUCCESS) { return ret; } + // update tbe task args size_t num_args = arg_table_.size(); for (size_t i = 0; i < num_args; ++i) { std::vector &ptr_to_arg_in_tasks = arg_table_[i]; @@ -129,18 +128,34 @@ Status SingleOp::UpdateArgs(const std::vector &inputs, const std::ve *arg_addr = args_[i]; } } + // update aicpu_TF or aicpu_CC args for (auto &task : tasks_) { + size_t io_addr_num = args_.size(); if (task->GetOpTaskType() == OP_TASK_AICPU) { - GELOGD("Update aicpu task args"); + GELOGD("Update aicpu_TF task args"); AiCpuTask *task_aicpu = dynamic_cast(task); GE_CHECK_NOTNULL(task_aicpu); - auto *dstIOAddr = const_cast(reinterpret_cast(task_aicpu->GetIOAddr())); - auto rt_ret = rtMemcpyAsync(dstIOAddr, sizeof(uint64_t) * args_.size(), &args_[0], + auto *dst_io_addr = const_cast(reinterpret_cast(task_aicpu->GetIOAddr())); + GE_CHECK_NOTNULL(dst_io_addr); + auto rt_ret = rtMemcpyAsync(dst_io_addr, sizeof(uint64_t) * args_.size(), &args_[0], sizeof(uint64_t) * args_.size(), RT_MEMCPY_HOST_TO_DEVICE_EX, stream_); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "rtMemcpyAsync addresses failed, ret = %d", rt_ret); return RT_FAILED; } + } else if (task->GetOpTaskType() == OP_TASK_AICPUCC) { + GELOGD("Update aicpu_CC task args"); + AiCpuCCTask *task_aicpu_cc = dynamic_cast(task); + GE_CHECK_NOTNULL(task_aicpu_cc); + const uintptr_t *task_io_addr = reinterpret_cast(task_aicpu_cc->GetIOAddr()); + GE_CHECK_NOTNULL(task_io_addr); + auto io_addr = reinterpret_cast(const_cast(task_io_addr)); + for (size_t i = 0; i < io_addr_num; ++i) { + io_addr[i] = reinterpret_cast(args_[i]); + } + } else { + GELOGW("Only TF_kernel aicpu and aicpu_CC are supported, but got %u", task->GetOpTaskType()); + continue; } } return SUCCESS; @@ -164,8 +179,90 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOp::ExecuteAsync(c return ret; } } + return ret; } void SingleOp::SetStream(rtStream_t stream) { stream_ = stream; } + +DynamicSingleOp::DynamicSingleOp(uintptr_t resource_id, rtStream_t stream) + : resource_id_(resource_id), stream_(stream) {} + +Status DynamicSingleOp::ValidateParams(const vector &input_desc, const std::vector &inputs, + std::vector &output_desc, std::vector &outputs) const { + if (inputs.size() != input_desc.size()) { + GELOGE(PARAM_INVALID, "Input number mismatches input desc number. Input num = %zu, input desc num = %zu", + inputs.size(), input_desc.size()); + return PARAM_INVALID; + } + + if (outputs.size() != output_desc.size()) { + GELOGE(PARAM_INVALID, "Output number mismatches output desc number. Output num = %zu, output desc num = %zu", + outputs.size(), output_desc.size()); + return PARAM_INVALID; + } + + if (input_desc.size() != num_inputs_) { + GELOGE(PARAM_INVALID, "Input number mismatches. expect %zu, but given %zu", num_inputs_, input_desc.size()); + return PARAM_INVALID; + } + + if (output_desc.size() != num_outputs_) { + GELOGE(PARAM_INVALID, "Output number mismatches. expect %zu, but given %zu", num_outputs_, output_desc.size()); + return PARAM_INVALID; + } + + return SUCCESS; +} + +Status DynamicSingleOp::AllocateWorkspaces(const std::vector &workspace_sizes, + std::vector &workspaces) { + static const std::string kPurpose("malloc workspace memory for dynamic op."); + if (workspace_sizes.empty()) { + GELOGD("No need to allocate workspace."); + return SUCCESS; + } + int64_t total_size = 0; + std::vector ws_offsets; + for (auto ws_size : workspace_sizes) { + // alignment and padding should be done in OpParaCalculate + GE_CHK_STATUS_RET_NOLOG(CheckInt64AddOverflow(total_size, ws_size)); + ws_offsets.emplace_back(total_size); + total_size += ws_size; + } + + GELOGD("Total workspace size is %ld", total_size); + StreamResource *stream_resource = SingleOpManager::GetInstance().GetResource(resource_id_, stream_); + GE_CHECK_NOTNULL(stream_resource); + auto ws_base = stream_resource->MallocMemory(kPurpose, static_cast(total_size)); + if (ws_base == nullptr) { + GELOGE(MEMALLOC_FAILED, "Failed to allocate memory of size: %ld", total_size); + return MEMALLOC_FAILED; + } + GELOGD("Done allocating workspace memory successfully."); + + for (auto ws_offset : ws_offsets) { + workspaces.emplace_back(ws_base + ws_offset); + } + + return SUCCESS; +} + +Status DynamicSingleOp::ExecuteAsync(const vector &input_desc, const vector &input_buffers, + vector &output_desc, vector &output_buffers) { + GE_CHECK_NOTNULL(op_task_); + GE_CHK_STATUS_RET_NOLOG(ValidateParams(input_desc, input_buffers, output_desc, output_buffers)); + GE_CHK_STATUS_RET_NOLOG(op_task_->UpdateRunInfo(input_desc, output_desc)); + std::vector workspace_buffers; + GE_CHK_STATUS_RET_NOLOG(AllocateWorkspaces(op_task_->GetWorkspaceSizes(), workspace_buffers)); + std::vector inputs; + std::vector outputs; + for (auto &buffer : input_buffers) { + inputs.emplace_back(buffer.data); + } + for (auto &buffer : output_buffers) { + outputs.emplace_back(buffer.data); + } + return op_task_->LaunchKernel(inputs, outputs, workspace_buffers, stream_); +} } // namespace ge diff --git a/src/ge/single_op/single_op.h b/src/ge/single_op/single_op.h index 08782b3b..d86c79ee 100644 --- a/src/ge/single_op/single_op.h +++ b/src/ge/single_op/single_op.h @@ -53,5 +53,26 @@ class SingleOp { std::vector> arg_table_; bool use_physical_addr_ = false; }; + +class DynamicSingleOp { + public: + DynamicSingleOp(uintptr_t resource_id, rtStream_t stream); + ~DynamicSingleOp() = default; + Status ExecuteAsync(const vector &input_desc, const std::vector &inputs, + std::vector &output_desc, std::vector &outputs); + + private: + friend class SingleOpModel; + Status ValidateParams(const vector &input_desc, const std::vector &inputs, + std::vector &output_desc, std::vector &outputs) const; + + Status AllocateWorkspaces(const std::vector &workspace_sizes, std::vector &workspaces); + + std::unique_ptr op_task_; + uintptr_t resource_id_ = 0; + rtStream_t stream_ = nullptr; + size_t num_inputs_ = 0; + size_t num_outputs_ = 0; +}; } // namespace ge #endif // GE_SINGLE_OP_SINGLE_OP_H_ diff --git a/src/ge/single_op/single_op_manager.cc b/src/ge/single_op/single_op_manager.cc index 990ca9cc..aa6f6d2b 100644 --- a/src/ge/single_op/single_op_manager.cc +++ b/src/ge/single_op/single_op_manager.cc @@ -19,9 +19,6 @@ #include #include -#include "runtime/dev.h" -#include "framework/common/debug/ge_log.h" - namespace ge { FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY SingleOpManager::~SingleOpManager() { for (auto &it : stream_resources_) { @@ -34,32 +31,15 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::GetOpFr const ModelData &model_data, void *stream, SingleOp **single_op) { + GELOGI("GetOpFromModel in. model name = %s", model_name.c_str()); if (single_op == nullptr) { GELOGE(PARAM_INVALID, "single op is null"); return PARAM_INVALID; } - uintptr_t resource_id; - // runtime uses NULL to denote a default stream for each device - if (stream == nullptr) { - // get current context - rtContext_t rt_cur_ctx = nullptr; - auto rt_err = rtCtxGetCurrent(&rt_cur_ctx); - if (rt_err != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "get current context failed, runtime result is %d", static_cast(rt_err)); - return RT_FAILED; - } - // use current context as resource key instead - GELOGI("use context as resource key instead when default stream"); - resource_id = reinterpret_cast(rt_cur_ctx); - } else { - GELOGI("use stream as resource key instead when create stream"); - resource_id = reinterpret_cast(stream); - } - - GELOGI("GetOpFromModel in. model name = %s, resource id = 0x%lx", model_name.c_str(), - static_cast(resource_id)); - StreamResource *res = GetResource(resource_id); + uintptr_t resource_id = 0; + GE_CHK_STATUS_RET(GetResourceId(stream, resource_id)); + StreamResource *res = GetResource(resource_id, stream); if (res == nullptr) { GELOGE(MEMALLOC_FAILED, "GetResource failed"); return MEMALLOC_FAILED; @@ -79,26 +59,19 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::GetOpFr return ret; } - auto *new_op = new (std::nothrow) SingleOp(); + auto new_op = std::unique_ptr(new (std::nothrow) SingleOp()); if (new_op == nullptr) { GELOGE(MEMALLOC_FAILED, "new SingleOp failed"); return MEMALLOC_FAILED; } GELOGI("To build operator: %s", model_name.c_str()); - ret = model.BuildOp(*res, *new_op); - if (ret != SUCCESS) { - GELOGE(ret, "Build op failed. op = %s, resource id = 0x%lx, ret = %u", model_name.c_str(), - static_cast(resource_id), ret); - delete new_op; - new_op = nullptr; - return ret; - } + GE_CHK_STATUS_RET(model.BuildOp(*res, *new_op), "Build op failed. op = %s, ret = %u", model_name.c_str(), ret); // stream is nullable new_op->SetStream(stream); - res->CacheOperator(model_data.model_data, new_op); - *single_op = new_op; + *single_op = new_op.get(); + res->CacheOperator(model_data.model_data, std::move(new_op)); return SUCCESS; } @@ -116,13 +89,14 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status SingleOpManager::Release return SUCCESS; } -StreamResource *SingleOpManager::GetResource(uintptr_t resource_id) { +StreamResource *SingleOpManager::GetResource(uintptr_t resource_id, rtStream_t stream) { std::lock_guard lock(mutex_); auto it = stream_resources_.find(resource_id); StreamResource *res = nullptr; if (it == stream_resources_.end()) { res = new (std::nothrow) StreamResource(); if (res != nullptr) { + res->SetStream(stream); stream_resources_.emplace(resource_id, res); } } else { @@ -141,4 +115,74 @@ StreamResource *SingleOpManager::TryGetResource(uintptr_t resource_id) { return it->second; } + +Status SingleOpManager::GetDynamicOpFromModel(const string &model_name, const ModelData &model_data, void *stream, + DynamicSingleOp **single_op) { + GE_CHECK_NOTNULL(single_op); + uintptr_t resource_id = 0; + GE_CHK_STATUS_RET(GetResourceId(stream, resource_id)); + StreamResource *res = GetResource(resource_id, stream); + if (res == nullptr) { + GELOGE(MEMALLOC_FAILED, "GetResource failed"); + return MEMALLOC_FAILED; + } + + DynamicSingleOp *op = res->GetDynamicOperator(model_data.model_data); + if (op != nullptr) { + GELOGD("Got operator from stream cache"); + *single_op = op; + return SUCCESS; + } + + if (!tiling_func_registered_) { + RegisterTilingFunc(); + } + + SingleOpModel model(model_name, model_data.model_data, model_data.model_len); + auto ret = model.Init(); + if (ret != SUCCESS) { + GELOGE(ret, "Init model failed. model = %s, ret = %u", model_name.c_str(), ret); + return ret; + } + + auto new_op = std::unique_ptr(new (std::nothrow) DynamicSingleOp(resource_id, stream)); + GE_CHECK_NOTNULL(new_op); + + GELOGI("To build operator: %s", model_name.c_str()); + GE_CHK_STATUS_RET(model.BuildDynamicOp(*new_op), "Build op failed. op = %s, ret = %u", model_name.c_str(), ret); + *single_op = new_op.get(); + res->CacheDynamicOperator(model_data.model_data, std::move(new_op)); + return SUCCESS; +} + +void SingleOpManager::RegisterTilingFunc() { + std::lock_guard lk(mutex_); + if (tiling_func_registered_) { + return; + } + + op_tiling_manager_.LoadSo(); + tiling_func_registered_ = true; +} + +Status SingleOpManager::GetResourceId(rtStream_t stream, uintptr_t &resource_id) { + // runtime uses NULL to denote a default stream for each device + if (stream == nullptr) { + // get current context + rtContext_t rt_cur_ctx = nullptr; + auto rt_err = rtCtxGetCurrent(&rt_cur_ctx); + if (rt_err != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "get current context failed, runtime result is %d", static_cast(rt_err)); + return RT_FAILED; + } + // use current context as resource key instead + GELOGI("use context as resource key instead when default stream"); + resource_id = reinterpret_cast(rt_cur_ctx); + } else { + GELOGI("use stream as resource key instead when create stream"); + resource_id = reinterpret_cast(stream); + } + + return SUCCESS; +} } // namespace ge diff --git a/src/ge/single_op/single_op_manager.h b/src/ge/single_op/single_op_manager.h index 15d32316..09ae0e4e 100644 --- a/src/ge/single_op/single_op_manager.h +++ b/src/ge/single_op/single_op_manager.h @@ -20,7 +20,7 @@ #include #include #include - +#include "common/ge/op_tiling_manager.h" #include "single_op/single_op_model.h" #include "single_op/stream_resource.h" @@ -34,16 +34,27 @@ class SingleOpManager { return instance; } - Status GetOpFromModel(const std::string &key, const ge::ModelData &model_data, void *stream, SingleOp **single_op); + Status GetOpFromModel(const std::string &model_name, const ge::ModelData &model_data, void *stream, + SingleOp **single_op); + + Status GetDynamicOpFromModel(const std::string &model_name, const ge::ModelData &model_data, void *stream, + DynamicSingleOp **dynamic_single_op); + + StreamResource *GetResource(uintptr_t resource_id, rtStream_t stream); Status ReleaseResource(void *stream); + void RegisterTilingFunc(); + private: - StreamResource *GetResource(uintptr_t resource_id); + static Status GetResourceId(rtStream_t stream, uintptr_t &resource_id); + StreamResource *TryGetResource(uintptr_t resource_id); std::mutex mutex_; + bool tiling_func_registered_ = false; std::unordered_map stream_resources_; + OpTilingManager op_tiling_manager_; }; } // namespace ge diff --git a/src/ge/single_op/single_op_model.cc b/src/ge/single_op/single_op_model.cc index 9decdf75..27958e7c 100644 --- a/src/ge/single_op/single_op_model.cc +++ b/src/ge/single_op/single_op_model.cc @@ -28,6 +28,7 @@ #include "graph/utils/tensor_utils.h" #include "runtime/rt.h" #include "task/aicpu_task_builder.h" +#include "task/aicpu_kernel_task_builder.h" #include "task/tbe_task_builder.h" using domi::TaskDef; @@ -42,12 +43,8 @@ SingleOpModel::SingleOpModel(const std::string &model_name, const void *model_da : model_name_(model_name), ori_model_data_(model_data), ori_model_size_(model_size) {} Status SingleOpModel::Init() { - auto ret = InitModel(); - if (ret != SUCCESS) { - return ret; - } - - return ParseInputsAndOutputs(); + GE_CHK_STATUS_RET_NOLOG(InitModel()); + return LoadAllNodes(); } Status SingleOpModel::InitModel() { @@ -149,7 +146,7 @@ void SingleOpModel::ParseOutputNode(const OpDescPtr &op_desc) { } } -Status SingleOpModel::ParseInputsAndOutputs() { +Status SingleOpModel::LoadAllNodes() { auto ge_model = model_helper_.GetGeModel(); GE_CHECK_NOTNULL(ge_model); Graph graph = ge_model->GetGraph(); @@ -167,19 +164,18 @@ Status SingleOpModel::ParseInputsAndOutputs() { auto node = nodes.at(i); auto op_desc = node->GetOpDesc(); GE_CHECK_NOTNULL(op_desc); - op_list_[i] = op_desc; + op_list_[i] = node; auto op_type = op_desc->GetType(); GELOGI("[%s] node[%zu] = %s, type = %s", model_name_.c_str(), i, node->GetName().c_str(), op_type.c_str()); if (op_type == DATA_TYPE || op_type == AIPP_DATA_TYPE) { - auto ret = ParseInputNode(op_desc); - if (ret != SUCCESS) { - return ret; - } + data_ops_.emplace_back(op_desc); + continue; } if (op_type == NETOUTPUT) { - ParseOutputNode(op_desc); + netoutput_op_ = op_desc; + continue; } ge_model->GetTBEKernelStore().LoadTBEKernelBinToOpDesc(op_desc); @@ -188,6 +184,14 @@ Status SingleOpModel::ParseInputsAndOutputs() { return SUCCESS; } +Status SingleOpModel::ParseInputsAndOutputs() { + for (auto &op_desc : data_ops_) { + GE_CHK_STATUS_RET_NOLOG(ParseInputNode(op_desc)); + } + ParseOutputNode(netoutput_op_); + return SUCCESS; +} + Status SingleOpModel::SetInputsAndOutputs(SingleOp &single_op) { // for lhisi const char *use_physical_address = std::getenv("GE_USE_PHYSICAL_ADDRESS"); @@ -198,11 +202,6 @@ Status SingleOpModel::SetInputsAndOutputs(SingleOp &single_op) { int arg_index = 0; for (size_t i = 0; i < input_offset_list_.size(); ++i) { auto *addr = model_params_.mem_base + input_offset_list_[i]; - auto ret = ModelUtils::ConvertVirtualAddressToPhysical(addr, input_sizes_[i], addr); - if (ret != SUCCESS) { - GELOGE(ret, "ConvertVirtualAddressToPhysical failed. Input index = %zu", i); - return ret; - } model_params_.addr_mapping_.emplace(reinterpret_cast(addr), arg_index++); single_op.input_sizes_.emplace_back(input_sizes_[i]); single_op.input_addr_list_.emplace_back(addr); @@ -210,11 +209,6 @@ Status SingleOpModel::SetInputsAndOutputs(SingleOp &single_op) { for (size_t i = 0; i < output_offset_list_.size(); ++i) { auto *addr = model_params_.mem_base + output_offset_list_[i]; - auto ret = ModelUtils::ConvertVirtualAddressToPhysical(addr, output_sizes_[i], addr); - if (ret != SUCCESS) { - GELOGE(ret, "ConvertVirtualAddressToPhysical failed. Output index = %zu", i); - return ret; - } model_params_.addr_mapping_.emplace(reinterpret_cast(addr), arg_index++); single_op.output_sizes_.emplace_back(output_sizes_[i]); single_op.output_addr_list_.emplace_back(addr); @@ -234,16 +228,34 @@ Status SingleOpModel::BuildTaskList(SingleOp &single_op) { task_def.DebugString().c_str()); auto task_type = static_cast(task_def.type()); if (task_type == RT_MODEL_TASK_KERNEL) { - GELOGD("Building TBE task"); - OpTask *task = nullptr; - auto ret = BuildKernelTask(task_def.kernel(), single_op, &task); - if (ret != SUCCESS) { - return ret; + const domi::KernelDef &kernel_def = task_def.kernel(); + const auto &context = kernel_def.context(); + auto kernel_type = static_cast(context.kernel_type()); + if (kernel_type == cce::ccKernelType::TE) { + GELOGD("Building TBE task"); + TbeOpTask *tbe_task = nullptr; + auto ret = BuildKernelTask(task_def.kernel(), &tbe_task); + if (ret != SUCCESS) { + return ret; + } + + single_op.arg_table_.resize(single_op.input_sizes_.size() + single_op.output_sizes_.size()); + ParseArgTable(tbe_task, single_op); + single_op.tasks_.emplace_back(tbe_task); + } else if (kernel_type == cce::ccKernelType::AI_CPU) { + GELOGD("Building AICPU_CC task"); + OpTask *task = nullptr; + auto ret = BuildCpuKernelTask(task_def.kernel(), &task); + if (ret != SUCCESS) { + return ret; + } + single_op.tasks_.emplace_back(task); + } else { + GELOGE(UNSUPPORTED, "Only TBE kernel and AI_CPU kernek are supported, but got %u", context.kernel_type()); + return UNSUPPORTED; } - - single_op.tasks_.emplace_back(task); } else if (task_type == RT_MODEL_TASK_KERNEL_EX) { - GELOGD("Building AICPU task"); + GELOGD("Building AICPU_TF task"); OpTask *task = nullptr; auto ret = BuildKernelExTask(task_def.kernel_ex(), single_op, &task); if (ret != SUCCESS) { @@ -278,15 +290,9 @@ void SingleOpModel::ParseArgTable(TbeOpTask *task, SingleOp &op) { } } -Status SingleOpModel::BuildKernelTask(const domi::KernelDef &kernel_def, SingleOp &single_op, OpTask **task) { +Status SingleOpModel::BuildKernelTask(const domi::KernelDef &kernel_def, TbeOpTask **task) { GE_CHECK_NOTNULL(task); const auto &context = kernel_def.context(); - auto kernel_type = static_cast(context.kernel_type()); - if (kernel_type != cce::ccKernelType::TE) { - GELOGE(UNSUPPORTED, "Only TBE kernel is supported, but got %u", context.kernel_type()); - return UNSUPPORTED; - } - auto iter = op_list_.find(context.op_index()); if (iter == op_list_.end()) { GELOGE(INTERNAL_ERROR, "op desc not found. op index = %u", context.op_index()); @@ -307,9 +313,6 @@ Status SingleOpModel::BuildKernelTask(const domi::KernelDef &kernel_def, SingleO return ret; } - single_op.arg_table_.resize(single_op.input_sizes_.size() + single_op.output_sizes_.size()); - ParseArgTable(tbe_task, single_op); - *task = tbe_task; return SUCCESS; } @@ -323,13 +326,13 @@ Status SingleOpModel::BuildKernelExTask(const domi::KernelExDef &kernel_def, Sin std::unique_ptr aicpu_task(new (std::nothrow) AiCpuTask()); if (aicpu_task == nullptr) { - GELOGE(MEMALLOC_FAILED, "create aicpu op task failed"); + GELOGE(MEMALLOC_FAILED, "create aicpu_TF op task failed"); return MEMALLOC_FAILED; } - auto builder = AiCpuTaskBuilder(iter->second, kernel_def); + auto builder = AiCpuTaskBuilder(iter->second->GetOpDesc(), kernel_def); auto ret = builder.BuildTask(*aicpu_task, model_params_); if (ret != SUCCESS) { - GELOGE(ret, "build aicpu op task failed"); + GELOGE(ret, "build aicpu_TF op task failed"); return ret; } @@ -337,16 +340,63 @@ Status SingleOpModel::BuildKernelExTask(const domi::KernelExDef &kernel_def, Sin return SUCCESS; } -Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { - auto ret = InitModelMem(resource); - if (ret != SUCCESS) { - return ret; +Status SingleOpModel::BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task) { + std::unique_ptr aicpucc_task(new (std::nothrow) AiCpuCCTask()); + if (aicpucc_task == nullptr) { + GELOGE(MEMALLOC_FAILED, "create aicpu_CC op task failed"); + return MEMALLOC_FAILED; } - ret = SetInputsAndOutputs(single_op); + auto builder = AiCpuCCTaskBuilder(kernel_def); + auto ret = builder.BuildTask(*aicpucc_task); if (ret != SUCCESS) { + GELOGE(ret, "build aicpu_CC op task failed"); return ret; } + + *task = aicpucc_task.release(); + return SUCCESS; +} + +Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { + GE_CHK_STATUS_RET_NOLOG(ParseInputsAndOutputs()); + GE_CHK_STATUS_RET_NOLOG(InitModelMem(resource)); + GE_CHK_STATUS_RET_NOLOG(SetInputsAndOutputs(single_op)); return BuildTaskList(single_op); } + +Status SingleOpModel::BuildTaskListForDynamicOp(DynamicSingleOp &single_op) { + auto ge_model = model_helper_.GetGeModel(); + GE_CHECK_NOTNULL(ge_model); + + auto tasks = ge_model->GetModelTaskDefPtr()->task(); + for (int i = 0; i < tasks.size(); ++i) { + const TaskDef &task_def = tasks[i]; + GELOGI("[%s] Task[%d], type = %u, DebugString = %s", model_name_.c_str(), i, task_def.type(), + task_def.DebugString().c_str()); + auto task_type = static_cast(task_def.type()); + if (task_type == RT_MODEL_TASK_KERNEL) { + if (single_op.op_task_ != nullptr) { + GELOGE(UNSUPPORTED, "Do not support dynamic op with multiple tasks."); + return UNSUPPORTED; + } + + TbeOpTask *task = nullptr; + GE_CHK_STATUS_RET_NOLOG(BuildKernelTask(task_def.kernel(), &task)); + single_op.op_task_.reset(task); + } else { + // skip + GELOGD("Skip task type: %d", static_cast(task_type)); + } + } + + return SUCCESS; +} + +Status SingleOpModel::BuildDynamicOp(DynamicSingleOp &single_op) { + single_op.num_inputs_ = data_ops_.size(); + single_op.num_outputs_ = netoutput_op_->GetAllInputsSize(); + ParseOpModelParams(model_helper_, model_params_); + return BuildTaskListForDynamicOp(single_op); +} } // namespace ge diff --git a/src/ge/single_op/single_op_model.h b/src/ge/single_op/single_op_model.h index 4d8aae30..caa958e5 100644 --- a/src/ge/single_op/single_op_model.h +++ b/src/ge/single_op/single_op_model.h @@ -50,9 +50,11 @@ class SingleOpModel { Status Init(); Status BuildOp(StreamResource &resource, SingleOp &single_op); + Status BuildDynamicOp(DynamicSingleOp &single_op); private: Status InitModel(); + Status LoadAllNodes(); Status ParseInputsAndOutputs(); Status SetInputsAndOutputs(SingleOp &single_op); @@ -62,8 +64,10 @@ class SingleOpModel { void ParseOutputNode(const OpDescPtr &op_desc); Status BuildTaskList(SingleOp &single_op); - Status BuildKernelTask(const domi::KernelDef &kernel_def, SingleOp &single_op, OpTask **task); + Status BuildTaskListForDynamicOp(DynamicSingleOp &dynamic_single_op); + Status BuildKernelTask(const domi::KernelDef &kernel_def, TbeOpTask **task); Status BuildKernelExTask(const domi::KernelExDef &kernel_def, SingleOp &single_op, OpTask **task); + Status BuildCpuKernelTask(const domi::KernelDef &kernel_def, OpTask **task); static void ParseOpModelParams(ModelHelper &model_helper, SingleOpModelParam ¶m); void ParseArgTable(TbeOpTask *task, SingleOp &op); @@ -74,13 +78,15 @@ class SingleOpModel { ModelHelper model_helper_; - map op_list_; + map op_list_; SingleOpModelParam model_params_; std::vector input_offset_list_; std::vector input_sizes_; std::vector output_offset_list_; std::vector output_sizes_; + std::vector data_ops_; + OpDescPtr netoutput_op_; }; } // namespace ge diff --git a/src/ge/single_op/stream_resource.cc b/src/ge/single_op/stream_resource.cc index e48afb96..703b22b2 100644 --- a/src/ge/single_op/stream_resource.cc +++ b/src/ge/single_op/stream_resource.cc @@ -23,12 +23,6 @@ namespace ge { StreamResource::~StreamResource() { - for (auto it : op_map_) { - // it's safe to delete a nullptr - delete it.second; - it.second = nullptr; - } - for (auto mem : memory_list_) { if (mem != nullptr) { auto rt_ret = rtFree(mem); @@ -44,7 +38,13 @@ StreamResource::~StreamResource() { } } -void StreamResource::CacheOperator(const void *key, SingleOp *single_op) { op_map_[key] = single_op; } +void StreamResource::CacheOperator(const void *key, std::unique_ptr &&single_op) { + op_map_[key] = std::move(single_op); +} + +void StreamResource::CacheDynamicOperator(const void *key, std::unique_ptr &&single_op) { + dynamic_op_map_[key] = std::move(single_op); +} SingleOp *StreamResource::GetOperator(const void *key) { auto it = op_map_.find(key); @@ -52,9 +52,20 @@ SingleOp *StreamResource::GetOperator(const void *key) { return nullptr; } - return it->second; + return it->second.get(); } +DynamicSingleOp *StreamResource::GetDynamicOperator(const void *key) { + auto it = dynamic_op_map_.find(key); + if (it == dynamic_op_map_.end()) { + return nullptr; + } + + return it->second.get(); +} + +void StreamResource::SetStream(rtStream_t stream) { stream_ = stream; } + uint8_t *StreamResource::DoMallocMemory(const std::string &purpose, size_t size, size_t &max_allocated, std::vector &allocated) { if (size <= max_allocated && !allocated.empty()) { @@ -62,6 +73,20 @@ uint8_t *StreamResource::DoMallocMemory(const std::string &purpose, size_t size, return allocated.back(); } + if (!allocated.empty()) { + GELOGD("Expand workspace memory size from %zu to %zu", max_allocated, size); + auto ret = rtStreamSynchronize(stream_); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtStreamSynchronize failed, ret = %d", ret); + return nullptr; + } + + auto addr = allocated.back(); + allocated.pop_back(); + (void)rtFree(addr); + max_allocated = 0; + } + uint8_t *buffer = nullptr; auto ret = rtMalloc(reinterpret_cast(&buffer), size, RT_MEMORY_HBM); if (ret != RT_ERROR_NONE) { diff --git a/src/ge/single_op/stream_resource.h b/src/ge/single_op/stream_resource.h index fc114c08..6f26c497 100644 --- a/src/ge/single_op/stream_resource.h +++ b/src/ge/single_op/stream_resource.h @@ -37,22 +37,27 @@ class StreamResource { StreamResource &operator=(const StreamResource &) = delete; StreamResource &operator=(StreamResource &&) = delete; - void CacheOperator(const void *key, SingleOp *single_op); + void CacheOperator(const void *key, std::unique_ptr &&single_op); + void CacheDynamicOperator(const void *key, std::unique_ptr &&single_op); + void SetStream(rtStream_t stream); SingleOp *GetOperator(const void *key); + DynamicSingleOp *GetDynamicOperator(const void *key); uint8_t *MallocMemory(const std::string &purpose, size_t size); uint8_t *MallocWeight(const std::string &purpose, size_t size); private: - static uint8_t *DoMallocMemory(const std::string &purpose, size_t size, size_t &max_allocated, - std::vector &allocated); + uint8_t *DoMallocMemory(const std::string &purpose, size_t size, size_t &max_allocated, + std::vector &allocated); size_t max_memory_size_ = 0; size_t max_weight_size_ = 0; std::vector memory_list_; std::vector weight_list_; - std::unordered_map op_map_; + std::unordered_map> op_map_; + std::unordered_map> dynamic_op_map_; + rtStream_t stream_ = nullptr; }; } // namespace ge diff --git a/src/ge/single_op/task/aicpu_kernel_task_builder.cc b/src/ge/single_op/task/aicpu_kernel_task_builder.cc new file mode 100644 index 00000000..936c7b67 --- /dev/null +++ b/src/ge/single_op/task/aicpu_kernel_task_builder.cc @@ -0,0 +1,56 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "single_op/task/aicpu_kernel_task_builder.h" + +namespace ge { +AiCpuCCTaskBuilder::AiCpuCCTaskBuilder(const domi::KernelDef &kernel_def) : kernel_def_(kernel_def) {} + +Status AiCpuCCTaskBuilder::SetKernelArgs(AiCpuCCTask &task) { + size_t aicpu_arg_size = kernel_def_.args_size(); + if (aicpu_arg_size <= 0) { + GELOGE(RT_FAILED, "aicpu_arg_size is invalid, value = %zu", aicpu_arg_size); + return RT_FAILED; + } + void *aicpu_args = malloc(aicpu_arg_size); + if (aicpu_args == nullptr) { + GELOGE(RT_FAILED, "malloc failed, size = %zu", aicpu_arg_size); + return RT_FAILED; + } + + task.SetKernelArgs(aicpu_args, aicpu_arg_size); + auto err = memcpy_s(aicpu_args, aicpu_arg_size, kernel_def_.args().data(), aicpu_arg_size); + if (err != EOK) { + GELOGE(RT_FAILED, "memcpy_s args failed, size = %zu, err = %d", aicpu_arg_size, err); + return RT_FAILED; + } + + task.SetIoAddr(static_cast(aicpu_args) + sizeof(aicpu::AicpuParamHead)); + return SUCCESS; +} + +Status AiCpuCCTaskBuilder::BuildTask(AiCpuCCTask &task) { + auto ret = SetKernelArgs(task); + if (ret != SUCCESS) { + return ret; + } + const std::string &so_name = kernel_def_.so_name(); + const std::string &kernel_name = kernel_def_.kernel_name(); + task.SetSoName(so_name); + task.SetkernelName(kernel_name); + return SUCCESS; +} +} // namespace ge \ No newline at end of file diff --git a/src/ge/single_op/task/aicpu_kernel_task_builder.h b/src/ge/single_op/task/aicpu_kernel_task_builder.h new file mode 100644 index 00000000..c445132e --- /dev/null +++ b/src/ge/single_op/task/aicpu_kernel_task_builder.h @@ -0,0 +1,40 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_SINGLE_OP_TASK_AICPU_KERNEL_TASK_BUILDER_H_ +#define GE_SINGLE_OP_TASK_AICPU_KERNEL_TASK_BUILDER_H_ + +#include +#include "aicpu/common/aicpu_task_struct.h" +#include "single_op/single_op.h" +#include "single_op/single_op_model.h" +#include "runtime/mem.h" + +namespace ge { +class AiCpuCCTaskBuilder { + public: + explicit AiCpuCCTaskBuilder(const domi::KernelDef &kernel_def); + ~AiCpuCCTaskBuilder() = default; + + Status BuildTask(AiCpuCCTask &task); + + private: + Status SetKernelArgs(AiCpuCCTask &task); + const domi::KernelDef &kernel_def_; +}; +} // namespace ge + +#endif // GE_SINGLE_OP_TASK_AICPUCC_TASK_BUILDER_H_ \ No newline at end of file diff --git a/src/ge/single_op/task/aicpu_task_builder.cc b/src/ge/single_op/task/aicpu_task_builder.cc index 1a4c37ca..bc2c76f6 100644 --- a/src/ge/single_op/task/aicpu_task_builder.cc +++ b/src/ge/single_op/task/aicpu_task_builder.cc @@ -129,7 +129,8 @@ Status AiCpuTaskBuilder::BuildTask(ge::AiCpuTask &task, const SingleOpModelParam task.task_info_ = kernel_def_.task_info(); task.workspace_addr_ = ws_addr_vec[0]; + auto debug_info = BuildTaskUtils::GetTaskInfo(op_desc_); + GELOGI("[TASK_INFO] %s %s", task.task_info_.c_str(), debug_info.c_str()); return SUCCESS; } - } // namespace ge diff --git a/src/ge/single_op/task/aicpu_task_builder.h b/src/ge/single_op/task/aicpu_task_builder.h index 0253ebd0..bd582a4f 100644 --- a/src/ge/single_op/task/aicpu_task_builder.h +++ b/src/ge/single_op/task/aicpu_task_builder.h @@ -36,7 +36,7 @@ class AiCpuTaskBuilder { Status SetInputOutputAddr(void **io_addr, const std::vector &addresses); Status SetFmkOpKernel(void *io_addr, void *ws_addr, STR_FWK_OP_KERNEL &kernel); - const OpDescPtr &op_desc_; + const OpDescPtr op_desc_; const domi::KernelExDef &kernel_def_; }; } // namespace ge diff --git a/src/ge/single_op/task/build_task_utils.cc b/src/ge/single_op/task/build_task_utils.cc index 883679be..9e97ee57 100644 --- a/src/ge/single_op/task/build_task_utils.cc +++ b/src/ge/single_op/task/build_task_utils.cc @@ -19,7 +19,9 @@ #include "runtime/rt.h" #include "graph/load/new_model_manager/model_utils.h" #include "graph/manager/graph_var_manager.h" +#include "graph/utils/type_utils.h" #include "framework/common/debug/ge_log.h" +#include "framework/common/types.h" namespace ge { namespace { @@ -62,4 +64,42 @@ std::vector BuildTaskUtils::GetKernelArgs(const OpDescPtr &op_desc, cons auto addresses = GetAddresses(op_desc, param); return JoinAddresses(addresses); } + +std::string BuildTaskUtils::GetTaskInfo(const OpDescPtr &op_desc) { + std::stringstream ss; + if (op_desc != nullptr) { + auto op_type = op_desc->GetType(); + if (op_type == ge::NETOUTPUT || op_type == ge::DATA) { + return ss.str(); + } + // Conv2D IN[DT_FLOAT16 NC1HWC0[256, 128, 7, 7, 16],DT_FLOAT16 FRACTAL_Z[128, 32, 16, 16]] + // OUT[DT_FLOAT16 NC1HWC0[256, 32, 7, 7, 16]] + ss << op_type << " IN["; + for (uint32_t idx = 0; idx < op_desc->GetInputsSize(); idx++) { + const GeTensorDescPtr &input = op_desc->MutableInputDesc(idx); + ss << TypeUtils::DataTypeToSerialString(input->GetDataType()) << " "; + ss << TypeUtils::FormatToSerialString(input->GetFormat()); + ss << VectorToString(input->GetShape().GetDims()); + if (idx < op_desc->GetInputsSize() - 1) { + ss << ","; + } + } + ss << "] OUT["; + + for (uint32_t idx = 0; idx < op_desc->GetOutputsSize(); idx++) { + const GeTensorDescPtr &output = op_desc->MutableOutputDesc(idx); + ss << TypeUtils::DataTypeToSerialString(output->GetDataType()) << " "; + Format out_format = output->GetFormat(); + const GeShape &out_shape = output->GetShape(); + const auto &dims = out_shape.GetDims(); + ss << TypeUtils::FormatToSerialString(out_format); + ss << VectorToString(dims); + if (idx < op_desc->GetOutputsSize() - 1) { + ss << ","; + } + } + ss << "]\n"; + } + return ss.str(); +} } // namespace ge diff --git a/src/ge/single_op/task/build_task_utils.h b/src/ge/single_op/task/build_task_utils.h index a5030e69..f5885fd2 100644 --- a/src/ge/single_op/task/build_task_utils.h +++ b/src/ge/single_op/task/build_task_utils.h @@ -18,6 +18,7 @@ #define GE_SINGLE_OP_TASK_BUILD_TASK_UTILS_H_ #include +#include #include "graph/op_desc.h" #include "single_op/single_op.h" @@ -31,6 +32,21 @@ class BuildTaskUtils { static std::vector> GetAddresses(const OpDescPtr &op_desc, const SingleOpModelParam ¶m); static std::vector JoinAddresses(const std::vector> &addresses); static std::vector GetKernelArgs(const OpDescPtr &op_desc, const SingleOpModelParam ¶m); + static std::string GetTaskInfo(const OpDescPtr &op_desc); + template + static std::string VectorToString(const std::vector &values) { + std::stringstream ss; + ss << '['; + auto size = values.size(); + for (size_t i = 0; i < size; ++i) { + ss << values[i]; + if (i != size - 1) { + ss << ", "; + } + } + ss << ']'; + return ss.str(); + } }; } // namespace ge #endif // GE_SINGLE_OP_TASK_BUILD_TASK_UTILS_H_ diff --git a/src/ge/single_op/task/op_task.cc b/src/ge/single_op/task/op_task.cc index e93fad71..ddc4992c 100644 --- a/src/ge/single_op/task/op_task.cc +++ b/src/ge/single_op/task/op_task.cc @@ -16,34 +16,48 @@ #include "single_op/task/op_task.h" +#include +#include +#include + #include "runtime/rt.h" -#include "framework/common/debug/ge_log.h" +#include "register/op_tiling.h" +#include "framework/common/debug/log.h" namespace ge { +namespace { +constexpr int kLaunchRetryTimes = 1000; +constexpr int kSleepTime = 10; +} // namespace + void TbeOpTask::SetStubFunc(const std::string &name, const void *stub_func) { this->stub_name_ = name; this->stub_func_ = stub_func; } -void TbeOpTask::SetKernelArgs(void *args, size_t arg_size, uint32_t block_dim) { - args_ = args; +void TbeOpTask::SetKernelArgs(std::unique_ptr &&args, size_t arg_size, uint32_t block_dim) { + args_ = std::move(args); arg_size_ = arg_size; block_dim_ = block_dim; } void TbeOpTask::SetSmDesc(void *sm_desc) { sm_desc_ = sm_desc; } -TbeOpTask::~TbeOpTask() { - if (args_ != nullptr) { - (void)rtFreeHost(args_); - } +const vector &OpTask::GetWorkspaceSizes() const { return workspace_sizes_; } +void OpTask::SetWorkspaceSizes(const vector &workspace_sizes) { workspace_sizes_ = workspace_sizes; } + +TbeOpTask::~TbeOpTask() { if (sm_desc_ != nullptr) { (void)rtMemFreeManaged(sm_desc_); } + + if (tiling_buffer_ != nullptr) { + (void)rtFree(tiling_buffer_); + } } -const void *TbeOpTask::GetArgs() const { return args_; } +const void *TbeOpTask::GetArgs() const { return args_.get(); } size_t TbeOpTask::GetArgSize() const { return arg_size_; } @@ -52,13 +66,118 @@ const std::string &TbeOpTask::GetStubName() const { return stub_name_; } Status TbeOpTask::LaunchKernel(rtStream_t stream) { GELOGD("To invoke rtKernelLaunch. task = %s, block_dim = %u", this->stub_name_.c_str(), block_dim_); auto *sm_desc = reinterpret_cast(sm_desc_); - auto ret = rtKernelLaunch(stub_func_, block_dim_, args_, static_cast(arg_size_), sm_desc, stream); + auto ret = rtKernelLaunch(stub_func_, block_dim_, args_.get(), static_cast(arg_size_), sm_desc, stream); + int retry_times = 0; + while (ret != RT_ERROR_NONE && retry_times < kLaunchRetryTimes) { + retry_times++; + GELOGW("Retry after %d ms, retry_times: %d", kSleepTime, retry_times); + std::this_thread::sleep_for(std::chrono::milliseconds(kSleepTime)); + ret = rtKernelLaunch(stub_func_, block_dim_, args_.get(), arg_size_, sm_desc, stream); + } + if (ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Invoke rtKernelLaunch failed. ret = %d, task = %s", ret, this->stub_name_.c_str()); return RT_FAILED; } - GELOGD("Invoke rtKernelLaunch succeeded. task = %s", this->stub_name_.c_str()); + GELOGI("[TASK_INFO] %s", this->stub_name_.c_str()); + return SUCCESS; +} + +Status TbeOpTask::UpdateRunInfo(const vector &input_desc, const vector &output_desc) { + GE_CHK_STATUS_RET_NOLOG(UpdateNodeByShape(input_desc, output_desc)); + // invoke OpParaCalculate + GELOGD("Start to invoke OpParaCalculate."); + optiling::OpRunInfo run_info; + auto ret = optiling::OpParaCalculate(*node_, run_info); + if (ret != GRAPH_SUCCESS) { + GELOGE(FAILED, "Failed to invoke OpParaCalculate. ret = %u", ret); + return FAILED; + } + SetWorkspaceSizes(run_info.workspaces); + block_dim_ = run_info.block_dim; + tiling_data_ = run_info.tiling_data.str(); + GELOGD("Done invoking OpParaCalculate successfully. block_dim = %u, tiling size = %zu", block_dim_, + tiling_data_.size()); + return SUCCESS; +} + +Status TbeOpTask::UpdateTensorDesc(const GeTensorDesc &src_tensor, GeTensorDesc &dst_tensor) { + int64_t storage_format_val = static_cast(FORMAT_RESERVED); + (void)AttrUtils::GetInt(src_tensor, ge::ATTR_NAME_STORAGE_FORMAT, storage_format_val); + auto storage_format = static_cast(storage_format_val); + if (storage_format == FORMAT_RESERVED) { + GELOGD("Storage format not set. update shape to [%s], and original shape to [%s]", + src_tensor.GetShape().ToString().c_str(), src_tensor.GetOriginShape().ToString().c_str()); + dst_tensor.SetShape(src_tensor.GetShape()); + dst_tensor.SetOriginShape(src_tensor.GetOriginShape()); + } else { + std::vector storage_shape; + if (!AttrUtils::GetListInt(src_tensor, ge::ATTR_NAME_STORAGE_SHAPE, storage_shape)) { + GELOGE(PARAM_INVALID, "Failed to get storage_shape while storage_format was set"); + return PARAM_INVALID; + } + + GELOGD("Storage format set. update shape to [%s], and original shape to [%s]", + GeShape(storage_shape).ToString().c_str(), src_tensor.GetShape().ToString().c_str()); + dst_tensor.SetShape(GeShape(std::move(storage_shape))); + dst_tensor.SetOriginShape(src_tensor.GetShape()); + } + + return SUCCESS; +} + +Status TbeOpTask::UpdateNodeByShape(const vector &input_desc, const vector &output_desc) { + auto op_desc = node_->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + // Set runtime shape to node + for (size_t i = 0; i < input_desc.size(); ++i) { + auto tensor_desc = op_desc->MutableInputDesc(i); + auto &runtime_tensor_desc = input_desc[i]; + GE_CHECK_NOTNULL(tensor_desc); + GE_CHK_STATUS_RET(UpdateTensorDesc(runtime_tensor_desc, *tensor_desc)); + } + + for (size_t i = 0; i < output_desc.size(); ++i) { + auto tensor_desc = op_desc->MutableOutputDesc(i); + auto &runtime_tensor_desc = output_desc[i]; + GE_CHECK_NOTNULL(tensor_desc); + GE_CHK_STATUS_RET(UpdateTensorDesc(runtime_tensor_desc, *tensor_desc)); + } + + return SUCCESS; +} + +void TbeOpTask::EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, size_t max_tiling_size) { + node_ = node; + tiling_buffer_ = tiling_buffer; + max_tiling_size_ = max_tiling_size; +} + +Status TbeOpTask::LaunchKernel(const vector &inputs, const vector &outputs, + const vector &workspaces, rtStream_t stream) { + GELOGD("[%s] Start to launch kernel", node_->GetName().c_str()); + std::vector args; + args.insert(args.end(), inputs.begin(), inputs.end()); + args.insert(args.end(), outputs.begin(), outputs.end()); + args.insert(args.end(), workspaces.begin(), workspaces.end()); + + if (tiling_buffer_ != nullptr) { + GELOGD("[%s] Start to copy tiling info. size = %zu", node_->GetName().c_str(), tiling_data_.size()); + GE_CHK_RT_RET(rtMemcpyAsync(tiling_buffer_, max_tiling_size_, tiling_data_.data(), tiling_data_.size(), + RT_MEMCPY_HOST_TO_DEVICE_EX, stream)); + + args.emplace_back(tiling_buffer_); + } + + if (memcpy_s(args_.get(), arg_size_, args.data(), args.size() * sizeof(void *)) != EOK) { + GELOGE(INTERNAL_ERROR, "[%s] Failed to update kernel args.", node_->GetName().c_str()); + return INTERNAL_ERROR; + } + + GELOGD("[%s] Start to invoke rtKernelLaunch", node_->GetName().c_str()); + GE_CHK_RT_RET(rtKernelLaunch(stub_func_, block_dim_, args_.get(), arg_size_, nullptr, stream)); + GELOGD("[%s] Done invoking rtKernelLaunch successfully", node_->GetName().c_str()); return SUCCESS; } @@ -88,8 +207,49 @@ Status AiCpuTask::LaunchKernel(rtStream_t stream) { GELOGE(RT_FAILED, "Invoke rtKernelLaunch failed. ret = %d, task = %s", ret, this->op_type_.c_str()); return RT_FAILED; } + GELOGI("[TASK_INFO] %s", this->task_info_.c_str()); + return SUCCESS; +} + +void AiCpuCCTask::SetKernelArgs(void *args, size_t arg_size) { + args_ = args; + arg_size_ = arg_size; + // the blockdim value is defult "1" for rtCpuKernelLaunch + block_dim_ = 1; +} - GELOGD("Invoke rtKernelLaunch succeeded. task = %s", this->op_type_.c_str()); +void AiCpuCCTask::SetSoName(const std::string &so_name) { so_name_ = so_name; } + +void AiCpuCCTask::SetkernelName(const std::string &kernel_Name) { kernel_name_ = kernel_Name; } + +void AiCpuCCTask::SetIoAddr(void *io_addr) { io_addr_ = io_addr; } + +const void *AiCpuCCTask::GetIOAddr() const { return io_addr_; } + +const void *AiCpuCCTask::GetArgs() const { return args_; } + +size_t AiCpuCCTask::GetArgSize() const { return arg_size_; } + +AiCpuCCTask::~AiCpuCCTask() { + if (args_ != nullptr) { + free(args_); + args_ = nullptr; + } +} + +Status AiCpuCCTask::LaunchKernel(rtStream_t stream) { + GELOGI("To invoke rtCpuKernelLaunch. block_dim = %u, so_name is %s, kernel_name is %s", block_dim_, so_name_.data(), + kernel_name_.data()); + // sm_desc is nullptr, because l2 buffer does not support + auto *sm_desc = reinterpret_cast(sm_desc_); + auto ret = + rtCpuKernelLaunch(static_cast(so_name_.data()), static_cast(kernel_name_.data()), + block_dim_, args_, static_cast(arg_size_), sm_desc, stream); + if (ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Invoke rtCpuKernelLaunch failed. ret = %d", ret); + return RT_FAILED; + } + GELOGD("Invoke rtCpuKernelLaunch succeeded"); return SUCCESS; } } // namespace ge diff --git a/src/ge/single_op/task/op_task.h b/src/ge/single_op/task/op_task.h index 168a71b3..3e261b3f 100644 --- a/src/ge/single_op/task/op_task.h +++ b/src/ge/single_op/task/op_task.h @@ -19,15 +19,18 @@ #include #include +#include #include "runtime/stream.h" #include "common/ge_inner_error_codes.h" #include "graph/op_kernel_bin.h" +#include "graph/node.h" namespace ge { enum OpTaskType { OP_TASK_TBE = 0, OP_TASK_AICPU, + OP_TASK_AICPUCC, OP_TASK_INVALID, }; @@ -36,7 +39,20 @@ class OpTask { OpTask() = default; virtual ~OpTask() = default; virtual Status LaunchKernel(rtStream_t stream) = 0; + virtual Status UpdateRunInfo(const vector &input_desc, const vector &output_desc) { + return UNSUPPORTED; + } + virtual Status LaunchKernel(const std::vector &inputs, const std::vector &outputs, + const std::vector &workspaces, rtStream_t stream) { + return UNSUPPORTED; + } virtual OpTaskType GetOpTaskType() = 0; + + const vector &GetWorkspaceSizes() const; + void SetWorkspaceSizes(const vector &workspace_sizes); + + private: + std::vector workspace_sizes_; }; class TbeOpTask : public OpTask { @@ -47,18 +63,33 @@ class TbeOpTask : public OpTask { void SetSmDesc(void *sm_desc); void SetStubFunc(const std::string &name, const void *stub_func); - void SetKernelArgs(void *args, size_t arg_size, uint32_t block_dim); + void SetKernelArgs(std::unique_ptr &&args, size_t arg_size, uint32_t block_dim); + + Status UpdateRunInfo(const vector &input_desc, const vector &output_desc) override; + + Status LaunchKernel(const vector &inputs, const vector &outputs, const vector &workspaces, + rtStream_t stream) override; + const void *GetArgs() const; size_t GetArgSize() const; const std::string &GetStubName() const; + void EnableDynamicSupport(const NodePtr &node, void *tiling_buffer, size_t max_tiling_size); private: + static Status UpdateTensorDesc(const GeTensorDesc &src_tensor, GeTensorDesc &dst_tensor); + Status UpdateNodeByShape(const vector &input_desc, const vector &output_desc); + const void *stub_func_ = nullptr; - void *args_ = nullptr; + std::unique_ptr args_; size_t arg_size_ = 0; uint32_t block_dim_ = 1; void *sm_desc_ = nullptr; std::string stub_name_; + + void *tiling_buffer_ = nullptr; + uint32_t max_tiling_size_ = 0; + std::string tiling_data_; + NodePtr node_; }; class AiCpuTask : public OpTask { @@ -79,6 +110,34 @@ class AiCpuTask : public OpTask { std::string op_type_; void *io_addr_ = nullptr; }; + +class AiCpuCCTask : public OpTask { + public: + AiCpuCCTask() = default; + ~AiCpuCCTask() override; + AiCpuCCTask(const AiCpuCCTask &) = delete; + AiCpuCCTask &operator=(const AiCpuCCTask &) = delete; + + Status LaunchKernel(rtStream_t stream) override; + OpTaskType GetOpTaskType() override { return OP_TASK_AICPUCC; } + const void *GetIOAddr() const; + const void *GetArgs() const; + void SetKernelArgs(void *args, size_t arg_size); + void SetSoName(const std::string &so_name); + void SetkernelName(const std::string &kernel_Name); + void SetIoAddr(void *io_addr); + size_t GetArgSize() const; + + private: + friend class AiCpuCCTaskBuilder; + std::string so_name_; + std::string kernel_name_; + void *args_ = nullptr; + size_t arg_size_ = 0; + uint32_t block_dim_ = 1; + void *sm_desc_ = nullptr; + void *io_addr_ = nullptr; +}; } // namespace ge #endif // GE_SINGLE_OP_TASK_OP_TASK_H_ diff --git a/src/ge/single_op/task/tbe_task_builder.cc b/src/ge/single_op/task/tbe_task_builder.cc index c0f6877f..23c023fd 100644 --- a/src/ge/single_op/task/tbe_task_builder.cc +++ b/src/ge/single_op/task/tbe_task_builder.cc @@ -17,20 +17,18 @@ #include "single_op/task/tbe_task_builder.h" #include -#include #include -#include "common/helper/model_helper.h" -#include "framework/common/debug/ge_log.h" #include "graph/load/new_model_manager/model_utils.h" #include "graph/debug/ge_attr_define.h" -#include "graph/load/new_model_manager/task_info/task_info.h" #include "graph/manager/graph_var_manager.h" #include "runtime/rt.h" #include "single_op/task/build_task_utils.h" namespace ge { namespace { +constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; +constexpr char const *kAttrOpParamSize = "op_para_size"; std::mutex g_reg_mutex; inline void GetKernelName(const OpDescPtr &op_desc, std::string &kernel_name) { @@ -85,9 +83,11 @@ bool KernelBinRegistry::AddKernel(const std::string &stub_name, const KernelHold return ret.second; } -TbeTaskBuilder::TbeTaskBuilder(const std::string &model_name, const OpDescPtr &op_desc, - const domi::KernelDef &kernel_def) - : op_desc_(op_desc), kernel_def_(kernel_def), stub_name_(model_name + "/" + op_desc->GetName() + "_tvmbin") {} +TbeTaskBuilder::TbeTaskBuilder(const std::string &model_name, const NodePtr &node, const domi::KernelDef &kernel_def) + : node_(node), + op_desc_(node->GetOpDesc()), + kernel_def_(kernel_def), + stub_name_(model_name + "/" + node->GetName() + "_tvmbin") {} Status TbeTaskBuilder::DoRegisterBinary(const OpKernelBin &kernel_bin, void **bin_handle, const SingleOpModelParam ¶m) const { @@ -246,17 +246,11 @@ Status TbeTaskBuilder::GetSmDesc(void **sm_desc, const SingleOpModelParam ¶m } Status TbeTaskBuilder::SetKernelArgs(TbeOpTask &task, const SingleOpModelParam ¶m) { - uint8_t *args = nullptr; size_t arg_size = kernel_def_.args_size(); - auto rtRet = rtMallocHost(reinterpret_cast(&args), arg_size); - if (rtRet != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "rtMallocHost failed, size = %zu, ret = %d", arg_size, static_cast(rtRet)); - return RT_FAILED; - } - - task.SetKernelArgs(args, arg_size, kernel_def_.block_dim()); + auto args = std::unique_ptr(new (std::nothrow) uint8_t[arg_size]); + GE_CHECK_NOTNULL(args); - rtRet = rtMemcpy(args, arg_size, kernel_def_.args().data(), arg_size, RT_MEMCPY_HOST_TO_HOST); + auto rtRet = rtMemcpy(args.get(), arg_size, kernel_def_.args().data(), arg_size, RT_MEMCPY_HOST_TO_HOST); if (rtRet != RT_ERROR_NONE) { GELOGE(RT_FAILED, "rtMemcpy args failed, size = %zu, ret = %d", arg_size, static_cast(rtRet)); return RT_FAILED; @@ -266,16 +260,23 @@ Status TbeTaskBuilder::SetKernelArgs(TbeOpTask &task, const SingleOpModelParam & const auto *args_offset_tmp = reinterpret_cast(context.args_offset().data()); uint16_t offset = *args_offset_tmp; - // copy args - std::vector tensor_device_addr_vec = BuildTaskUtils::GetKernelArgs(op_desc_, param); - void *src_addr = reinterpret_cast(tensor_device_addr_vec.data()); - uint64_t src_len = sizeof(void *) * tensor_device_addr_vec.size(); - rtRet = rtMemcpy(args + offset, arg_size - offset, src_addr, src_len, RT_MEMCPY_HOST_TO_HOST); - if (rtRet != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "rtMemcpy addresses failed, ret = %d", static_cast(rtRet)); - return RT_FAILED; + bool is_dynamic = false; + (void)AttrUtils::GetBool(op_desc_, kAttrSupportDynamicShape, is_dynamic); + if (is_dynamic) { + GE_CHK_STATUS_RET_NOLOG(InitTilingInfo(task)); + } else { + // copy args + std::vector tensor_device_addr_vec = BuildTaskUtils::GetKernelArgs(op_desc_, param); + void *src_addr = reinterpret_cast(tensor_device_addr_vec.data()); + uint64_t src_len = sizeof(void *) * tensor_device_addr_vec.size(); + rtRet = rtMemcpy(args.get() + offset, arg_size - offset, src_addr, src_len, RT_MEMCPY_HOST_TO_HOST); + if (rtRet != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "rtMemcpy addresses failed, ret = %d", static_cast(rtRet)); + return RT_FAILED; + } } + task.SetKernelArgs(std::move(args), arg_size, kernel_def_.block_dim()); return SUCCESS; } @@ -290,6 +291,8 @@ Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶ if (ret != SUCCESS) { return ret; } + auto task_info = BuildTaskUtils::GetTaskInfo(op_desc_); + GELOGI("[TASK_INFO] %s %s", stub_name_.c_str(), task_info.c_str()); void *stub_func = nullptr; auto rtRet = rtGetFunctionByName(stub_name_.c_str(), &stub_func); @@ -301,4 +304,23 @@ Status TbeTaskBuilder::BuildTask(TbeOpTask &task, const SingleOpModelParam ¶ task.SetStubFunc(stub_name_, stub_func); return SUCCESS; } + +Status TbeTaskBuilder::InitTilingInfo(TbeOpTask &task) { + GELOGD("Start alloc tiling data of node %s.", op_desc_->GetName().c_str()); + int64_t max_size = -1; + (void)AttrUtils::GetInt(op_desc_, kAttrOpParamSize, max_size); + GELOGD("Got op param size by key: %s, ret = %ld", kAttrOpParamSize, max_size); + if (max_size <= 0) { + GELOGE(PARAM_INVALID, "[%s] Invalid op_param_size: %ld.", op_desc_->GetName().c_str(), max_size); + return PARAM_INVALID; + } + + void *tiling_buffer = nullptr; + GE_CHK_RT_RET(rtMalloc(&tiling_buffer, static_cast(max_size), RT_MEMORY_HBM)); + GE_CHECK_NOTNULL(tiling_buffer); + GELOGD("[%s] Done allocating tiling buffer, size=%ld.", op_desc_->GetName().c_str(), max_size); + + task.EnableDynamicSupport(node_, tiling_buffer, static_cast(max_size)); + return SUCCESS; +} } // namespace ge diff --git a/src/ge/single_op/task/tbe_task_builder.h b/src/ge/single_op/task/tbe_task_builder.h index 5e0965bf..7c5f8054 100644 --- a/src/ge/single_op/task/tbe_task_builder.h +++ b/src/ge/single_op/task/tbe_task_builder.h @@ -65,12 +65,13 @@ class KernelBinRegistry { class TbeTaskBuilder { public: - TbeTaskBuilder(const std::string &model_name, const OpDescPtr &op_desc, const domi::KernelDef &kernel_def); + TbeTaskBuilder(const std::string &model_name, const NodePtr &node, const domi::KernelDef &kernel_def); ~TbeTaskBuilder() = default; Status BuildTask(TbeOpTask &task, const SingleOpModelParam ¶m); private: + Status InitTilingInfo(TbeOpTask &task); Status SetKernelArgs(TbeOpTask &task, const SingleOpModelParam ¶m); Status GetSmDesc(void **sm_desc, const SingleOpModelParam ¶m) const; @@ -82,7 +83,8 @@ class TbeTaskBuilder { static Status DoRegisterFunction(void *bin_handle, const char *stub_name, const char *kernel_name); - const OpDescPtr &op_desc_; + const NodePtr node_; + const OpDescPtr op_desc_; const domi::KernelDef &kernel_def_; const std::string stub_name_; }; diff --git a/src/ge/stub/Makefile b/src/ge/stub/Makefile deleted file mode 100644 index a0b35b42..00000000 --- a/src/ge/stub/Makefile +++ /dev/null @@ -1,6 +0,0 @@ -inc_path := $(shell pwd)/inc/external/ -out_path := $(shell pwd)/out/atc/lib64/stub/ -stub_path := $(shell pwd)/framework/domi/stub/ - -mkdir_stub := $(shell mkdir -p $(out_path)) -local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) diff --git a/src/ge/stub/README b/src/ge/stub/README deleted file mode 100644 index ca98ce85..00000000 --- a/src/ge/stub/README +++ /dev/null @@ -1,4 +0,0 @@ -################################################################################### -the directory (stub) saves the stub file -gen_stubapi.py is using for retrieving API and generating stub functions -################################################################################### diff --git a/src/ge/stub/gen_stubapi.py b/src/ge/stub/gen_stubapi.py deleted file mode 100644 index 6185c479..00000000 --- a/src/ge/stub/gen_stubapi.py +++ /dev/null @@ -1,573 +0,0 @@ -import os -import re -import sys -import logging - -logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s', - level=logging.INFO) - -""" - this attr is used for symbol table visible -""" -GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY' - -""" - generate stub func body by return type -""" -RETURN_STATEMENTS = { - 'graphStatus': ' return GRAPH_SUCCESS;', - 'Status': ' return SUCCESS;', - 'Graph': ' return Graph();', - 'Graph&': ' return *this;', - 'Format': ' return Format();', - 'Format&': ' return *this;', - 'Shape': ' return Shape();', - 'Shape&': ' return *this;', - 'TensorDesc': ' return TensorDesc();', - 'TensorDesc&': ' return *this;', - 'Tensor': ' return Tensor();', - 'Tensor&': ' return *this;', - 'Operator': ' return Operator();', - 'Operator&': ' return *this;', - 'Ptr': ' return nullptr;', - 'std::string': ' return "";', - 'std::string&': ' return "";', - 'string': ' return "";', - 'int': ' return 0;', - 'DataType': ' return DT_FLOAT;', - 'InferenceContextPtr': ' return nullptr;', - 'SubgraphBuilder': ' return nullptr;', - 'OperatorImplPtr': ' return nullptr;', - 'OutHandler': ' return nullptr;', - 'std::vector': ' return {};', - 'std::vector': ' return {};', - 'std::map': ' return {};', - 'uint32_t': ' return 0;', - 'int64_t': ' return 0;', - 'uint64_t': ' return 0;', - 'size_t': ' return 0;', - 'float': ' return 0.0f;', - 'bool': ' return false;', -} - -""" - max code len per line in hua_wei software programming specifications -""" -max_code_len_per_line = 100 - -""" - white_list_for_debug, include_dir_key_words is to - determines which header files to generate cc files from - when DEBUG on -""" -white_list_for_debug = ["operator.h", "tensor.h", - "graph.h", "operator_factory.h", - "ge_ir_build.h"] -include_dir_key_words = ["ge", "graph"] -DEBUG = True - - -def need_generate_func(func_line): - """ - :param func_line: - :return: - """ - if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \ - or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"): - return False - return True - - -def file_endswith_white_list_suffix(file): - """ - :param file: - :return: - """ - if DEBUG: - for suffix in white_list_for_debug: - if file.endswith(suffix): - return True - return False - else: - return True - - -""" - belows are patterns used for analyse .h file -""" -# pattern function -pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after -([a-zA-Z~_] # void int likely -.* -[)] #we find ) -(?!.*{) # we do not want the case int abc() const { return 1;} -.*) -(;.*) #we want to find ; and after for we will replace these later -\n$ -""", re.VERBOSE | re.MULTILINE | re.DOTALL) - -# pattern comment -pattern_comment = re.compile(r'^\s*//') -pattern_comment_2_start = re.compile(r'^\s*/[*]') -pattern_comment_2_end = re.compile(r'[*]/\s*$') -# pattern define -pattern_define = re.compile(r'^\s*#define') -pattern_define_return = re.compile(r'\\\s*$') -# blank line -pattern_blank_line = re.compile(r'^\s*$') -# virtual,explicit,friend,static -pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)') -# lead space -pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]') -# functions will have patterns such as func ( or func( -# but operator is an exception; the class name is preceded by an operator, and the above mode does not exist -# format like :"operator = ()" -pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]') -# template -pattern_template = re.compile(r'^\s*template') -pattern_template_end = re.compile(r'>\s*$') -# namespace -pattern_namespace = re.compile(r'namespace.*{') -# class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with -pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+ 0 and not friend_match: - line, func_name = self.handle_class_member_func(line, template_string) - # Normal functions - else: - line, func_name = self.handle_normal_func(line, template_string) - - need_generate = need_generate_func(line) - # func body - line += self.implement_function(line) - # comment - line = self.gen_comment(start_i) + line - # write to out file - self.write_func_content(line, func_name, need_generate) - # next loop - self.line_index += 1 - - logging.info('Added %s functions', len(self.func_list_exist)) - logging.info('Successfully converted,please see ' + self.output_file) - - def handle_func1(self, line): - """ - :param line: - :return: - """ - find1 = re.search('[(]', line) - if not find1: - self.line_index += 1 - return "continue", line, None - find2 = re.search('[)]', line) - start_i = self.line_index - space_match = pattern_leading_space.search(line) - # deal with - # int abc(int a, - # int b) - if find1 and (not find2): - self.line_index += 1 - line2 = self.input_content[self.line_index] - if space_match: - line2 = re.sub('^' + space_match.group(1), '', line2) - line += line2 - while self.line_index < len(self.input_content) and (not re.search('[)]', line2)): - self.line_index += 1 - line2 = self.input_content[self.line_index] - line2 = re.sub('^' + space_match.group(1), '', line2) - line += line2 - - match_start = pattern_start.search(self.input_content[self.line_index]) - match_end = pattern_end.search(self.input_content[self.line_index]) - if match_start: # like ) { or ) {} int the last line - if not match_end: - self.stack.append('normal_now') - ii = start_i - while ii <= self.line_index: - ii += 1 - self.line_index += 1 - return "continue", line, start_i - logging.info("line[%s]", line) - # ' int abc();'->'int abc()' - (line, match) = pattern_func.subn(r'\2\n', line) - logging.info("line[%s]", line) - # deal with case: - # 'int \n abc(int a, int b)' - if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]): - line = self.input_content[start_i - 1] + line - line = line.lstrip() - if not match: - self.line_index += 1 - return "continue", line, start_i - return "pass", line, start_i - - def handle_stack(self, match_start): - """ - :param match_start: - :return: - """ - line = self.input_content[self.line_index] - match_end = pattern_end.search(line) - if match_start: - self.stack.append('normal_now') - if match_end: - top_status = self.stack.pop() - if top_status == 'namespace_now': - self.output_fd.write(line + '\n') - elif top_status == 'class_now': - self.stack_class.pop() - self.stack_template.pop() - if match_start or match_end: - self.line_index += 1 - return "continue" - - if len(self.stack) > 0 and self.stack[-1] == 'normal_now': - self.line_index += 1 - return "continue" - return "pass" - - def handle_class(self, template_string, line, match_start, match_class): - """ - :param template_string: - :param line: - :param match_start: - :param match_class: - :return: - """ - if match_class: # we face a class - self.stack_template.append(template_string) - self.stack.append('class_now') - class_name = match_class.group(3) - - # class template specializations: class A > - if '<' in class_name: - k = line.index('<') - fit = 1 - for ii in range(k + 1, len(line)): - if line[ii] == '<': - fit += 1 - if line[ii] == '>': - fit -= 1 - if fit == 0: - break - class_name += line[k + 1:ii + 1] - logging.info('class_name[%s]', class_name) - self.stack_class.append(class_name) - while not match_start: - self.line_index += 1 - line = self.input_content[self.line_index] - match_start = pattern_start.search(line) - self.line_index += 1 - return "continue" - return "pass" - - def handle_template(self): - line = self.input_content[self.line_index] - match_template = pattern_template.search(line) - template_string = '' - if match_template: - match_template_end = pattern_template_end.search(line) - template_string = line - while not match_template_end: - self.line_index += 1 - line = self.input_content[self.line_index] - template_string += line - match_template_end = pattern_template_end.search(line) - self.line_index += 1 - return template_string - - def handle_namespace(self): - line = self.input_content[self.line_index] - match_namespace = pattern_namespace.search(line) - if match_namespace: # we face namespace - self.output_fd.write(line + '\n') - self.stack.append('namespace_now') - self.line_index += 1 - - def handle_normal_func(self, line, template_string): - template_line = '' - self.stack_template.append(template_string) - if self.stack_template[-1] != '': - template_line = re.sub(r'\s*template', 'template', self.stack_template[-1]) - # change '< class T = a, class U = A(3)>' to '' - template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) - template_line = re.sub(r'\s*=.*,', ',', template_line) - template_line = re.sub(r'\s*=.*', '', template_line) - line = re.sub(r'\s*=.*,', ',', line) - line = re.sub(r'\s*=.*\)', ')', line) - line = template_line + line - self.stack_template.pop() - func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() - logging.info("line[%s]", line) - logging.info("func_name[%s]", func_name) - return line, func_name - - def handle_class_member_func(self, line, template_string): - template_line = '' - x = '' - if template_string != '': - template_string = re.sub(r'\s*template', 'template', template_string) - template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string) - template_string = re.sub(r'\s*=.*,', ',', template_string) - template_string = re.sub(r'\s*=.*', '', template_string) - if self.stack_template[-1] != '': - if not (re.search(r'<\s*>', stack_template[-1])): - template_line = re.sub(r'^\s*template', 'template', stack_template[-1]) - if not (re.search(r'<.*>', self.stack_class[-1])): - # for x we get like template -> - x = re.sub(r'template\s*<', '<', template_line) # remove template -> - x = re.sub(r'\n', '', x) - x = re.sub(r'\s*=.*,', ',', x) - x = re.sub(r'\s*=.*\>', '>', x) - x = x.rstrip() # remove \n - x = re.sub(r'(class|typename)\s+|(|\s*class)', '', - x) # remove class,typename -> - x = re.sub(r'<\s+', '<', x) - x = re.sub(r'\s+>', '>', x) - x = re.sub(r'\s+,', ',', x) - x = re.sub(r',\s+', ', ', x) - line = re.sub(r'\s*=\s+0', '', line) - line = re.sub(r'\s*=\s+.*,', ',', line) - line = re.sub(r'\s*=\s+.*\)', ')', line) - logging.info("x[%s]\nline[%s]", x, line) - # if the function is long, void ABC::foo() - # breaks into two lines void ABC::\n foo() - temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1) - if len(temp_line) > max_code_len_per_line: - line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1) - else: - line = temp_line - logging.info("line[%s]", line) - # add template as the above if there is one - template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) - template_line = re.sub(r'\s*=.*,', ',', template_line) - template_line = re.sub(r'\s*=.*', '', template_line) - line = template_line + template_string + line - func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() - logging.info("line[%s]", line) - logging.info("func_name[%s]", func_name) - return line, func_name - - def write_func_content(self, content, func_name, need_generate): - if not (func_name in self.func_list_exist) and need_generate: - self.output_fd.write(content) - self.func_list_exist.append(func_name) - logging.info('add func:[%s]', func_name) - - def gen_comment(self, start_i): - comment_line = '' - # Function comments are on top of function declarations, copy them over - k = start_i - 1 # one line before this func start - if pattern_template.search(self.input_content[k]): - k -= 1 - if pattern_comment_2_end.search(self.input_content[k]): - comment_line = self.input_content[k].lstrip() - while not pattern_comment_2_start.search(self.input_content[k]): - k -= 1 - comment_line = self.input_content[k].lstrip() + comment_line - else: - for j in range(k, 0, -1): - c_line = self.input_content[j] - if pattern_comment.search(c_line): - c_line = re.sub(r'\s*//', '//', c_line) - comment_line = c_line + comment_line - else: - break - return comment_line - - @staticmethod - def implement_function(func): - function_def = '' - function_def += '{\n' - - all_items = func.split() - start = 0 - return_type = all_items[start] - if return_type == "const": - start += 1 - return_type = all_items[start] - if return_type.startswith(('std::map', 'std::set', 'std::vector')): - return_type = "std::map" - if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): - return_type = "Ptr" - if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): - return_type += "&" - if RETURN_STATEMENTS.__contains__(return_type): - function_def += RETURN_STATEMENTS[return_type] - else: - logging.warning("Unhandled return type[%s]", return_type) - - function_def += '\n' - function_def += '}\n' - function_def += '\n' - return function_def - - -def collect_header_files(path): - """ - :param path: - :return: - """ - header_files = [] - shared_includes_content = [] - for root, dirs, files in os.walk(path): - files.sort() - for file in files: - if file.find("git") >= 0: - continue - if not file.endswith('.h'): - continue - file_path = os.path.join(root, file) - file_path = file_path.replace('\\', '/') - header_files.append(file_path) - include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:]) - shared_includes_content.append(include_str) - return header_files, shared_includes_content - - -def generate_stub_file(inc_dir, out_cc_dir): - """ - :param inc_dir: - :param out_cc_dir: - :return: - """ - target_header_files, shared_includes_content = collect_header_files(inc_dir) - for header_file in target_header_files: - if not file_endswith_white_list_suffix(header_file): - continue - cc_file = re.sub('.h*$', '.cc', header_file) - h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content) - h_2_cc.h2cc() - - -def gen_code(inc_dir, out_cc_dir): - """ - :param inc_dir: - :param out_cc_dir: - :return: - """ - if not inc_dir.endswith('/'): - inc_dir += '/' - if not out_cc_dir.endswith('/'): - out_cc_dir += '/' - for include_dir_key_word in include_dir_key_words: - generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir) - - -if __name__ == '__main__': - inc_dir = sys.argv[1] - out_cc_dir = sys.argv[2] - gen_code(inc_dir, out_cc_dir) diff --git a/tests/depends/cce/src/op_kernel_registry.cc b/tests/depends/cce/src/op_kernel_registry.cc index 9bb32a31..5ccd1391 100644 --- a/tests/depends/cce/src/op_kernel_registry.cc +++ b/tests/depends/cce/src/op_kernel_registry.cc @@ -1,19 +1,3 @@ -/** - * Copyright 2019 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - #include "register/op_kernel_registry.h" namespace ge { diff --git a/tests/st/resnet50/common.cc b/tests/st/resnet50/common.cc old mode 100755 new mode 100644 diff --git a/tests/ut/ge/graph/passes/flow_ctrl_pass_unittest.cc b/tests/ut/ge/graph/passes/flow_ctrl_pass_unittest.cc old mode 100755 new mode 100644 diff --git a/tests/ut/ge/graph/passes/folding_kernel/expanddims_kernel_unittest.cc b/tests/ut/ge/graph/passes/folding_kernel/expanddims_kernel_unittest.cc old mode 100755 new mode 100644 diff --git a/tests/ut/ge/graph/passes/merge_pass_unittest.cc b/tests/ut/ge/graph/passes/merge_pass_unittest.cc old mode 100755 new mode 100644 diff --git a/tests/ut/ge/graph/passes/net_output_pass_unittest.cc b/tests/ut/ge/graph/passes/net_output_pass_unittest.cc old mode 100755 new mode 100644 diff --git a/tests/ut/ge/graph/passes/snapshot_pass_unittest.cc b/tests/ut/ge/graph/passes/snapshot_pass_unittest.cc old mode 100755 new mode 100644 diff --git a/tests/ut/ge/single_op/single_op_manager_unittest.cc b/tests/ut/ge/single_op/single_op_manager_unittest.cc old mode 100755 new mode 100644 diff --git a/tests/ut/ge/single_op/single_op_model_unittest.cc b/tests/ut/ge/single_op/single_op_model_unittest.cc old mode 100755 new mode 100644 diff --git a/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h b/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h index 35134faa..023812dd 100644 --- a/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h +++ b/third_party/fwkacllib/inc/cce/fwk_adpt_struct.h @@ -37,6 +37,7 @@ enum FWKAdptAPIRetCode { FWK_ADPT_SESSION_NOT_EXIST = 10, // session id not exist FWK_ADPT_SESSION_ALREADY_EXIST = 11, // session id alread exist for create session FWK_ADPT_NATIVE_END_OF_SEQUENCE = 12, // end of sequence + FWK_ADPT_EXTEND_TYPE_NOT_EXIST = 13, // extend info type not exist FWK_ADPT_UNKNOWN_ERROR = 99 // unknown error code }; @@ -55,9 +56,17 @@ enum FWKTaskExtInfoType { FWK_ADPT_EXT_SHAPE_TYPE = 0, FWK_ADPT_EXT_INPUT_SHAPE, FWK_ADPT_EXT_OUTPUT_SHAPE, + FWK_ADPT_EXT_UPDATE_ADDR, FWK_ADPT_EXT_INVALID }; +enum FWKExtUpdateAddrType { + FWK_ADPT_UPDATE_NULL = 0, + FWK_ADPT_UPDATE_INPUT, + FWK_ADPT_UPDATE_OUTPUT, + FWK_ADPT_UPDATE_INPUT_OUTPUT +}; + // API Parameter Structure struct StrFWKKernel { FWKOperateType opType; diff --git a/third_party/fwkacllib/inc/hccl/base.h b/third_party/fwkacllib/inc/hccl/base.h index 74163baf..1d83d7bf 100644 --- a/third_party/fwkacllib/inc/hccl/base.h +++ b/third_party/fwkacllib/inc/hccl/base.h @@ -102,6 +102,11 @@ typedef enum tagHcclDataType { HCCL_DATA_TYPE_RESERVED /**< reserved */ } hcclDataType_t; +constexpr u32 HCCL_UNIQUE_ID_BYTES = 2060; // 2060: unique id length +using hcclUniqueId = struct hcclUniqueIdDef { + char internal[HCCL_UNIQUE_ID_BYTES]; +}; + const u32 HCCL_MAX_SEGMENT_NUM = 8; // The max number of gradient segments. /** @@ -120,6 +125,12 @@ enum GradSplitForceMode { FORCE_RESERVED /**< reserved */ }; +enum OriginalGraphShapeType { + KNOWN_SHAPE, + UNKNOWN_SHAPE, + SHAPE_RESERVED /**< reserved */ +}; + /** * @brief stream handle. */ diff --git a/third_party/fwkacllib/inc/hccl/hcom.h b/third_party/fwkacllib/inc/hccl/hcom.h index a448d411..19bf4fb3 100644 --- a/third_party/fwkacllib/inc/hccl/hcom.h +++ b/third_party/fwkacllib/inc/hccl/hcom.h @@ -22,7 +22,6 @@ #ifndef HCOM_H_ #define HCOM_H_ -#include #include #ifdef __cplusplus @@ -246,8 +245,9 @@ hcclResult_t hcom_receive(const char *tag, void *outputPtr, u64 count, hcclDataT * @param segmentIdx A list identifying the index of end gradient in each segment. * @return hcclResult_t */ -hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, - u32 maxSegmentNum, u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force = FORCE_NONE); +hcclResult_t hcom_get_split_strategy(const char *group, const struct model_feature *feature, u32 maxSegmentNum, + u32 *segmentNum, u32 *segmentIdx, GradSplitForceMode force = FORCE_NONE, + OriginalGraphShapeType shapeType = KNOWN_SHAPE); /** * @brief Set the gradient split strategy with in the group, according to gradient index. diff --git a/third_party/fwkacllib/inc/mmpa/mmpa_api.h b/third_party/fwkacllib/inc/mmpa/mmpa_api.h index f1e30538..ce1c9720 100644 --- a/third_party/fwkacllib/inc/mmpa/mmpa_api.h +++ b/third_party/fwkacllib/inc/mmpa/mmpa_api.h @@ -20,7 +20,7 @@ #define LINUX 0 #define WIN 1 -#if(OS_TYPE == LINUX) +#if(OS_TYPE == LINUX) //lint !e553 #ifndef _GNU_SOURCE #define _GNU_SOURCE @@ -84,7 +84,7 @@ #endif -#if(OS_TYPE == WIN) +#if(OS_TYPE == WIN) //lint !e553 #include #include #include "Windows.h" diff --git a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h index ce83d143..6ac8f8f6 100644 --- a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h +++ b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_linux.h @@ -344,6 +344,8 @@ extern INT32 mmRealPath(const CHAR *path, CHAR *realPath, INT32 realPathLen); extern INT32 mmDup2(INT32 oldFd, INT32 newFd); +extern INT32 mmDup(INT32 fd); + extern INT32 mmUnlink(const CHAR *filename); extern INT32 mmChmod(const CHAR *filename, INT32 mode); diff --git a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h index ef15f371..68a70c27 100644 --- a/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h +++ b/third_party/fwkacllib/inc/mmpa/sub_inc/mmpa_win.h @@ -378,6 +378,7 @@ _declspec(dllexport) INT32 mmGetRealPath(CHAR *path, CHAR *realPath); _declspec(dllexport) INT32 mmRealPath(const CHAR *path, CHAR *realPath, INT32 realPathLen); _declspec(dllexport) INT32 mmDup2(INT32 oldFd, INT32 newFd); +_declspec(dllexport) INT32 mmDup(INT32 fd); _declspec(dllexport) INT32 mmUnlink(const CHAR *filename); _declspec(dllexport) INT32 mmChmod(const CHAR *filename, INT32 mode); _declspec(dllexport) INT32 mmFileno(FILE *stream); diff --git a/third_party/fwkacllib/inc/ops/all_ops.h b/third_party/fwkacllib/inc/ops/all_ops.h index 031e955c..c30bf32b 100644 --- a/third_party/fwkacllib/inc/ops/all_ops.h +++ b/third_party/fwkacllib/inc/ops/all_ops.h @@ -31,7 +31,9 @@ #include "functional_ops.h" #include "get_data_ops.h" #include "hcom_ops.h" +#include "hvd_ops.h" #include "image_ops.h" +#include "internal_ops.h" #include "linalg_ops.h" #include "logging_ops.h" #include "lookup_ops.h" diff --git a/third_party/fwkacllib/inc/ops/array_ops.h b/third_party/fwkacllib/inc/ops/array_ops.h index 0d2a05a3..7c6f9b2c 100644 --- a/third_party/fwkacllib/inc/ops/array_ops.h +++ b/third_party/fwkacllib/inc/ops/array_ops.h @@ -1084,6 +1084,43 @@ REG_OP(TransShape) .ATTR(outShape,ListInt ,{}) .OP_END_FACTORY_REG(TransShape); +/** +*@brief Computes the (possibly normalized) Levenshtein Edit Distance. + +*@par Inputs: +*@li hypothesis_indices: The indices of the hypothesis list SparseTensor.\n +This is an N x R int64 matrix. +*@li hypothesis_shape: The values of the hypothesis list SparseTensor.\n +This is an N-length vector. +*@li hypothesis_shape: The shape of the hypothesis list SparseTensor.\n +This is an R-length vector. +*@li truth_indices: The indices of the truth list SparseTensor.\n +This is an M x R int64 matrix. +*@li truth_shape: The values of the truth list SparseTensor.\n +This is an M-length vector. +*@li truth_shape: The shape of the truth list SparseTensor.\n +This is an R-length vector + +*@par Attributes: +*@li normalize: boolean (if true, edit distances are normalized by length of truth). + +*@par Outputs: +*@li output: A dense float tensor with rank R - 1. + +*@par Third-party framework compatibility +* Compatible with TensorFlow EditDistance operator. +*/ +REG_OP(EditDistance) + .INPUT(hypothesis_indices, TensorType({DT_INT64})) + .INPUT(hypothesis_values, TensorType::BasicType()) + .INPUT(hypothesis_shape, TensorType({DT_INT64})) + .INPUT(truth_indices, TensorType({DT_INT64})) + .INPUT(truth_values, TensorType::BasicType()) + .INPUT(truth_shape, TensorType({DT_INT64})) + .ATTR(normalize, Bool, true) + .OUTPUT(output, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(EditDistance) + } // namespace ge #endif // GE_OP_ARRAY_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/ctc_ops.h b/third_party/fwkacllib/inc/ops/ctc_ops.h index 00485a14..74b797f3 100644 --- a/third_party/fwkacllib/inc/ops/ctc_ops.h +++ b/third_party/fwkacllib/inc/ops/ctc_ops.h @@ -50,7 +50,6 @@ If not specified, defaults to true *@par Third-party framework compatibility * Compatible with TensorFlow CTCLoss operator. */ - REG_OP(CTCLoss) .INPUT(inputs, TensorType({DT_FLOAT, DT_DOUBLE})) .INPUT(labels_indices, TensorType({DT_INT64})) @@ -63,6 +62,77 @@ REG_OP(CTCLoss) .ATTR(ignore_longer_outputs_than_inputs, Bool, false) .OP_END_FACTORY_REG(CTCLoss) +/** +*@brief Performs greedy decoding on the logits given in inputs. + +*@par Inputs: +*@li inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. +*@li sequence_length: A vector containing sequence lengths, size `(batch_size)`. + +*@par Attributes: +*@li merge_repeated: If True, merge repeated classes in output. + +*@par Outputs: +*@li decoded_indices: Indices matrix, size `(total_decoded_outputs x 2)`,\n +of a `SparseTensor`. The rows store: [batch, time]. +*@li decoded_values: Values vector, size: `(total_decoded_outputs)`,\n +of a `SparseTensor`. The vector stores the decoded classes. +*@li decoded_shape: Shape vector, size `(2)`, of the decoded SparseTensor.\n +Values are: `[batch_size, max_decoded_length]`. +*@li log_probability: Matrix, size `(batch_size x 1)`, containing sequence\n +log-probabilities. + +*@par Third-party framework compatibility +* Compatible with TensorFlow CTCGreedyDecoder operator. +*/ +REG_OP(CTCGreedyDecoder) + .INPUT(inputs, TensorType({DT_FLOAT, DT_DOUBLE})) + .INPUT(sequence_length, TensorType({DT_INT32})) + .ATTR(merge_repeated, Bool, false) + .OUTPUT(decoded_indices, TensorType({DT_INT64})) + .OUTPUT(decoded_values, TensorType({DT_INT64})) + .OUTPUT(decoded_shape, TensorType({DT_INT64})) + .OUTPUT(log_probability, TensorType({DT_FLOAT, DT_DOUBLE})) + .OP_END_FACTORY_REG(CTCGreedyDecoder) + +/** +*@brief Performs beam search decoding on the logits given in input. + +*@par Inputs: +*@li inputs: 3-D, shape: `(max_time x batch_size x num_classes)`, the logits. +*@li sequence_length: A vector containing sequence lengths, size `(batch_size)`. + +*@par Attributes: +*@li merge_repeated: If True, merge repeated classes in output. + +*@par Outputs: +*@li decoded_indices: A list (length: top_paths) of indices matrices. Matrix j,\n +size `(total_decoded_outputs[j] x 2)`, has indices of a\n +`SparseTensor`. The rows store: [batch, time]. +*@li decoded_values: A list (length: top_paths) of values vectors. Vector j,\n +size `(length total_decoded_outputs[j])`, has the values of a\n +`SparseTensor`. The vector stores the decoded classes for beam j. +*@li decoded_shape: A list (length: top_paths) of shape vector. Vector j,\n +size `(2)`, stores the shape of the decoded `SparseTensor[j]`.\n +Its values are: `[batch_size, max_decoded_length[j]]`. +*@li log_probability: A matrix, shaped: `(batch_size x top_paths)`. The\n +sequence log-probabilities. + +*@par Third-party framework compatibility +* Compatible with TensorFlow CTCBeamSearchDecoder operator. +*/ +REG_OP(CTCBeamSearchDecoder) + .INPUT(inputs, TensorType({DT_FLOAT, DT_DOUBLE})) + .INPUT(sequence_length, TensorType({DT_INT32})) + .REQUIRED_ATTR(beam_width, Int) + .REQUIRED_ATTR(top_paths, Int) + .ATTR(merge_repeated, Bool, true) + .DYNAMIC_OUTPUT(decoded_indices, TensorType({DT_INT64})) + .DYNAMIC_OUTPUT(decoded_values, TensorType({DT_INT64})) + .DYNAMIC_OUTPUT(decoded_shape, TensorType({DT_INT64})) + .OUTPUT(log_probability, TensorType({DT_FLOAT, DT_DOUBLE})) + .OP_END_FACTORY_REG(CTCBeamSearchDecoder) + } // namespace ge #endif //GE_OP_CTC_OPS_H \ No newline at end of file diff --git a/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h b/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h index 04e1cea3..1022880f 100644 --- a/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/elewise_calculation_ops.h @@ -483,9 +483,9 @@ REG_OP(Equal) *x: A Tensor. Must be one of the following types: float16, float32, double, complex64, complex128. *@par Attributes: -*@li base: An optional attribute of type float32, specifying the base gamma. Defaults to "-1". -*@li scale: An optional attribute of type float32, specifying the scale alpha. Defaults to "1". -*@li shift: An optional attribute of type float32, specifying the shift beta. Defaults to "0". +*@li base: An optional attribute of type float32, specifying the base gamma. Defaults to "-1.0". +*@li scale: An optional attribute of type float32, specifying the scale alpha. Defaults to "1.0". +*@li shift: An optional attribute of type float32, specifying the shift beta. Defaults to "0.0". *@par Outputs: *y: A Tensor of the same type as "x". @@ -1016,17 +1016,17 @@ REG_OP(BesselI1e) * y = log_base(shift + scale * x), with "base" > 0. * @par Inputs: -* @li x: A Tensor of type UnaryDataType. +* @li x: A Tensor of type complex64, complex128, float16, float32 or double. * @par Attributes: -* @li base: An optional float32, specifying the base "e". Defaults to "-1" +* @li base: An optional float32, specifying the base "e". Defaults to "-1.0" * @li scale: An optional float32, specifying the scale of input "x". Defaults -* to "1" -* @li shift: An optional float32, specifying the shift. Defaults to "0" +* to "1.0" +* @li shift: An optional float32, specifying the shift. Defaults to "0.0" * @par Outputs: -* y: A Tensor of type UnaryDataType. +* y: A Tensor has same type as "x". * @attention Constraints: * @li "base" is supposed to be greater than 0. Retaining the default @@ -1100,12 +1100,6 @@ REG_OP(SqrtGrad) .OUTPUT(z, TensorType(UnaryDataType)) .OP_END_FACTORY_REG(SqrtGrad) -REG_OP(Multiply) - .INPUT(x, TensorType({DT_FLOAT,DT_UINT8,DT_INT8,DT_UINT16,DT_INT16,DT_INT32,DT_INT64,DT_DOUBLE,DT_FLOAT16})) - .INPUT(y, TensorType({DT_FLOAT,DT_UINT8,DT_INT8,DT_UINT16,DT_INT16,DT_INT32,DT_INT64,DT_DOUBLE,DT_FLOAT16})) - .OUTPUT(z, TensorType({DT_FLOAT,DT_UINT8,DT_INT8,DT_UINT16,DT_INT16,DT_INT32,DT_INT64,DT_DOUBLE,DT_FLOAT16})) - .OP_END_FACTORY_REG(Multiply) - /** *@brief Returns x + y element-wise. *@par Inputs: @@ -2262,7 +2256,7 @@ REG_OP(ArgMinD) *dtype: The output type, either "int32" or "int64". Defaults to "int64". *@par Outputs: -*y: A multi-dimensional Tensor of type int32, specifying the index with the largest value. The dimension is one less than that of "x". +*y: A multi-dimensional Tensor of type int32 or int64, specifying the index with the largest value. The dimension is one less than that of "x". *@attention Constraints: *@li x: If there are multiple maximum values, the index of the first maximum value is used. @@ -2398,8 +2392,8 @@ REG_OP(ArgMinWithValue) *y: A Tensor. Has the same type and format as "x". *@par Attributes: -*@li N: A required attribute. the number of input x, max size is 32. -*@li model: An optional attribute. Defaults to "1". +*@li N: A required attribute. the number of input x, max size is 32. Type is int. +*@li model: An optional attribute. Type is int. Defaults to "1". * "0": product, "1": sum, "2": max. *@li coeff: A required attribute. Must met all of following rules: * size of "coeff" must be equal to len("x") or is null. @@ -2692,6 +2686,86 @@ REG_OP(AdamApplyOne) .OUTPUT(output2, TensorType({DT_FLOAT16,DT_FLOAT})) .OP_END_FACTORY_REG(AdamApplyOne) +/** +*@brief A fusion operator for bert lamb. + +*@par Inputs: +*Eleven inputs, including: +* @li input0: A Tensor. Must be one of the following types: float16, float32. +* @li input1: A Tensor. Must be one of the following types: float16, float32. +* @li input2: A Tensor. Must be one of the following types: float16, float32. +* @li input3: A Tensor. Must be one of the following types: float16, float32. +* @li input4: A Tensor. Must be one of the following types: float16, float32. +* @li mul0_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul1_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul2_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul3_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul4_x: A Tensor. Must be one of the following types: float16, float32. +* @li add2_y: A Tensor. Must be one of the following types: float16, float32. + +*@par Outputs: +*Three outputs, including: +* @li output0: A Tensor. Must be one of the following types: float16, float32. +* @li output1: A Tensor. Must be one of the following types: float16, float32. +* @li output2: A Tensor. Must be one of the following types: float16, float32. + +*/ +REG_OP(AdamApplyOneWithDecayAssign) + .INPUT(input0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input4, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul0_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul1_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul2_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul3_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul4_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(add2_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output0, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output1, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output2, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(AdamApplyOneWithDecayAssign) + +/** +*@brief A fusion operator for bert lamb. + +*@par Inputs: +*Ten inputs, including: +* @li input0: A Tensor. Must be one of the following types: float16, float32. +* @li input1: A Tensor. Must be one of the following types: float16, float32. +* @li input2: A Tensor. Must be one of the following types: float16, float32. +* @li input3: A Tensor. Must be one of the following types: float16, float32. +* @li input4: A Tensor. Must be one of the following types: float16, float32. +* @li mul0_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul1_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul2_x: A Tensor. Must be one of the following types: float16, float32. +* @li mul3_x: A Tensor. Must be one of the following types: float16, float32. +* @li add2_y: A Tensor. Must be one of the following types: float16, float32. + +*@par Outputs: +*Three outputs, including: +* @li output0: A Tensor. Must be one of the following types: float16, float32. +* @li output1: A Tensor. Must be one of the following types: float16, float32. +* @li output2: A Tensor. Must be one of the following types: float16, float32. + +*/ +REG_OP(AdamApplyOneAssign) + .INPUT(input0, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input1, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input2, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input3, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(input4, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul0_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul1_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul2_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(mul3_x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(add2_y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output0, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output1, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(output2, TensorType({DT_FLOAT16,DT_FLOAT})) + .OP_END_FACTORY_REG(AdamApplyOneAssign) + /** *@brief Confuse select, maximum, greater and sqrt. @@ -2781,22 +2855,19 @@ REG_OP(SquareSumAll) *@brief Confuse broadcast, addn and mul. *@par Inputs: -*Five inputs, including: -* @li x1: A Tensor. Must be one of the following types:int32 float16, float32. +*Three inputs, including: +* @li x1: A Tensor. Must be one of the following types:int32, int16, float16, float32. * @li x2: A Tensor of the same type as "x1". * @li x3: A Tensor of the same type as "x1". *@par Outputs: -*@li y: A Tensor. Has the same type as "x1". - -*@par Third-party framework compatibility: -* Compatible with the TensorFlow operator LRN. +* y: A Tensor. Has the same type as "x1". */ REG_OP(FusedMulAddN) - .INPUT(x1, TensorType::NumberType()) - .INPUT(x2, TensorType::NumberType()) - .INPUT(x3, TensorType::NumberType()) - .OUTPUT(y, TensorType::NumberType()) + .INPUT(x1, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT16})) + .INPUT(x2, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT16})) + .INPUT(x3, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT16})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT16})) .OP_END_FACTORY_REG(FusedMulAddN) /** @@ -3042,6 +3113,22 @@ REG_OP(KLDiv) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .OP_END_FACTORY_REG(KLDiv) +/** +*@brief copy data from x to y.. + +*@par Inputs: +*One inputs, including: +* @li x: A Tensor. Must be one of the following types: float16, float32, int8, uint8, int32, bool. + +*@par Outputs: +*y: A Tensor. Has the same type as "x". + +*@par Third-party framework compatibility +*/ +REG_OP(TensorMove) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8, DT_BOOL})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT8, DT_UINT8, DT_BOOL})) + .OP_END_FACTORY_REG(TensorMove) } // namespace ge diff --git a/third_party/fwkacllib/inc/ops/hvd_ops.h b/third_party/fwkacllib/inc/ops/hvd_ops.h new file mode 100644 index 00000000..09748b8e --- /dev/null +++ b/third_party/fwkacllib/inc/ops/hvd_ops.h @@ -0,0 +1,77 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_OP_HVD_OPS_H_ +#define GE_OP_HVD_OPS_H_ + +#include "graph/operator_reg.h" + +namespace ge { +/** + * @brief Outputs a tensor gathering all input tensors. + * @par Inputs: + * x: A tensor. Must be one of the following types: uint8, int8, uint16, int16, int32, + * int64, float16, bool. + * @par Attributes: + * @li rank_size: A required integer identifying the number of ranks + * participating in the op. + * @par Outputs: + * y: A Tensor. Has the same type as "x". + */ +REG_OP(HorovodAllgather) + // GE not support float64 currently + .INPUT(x, TensorType({DT_UINT8, DT_INT8, DT_UINT16, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_BOOL})) + .OUTPUT(y, TensorType({DT_UINT8, DT_INT8, DT_UINT16, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_BOOL})) + // add rank_size attr + .REQUIRED_ATTR(rank_size, Int) + .OP_END_FACTORY_REG(HorovodAllgather) + +/** + * @brief Outputs a tensor containing the reduction across all input tensors + * passed to op. + * @par Inputs: + * x: A tensor. Must be one of the following types: int32, int64, float16, float32 + * @par Attributes: + * @li reduce_op: A required int identifying the reduction operation to + * perform.The supported operation are: "sum", "max", "min", "prod". + * @par Outputs: + * y: A Tensor. Has the same type as "x". + */ +REG_OP(HorovodAllreduce) + .INPUT(x, TensorType({DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT})) + .OUTPUT(y, TensorType({DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT})) + .REQUIRED_ATTR(reduce_op, Int) + .OP_END_FACTORY_REG(HorovodAllreduce) + +/** + * @brief Broadcasts the input tensor in root rank to all ranks. + * @par Inputs: + * x: A list of dynamic input tensor. Must be one of the following types: + * int8, int32, float16, float32. + * @par Attributes: + * @li root_rank: A required integer identifying the root rank in the op + * input of this rank will be broadcast to other ranks. + * @par Outputs: + * y: A list of dynamic output tensor. Has the same type and length as "x". + */ +REG_OP(HorovodBroadcast) + .INPUT(x, TensorType({DT_UINT8, DT_INT8, DT_UINT16, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_BOOL})) + .OUTPUT(y, TensorType({DT_UINT8, DT_INT8, DT_UINT16, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_BOOL})) + .REQUIRED_ATTR(root_rank, Int) + .OP_END_FACTORY_REG(HorovodBroadcast) + +} // namespace ge +#endif // GE_OP_HVD_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/image_ops.h b/third_party/fwkacllib/inc/ops/image_ops.h index f5ddaf5e..59b99841 100644 --- a/third_party/fwkacllib/inc/ops/image_ops.h +++ b/third_party/fwkacllib/inc/ops/image_ops.h @@ -990,6 +990,40 @@ REG_OP(ResizeBilinearV2D) .REQUIRED_ATTR(size, ListInt) .OP_END_FACTORY_REG(ResizeBilinearV2D) +/** +*@brief Resizes "images" to "size" using bilinear interpolation and keep ration at the time. + +*@par Inputs: +* One input: +*images: An NC1HWC0 Tensor. \n +* Must be one of the following types: float16, float32. + +*@par Attributes: +*@li min_dimension: A required int32 attribute for the min dimension for the images. +* No default value. +*@li max_dimension: A required int32 attribute for the max dimension for the images. +* No default value. +*@li align_corners: An optional bool. If "true", the centers of the corner +* pixels of the input and output tensors are aligned. Defaults to "false". +*@li half_pixel_centers: indicates if the offset coordinates are normalized +* Defaults to "false". + +*@par Outputs: +*y: A Tensor with type float32 and the same format as input "images". + +*@attention Constraints: +* The input "images" must be a tensor of 5 elements: images[2] <= 2048, \n +images[3] <= 2048. +*/ +REG_OP(KeepRationResizeBilinear) + .INPUT(images, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT})) + .REQUIRED_ATTR(min_dimension, Int) + .REQUIRED_ATTR(max_dimension, Int) + .ATTR(align_corners, Bool, false) + .ATTR(half_pixel_centers, Bool, false) + .OP_END_FACTORY_REG(KeepRationResizeBilinear) + /** *@brief Resizes "images" to "size" using nearest neighbor interpolation. diff --git a/third_party/fwkacllib/inc/ops/internal_ops.h b/third_party/fwkacllib/inc/ops/internal_ops.h new file mode 100644 index 00000000..8c261382 --- /dev/null +++ b/third_party/fwkacllib/inc/ops/internal_ops.h @@ -0,0 +1,61 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef GE_OP_INTERNAL_OPS_H_ +#define GE_OP_INTERNAL_OPS_H_ + +#include "graph/operator_reg.h" +#include "graph/operator.h" + +namespace ge { + +/** +*@brief aicpu assit help op for auxiliary matrix generation. + +*@par Inputs: +*The input is dynamic for attribute func_name \n + +*@par Attributes: +*@li func_name:An required param, for example "topkv2". \n + +*@par Outputs: +*The output is dynamic for attribute func_name. +*/ +REG_OP(AssistHelp) + .DYNAMIC_INPUT(x, TensorType({ DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, + DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE })) + .DYNAMIC_OUTPUT(y, TensorType({ DT_FLOAT, DT_FLOAT16, DT_INT8, DT_INT16, DT_UINT16, + DT_UINT8, DT_INT32, DT_INT64, DT_UINT32, DT_UINT64, DT_BOOL, DT_DOUBLE})) + . REQUIRED_ATTR (func_name, String) + . OP_END_FACTORY_REG(AssistHelp) + +/** +*@brief aicpu cache help for lhisi cache flush. + +*@par Inputs: +*The input is dynamic for attribute func_name \n + +*@par Outputs: +*The output is dynamic for attribute func_name. +*/ +REG_OP(CacheUpdate) + .INPUT(x, TensorType::BasicType()) + .OUTPUT(x, TensorType::BasicType()) + .OP_END_FACTORY_REG(CacheUpdate) + +} // namespace ge + +#endif // GE_OP_INTERNAL_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/math_ops.h b/third_party/fwkacllib/inc/ops/math_ops.h index 5d34804c..b0c35c28 100644 --- a/third_party/fwkacllib/inc/ops/math_ops.h +++ b/third_party/fwkacllib/inc/ops/math_ops.h @@ -29,9 +29,9 @@ namespace ge { * x: A Tensor of type float16 or float32. *@par Attributes: -*@li power: Optional. Defaults to 1.0. -*@li scale: Optional. Defaults to 1.0. -*@li shift: Optional. Defaults to 0.0. +*@li power: Optional. Must be one of the following types: float32. Defaults to 1.0. +*@li scale: Optional. Must be one of the following types: float32. Defaults to 1.0. +*@li shift: Optional. Must be one of the following types: float32. Defaults to 0.0. *@par Outputs: * y: A Tensor. Has the same type and shape as "x". diff --git a/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h b/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h index 29cf0df3..7cb24ee7 100644 --- a/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/matrix_calculation_ops.h @@ -167,50 +167,6 @@ REG_OP(BatchMatMul) .ATTR(adj_x2, Bool, false) .OP_END_FACTORY_REG(BatchMatMul) -REG_OP(MeanCCE) - .INPUT(x, TensorType::ALL()) - .INPUT(indices, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) - .ATTR(keep_dims, Bool, false) - .ATTR(value1, ListInt, {}) - .ATTR(mode, Int, 3) // 0:max pooling or 1:avg pooling - .ATTR(pad_mode, Int, 0) - .ATTR(global_pooling, Bool, true) // tensorflow have no attr, set default value - .ATTR(window, ListInt, {1,1}) // kernel size - .ATTR(pad, ListInt, {0,0,0,0}) // pad size - .ATTR(stride, ListInt, {1,1}) // stride size - .ATTR(ceil_mode, Int, 0) - .ATTR(data_mode, Int, 1) - .ATTR(nan_opt, Int, 0) - .ATTR(fomart, Int, 0) - .OP_END_FACTORY_REG(MeanCCE) - -REG_OP(MeanGrad) - .INPUT(x, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) - .ATTR(mode, Int, 1) // 0:max pooling or 1:avg pooling - .ATTR(pad_mode, Int, 0) - .ATTR(global_pooling, Bool, false) - .ATTR(window, ListInt, {1,1}) // kernel size - .ATTR(pad, ListInt, {0,0,0,0}) // pad size - .ATTR(stride, ListInt, {1,1}) // stride size - .ATTR(ceil_mode, Int, 0) - .ATTR(data_mode, Int, 1) - .ATTR(nan_opt, Int, 0) - .ATTR(mean_grad_output_shape_value, ListInt, {1,1,1,1}) - .ATTR(mean_grad_output_shape_format, Int, 1) //must be NHWC - .OP_END_FACTORY_REG(MeanGrad) - -REG_OP(MatMulCCE) - .INPUT(x1, TensorType({DT_FLOAT})) - .INPUT(x2, TensorType({DT_FLOAT})) - .OPTIONAL_INPUT(x3, TensorType({DT_FLOAT})) - .OUTPUT(y, TensorType({DT_FLOAT})) - .ATTR(transpose_a, Bool, false) - .ATTR(transpose_b, Bool, false) - .ATTR(has_bias, Bool, false) - .OP_END_FACTORY_REG(MatMulCCE) - /** *@brief Computes half the L2 norm of a tensor without the sqrt. @@ -673,8 +629,9 @@ REG_OP(DiagPart) *@par Attributes: *@li num_output: Reserved. -*@li transpose: A bool, specifying whether to transpose, either "true" or "false". Defaults to "false". -*@li axis: Optional. A int. 1 or 2. +*@li transpose: A bool, specifying weight whether to transpose, either "true" or "false". Defaults to "false". +*@li axis: Optional. A int, 1 or 2, specifying which dimension the input "K" starts from. Defaults to 1. + * The product of the subsequent dimensions starting form first dimension or the second dimension is "K". *@li offset_x: Reserved. *@par Outputs: @@ -698,6 +655,45 @@ REG_OP(FullyConnection) .ATTR(offset_x, Int, 0) .OP_END_FACTORY_REG(FullyConnection) +/** +*@brief Also known as a "fully-connected-compress" layer, computes an inner product with a set of learned weights, and (optionally) adds biases. + +*@par Inputs: +* Four inputs, including: +*@li x: A Tensor of type uint8, int8. +*@li w: A weight matrix of type int8, int8. +*@li w: A compress index matrix of type int8, int8. +*@li b: A Tensor of type float16, int32, int32. +*@li offset_w: A Tensor of type int8.i + +*@par Attributes: +*@li num_output: Reserved. +*@li transpose: A bool, specifying whether to transpose, either "true" or "false". Defaults to "false". +*@li axis: Reserved. +*@li offset_x: Reserved. + +*@par Outputs: +*y: The result tensor of type int32. + +*@par Third-party framework compatibility +* Compatible with the Caffe operator InnerProduct. + +*@par Quantization supported or not +* Yes +*/ +REG_OP(FullyConnectionCompress) + .INPUT(x, TensorType({DT_UINT8, DT_INT8})) + .INPUT(w, TensorType({DT_INT8})) + .INPUT(comress_index, TensorType({DT_INT8})) + .OPTIONAL_INPUT(b, TensorType({DT_INT32})) + .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) + .OUTPUT(y, TensorType({DT_INT32})) + .REQUIRED_ATTR(num_output, Int) + .ATTR(transpose, Bool, false) + .ATTR(axis, Int, 1) + .ATTR(offset_x, Int, 0) + .OP_END_FACTORY_REG(FullyConnectionCompress) + /** *@brief Computes the confusion matrix from predictions and labels. diff --git a/third_party/fwkacllib/inc/ops/nn_batch_norm_ops.h b/third_party/fwkacllib/inc/ops/nn_batch_norm_ops.h index e8eb4769..296dd63c 100644 --- a/third_party/fwkacllib/inc/ops/nn_batch_norm_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_batch_norm_ops.h @@ -21,95 +21,6 @@ namespace ge { -/** -*@brief A fusion operator for batchnorm. - -*@par Inputs: -*Ten inputs, including: -* @li x: A Tensor. Must be one of the following types: float32. -* @li scale: A Tensor. Must be one of the following types: float32. -* @li b: A Tensor. Must be one of the following types: float32. -* @li mean: A Tensor. Must be one of the following types: float32. -* @li variance: A Tensor. Must be one of the following types: float32. - -*@par Attributes: -* @li mode: A Tensor. Must be one of the following types: int. -* @li epsilon: A Tensor. Must be one of the following types: float32. -* @li momentum: A Tensor. Must be one of the following types: float32. -* @li is_training: A Tensor. Must be one of the following types: bool. -* @li is_training_fusion: A Tensor. Must be one of the following types: bool. -* @li moving_average_fraction: A Tensor. Must be one of the following types: float32. - -*@par Outputs: -*Three outputs, including: -* @li y: A Tensor. Must be one of the following types: float32. -* @li running_mean: A Tensor. Must be one of the following types: float32. -* @li running_variance: A Tensor. Must be one of the following types: float32. -* @li save_mean: A Tensor. Must be one of the following types: float32. -* @li save_inv_variance: A Tensor. Must be one of the following types: float32. -* @li save_inv_variance1: A Tensor. Must be one of the following types: float32. - -*/ -REG_OP(FusedBatchNorm) - .INPUT(x, TensorType{DT_FLOAT}) - .INPUT(scale, TensorType{DT_FLOAT}) - .INPUT(b, TensorType{DT_FLOAT}) - .INPUT(mean, TensorType{DT_FLOAT}) - .INPUT(variance, TensorType{DT_FLOAT}) - .OUTPUT(y, TensorType{DT_FLOAT}) - .OUTPUT(running_mean, TensorType{DT_FLOAT}) - .OUTPUT(running_variance, TensorType{DT_FLOAT}) - .OUTPUT(save_mean, TensorType{DT_FLOAT}) - .OUTPUT(save_inv_variance, TensorType{DT_FLOAT}) - .OUTPUT(save_inv_variance1, TensorType{DT_FLOAT}) - .ATTR(mode, Int, 1) - .ATTR(epsilon, Float, 1e-5f) - .ATTR(momentum, Float, 0.9) - .ATTR(is_training, Bool, true) - .ATTR(is_training_fusion, Bool, true) - .ATTR(moving_average_fraction, Float, 0.00300002098) - .OP_END_FACTORY_REG(FusedBatchNorm) - -/** -*@brief A fusion operator for batchnorm. - -*@par Inputs: -*Ten inputs, including: -* @li dy: A Tensor. Must be one of the following types: float32. -* @li x: A Tensor. Must be one of the following types: float32. -* @li scale: A Tensor. Must be one of the following types: float32. -* @li save_mean: A Tensor. Must be one of the following types: float32. -* @li save_inv_variance: A Tensor. Must be one of the following types: float32. -* @li save_inv_variance1: A Tensor. Must be one of the following types: float32. - -*@par Attributes: -* @li epsilon: A Tensor. Must be one of the following types: float32. -* @li momentum: A Tensor. Must be one of the following types: float32. - -*@par Outputs: -*Three outputs, including: -* @li dx: A Tensor. Must be one of the following types: float32. -* @li bn_scale: A Tensor. Must be one of the following types: float32. -* @li bn_bias: A Tensor. Must be one of the following types: float32. - -*@par Third-party framework compatibility -* Compatible with the L2 scenario of PyTorch operator Normalize. -*/ - -REG_OP(FusedBatchNormGrad) - .INPUT(dy, TensorType{DT_FLOAT}) - .INPUT(x, TensorType{DT_FLOAT}) - .INPUT(scale, TensorType{DT_FLOAT}) - .INPUT(save_mean, TensorType{DT_FLOAT}) - .INPUT(save_inv_variance, TensorType{DT_FLOAT}) - .INPUT(save_inv_variance1, TensorType{DT_FLOAT}) - .OUTPUT(dx, TensorType{DT_FLOAT}) - .OUTPUT(bn_scale, TensorType{DT_FLOAT}) - .OUTPUT(bn_bias, TensorType{DT_FLOAT}) - .ATTR(epsilon, Float, 0.0) - .ATTR(momentum, Float, 0.0) - .OP_END_FACTORY_REG(FusedBatchNormGrad) - /** *@brief Normalizes elements of a specific dimension of eigenvalues (L2). @@ -361,14 +272,14 @@ REG_OP(BatchNormGradExt2) *@par Inputs: *@li x: A 4D or 5D Tensor of type float16 or float32, with format NHWC or NCHW for 4D or NC1HWC0 for 5D. *@li mean: A Tensor of type float32 or float16. Must be 1D if input "x" Specifies the mean used for inference. -*@li variance: A Tensor of type float32 or float16. Must be 1D if input "x" Specifies the variance used for inference. -*@li momentum: A Tensor of type float32 or float16, represents the mean and the variance's scale factor +*@li variance: A Tensor of type float32 or float16 . Must be 1D if input "x" Specifies the variance used for inference. +*@li momentum: A Tensor,represents the mean and the variance's scale factor *@li scale: An optional tensor of type float16 or float32, no use *@li offset: An optional tensor of type float16 or float32, no use *@par Attributes: *@li epsilon: An optional float32, specifying the small value added to variance to avoid dividing by zero. Defaults to "0.00001". *@li use_global_stats: mean inference mode , only can be "True". -*@li mode: An optional attr, not use +*@li mode: An optional input, not use *@par Outputs:\n *@li y: A 4D or 5D Tensor of type float16 or float32 for the normalized "x" */ @@ -391,11 +302,11 @@ REG_OP(BNInference) *@li mean: A Tensor of type float32 or float16. Must be 1D if input "x" Specifies the mean used for inference. *@li variance: A Tensor of type float32 or float16 . Must be 1D if input "x" Specifies the variance used for inference. -*@li momentum: A Tensor of type float32 or float16, the mean and the variance's Scale factor +*@li momentum: An optional float, mean and variance's Scale factor *@par Attributes: *@li epsilon: An optional float32, specifying the small value added to variance to avoid dividing by zero. Defaults to "0.00001". *@li use_global_stats: mean inference mode , only can be "True". -*@li mode: An optional inpout, not use +*@li mode: An optional attr, not use *@par Outputs: *@li alpha: A Tensor of type float16 or float32 for the cpu calculate mean *@li beta: A Tensor of type float16 or float32 for the cpu calculate variance @@ -418,8 +329,8 @@ REG_OP(BnHost) *@par Inputs: *@li x: A 4D or 5D Tensor of type float16 or float32, with format NHWC or NCHW for 4D or NC1HWC0 for 5D. -*@li mean: A Tensor of type float32 or float16. Must be 1D if input "x" Specifies the mean used for inference. -*@li variance: A Tensor of type float32 or float16 . Must be 1D if input "x" Specifies the variance used for inference. +*@li mean: A Tensor of type float32 or float16. Must be 1D if input "x" Specifies the mean used for inference. +*@li variance: A Tensor of type float32 or float16 . Must be 1D if input "x" Specifies the variance used for inference. *@li scale: An optional tensor of type float16 or float32, no use *@li offset: An optional tensor of type float16 or float32, no use *@par Attributes: diff --git a/third_party/fwkacllib/inc/ops/nn_calculation_ops.h b/third_party/fwkacllib/inc/ops/nn_calculation_ops.h index 3529e9ca..e9180332 100644 --- a/third_party/fwkacllib/inc/ops/nn_calculation_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_calculation_ops.h @@ -143,31 +143,29 @@ REG_OP(DepthwiseConv2DBackpropFilterD) * @par Inputs: * Three inputs include: \n * @li input_size: 4D shape of input tensor [N, C, H, W] or [N, H, W, C], -* support int32 -* @li filter: 4D filter tensor with shape of [H, W, C, K], support float16, -* float32, double +* support int32, int64 +* @li filter: 4D filter tensor with shape of [H, W, C, K], support float16. * @li out_backprop: 4D tensor with shape [N, C, H, W] or [N, H, W, C]. -* Must be one of the following types: float16, float32, double. +* Must be one of the following types: float16. * @par Attributes: -* @li strides: A required list or tuple. The stride of the sliding window for +* @li strides: A required list or tuple of int32. The stride of the sliding window for * height and width of input "x" of the convolution. * Must be with shape [1, 1, stride_height, stride_width] or [1, stride_height, * stride_width, 1]. -* @li dilations: An optional list or tuple. The dilation factor for each -* dimension of input "x". +* @li dilations: An optional list or tuple of int32. The dilation factor for each +* dimension of input "x". Defaults to "[1, 1, 1, 1]". * If set to k > 1, there will be k-1 skipped cells between each filter element * on that dimension. Must be with shape [1, 1, dilation_height, dilation_width] * or [1, dilation_height, dilation_width, 1]. -* @li pads: A required list or tuple. Padding added to each dimension of the +* @li pads: A required list or tuple of int32. Padding added to each dimension of the * input. * @li data_format: An optional string. Input data format, either "NHWC" or -* "NCHW". +* "NCHW". Defaults to "NHWC". * @par Outputs: * input_grad: Gradient of the deep convolution relative to the input with shape -* [N, C, H, W] or [N, H, W, C] Must be one of the following types: float16, -* float32, double. +* [N, C, H, W] or [N, H, W, C] Must be one of the following types: float16. * @attention Constraints:\n * The feature map is 4D with shape [N, C, Hi, Wi] or [N, Hi, Wi, C], but @@ -259,8 +257,8 @@ REG_OP(DepthwiseConv2DBackpropInputD) *@par Inputs: *Two required inputs and two optional inputs, including: \n -* @li x: A 4D tensor of type float16, with shape [N, C, H, W] or [N, H, W, C] -* @li filter: A 4D tensor of type float16, with shape [H, W, C, K] +* @li x: A 4D tensor of type float16 or int8, with shape [N, C, H, W] or [N, H, W, C] +* @li filter: A 4D tensor of type float16 or int8, with shape [H, W, C, K] * @li bias: An optional tensor of type float16 or int32 * @li offset_w: An optional float16 or int8, used for quantized inference @@ -273,8 +271,8 @@ REG_OP(DepthwiseConv2DBackpropInputD) * dimension of input "x". * If set to k > 1, there will be k-1 skipped cells between each filter element * on that dimension. Must be with shape [1, 1, dilation_height, dilation_width] -* or [1, dilation_height, dilation_width, 1]. -* @li pads: A required list or tuple. Padding added to each dimension of the +* or [1, dilation_height, dilation_width, 1]. Defaults to "[1, 1, 1, 1]". +* @li pads: A required list or tuple of int32. Padding added to each dimension of the * input. * @li data_format: An optional string. Input data format, either "NHWC" or * "NCHW". Defaults to "NHWC". @@ -282,7 +280,7 @@ REG_OP(DepthwiseConv2DBackpropInputD) * Defaults to 0. * @par Outputs: -* y: 4D tensor of type float16, with shape [N, C, H, W] or [N, H, W, C] +* y: 4D tensor of type float16 or int32, with shape [N, C, H, W] or [N, H, W, C] * @attention Constraints:\n * The feature map is 4D with shape [N, C, Hi, Wi] or [N, Hi, Wi, C], but @@ -314,53 +312,6 @@ REG_OP(DepthwiseConv2D) .ATTR(offset_x, Int, 0) .OP_END_FACTORY_REG(DepthwiseConv2D) -REG_OP(Conv2DCCE) - .INPUT(x, TensorType{DT_FLOAT}) // The input tensor - .INPUT(w, TensorType({DT_FLOAT, DT_INT8})) // The weight tensor ,If QuantType =1 ,shall use type""tensor(int8) - .OPTIONAL_INPUT(b, TensorType{DT_FLOAT}) // Optional 1D bias to be added to the convolution, has size of M. - .OUTPUT(y, TensorType{DT_FLOAT}) // The output tensor - .ATTR(mode, Int, 1) - .ATTR(group, Int, 1) // number of groups input channels and output channels are divided into - .ATTR(num_output, Int, 0) // number of output tensor - .ATTR(pad, ListInt, {0, 0, 0, 0}) // Padding for the beginning and ending along each axis - .ATTR(kernel, ListInt, {0, 0}) - .ATTR(stride, ListInt, {1, 1}) // Stride along each axis. - .ATTR(dilation, ListInt, {1, 1}) // dilation value along each axis of the filter. - .ATTR(pad_mode, Int, 0) // pad mode, 0:NOTSET, 1:SAME_UPPER, SAME_LOWER or 2:VALID.defaul default value is 0:NOTSET - .ATTR(algo, Int, 2) - .OP_END_FACTORY_REG(Conv2DCCE) - -REG_OP(Conv2DBackpropFilterCCE) - .INPUT(x, TensorType{DT_FLOAT}) - .INPUT(filter_sizes, TensorType{DT_INT8}) - .INPUT(out_backprop, TensorType{DT_FLOAT}) - .OUTPUT(y, TensorType{DT_FLOAT}) - .ATTR(conv_grad_filter_output_shape, ListInt, {0, 0, 0, 0}) - .ATTR(mode, Int, 1) - .ATTR(group, Int, 1) - .ATTR(pad, ListInt, {0, 0, 0, 0}) - .ATTR(stride, ListInt, {1, 1}) - .ATTR(dilation, ListInt, {1, 1}) - .ATTR(padding, Int, 0) //pad_mode:same valid - .ATTR(algo, Int, 0) - .OP_END_FACTORY_REG(Conv2DBackpropFilterCCE) - -REG_OP(Conv2DBackpropInputCCE) - .INPUT(input_sizes, TensorType{DT_INT8}) - .INPUT(filter, TensorType{DT_FLOAT}) - .INPUT(out_backprop, TensorType{DT_FLOAT}) - .OUTPUT(output, TensorType{DT_FLOAT}) - .ATTR(conv_grad_input_output_shape, ListInt, {0, 0, 0, 0}) - .ATTR(mode, Int, 1) - .ATTR(format, Int, 0) - .ATTR(group, Int, 1) - .ATTR(pad_mode, Int, 0) - .ATTR(stride, ListInt, {1, 1}) - .ATTR(dilation, ListInt, {1, 1}) - .ATTR(pad, ListInt, {0, 0, 0, 0}) - .ATTR(algo, Int, 0) - .OP_END_FACTORY_REG(Conv2DBackpropInputCCE) - /** *@brief Performs the the backward operation for "BiasAdd" on the "bias" tensor. * It accumulates all the values from out_backprop into the feature @@ -462,24 +413,24 @@ REG_OP(Conv2DBackpropInputD) * @li x: A Tensor. Must have the same type as "filter". 4D with shape * [batch, out_channels, out_height, out_width]. Gradients with respect * to the output of the convolution. - * @li filter: A Tensor of type float16. + * @li filter: A Tensor of type float16, float32, double or int8. * 4D with shape [out_channels, in_channel, filter_height, filter_width].\n * Two optional inputs: - * @li bias: An optional tensor of type float16 - * @li offset_w: An optional 1D tensor for quantized deconvolution. Reserved.\n + * @li bias: An optional tensor of type float16, float32, int32 or int64. + * @li offset_w: An optional 1D tensor for quantized deconvolution. Type is int8. Reserved.\n *@par Attributes: * Six attributes: * @li strides: A tuple or list of 2 integers. The stride of the sliding window - * for H/W dimension. + * for H/W dimension. Defaults to [1, 1, 1, 1]. * @li pads: A tuple or list of 4 integers. The [top, bottom, left, right] - * padding on the feature map + * padding on the feature map. Defaults to [0, 0, 0, 0]. * @li dilations: A tuple or list of 4 integers. The dilation factor for each * dimension of input. Must be [1, 1, 1, 1]. * @li groups: Number of blocked connections from input channels to - * output channels. - * @li data_format: An optional string from: "NCHW". Defaults to "NCHW".\n + output channels. Defaults to "1". + * @li data_format: An optional string from: "NCHW". Defaults to "NCHW". \n Specify the data format of the input and output data. - * @li offset_x: An optional integer for quantized deconvolution. + * @li offset_x: An optional integer for quantized deconvolution. Defaults to "0". *@par Outputs: * y: A Tensor. Has the same type as "filter". 4D tensor with shape * [batch, channels, height, width]. @@ -577,17 +528,17 @@ REG_OP(Conv2DBackpropFilterD) * * The input and output tensor attributes are listed as follows: * @verbatim - Tensor | x | filter | bias | offset_w | y + |Tensor | x | filter | bias | offset_w | y -----------|---------|---------|---------|----------|-------- - Data Type | float16 | float16 | float16 | _ | float16 - |---------|---------|---------|----------|-------- - | float32 | float32 | float32 | _ | float32 - |---------|---------|---------|----------|-------- - | int8 | int8 | int32 | int8 | int32 + |Data Type | float16 | float16 | float16 | _ | float16 + | |---------|---------|---------|----------|-------- + | | float32 | float32 | float32 | _ | float32 + | |---------|---------|---------|----------|-------- + | | int8 | int8 | int32 | int8 | int32 -----------|---------|---------|---------|----------|-------- - Format | NCHW | NCHW | ND | ND | NCHW - | NHWC | NHWC | | | NHWC - | | HWCN | | | + |Format | NCHW | NCHW | ND | ND | NCHW + | | NHWC | NHWC | | | NHWC + | | | HWCN | | | @endverbatim * It should be noted that the data types must correspond to each other, but the * format does not need to. @@ -602,10 +553,10 @@ REG_OP(Conv2DBackpropFilterD) * for dilated convolution. Has the same dimension order and value as "strides". * @li groups: Number of blocked connections from input channels to output * channels. Input channels and output channels must both be divisible by -* "groups". -* @li offset_x: An optional integer for quantized convolution. +* "groups".Type is int32. +* @li offset_x: An optional integer for quantized convolution. Type is int32. Defaults to "0". * @li data_format: An optional string from: "NHWC", "NCHW". Specifying the -* data format of the input and output images. Reserved. +* data format of the input and output images. Type is string. Defaults to "NHWC". Reserved. *@par Outputs: * @li y: A 4D Tensor of output images. @@ -613,23 +564,23 @@ REG_OP(Conv2DBackpropFilterD) *@attention * @li The parameter scope is listed as follows: * @verbatim - Name | Field | Scope + |Name | Field | Scope ------------------|--------------|---------- - Input Image Size | H dimension | [1, 4096] - | W dimension | [1, 4096] + |Input Image Size | H dimension | [1, 4096] + | | W dimension | [1, 4096] ------------------|--------------|---------- - Filter Size | H dimension | [1, 255] - | W dimension | [1, 255] + |Filter Size | H dimension | [1, 255] + | | W dimension | [1, 255] ------------------|--------------|---------- - Stride Size | H dimension | [1, 63] - | W dimension | [1, 63] + |Stride Size | H dimension | [1, 63] + | | W dimension | [1, 63] ------------------|--------------|---------- - Padding Size | top side | [0, 255] - | bottom side | [0, 255] - | left side | [0, 255] - | right side | [0, 255] + |Padding Size | top side | [0, 255] + | | bottom side | [0, 255] + | | left side | [0, 255] + | | right side | [0, 255] ------------------|--------------|---------- - Dilation Size | H dimension | [1, 255] + |Dilation Size | H dimension | [1, 255] | W dimension | [1, 255] @endverbatim @@ -684,36 +635,46 @@ REG_OP(Conv2DCompress) /** *@brief Computes a 3D convolution given 5D "x" and "filter" tensors. -*@par Inputs: -*@li x: A 5D tensor. Must be one of the following types: float16, float32, float64. The format is NCDHW or NDHWC. -*@li filter: A 5D tensor of the same type as "x". The format is NCDHW, NDHWC or DHWCN. -*@li bias: An optional 1D tensor of the same type as "x". + *@par Inputs: + * @li x: A 5D tensor. Must be one of the following types: float16, float32, float64. The format is NCDHW or NDHWC. + * @li filter: A 5D tensor of the same type as "x". The format is NCDHW, NDHWC or DHWCN. + +*@par Optional input: + * @li bias: An optional 1D tensor of the same type as "x". + * @li offset_w: An optional 1D tensor for quantized deconvolution. Reserved. + +*@par Required Attributes: +* @li strides: A list of 5 ints. Specifies the stride of the sliding window for each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". +* @li pads: A list of 6 ints. Supports only padding along the D, H and W dimensions in sequence of head, tail, top, bottom, left and right. *@par Attributes: -*@li strides: A list of 5 ints. Specifies the stride of the sliding window for each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". -*@li pads: A list of 6 ints. Supports only padding along the D, H and W dimensions in sequence of head, tail, top, bottom, left and right. -*@li data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". Specify the data format of the input and output data. -*@li dilations: A list of 5 ints. Specifies the dilation factor for each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". + * @li groups: Number of blocked connections from input channels to output channels. + * @li data_format: An optional string from: "NDHWC", "NCDHW". Defaults to "NDHWC". Specify the data format of the input and output data. + * @li dilations: A list of 5 ints. Specifies the dilation factor for each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". + * @li offset_x: An optional int. Input offset, used for quantized inference. Defaults to 0. *@par Outputs: -*y: A Tensor. Has the same type as "x". + *y: A Tensor. Has the same type as "x". -*@attention Constraints:\n -*The image size after padding is greater than the filter size.\n +*@attention Constraints: + *The image size after padding is greater than the filter size. *@par Third-party framework compatibility -*@li Compatible with the TensorFlow operator conv3d. -*@li Compatible with the Caffe operator Convolution. + * @li Compatible with the TensorFlow operator conv3d. + * @li Compatible with the Caffe operator Convolution. */ REG_OP(Conv3D) .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) - .ATTR(strides, ListInt, {1, 1, 1, 1, 1}) - .ATTR(pads, ListInt, {0, 0, 0, 0, 0, 0}) - .ATTR(data_format, String, "NDHWC") + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) + .ATTR(groups, Int, 1) + .ATTR(data_format, String, "NDHWC") + .ATTR(offset_x, Int, 0) .OP_END_FACTORY_REG(Conv3D) /** @@ -723,28 +684,35 @@ REG_OP(Conv3D) * @li input_size: A Tensor of type int32, int64. An integer vector representing the shape of input, * where input is a 5-D tensor [batch, depth, height, width, channels] or [batch, channels, depth, height, width]. * @li filter: A Tensor. Must be one of the following types: float16, float32, float64. - * @li grads: A Tensor. Must have the same type as filter. 5-D with shape [batch, depth, out_height, out_width, out_channels] + * @li out_backprop: A Tensor. Must have the same type as filter. 5-D with shape [batch, depth, out_height, out_width, out_channels] * or [batch, out_channels, depth, out_height, out_width]. Gradients with respect to the output of the convolution. + +*@par Required Attributes: + * @li strides: A list of 5 ints. Specifies the stride of the sliding window for each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". + * @li pads: A list of 6 ints. Supports only padding along the D, H and W dimensions in sequence of head, tail, top, bottom, left and right. + *@par Attributes: - * Four attributes: - * @li strides: A tuple/list of 3 integers. The stride of the sliding window for D/H/W dimension. - * @li pads: A tuple/list of 6 integers - * @li dilations: A tuple/list of 6 integers, The dilation factor for each dimension of input, now only support [1,1,1,1,1] + * Three attributes: + * @li groups: Number of blocked connections from input channels to output channels. * @li data_format: An optional string from: "NDHWC", "NCHWD". Defaults to "NDHWC". Specify the data format of the input and output data. + * @li dilations: A tuple/list of 6 integers, The dilation factor for each dimension of input, now only support [1,1,1,1,1] + *@par Outputs: * y: A Tensor. Has the same type as filter,and has same format as input_size + *@par Third-party framework compatibility * Compatible with Tensorflow's conv3d_backprop_input */ REG_OP(Conv3DBackpropInput) .INPUT(input_size, TensorType({DT_INT32, DT_INT64})) .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) - .INPUT(grads, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(out_backprop, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) .REQUIRED_ATTR(strides, ListInt) - .ATTR(pads, ListInt, {0, 0, 0, 0, 0, 0}) - .ATTR(data_format, String, "NDHWC") + .REQUIRED_ATTR(pads, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) + .ATTR(groups, Int, 1) + .ATTR(data_format, String, "NDHWC") .OP_END_FACTORY_REG(Conv3DBackpropInput) /** @@ -752,46 +720,56 @@ REG_OP(Conv3DBackpropInput) *@par Inputs: * Two inputs: * @li filter: A Tensor. Types is float16. - * @li grads: A Tensor. Must have the same type as filter. + * @li out_backprop: A Tensor. Must have the same type as filter. + +*@par Required Attributes: + *@li strides: A list of 5 ints. Specifies the stride of the sliding window for + each dimension of "x". The N and C dimensions must be 1. Has the same format as "x". + *@li pads: A list of 6 ints. Supports only padding along the D, H and W + dimensions in sequence of head, tail, top, bottom, left and right. + *@li input_size: A Tensor of type int32, int64. An integer vector representing the shape of input, + * where input is a 5-D tensor [batch, depth, height, width, channels] or [batch, channels, depth, height, width]. + *@par Attributes: - * Five attributes: - * @li input_size A Tensor of type int32. An integer vector representing the shape of input, - * @li strides: A tuple/list of 3 integers. The stride of the sliding window for D/H/W dimension. - * @li pads: A tuple/list of 4 integers - * @li dilations: A tuple/list of 5 integers, The dilation factor for each dimension of input, now only support [1,1,1,1,1] + * Three attributes: + * @li groups: Number of blocked connections from input channels to output channels. * @li data_format: An optional string from: "NDHWC", "NCHWD". Defaults to "NDHWC". Specify the data format of the input and output data. + * @li dilations: A tuple/list of 5 integers, The dilation factor for each dimension of input, now only support [1,1,1,1,1] *@par Outputs: * y: A Tensor. Has the same type as filter *@par Third-party framework compatibility * Compatible with Tensorflow's conv3d_backprop_input */ + + REG_OP(Conv3DBackpropInputD) .INPUT(filter, TensorType({DT_FLOAT16})) - .INPUT(grads, TensorType({DT_FLOAT16})) + .INPUT(out_backprop, TensorType({DT_FLOAT16})) .OUTPUT(y, TensorType({DT_FLOAT16})) .REQUIRED_ATTR(input_size, ListInt) .REQUIRED_ATTR(strides, ListInt) - .ATTR(pads, ListInt, {0, 0, 0, 0, 0, 0}) - .ATTR(data_format, String, "NDHWC") + .REQUIRED_ATTR(pads, ListInt) .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) + .ATTR(groups, Int, 1) + .ATTR(data_format, String, "NDHWC") .OP_END_FACTORY_REG(Conv3DBackpropInputD) REG_OP(LSTM) - .INPUT(x, TensorType({DT_FLOAT16})) - .INPUT(cont, TensorType({DT_FLOAT32,DT_FLOAT16})) - .INPUT(w_x, TensorType({DT_FLOAT16})) - .INPUT(bias, TensorType({DT_FLOAT16,DT_FLOAT32,DT_INT16,DT_INT32})) - .INPUT(w_h, TensorType({DT_FLOAT16})) - .OPTIONAL_INPUT(x_static, TensorType({DT_FLOAT16})) - .OPTIONAL_INPUT(h_0, TensorType({DT_FLOAT16,DT_FLOAT32})) - .OPTIONAL_INPUT(c_0, TensorType({DT_FLOAT16,DT_FLOAT32})) - .OPTIONAL_INPUT(w_x_static, TensorType({DT_FLOAT16})) - .OUTPUT(h, TensorType({DT_FLOAT16, DT_FLOAT})) - .OUTPUT(h_t, TensorType({DT_FLOAT16, DT_FLOAT})) - .OUTPUT(c_t, TensorType({DT_FLOAT16, DT_FLOAT})) - .ATTR(num_output, Int, 0) - .ATTR(expose_hidden, Bool, false) - .OP_END_FACTORY_REG(LSTM) + .INPUT(x, TensorType({DT_FLOAT16})) + .INPUT(cont, TensorType({DT_FLOAT32,DT_FLOAT16})) + .INPUT(w_x, TensorType({DT_FLOAT16})) + .INPUT(bias, TensorType({DT_FLOAT16,DT_FLOAT32,DT_INT16,DT_INT32})) + .INPUT(w_h, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(x_static, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(h_0, TensorType({DT_FLOAT16,DT_FLOAT32})) + .OPTIONAL_INPUT(c_0, TensorType({DT_FLOAT16,DT_FLOAT32})) + .OPTIONAL_INPUT(w_x_static, TensorType({DT_FLOAT16})) + .OUTPUT(h, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(h_t, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(c_t, TensorType({DT_FLOAT16, DT_FLOAT})) + .ATTR(num_output, Int, 0) + .ATTR(expose_hidden, Bool, false) + .OP_END_FACTORY_REG(LSTM) /** *@brief Computes the gradients of convolution3D with respect to the filter @@ -851,6 +829,8 @@ REG_OP(Conv3DBackpropFilter) *@par Third-party framework compatibility * Compatible with Tensorflow's conv3d_backprop_filter */ + + REG_OP(Conv3DBackpropFilterD) .INPUT(x, TensorType({DT_FLOAT16})) .INPUT(out_backprop, TensorType({DT_FLOAT16})) @@ -862,5 +842,86 @@ REG_OP(Conv3DBackpropFilterD) .ATTR(groups, Int, 1) .ATTR(data_format, String, "NDHWC") .OP_END_FACTORY_REG(Conv3DBackpropFilterD) + +/** +*@brief Computes the transpose of convolution 3d with respect to the input. +*@par Inputs: + * Five inputs: + * @li input_size: A Tensor of type int32. An integer vector representing the shape of input + * @li x: A Tensor. + * @li filter: A Tensor. Types is float16. + * @li bias: An optional 1D tensor of the same type as "x". + * @li offset_w: An optional 1D tensor for quantized deconvolution. Reserved. + +*@par Required Attributes: + * @li strides: A tuple/list of 3 integers. The stride of the sliding window for D/H/W dimension. + * @li pads: A tuple/list of 6 integers +*@par Attributes: + * Five attributes: + * @li groups: Number of blocked connections from input channels to output channels. + * @li dilations: A tuple/list of 5 integers, The dilation factor for each dimension of input, now only support [1,1,1,1,1] + * @li data_format: An optional string from: "NDHWC", "NCHWD". Defaults to "NDHWC". Specify the data format of the input and output data. + * @li output_padding: The size will be added in the output shape. + * @li offset_x: Input offset_x value +*@par Outputs: + * y: A Tensor. Has the same type as filter +*/ +REG_OP(Conv3DTranspose) + .INPUT(input_size, TensorType({DT_INT32, DT_INT64})) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(filter, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) + .ATTR(groups, Int, 1) + .ATTR(data_format, String, "NDHWC") + .ATTR(output_padding, ListInt, {0, 0, 0, 0, 0}) + .ATTR(offset_x, Int, 0) + .OP_END_FACTORY_REG(Conv3DTranspose) + +/** +*@brief Computes the transpose of convolution 3d with respect to the input. +*@par Inputs: + * Four inputs: + * @li x: A Tensor. + * @li filter: A Tensor. Types is float16. + * @li bias: An optional 1D tensor of the same type as "x". + * @li offset_w: An optional 1D tensor for quantized deconvolution. Reserved. + +*@par Required Attributes: + * @li input_size: A Tensor of type int32. An integer vector representing the shape of input + * @li strides: A tuple/list of 3 integers. The stride of the sliding window for D/H/W dimension. + * @li pads: A tuple/list of 6 integers +*@par Attributes: + * Five attributes: + * @li dilations: A tuple/list of 5 integers, The dilation factor for each dimension of input, now only support [1,1,1,1,1] + * @li groups: Number of blocked connections from input channels to output channels. + * @li data_format: An optional string from: "NDHWC", "NCHWD". Defaults to "NDHWC". Specify the data format of the input and output data. + * @li output_padding: The size will be added in the output shape. + * @li offset_x: Input offset_x value +*@par Outputs: + * y: A Tensor. Has the same type as filter +*/ + + +REG_OP(Conv3DTransposeD) + .INPUT(x, TensorType({DT_FLOAT16})) + .INPUT(filter, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(bias, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(offset_w, TensorType({DT_INT8})) + .OUTPUT(y, TensorType({DT_FLOAT16})) + .REQUIRED_ATTR(input_size, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(dilations, ListInt, {1, 1, 1, 1, 1}) + .ATTR(groups, Int, 1) + .ATTR(data_format, String, "NDHWC") + .ATTR(output_padding, ListInt, {0, 0, 0, 0, 0}) + .ATTR(offset_x, Int, 0) + .OP_END_FACTORY_REG(Conv3DTransposeD) + } // namespace ge #endif // GE_OP_NN_CALCULATION_OPS_H diff --git a/third_party/fwkacllib/inc/ops/nn_detect_ops.h b/third_party/fwkacllib/inc/ops/nn_detect_ops.h index 5dca8a9d..0a91e237 100644 --- a/third_party/fwkacllib/inc/ops/nn_detect_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_detect_ops.h @@ -187,14 +187,15 @@ REG_OP(ROIAlignGrad) *@li features: A 5HD Tensor of type float32 or float16. *@li rois: ROI position. A 2D Tensor of float32 or float16 with shape (N, 5). "N" indicates the number of ROIs, the value "5" indicates the indexes of images where the ROIs are located, * "x0", "y0", "x1", and "y1". -*@li rois_n: An optional input, specifying the number of valid ROIs. This parameter is reserved. +*@li rois_n: An optional input of type int32, specifying the number of valid ROIs. This parameter is reserved. *@par Attributes: -*@li spatial_scale: A required attribute of type float, specifying the scaling ratio of "features" to the original image. -*@li pooled_height: A required attribute of type int, specifying the H dimension. -*@li pooled_width: A required attribute of type int, specifying the W dimension. -*@li sample_num: An optional attribute of type int, specifying the horizontal and vertical sampling frequency of each output. If this attribute is set to "0", +*@li spatial_scale: A required attribute of type float32, specifying the scaling ratio of "features" to the original image. +*@li pooled_height: A required attribute of type int32, specifying the H dimension. +*@li pooled_width: A required attribute of type int32, specifying the W dimension. +*@li sample_num: An optional attribute of type int32, specifying the horizontal and vertical sampling frequency of each output. If this attribute is set to "0", * the sampling frequency is equal to the rounded up value of "rois", which is a floating point number. Defaults to "2". +*@li roi_end_mode: An optional attribute of type int32. Defaults to "1". *@par Outputs: * output: Outputs the feature sample of each ROI position. The format is 5HD Tensor of type float32 or float16. The axis N is the number of input ROIs. Axes H, W, and C are consistent @@ -362,15 +363,15 @@ REG_OP(PSROIPooling) *@li im_info: An ND tensor of type float16 or float32, specifying the Image information. *@li actual_rois_num: An optional NCHW tensor of type int32, specifying the number of valid boxes per batch. *@par Attributes: -*@li batch_rois: An optional int32, specifying the number of images to be predicted. +*@li batch_rois: An optional int32, specifying the number of images to be predicted. Defaults to "1". *@li num_classes: An required int32, specifying the number of classes to be predicted. The value must be greater than 0. *@li score_threshold: An required float32, specifying the threshold for box filtering. The value range is [0.0, 1.0]. *@li iou_threshold: An required float32, specifying the confidence threshold for box filtering, which is the output "obj" of operator Region. The value range is (0.0, 1.0). *@par Outputs: -*@li box: An NCHW tensor of type float16 or float32, describing the information of each output box, including the coordinates, class, and confidence. -Proposal of actual output, with output shape [batch, numBoxes,8], 8 means [x1, y1, x2, y2, score, label, batchID, NULL], the maximum value of numBoxes is 1024. +*@li box: A tensor of type float16 or float32 for proposal of actual output, with output shape [batch, numBoxes,8]. +* 8 means [x1, y1, x2, y2, score, label, batchID, NULL], the maximum value of numBoxes is 1024. That is, take min (the maximum number of input boxes, 1024) -*@li actual_bbox_num: An NCHW tensor of type int32 With shape [bacth, num_classes], specifying the number of output boxes. +*@li actual_bbox_num: A tensor of type int32 With shape [bacth, num_classes], specifying the number of output boxes. *@attention Constraints:\n *@li totalnum < max_rois_num * batch_rois. @@ -414,9 +415,9 @@ REG_OP(FSRDetectionOutput) *@li confidence_threshold: An optional float32, specify the topk filter threshold. Only consider detections with confidence greater than the threshold *@li kernel_name: An optional string, specifying the operator name. Defaults to "ssd_detection_output". *@par Outputs: -*@li out_boxnum: An NCHW tensor of type int32, specifying the number of output boxes. -*@li y: An NCHW tensor of type float16 or float32 with shape [batch,keep_top_k, 8], describing the information of each output box, including the coordinates, -* class, and confidence. In output shape, 8 means (batchID, label(classID), score (class probability), xmin, ymin, xmax, ymax, null) +*@li out_boxnum: A tensor of type int32, specifying the number of output boxes. +*@li y: A tensor of type float16 or float32 with shape [batch,keep_top_k, 8], describing the information of each output box. +* In output shape, 8 means (batchID, label(classID), score (class probability), xmin, ymin, xmax, ymax, null) * It is a custom operator. It has no corresponding operator in Caffe. */ REG_OP(SSDDetectionOutput) @@ -447,10 +448,10 @@ REG_OP(SSDDetectionOutput) *@li boxes: A required int32, specifying the number of anchor boxes. Defaults to "5" for V2 or "3" for V3. *@li coords: An int32, specifying the number of parameters required for locating an object. The value is fixed at "4", corresponding to (x,y,w,h). *@li classes: An int32, specifying the number of prediction classes. Defaults to "80". The value range is [1, 1024]. -*@li yolo_version: A string, specifying the YOLO version, either "V2" or "V3". -*@li softmax: A bool, specifying whether to perform softmax, valid only when "yolo_version = V2". -*@li background: A bool, specifying the operation types of the obj and classes, used in conjunction with "softmax" and valid only when "yolo_version = V2". -*@li softmaxtree: A bool, Fixed to False, defined in Lite, but not used. +*@li yolo_version: A string, specifying the YOLO version, either "V2" or "V3".Defaults to "V3" +*@li softmax: A bool, specifying whether to perform softmax, valid only when "yolo_version = V2". Defaults to "false". +*@li background: A bool, specifying the operation types of the obj and classes, used in conjunction with "softmax" and valid only when "yolo_version = V2". Defaults to "false". +*@li softmaxtree: A bool, Fixed to False, defined in Lite, but not used. Defaults to "false". *@par Outputs: *@li coord_data: A float16 or float32 with shape [N, boxes*coords, ceilx(height*width*2+32, 32)/2], where "ceil" indicates that a detected box is aligned upwards with the second parameter. Specifies the coordinates of a detected box. @@ -501,10 +502,10 @@ and the actual image height and width. *@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "512". * *@par Outputs: -*@li boxout: An NCHW tensor of type float16 or float32 with shape [batch,6,post_nms_topn]. describing the information of each output box, including the coordinates, class, -and confidence. In output shape, 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. -*@li boxoutnum: An NCHW tensor of type int32 with shape [batch,8,1,1], specifying the number of output boxes. It means only the first one of the 8 numbers is valid, -the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 +*@li boxout: A tensor of type float16 or float32 with shape [batch,6,post_nms_topn]. describing the information of each output box, +* In output shape, 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. +*@li boxoutnum: A tensor of type int32 with shape [batch,8,1,1], specifying the number of output boxes. It means only the first one of the 8 numbers is valid, +* the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 * *@attention Constraints:\n *@li This operator applies only to the YOLO v2 network. @@ -561,10 +562,10 @@ and the actual image height and width. *@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "512". * *@par Outputs: -*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. -With shape [batch,6,post_nms_topn], 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. -*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. -With shape [batch,8,1,1], means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 +*@li boxout: A tensor of type float16 or float32 with shape [batch,6,post_nms_topn]. describing the information of each output box, +* In output shape, 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. +*@li boxoutnum: A tensor of type int32 with shape [batch,8,1,1], specifying the number of output boxes. It means only the first one of the 8 numbers is valid, +* the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 * *@attention Constraints:\n *@li This operator applies only to the YOLO v2 network. @@ -621,11 +622,11 @@ and the actual image height and width. *@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "512". * *@par Outputs: -*@li boxout: An NCHW tensor of type float16 or float32 with shape [batch,6,post_nms_topn], describing the information of each output box, including the coordinates, class, and confidence. -In output shape, 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. -*@li boxoutnum: An NCHW tensor of type int32 with shape [batch,8,1,1], specifying the number of output boxes. -The output shape means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 - +*@li boxout: A tensor of type float16 or float32 with shape [batch,6,post_nms_topn], describing the information of each output box. +* In output shape, 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. +*@li boxoutnum: A tensor of type int32 with shape [batch,8,1,1], specifying the number of output boxes. +* The output shape means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 +* *@attention Constraints:\n *@li This operator applies only to the YOLO v3 network. *@li The preceding layer of operator Yolov3DetectionOutput must be three Yolo operators. @@ -688,12 +689,11 @@ and the actual image height and width. *@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "512". * *@par Outputs: -*@li boxout: An NCHW tensor of type float16, describing the information of each output box, including the coordinates, class, and confidence. -With shape [batch,6,post_nms_topn], 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. -*@li boxoutnum: An NCHW tensor of type int32, specifying the number of output boxes. -With shape [batch,8,1,1], means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 +*@li boxout: A tensor of type float16 or float32 with shape [batch,6,post_nms_topn], describing the information of each output box. +* In output shape, 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. +*@li boxoutnum: A tensor of type int32 with shape [batch,8,1,1], specifying the number of output boxes. +* The output shape means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 * - *@attention Constraints:\n *@li This operator applies only to the YOLO v3 network. *@li The preceding layer of operator Yolov3DetectionOutput must be three Yolo operators. @@ -734,6 +734,65 @@ REG_OP(YoloV3DetectionOutputD) .OUTPUT(box_out_num, TensorType({DT_INT32})) .OP_END_FACTORY_REG(YoloV3DetectionOutputD) +/** +*@brief Performs YOLO V3 detection. + +*@par Inputs: +*16 Input, including: +*@li The outputs of operator Yolo at the preceding layer (that is, three Yolo operators on YOLO v3) are used as the inputs of operator Yolov3DetectionOutput. \n +A Yolo operator has three outputs: "coords", "obj", and "class". For details, see the description of operator Yolo. +*@li imginfo: A float16, describing the image information including the required image height and width \n +and the actual image height and width. +*@li windex: A windex tensor with shape [height,weight]. Has the same type as the inputs. [[0,1,2...(weight-1)],[0,1,2...(w-1)]...[0,1,2...(weight-1)]] consisting of h groups of [0, 1, 2...(weight-1)] is formed for the three Yolo outputs, respectively. + +*@li hindex: A hindex tensor with shape [height,weight]. Has the same type as the inputs. [[0,0...0],[1,1...1],[2,2...2]...[height-1,height-1...,height-1]] is formed for the three Yolo outputs, respectively. + +* +*@par Attributes: +*@li biases: A required float32. "biases = Number of Yolo operators at the preceding layer x 2 x boxes" +*@li boxes: A required int32, specifying the number of anchor boxes predicted for each Yolo layer. +*@li coords: Specifies the number of coordinate parameters. Must be 4. +*@li classes: A required int32, specifying the number of classes to be predicted. The value range is [1, 80]. +*@li relative: An optional bool. Defaults to and must be "true". +*@li obj_threshold: A required float, specifying the confidence threshold for box filtering, which is the output "obj" of operator Yolo). The value range is [0.0, 1.0]. +*@li post_nms_topn: An optional int32. This attribute is reserved. +*@li score_threshold: A required float, specifying the class score threshold for box filtering, which is the output "class" of operator Yolo). The value range is [0.0, 1.0]. +*@li iou_threshold: A required float, specifying the intersection-over-union (IOU) threshold for box filtering. The value range is [0.0, 1.0].\n +*@li pre_nms_topn: An optional int, specifying the number of boxes for non-maximum suppression (NMS). Defaults to "512". +* +*@par Outputs: +*@li boxout: A tensor of type float16 or float32 with shape [batch,6,post_nms_topn], describing the information of each output box. +* In output shape, 6 means x1, y1, x2, y2, score, label(class). Output by the number of box_out_num. +*@li boxoutnum: A tensor of type int32 with shape [batch,8,1,1], specifying the number of output boxes. +* The output shape means only the first one of the 8 numbers is valid, the number of valid boxes in each batch, the maximum number of valid boxes in each batch is 1024 +* +*@attention Constraints:\n +*@li This operator applies only to the YOLO v3 network. +*@li The preceding layer of operator Yolov3DetectionOutput must be three Yolo operators. +*@see Yolo() +*@par Third-party framework compatibility +* It is a custom operator. It has no corresponding operator in Caffe. +*/ +REG_OP(YoloV3DetectionOutputV2) + .DYNAMIC_INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .DYNAMIC_INPUT(windex, TensorType({DT_FLOAT16,DT_FLOAT})) + .DYNAMIC_INPUT(hindex, TensorType({DT_FLOAT16,DT_FLOAT})) + .REQUIRED_ATTR(biases, ListFloat) + .ATTR(boxes, Int, 3) + .ATTR(coords, Int, 4) + .ATTR(classes, Int, 80) + .ATTR(relative, Bool, true) + .ATTR(obj_threshold, Float, 0.5) + .ATTR(post_nms_topn, Int, 512) + .ATTR(score_threshold, Float, 0.5) + .ATTR(iou_threshold, Float, 0.45) + .ATTR(pre_nms_topn, Int, 512) + .ATTR(N, Int, 10) + .ATTR(resize_origin_img_to_net, Bool, false) + .OUTPUT(box_out, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(box_out_num, TensorType({DT_INT32})) + .OP_END_FACTORY_REG(YoloV3DetectionOutputV2) + /** *@brief Spatial Pyramid Pooling, multi-level pooling. * Pooling out(n, sigma(c*2^i*2^i)) tensor, i in range[0,pyramid_height). @@ -1084,6 +1143,131 @@ REG_OP(DecodeWheelsTarget) .OUTPUT(boundary_encoded, TensorType({DT_FLOAT16})) .OP_END_FACTORY_REG(DecodeWheelsTarget) +/** +*@brief Computes nms for input boxes and score, support multiple batch and classes. +* will do clip to window, score filter, top_k, and nms + +*@par Inputs: +* Four inputs, including: \n +*@li boxes: boxes, a 4D Tensor of type float16 with +* shape (batch, num_anchors, num_classes, 4). "batch" indicates the batch size of image, +* and "num_anchors" indicates num of boxes, and "num_classes" indicates classes of detect. +* and the value "4" refers to "x0", "x1", "y0", and "y1". +*@li scores: boxes, a 4D Tensor of type float16 with +* shape (batch, num_anchors, num_classes). +*@li clip_window: window size, a 2D Tensor of type float16 with +* shape (batch, 4). 4" refers to "anchor_x0", "anchor_x1", "anchor_y0", and "anchor_y1". +*@li num_valid_boxes: valid boxes number for each batch, a 1D Tensor of type int32 with +* shape (batch,). + +*@par Attributes: +*@li score_threshold: A required attribute of type float32, specifying the score filter iou iou_threshold. +*@li iou_threshold: A required attribute of type float32, specifying the nms iou iou_threshold. +*@li max_size_per_class: A required attribute of type int, specifying the nms output num per class. +*@li max_total_size: A required attribute of type int, specifying the the nms output num per batch. +*@li change_coordinate_frame: A required attribute of type bool, whether to normalize coordinates after clipping. +*@li transpose_box: A required attribute of type bool, whether inserted transpose before this op. + +*@par Outputs: +*@li nmsed_boxes: A 3D Tensor of type float16 with shape (batch, max_total_size, 4), +* specifying the output nms boxes per batch. +*@li nmsed_scores: A 2D Tensor of type float16 with shape (N, 4), +* specifying the output nms score per batch. +*@li nmsed_classes: A 2D Tensor of type float16 with shape (N, 4), +* specifying the output nms class per batch. +*@li nmsed_num: A 1D Tensor of type float16 with shape (N, 4), specifying the valid num of nmsed_boxes. + +*@attention Constraints: +* Only computation of float16 data is supported. +*/ +REG_OP(BatchMultiClassNonMaxSuppression) + .INPUT(boxes, TensorType({DT_FLOAT16})) + .INPUT(scores, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(clip_window, TensorType({DT_FLOAT16})) + .OPTIONAL_INPUT(num_valid_boxes, TensorType({DT_INT32})) + .OUTPUT(nmsed_boxes, TensorType({DT_FLOAT16})) + .OUTPUT(nmsed_scores, TensorType({DT_FLOAT16})) + .OUTPUT(nmsed_classes, TensorType({DT_FLOAT16})) + .OUTPUT(nmsed_num, TensorType({DT_INT32})) + .REQUIRED_ATTR(score_threshold, Float) + .REQUIRED_ATTR(iou_threshold, Float) + .REQUIRED_ATTR(max_size_per_class, Float) + .REQUIRED_ATTR(max_total_size, Float) + .ATTR(change_coordinate_frame, Bool, false) + .ATTR(transpose_box, Bool, false) + .OP_END_FACTORY_REG(BatchMultiClassNonMaxSuppression) + +/** +* @brief To absolute the bounding box. + +* @par Inputs: +* @li normalized_boxes: A 3D Tensor of type float16 or float32. +* @li shape_hw: A 1D Tensor of type int32. + +* @par Attributes: +* @li reversed_box: An optional bool, specifying the last two dims is "4,num" or +* "num,4", "true" for "4,num", "false" for "num,4". Defaults to "false". + +* @par Outputs: +* y: A Tensor. Has the same type and shape as "normalized_boxes". + +* @attention Constraints: +* "normalized_boxes"'s shape must be (batch,num,4) or (batch,4,num). +* "shape_hw"'s shape must be (4,) +*/ +REG_OP(ToAbsoluteBBox) + .INPUT(normalized_boxes, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(shape_hw, TensorType({DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .ATTR(reversed_box, Bool, false) + .OP_END_FACTORY_REG(ToAbsoluteBBox) + +/** +*@brief Computes Normalize bbox function. +* +*@par Inputs: +*Inputs include: +* @li boxes: A Tensor. Must be float16 or float32. +* @li shape_hw: A Tensor. Must be int32. +* +*@par Attributes: +* reversed_box: optional, bool. Defaults to "False" +* +*@par Outputs: +* y: A Tensor. Must have the same type and shape as boxes. +*/ +REG_OP(NormalizeBBox) + .INPUT(boxes, TensorType({DT_FLOAT16, DT_FLOAT})) + .INPUT(shape_hw, TensorType({DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .ATTR(reversed_box, Bool, false) + .OP_END_FACTORY_REG(NormalizeBBox) + +/** +*@brief Computes decode bboxv2 function. +* +*@par Inputs: +*Inputs include: +* @li boxes: A Tensor. Must be float16 or float32. +* @li anchors: A Tensor. Must be int32. +* +*@par Attributes: +* @li scales: optional, listfloat, . +* @li decode_clip: optional, float, threahold of decode process. +* @li reversed_boxes: optional, bool,. +* +*@par Outputs: +* y: A Tensor. Must have the same type as box_predictions. +*/ +REG_OP(DecodeBboxV2) + .INPUT(boxes, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(anchors, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT})) + .ATTR(scales, ListFloat, {1.0, 1.0, 1.0, 1.0}) + .ATTR(decode_clip, Float, 0.0) + .ATTR(reversed_box, Bool, false) + .OP_END_FACTORY_REG(DecodeBboxV2) + } // namespace ge #endif // GE_OP_NN_DETECT_OPS_H_ diff --git a/third_party/fwkacllib/inc/ops/nn_norm_ops.h b/third_party/fwkacllib/inc/ops/nn_norm_ops.h index d4db7cf0..f5b20cdd 100644 --- a/third_party/fwkacllib/inc/ops/nn_norm_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_norm_ops.h @@ -44,16 +44,6 @@ REG_OP(LogSoftmaxGrad) .ATTR(axis, ListInt, {-1}) .OP_END_FACTORY_REG(LogSoftmaxGrad) -REG_OP(SparseSoftmaxCrossEntropyWithLogitsCCE) - .INPUT(features, TensorType{DT_FLOAT}) - .INPUT(labels, TensorType{DT_FLOAT}) - .OUTPUT(out, TensorType{DT_FLOAT}) - .OUTPUT(non, TensorType{DT_FLOAT}) - .ATTR(cross_entropy_is_grad, Bool, 0) - .ATTR(cross_entropy_mode, Int, 1) - .ATTR(softmax_cross_entropy_lossscale_div_batch, Float, 1.0) - .OP_END_FACTORY_REG(SparseSoftmaxCrossEntropyWithLogitsCCE) - /** *@brief Computes sparse softmax cross entropy cost and gradients to backpropagate. @@ -291,8 +281,8 @@ REG_OP(BinaryCrossEntropyGrad) * double. Should be a Variable Tensor. *@par Attributes: -*axes: A list of ints. The dimension softmax would be performed on. Defaults -* to "{-1}". +*axes: A list of int. The dimension softmax would be performed on. Defaults +* to "[-1]". *@par Outputs: *y: A Tensor. Has the same dimensionality and shape as the "x" with values in @@ -330,22 +320,6 @@ REG_OP(LogSoftmaxV2) .ATTR(axes, ListInt, {-1}) .OP_END_FACTORY_REG(LogSoftmaxV2) -REG_OP(FusedBatchNormV2) - .INPUT(x, TensorType{DT_FLOAT}) /* Input data tensor from the previous operator"" */ - .INPUT(scale, TensorType{DT_FLOAT}) /* If spatial is true, the dimension of bias is (C) If spatial is false, the dimensions of scale are (C x D1 x ... x Dn)*/ - .INPUT(b, TensorType{DT_FLOAT}) /* If spatial is true, the dimension of bias is (C) If spatial is false, the dimensions of scale are (C x D1 x ... x Dn)*/ - .OPTIONAL_INPUT(mean, TensorType{DT_FLOAT}) /* If spatial is true, the dimension of the running mean (training) or the estimated mean (testing) is (C).If spatial is false, the dimensions of the running mean (training) or the estimated mean (testing) are (C x D1 x ... x Dn)*/ - .OPTIONAL_INPUT(variance, TensorType{DT_FLOAT}) /* If spatial is true, the dimension of the running variance(training) or the estimated variance (testing) is (C). If spatial is false, the dimensions of the running variance(training) or the estimated variance (testing) are (C x D1 x ... x Dn).*/ - .OUTPUT(y, TensorType{DT_FLOAT}) /* The output tensor of the same shape as X */ - .ATTR(momentum, Float, 0.9) // Factor used in computing the running mean and variance. - .ATTR(epsilon, Float, 1e-5f) // The epsilon value to use to avoid division by zero - .ATTR(mode, Int, 1) // 1 means using "CC_BATCHNORM_SPATIAL"; 0 means using "CC_BATCHNORM_PER_ACTIVATION"; only support 1 now - .ATTR(use_global_stats, Bool, true) - .ATTR(alpha, Float, 1) - .ATTR(beta, Float, 0) - .OP_END_FACTORY_REG(FusedBatchNormV2) - - /** *@brief Confuse mul, sum and sub. @@ -632,7 +606,7 @@ REG_OP(DropOutDoMask) * Three inputs, including: *@li x: An ND tensor of type float16 or float32. *@li scale: An ND tensor of type float16 or float32. -*@li bias: An ND tensor of type float16 or float32. +*@li bias: An optional ND tensor of type float16 or float32. *@par Attributes: *@li axis: An optional int32 used to compute the shape of scale and bias input from the online bottoms. Defaults to "1". @@ -679,11 +653,11 @@ REG_OP(Scale) * depth_radius = (local_size - 1) / 2. local_size is the number of channels to sum over (for ACROSS_CHANNELS) * or the side length of the square region to sum over (for WITHIN_CHANNEL). *@li bias: An optional float32. An offset, usually > 0 to avoid dividing by 0. -* Defaults to "1". +* Defaults to "1.0". *@li alpha: An optional float32. A scaling factor, usually positive. -* Defaults to "1". +* Defaults to "1.0". *@li beta: An optional float32. An exponent. Defaults to "0.75" for the caffe framework, Defaults to "0.5" for others. -*@li norm_region: An optional string. A mode option. "ACROSS_CHANNELS":0, "WITHIN_CHANNEL":1. Defaults to "ACROSS_CHANNELS". +*@li norm_region: An optional string. A mode option. "ACROSS_CHANNELS":0. Defaults to "ACROSS_CHANNELS". *@par Outputs: *y: A Tensor. Has the same data type and shape as "x". @@ -836,6 +810,56 @@ REG_OP(GroupNorm) .ATTR(num_groups, Int, 2) .OP_END_FACTORY_REG(GroupNorm) +/** +*@brief Performs instance normalization. + +*@par Inputs:\n +* Five inputs, including: (NC1HWC0, supported) +*@li x: A 5D Tensor of type float16 or float32, NC1HWC0. +*@li gamma: A Tensor of type float32. +A 5D Tensor for scaling factor, to scale the normalized x. +*@li beta: A Tensor of type float32. +A 5D Tensor for offset, to shift to the normalized x. +*@li mean: A Tensor of type float32. +A 5D Tensor Specifies the mean used for inference. Reserved. +*@li variance: A Tensor of type float32. +A 5D Tensor Specifies the variance used for inference. Reserved. + +*@par Attributes: +*@li is_training: An optional bool, specifying if the operation is used for \n +training or inference. Defaults to "True". +*@li momentum: An optional float32, \n +the value used for the running_mean and running_var computation. Default: "0.1". +*@li epsilon: An optional float32, specifying the small value added to \n +variance to avoid dividing by zero. Defaults to "0.00001". + +*@par Outputs:\n +* Three outputs, including: (NHWC, NCHW NC1HWC0 supported) +*@li y: A 5D tensor of type float16 or float32 for the normalized "x", \n +*@li batch_mean: A Tensor of type float32. +Specifies the mean of "x". +*@li batch_variance: A Tensor of type float32. +Specifies the variance of "x". + +*@par Third-party framework compatibility +*@li Compatible with the PyTorch operator InstanceNorm. +*/ +REG_OP(InstanceNormV2) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT})) + .OPTIONAL_INPUT(gamma, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(beta, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(mean, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(variance, TensorType({DT_FLOAT})) + + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT})) + .OUTPUT(batch_mean, TensorType({DT_FLOAT})) + .OUTPUT(batch_variance, TensorType({DT_FLOAT})) + + .ATTR(is_training, Bool, true) + .ATTR(momentum, Float, 0.1) + .ATTR(epsilon, Float, 0.00001) + .OP_END_FACTORY_REG(InstanceNormV2) + } // namespace ge #endif //GE_OP_NN_NORM_OPS_H diff --git a/third_party/fwkacllib/inc/ops/nn_pooling_ops.h b/third_party/fwkacllib/inc/ops/nn_pooling_ops.h index 5eb11445..98c4b246 100644 --- a/third_party/fwkacllib/inc/ops/nn_pooling_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_pooling_ops.h @@ -101,6 +101,42 @@ REG_OP(AvgPool) .ATTR(data_format, String, "NHWC") .OP_END_FACTORY_REG(AvgPool) +/** +*@brief Performs average pooling on the input. + +*@par Inputs: +*x: A 5-D Tensor of shape [batch, depth, height, width, channels] and type float16, float32, double. + +*@par Attributes: +*@li ksize: List of ints that has length 1, 3 or 5. The size of the window for each dimension of the input tensor. +*@li strides:List of ints that has length 1, 3 or 5. The stride of the sliding window for each dimension of the input tensor. +*@li pads: List of ints, implicit zero paddings on both sides of the input. +*@li ceil_mode: When true, will use ceil instead of floor in the formula to compute the output shape. +*@li count_include_pad: When true, will include the zero-padding in the averaging calculation. +*@li divisor_override: if specified, it will be used as divisor, otherwise size of the pooling region will be used. +*@li data_format: A string, format of input data. + +*@par Outputs: +*y: The average pooled output tensor. + +*@attention Constraints: +*@li "ksize" is in the range [1, 255]. "strides" is in the range [1, 63] + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator AvgPool3D. +*/ +REG_OP(AvgPool3D) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) + .REQUIRED_ATTR(ksize, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(ceil_mode, Bool, false) + .ATTR(count_include_pad, Bool, true) + .ATTR(divisor_override, Int, 0) + .ATTR(data_format, String, "NDHWC") + .OP_END_FACTORY_REG(AvgPool3D) + /** *@brief Performs max_pool_ext2 on the input. @@ -184,17 +220,62 @@ REG_OP(MaxPool) .OP_END_FACTORY_REG(MaxPool) REG_OP(MaxPool3D) - .INPUT(x, TensorType({DT_FLOAT16})) - .OUTPUT(y, TensorType({DT_FLOAT16})) + .INPUT(x, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_FLOAT16, DT_FLOAT32, DT_DOUBLE})) .REQUIRED_ATTR(ksize, ListInt) .REQUIRED_ATTR(strides, ListInt) .REQUIRED_ATTR(padding, String) .ATTR(pads, ListInt, {0,0,0}) - .ATTR(dilation, ListInt, {0,0,0}) + .ATTR(dilation, ListInt, {1,1,1}) .ATTR(ceil_mode, Int, 0) .ATTR(data_format, String, "NDHWC") .OP_END_FACTORY_REG(MaxPool3D) + +/** +* @brief Computes second-order gradients of the maxpooling3d function. + +* @par Inputs: +* @li orig_x: Original forward input tensor(NDC1HWC0) of type float16 +* @li orig_y: Original forward output tensor(NDC1HWC0) of type float16 +* @li grads: Gradient tensor(NDC1HWC0) of type float16 +* @li assist: Assist tensor(NDC1HWC0) of type float16 + +* @par Attributes: +* @li ksize: A required list or tuple, +* specifying the size of the sliding window. +* @li strides: A required list or tuple, +* specifying the stride of the sliding window. +* @li pads: A required list or tuple +* @li padding: A required string, window sliding mode. Either SAME or VALID. +* @li data_format: An optional string. +* Format of the original input, either NCDHW or NDHWC. Defaults to NDHWC. + +* @attention Constraints: +* @li Only the Ascend 910 platform is supported. +* @li "orig_x" and "grads" must have the same shape. +* @li "orig_y" and "y" must have the same shape. Otherwise, an error is reported. +* @li "orig_x", "orig_y", "grads", and "y" must be NDC1HWC0 tensors. + +* @par Outputs: +* @li y: Result tensor of type float16 + +* @par Third-party framework compatibility +* @li Compatible with the TensorFlow operator MaxPool3DGradGrad. +*/ + +REG_OP(MaxPool3DGradGrad) + .INPUT(orig_x, TensorType::RealNumberType()) + .INPUT(orig_y, TensorType::RealNumberType()) + .INPUT(grads, TensorType::RealNumberType()) + .OUTPUT(y, TensorType::RealNumberType()) + .REQUIRED_ATTR(ksize, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(data_format, String, "NDHWC") + .OP_END_FACTORY_REG(MaxPool3DGradGrad) + + /** * @brief Computes gradients of the maxpooling function. @@ -239,9 +320,10 @@ REG_OP(MaxPoolGrad) * @brief Computes second-order gradients of the maxpooling function. * @par Inputs: -* @li x1: Original forward input tensor of type RealNumberType -* @li x2: Original forward output tensor of type RealNumberType -* @li grad: Gradient tensor of type RealNumberType +* @li x1: Original forward input tensor. Supported type:float, double, int32, + * uint8, int16, int8, int64, uint16, half, uint32, uint64. +* @li x2: Has the same type and format as input "x1". +* @li grad:Has the same type and format as input "x1". * @par Attributes: * @li ksize: A required list or tuple, @@ -262,7 +344,7 @@ REG_OP(MaxPoolGrad) * @li Other dimensions of ksize and strides is 1. * @par Outputs: -* @li y: Result tensor of type RealNumberType +* @li y: Has the same type and format as input "x1". * @par Third-party framework compatibility * @li Compatible with the TensorFlow operator MaxPoolGradGrad. @@ -397,19 +479,56 @@ REG_OP(MaxPoolGradWithArgmax) .REQUIRED_ATTR(padding, String) .OP_END_FACTORY_REG(MaxPoolGradWithArgmax) +/** +*@brief Performs transform mask to argmax. + +*@par Inputs: +* Two input: +*x: An NC1HWC0 Tensor of type float16. +*mask: An NC1HWC0 Tensor of type uint16. + +*@par Attributes: +*@li ksize: A required list of int8, int16, int32, or int64 values, specifying the size of the window for each dimension of the input tensor. No default value. +*@li strides: A required list of int8, int16, int32, or int64 values, specifying the stride of the sliding window for each dimension of the input tensor. No default value. +*@li padding: A required string. No default value. + +*@par Outputs: +*argmax: An NC1HWC0 Tensor of type int32. + +*@attention Constraints: +*@li "ksize" is a list that has length 4: ksize[0] = 1 or ksize[3] = 1, ksize[1] * ksize[2] <= 255. +*@li "stride is a list that has length 4: strides[0] = 1 or strides[3] = 1, strides[1] <= 63, strides[0] >= 1, strides[2] <= 63, strides[2] >= 1. +*@li "padding" is either "SAME" or "VALID". + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator Mask2Argmax. +*/ +REG_OP(Mask2Argmax) + .INPUT(x, TensorType::RealNumberType()) + .INPUT(mask, TensorType::IndexNumberType()) + .OUTPUT(argmax, TensorType::IndexNumberType()) + .REQUIRED_ATTR(ksize, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(padding, String) + .REQUIRED_ATTR(originshape, ListInt) + .OP_END_FACTORY_REG(Mask2Argmax) + /** * @brief Computes second-order gradients of the maxpooling function. * @par Inputs: -* @li x: Original forward input tensor of type RealNumberType -* @li grad: Gradient tensor of type RealNumberType -* @li argmax: An tensor of type IndexNumberType +* @li x: Original forward input tensor. Supported type: float, double, int32, + * uint8, int16, int8, int64, uint16, half, uint32, uint64. +* @li grad: Gradient tensor. Supported type: float, double, int32, + * uint8, int16, int8, int64, uint16, half, uint32, uint64. +* @li argmax: An tensor of type int32 or int64. * @par Attributes: * @li ksize: A required list, specifying the size of the sliding window. * @li strides: A required list, specifying the stride of the sliding window. * @li padding: A required string, window sliding mode. Either SAME or VALID. * @par Outputs: -* @li y:Result tensor of type RealNumberType +* @li y:Result tensor. Supported type: float, double, int32, + * uint8, int16, int8, int64, uint16, half, uint32, uint64 * @attention Constraints: * @li Only the cloud platform is supported. @@ -495,35 +614,7 @@ REG_OP(AvgPoolGradD) .OP_END_FACTORY_REG(AvgPoolGradD) -REG_OP(MaxPoolWithArgmaxCCE) - .INPUT(x, TensorType::ALL()) - .OUTPUT(y, TensorType::ALL()) - .OUTPUT(argmax, TensorType::ALL()) - .ATTR(mode, Int, 0) - .ATTR(pad_mode, Int, 0) - .ATTR(window, ListInt, {1,1}) - .ATTR(stride, ListInt, {1,1}) - .ATTR(pad, ListInt, {0,0,0,0}) - .ATTR(ceil_mode, Int, 0) - .ATTR(data_mode, Int, 1) - .ATTR(nan_opt, Int, 0) - .OP_END_FACTORY_REG(MaxPoolWithArgmaxCCE) - -REG_OP(MaxPoolGradWithArgmaxCCE) - .INPUT(x, TensorType::ALL()) - .INPUT(grad,TensorType::ALL()) - .INPUT(arg,TensorType::ALL()) - .OUTPUT(output,TensorType::ALL()) - .ATTR(mode, Int, 0) - .ATTR(max_pool_grad_output_shape, ListInt, {0,0,0,0}) - .ATTR(pad_mode, Int, 0) - .ATTR(window, ListInt, {1,1}) - .ATTR(stride, ListInt, {1,1}) - .ATTR(pad, ListInt, {0,0,0,0}) - .ATTR(ceil_mode, Int, 0) - .ATTR(data_mode, Int, 1) - .ATTR(nan_opt, Int, 0) - .OP_END_FACTORY_REG(MaxPoolGradWithArgmaxCCE) + /** *@brief :upsample the layer @@ -531,7 +622,7 @@ REG_OP(MaxPoolGradWithArgmaxCCE) * one input, including: *@li x: A tensor of type float16 or float32. *@par Attributes: -*@li scale: A optional float, scale factor of x. Defaults to "1.0". +*@li scale: A optional float32, scale factor of x. Defaults to "1.0". *@li stride_h: An optional int32, broadcast the axis of h. Defaults to "2". *@li stride_w: An optional int32, broadcast the axis of w. Defaults to "2". *@par Outputs: @@ -749,7 +840,186 @@ REG_OP(DataFormatVecPermute) .ATTR(dst_format, String, "NCHW") .OP_END_FACTORY_REG(DataFormatVecPermute) +/** +* @brief Computes gradients of the MaxPool3D function. + +* @par Inputs: +* @li orig_x: A mutable NDC1HWC0 tensor of type float16. +* @li orig_y: A mutable NDC1HWC0 tensor of type float16. +* @li grads: A mutable NDC1HWC0 tensor of type float16. +* @par Attributes: +* @li ksize: A required tuple or list, specifying the size of the window for +* each dimension of the input tensor. +* @li strides: A required tuple or list, specifying the stride of the sliding +* window for each dimension of the input tensor. +* @li pads: A list of 6 ints. Supports only padding along the D, +* H and W dimensions in sequence of head, tail, top, bottom, left and right. +* to use. +* @li data_format: An optional string, Specify the data format of the input and +* output data. With the default format "NDHWC". + +* @par Outputs: +* y: A mutable tensor. Has the same shape as "orig_x", but type is float32. + +* @par Third-party framework compatibility +* Compatible with the TensorFlow operator MaxPool3DGrad. +*/ +REG_OP(MaxPool3DGrad) + .INPUT(orig_x, TensorType::RealNumberType()) + .INPUT(orig_y, TensorType::RealNumberType()) + .INPUT(grads, TensorType::RealNumberType()) + .OUTPUT(y, TensorType::RealNumberType()) + .REQUIRED_ATTR(ksize, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(data_format, String, "NDHWC") + .OP_END_FACTORY_REG(MaxPool3DGrad) + +/** +*@brief Performs AvgPool1D on the input. + +*@par Inputs: +*x: A Tensor. Must be one of the following types: int8, uint8, int16, int32, int64, float16, float32, float64. + +*@par Attributes: +*@li ksize: An required int, specifying the size of the window. +*@li strides: An required int. +*@li pads: A required tuple or list. +*@li ceil_mode: An optional bool. Defaults to False. +*@li count_include_pad: An optional bool. Defaults to False. + +*@par Outputs: +*y: A Tensor. Has the same type as x. + +*@par Third-party framework compatibility +*@li compatible with pytorch AvgPool1D operator. +*/ +REG_OP(AvgPool1D) + .INPUT(x, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .REQUIRED_ATTR(ksize, Int) + .REQUIRED_ATTR(strides, Int) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(ceil_mode, Bool, false) + .ATTR(count_include_pad, Bool, false) + .OP_END_FACTORY_REG(AvgPool1D) + +/** +*@brief Performs AvgPool1D on the input. + +*@par Inputs: +*x: A Tensor. Must be one of the following types: int8, uint8, int16, int32, int64, float16, float32, float64. + +*@par Attributes: +*@li ksize: An required int, specifying the size of the window. +*@li strides: An required int. +*@li pads: A required tuple or list. +*@li ceil_mode: An optional bool. Defaults to False. +*@li count_include_pad: An optional bool. Defaults to False. + +*@par Outputs: +*y: A Tensor. Has the same type as x. + +*@par Third-party framework compatibility +*@li compatible with pytorch AvgPool1D operator. +*/ +REG_OP(AvgPool1DD) + .INPUT(x, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .INPUT(assist_matrix, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .OUTPUT(y, TensorType({DT_INT8, DT_UINT8, DT_INT16, DT_INT32, DT_INT64, DT_FLOAT16, DT_FLOAT, DT_DOUBLE})) + .REQUIRED_ATTR(ksize, Int) + .REQUIRED_ATTR(strides, Int) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(ceil_mode, Bool, false) + .ATTR(count_include_pad, Bool, false) + .OP_END_FACTORY_REG(AvgPool1DD) +/** +*@brief Performs max pooling on the input and outputs both max values and indices. + +*@par Inputs: +* One input: +*x: An NC1HWC0 Tensor of type float16. +*@par Attributes: +*@li ksize: A required list of int8, int16, int32, or int64 values, specifying the size of the window for +* each dimension of the input tensor. No default value. +*@li strides: A required list of int8, int16, int32, or int64 values, specifying the stride of the sliding window for +* each dimension of the input tensor. No default value. +*@li pads: A required string. No default value. +*@li dtype: A optional int. default value is 3. +*@li dilation: A optional list of int8, int16, int32, or int64 values. +*@li ceil_mode: A optional bool. default value is false. + +*@par Outputs: +*y: A Tensor. Has the same type and format as input "x". +*argmax: A Tensor. type:uint16, format:NC1HWC0. +*@attention Constraints: +*@li "ksize" is a list that has length 4: ksize[0] = 1 or ksize[3] = 1, ksize[1] * ksize[2] <= 255. +*@li "strides is a list that has length 4: strides[0] = 1 or strides[3] = 1, strides[1] <= 63, strides[0] >= 1, +* strides[2] <= 63, strides[2] >= 1. +*@li "dilation" is a list that has length 4. +*@li "ceil_mode" is a bool, default is false. + +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator MaxPoolWithArgmax. +*/ +REG_OP(MaxPoolWithArgmaxV2) + .INPUT(x, TensorType({DT_FLOAT16})) + .OUTPUT(y, TensorType({DT_FLOAT16})) + .OUTPUT(argmax, TensorType({DT_UINT16})) + .REQUIRED_ATTR(ksize, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(dtype, Int, 3) + .ATTR(dilation, ListInt, {1, 1, 1, 1}) + .ATTR(ceil_mode, Bool, false) + .OP_END_FACTORY_REG(MaxPoolWithArgmaxV2) + +/** +*@brief Performs the backpropagation of MaxPoolWithArgmaxV2. + +*@par Inputs: +* Three inputs, including: +*@li x: An NC1HWC0 tensor of type float16. +*@li grad: An NC1HWC0 tensor of type float16. +*@li argmx: An NC1HWC0 tensor of type uint16 or int64. + +*@par Attributes: +*@li ksize: A required list of int8, int16, int32, or int64 values, specifying the size of the window for + * each dimension of the input tensor. No default value. +*@li strides: A required list of int8, int16, int32, or int64 values, specifying the stride of the sliding window for + * each dimension of the input tensor. No default value. +*@li pads: A required string. No default value. +*@li dtype: A optional int. default value is 3. +*@li dilation: A optional list of int8, int16, int32, or int64 values. +*@li ceil_mode: A optional bool. default value is false. + +*@par Outputs: +*y: A Tensor. Has the same type and format as input "x". + +*@attention Constraints: +*@li "ksize" is a list that has length 4: ksize[0] = 1 or ksize[3] = 1, ksize[1] * ksize[2] <= 255. +*@li "strides" is a list that has length 4: strides[0] = 1 or strides[3] = 1 +*@li "dilation" is a list that has length 4. +*@li "ceil_mode" is a bool, default is false. + +*@see max_pool_grad_with_argmaxv2 +*@par Third-party framework compatibility +* Compatible with the TensorFlow operator MaxPoolGradWithArgmaxV2. +*/ + +REG_OP(MaxPoolGradWithArgmaxV2) + .INPUT(x, TensorType({DT_FLOAT16})) + .INPUT(grad, TensorType({DT_FLOAT16})) + .INPUT(argmax, TensorType({DT_UINT16})) + .OUTPUT(y, TensorType({DT_FLOAT16})) + .REQUIRED_ATTR(ksize, ListInt) + .REQUIRED_ATTR(strides, ListInt) + .REQUIRED_ATTR(pads, ListInt) + .ATTR(dtype, Int, 3) + .ATTR(dilation, ListInt, {1,1,1,1}) + .ATTR(ceil_mode, Bool, false) + .OP_END_FACTORY_REG(MaxPoolGradWithArgmaxV2) } // namespace ge #endif // GE_OP_NN_POOLING_OPS_H diff --git a/third_party/fwkacllib/inc/ops/nn_training_ops.h b/third_party/fwkacllib/inc/ops/nn_training_ops.h index 1c9aa516..cc17103c 100644 --- a/third_party/fwkacllib/inc/ops/nn_training_ops.h +++ b/third_party/fwkacllib/inc/ops/nn_training_ops.h @@ -307,16 +307,6 @@ REG_OP(ApplyMomentum) .ATTR(use_locking, Bool, false) .OP_END_FACTORY_REG(ApplyMomentum) -REG_OP(ApplyMomentumCCE) - .INPUT(var, TensorType::NumberType()) - .INPUT(accum, TensorType::NumberType()) - .INPUT(lr, TensorType::NumberType()) - .INPUT(grad, TensorType::NumberType()) - .INPUT(momentum, TensorType::NumberType()) - .OUTPUT(var, TensorType::NumberType()) - .ATTR(use_nesterov, Bool, false) - .ATTR(use_locking, Bool, false) - .OP_END_FACTORY_REG(ApplyMomentumCCE) /** *@brief Updates "var" according to the momentum scheme. Set use_nesterov = True if you @@ -1508,7 +1498,7 @@ REG_OP(ApplyProximalAdagradD) *@par Attributes: *use_locking: An optional bool. Defaults to "False".\n * If "True", updating of the var and accum tensors will be protected by a lock; \n -* If "False", the behavior is undefined, but may exhibit less contention. +* If "False", the behavior is undefined, but may exhibit less contention. *@par Outputs: *var: A mutable Tensor. Has the same type as "var". diff --git a/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h b/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h index 1405fdb7..a01073cf 100644 --- a/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h +++ b/third_party/fwkacllib/inc/ops/nonlinear_fuc_ops.h @@ -83,7 +83,7 @@ REG_OP(TanhGrad) *@par Inputs: *One input: -*x: A Tensor. Must be one of the following types: float16, float32, complex64, complex128, int32, int64 +*x: A Tensor. Must be one of the following types: float16, float32, complex64, complex128, double. *@par Outputs: *y: A Tensor. Has the same type as "x". @@ -184,7 +184,7 @@ REG_OP(Relu6Grad) * @brief Compute sigmoid of "x" element-wise. * @par Inputs: -* A Tensor of type UnaryDataType. +* A Tensor of type complex64, complex128, float16, float32 or double. * @par Outputs: * A Tensor. Has the same type as "x". @@ -220,7 +220,7 @@ REG_OP(SigmoidGrad) *if x>0, x+log(1+exp(-x)); otherwise log(1+exp(x)). *@par Inputs: -*x: A Tensor of type float16 or float32. +*x: A Tensor of type double, float16 or float32. *@par Outputs: *y: A tensor. Has the same type and format as input "x". @@ -442,7 +442,7 @@ REG_OP(PReluGrad) *x: A float16, float32 or double, for the input data type. *@par Attributes: -*alpha: A float. Defines at which negative value the ELU saturates. Defaults to "1.0". +*alpha: A float32. Defines at which negative value the ELU saturates. Defaults to "1.0". *@par Outputs: *y: A float16, float32 or double, for the normalized result. diff --git a/third_party/fwkacllib/inc/ops/reduce_ops.h b/third_party/fwkacllib/inc/ops/reduce_ops.h index 8819d2d5..8cf9f342 100644 --- a/third_party/fwkacllib/inc/ops/reduce_ops.h +++ b/third_party/fwkacllib/inc/ops/reduce_ops.h @@ -673,7 +673,7 @@ REG_OP(ReduceAnyD) *@par Attributes: *@li operation: An optional int32 from 1(SUM), 2(ASUM), 3(SUMSQ), and 4(MEAN), -*specifying the reduction algorithm. Defaults to 1. +*specifying the reduction algorithm. Defaults to "1". *@li axis: An optional int32, specifying the first axis to reduce. Defaults to "0". *The value range is [-N, N-1], where N is the input tensor rank. *@li coeff: An optional float32, specifying the scale coefficient. Defaults to "1.0". @@ -745,7 +745,190 @@ REG_OP(EuclideanNormD) .ATTR(keep_dims, Bool, false) .OP_END_FACTORY_REG(EuclideanNormD) -} //namespace ge +/** +*@brief Performs instance normalization for inference. + +*@par Inputs:\n +* Five inputs, including: (NC1HWC0 supported) +*@li x: A Tensor of type float16 or float32. +*@li gamma: A [N, C1, 1, 1, C0] Tensor of type float32, for the scaling gamma. +*@li beta: A [N, C1, 1, 1, C0] Tensor of type float32, for the scaling beta. +*@li mean: A [N, C1, 1, 1, C0] ensor of type float32, for the mean. +*@li variance: A [N, C1, 1, 1, C0] Tensor of type float32, for the variance. + +*@par Attributes: +*epsilon: An optional float32, specifying the small value added to variance to avoid dividing by zero. +Defaults to "0.00001". + +*@par Outputs:\n +*y: A Tensor of type float16 or float32 for the normalized "x". +*batch_mean: A Tensor of type float32 for the result mean. +*batch_ variance: A Tensor of type float32 for the result variance. + +*@attention Constraints: +*For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction. +*/ +REG_OP(INInferV2) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .OPTIONAL_INPUT(gamma, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(beta, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(mean, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(variance, TensorType({DT_FLOAT})) + .ATTR(epsilon, Float, 0.00001) + .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(batch_mean, TensorType({DT_FLOAT})) + .OUTPUT(batch_variance, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(INInferV2) + +/** +*@brief Performs reduced instance normalization. + +*@par Inputs:\n +*x: A Tensor of type float16 or float32, with format NC1HWC0. + +*@par Outputs: +*@li sum: A Tensor of type float32 for SUM reduced "x". +*@li square_sum: A Tensor of type float32 for SUMSQ reduced "x". + +*@attention Constraints:\n +* This operator is a InstanceNorm fusion operator for updating the moving averages for training. \n +* This operator is used in conjunction with INTrainingUpdateV2. +*/ +REG_OP(INTrainingReduceV2) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(sum, TensorType({DT_FLOAT})) + .OUTPUT(square_sum, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(INTrainingReduceV2) + + +/** +*@brief Performs update instance normalization. + +*@par Inputs:\n +* Seven inputs, including: (NC1HWC0supported) +*@li x: A Tensor of type float16 or float32. +*@li sum: A T [N, C1, 1, 1, C0] ensor of type float32 for the output of operator INTrainingReduceV2. +*@li square_sum: A [N, C1, 1, 1, C0] Tensor of type float32 for the output of operator INTrainingReduceV2. +*@li gamma: A [N, C1, 1, 1, C0] Tensor of type float32, for the scaling gamma. +*@li beta: A [N, C1, 1, 1, C0] Tensor of type float32, for the scaling beta. +*@li mean: A [N, C1, 1, 1, C0] Tensor of type float32, for the updated mean. +*@li variance: A [N, C1, 1, 1, C0] Tensor of type float32, for the updated variance. + +*@par Attributes: +*@li momentum: A required float32, specifying the momentum to update mean and var. +*@li epsilon: A required float32, specifying the small value added to variance to avoid dividing by zero. + +*@par Outputs:\n +* Three outputs, including: (NC1HWC0 supported) +*@li y: A Tensor of type float16 or float32, for normalized "x". +*@li batch_mean: A Tensor of type float32, for the updated mean. +*@li batch_variance: A Tensor of type float32, for the updated variance. + +*@attention Constraints: +*@li This operator is a InstanceNorm fusion operator for updating the moving averages for training. \n +* This operator is used in conjunction with INTrainingReduceV2. +*@li For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction. +*/ +REG_OP(INTrainingUpdateV2) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(sum, TensorType({DT_FLOAT})) + .INPUT(square_sum, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(gamma, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(beta, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(mean, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(variance, TensorType({DT_FLOAT})) + .ATTR(momentum, Float, 0.1) + .ATTR(epsilon, Float, 0.00001) + .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(batch_mean, TensorType({DT_FLOAT})) + .OUTPUT(batch_variance, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(INTrainingUpdateV2) + + +/** +*@brief Performs reduced group normalization. + +*@par Inputs:\n +*x: A Tensor of type float16 or float32, with format NCHW NHWC. + +*@par Outputs: +*@li sum: A Tensor of type float32 for SUM reduced "x". +*@li square_sum: A Tensor of type float32 for SUMSQ reduced "x". + + +*@par Attributes: +*@li num_groups: Int, specifying the num of groups. required, same to GNTrainingUpdate. + +*@attention Constraints:\n +* This operator is a GroupNorm fusion operator for updating the moving averages for training. \n +* This operator is used in conjunction with GNTrainingUpdate. +*/ +REG_OP(GNTrainingReduce) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(sum, TensorType({DT_FLOAT})) + .OUTPUT(square_sum, TensorType({DT_FLOAT})) + .ATTR(num_groups, Int, 2) + .OP_END_FACTORY_REG(GNTrainingReduce) + + +/** +*@brief Performs update group normalization. + +*@par Inputs:\n +* Eight inputs, including: (NCHW NHWC supported) +*@li x: A Tensor of type float16 or float32. +*@li sum: A 5D Tensor of type float32, +shape is [N, G, 1, 1, 1] for NCHW, [N, 1, 1, G, 1] for NHWC +for the output of operator GNTrainingReduce. +*@li square_sum: A 5D Tensor of type float32, +shape is [N, G, 1, 1, 1] for NCHW, [N, 1, 1, G, 1] for NHWC +for the output of operator GNTrainingReduce. +*@li scale: A 5D Tensor of type float32, +shape is [1, G, 1, 1, 1] for NCHW, [1, 1, 1, G, 1] for NHWC +is for the scaling gamma. +*@li offset: A 5D Tensor of type float32, +shape is [1, G, 1, 1, 1] for NCHW, [1, 1, 1, G, 1] for NHWC +for the scaling beta. +*@li mean: A 5D Tensor of type float32, +shape is [N, G, 1, 1, 1] for NCHW, [N, 1, 1, G, 1] for NHWC +for the updated mean. +*@li variance: A 5D Tensor of type float32, +shape is [N, G, 1, 1, 1] for NCHW, [N, 1, 1, G, 1] for NHWC +for the updated variance. + + +*@par Attributes: +*@li epsilon: A float32, specifying the small value added to variance to avoid dividing by zero. +*@li num_groups: Int, specifying the num of groups. required, same to GNTrainingReduce + +*@par Outputs:\n +* Three outputs, including: (NC1HWC0 supported) +*@li y: A Tensor of type float16 or float32, for normalized "x". +*@li batch_mean: A Tensor of type float32, for the updated mean. +*@li batch_variance: A Tensor of type float32, for the updated variance. + +*@attention Constraints: +*@li This operator is a InstanceNorm fusion operator for updating the moving averages for training. \n +* This operator is used in conjunction with GNTrainingUpdate. +*@li For Ascend 310, the result accuracy fails to reach 1‰ due to the square root instruction. +*/ +REG_OP(GNTrainingUpdate) + .INPUT(x, TensorType({DT_FLOAT16,DT_FLOAT})) + .INPUT(sum, TensorType({DT_FLOAT})) + .INPUT(square_sum, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(scale, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(offset, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(mean, TensorType({DT_FLOAT})) + .OPTIONAL_INPUT(variance, TensorType({DT_FLOAT})) + .ATTR(num_groups, Int, 2) + .ATTR(epsilon, Float, 0.0001) + .OUTPUT(y, TensorType({DT_FLOAT16,DT_FLOAT})) + .OUTPUT(batch_mean, TensorType({DT_FLOAT})) + .OUTPUT(batch_variance, TensorType({DT_FLOAT})) + .OP_END_FACTORY_REG(GNTrainingUpdate) + +} //namespace ge + #endif /* GE_OP_REDUCE_OPS_H */ diff --git a/third_party/fwkacllib/inc/ops/rnn.h b/third_party/fwkacllib/inc/ops/rnn.h index c4d64b0a..b72d9a79 100644 --- a/third_party/fwkacllib/inc/ops/rnn.h +++ b/third_party/fwkacllib/inc/ops/rnn.h @@ -67,6 +67,13 @@ REG_OP(BasicLSTMCell) .ATTR(activation, String, "tanh") .OP_END_FACTORY_REG(BasicLSTMCell) +REG_OP(DynamicLSTM) + .INPUT(x, TensorType({DT_FLOAT32})) + .INPUT(w, TensorType({DT_FLOAT32})) + .INPUT(b, TensorType({DT_FLOAT32})) + .OUTPUT(output_h, TensorType({DT_FLOAT32})) + .OP_END_FACTORY_REG(DynamicLSTM) + /** *@brief: Basic LSTM Cell backward calculation.Calculate the gradient of input and hidden state. *@par Inputs: @@ -87,7 +94,7 @@ REG_OP(BasicLSTMCellInputGrad) .INPUT(dgate, TensorType({DT_FLOAT16})) .INPUT(w, TensorType({DT_FLOAT16})) .OPTIONAL_INPUT(dropout_mask, TensorType({DT_UINT8})) - .OUTPUT(dxt, TensorType({DT_FLOAT16})) + .OUTPUT(dxt, TensorType({DT_FLOAT16, DT_FLOAT32})) .OUTPUT(dht, TensorType({DT_FLOAT16, DT_FLOAT32})) .ATTR(keep_prob, Float, 1.0) .OP_END_FACTORY_REG(BasicLSTMCellInputGrad) diff --git a/third_party/fwkacllib/inc/ops/selection_ops.h b/third_party/fwkacllib/inc/ops/selection_ops.h index aafcece0..c2e6f13a 100644 --- a/third_party/fwkacllib/inc/ops/selection_ops.h +++ b/third_party/fwkacllib/inc/ops/selection_ops.h @@ -89,7 +89,8 @@ REG_OP(RangeD) *@par Inputs: *Two inputs, including: -* @li x: A Tensor of type TensorType::BasicType(). +* @li x: A Tensor. +* Must be one of the following types: float16, float32, double, int64, int32, uint8, uint16, uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32. * @li multiples: A 1D Tensor of type int32 or int64. * The length must be the same as the number of dimensions in "input" @@ -496,7 +497,7 @@ REG_OP(UnsortedSegmentSumD) *@par Inputs: * Two inputs, including:\n *@li x: An ND Tensor (up to 8D). \n -*Must be one of the following types: int8, uint8, int16, uint16, int32, int64, bool, float32, double +*Must be one of the following types: int8, uint8, int16, uint16, int32, int64, bool, float16, float32, double, complex64, complex128, string. *@li axis: A 1D Tensor.\n *Must be one of the following types: int32, int64 @@ -1003,9 +1004,8 @@ REG_OP(StridedSliceAssign) * @par Inputs: * Two inputs, including: -* @li var: A mutable ND Tensor of type BasicType. -* @li input_value: A mutable ND "Tensor" of type BasicType. - +* @li var: A mutable ND Tensor of the following types:int32, int16, float16, float32. +* @li input_value: A mutable ND "Tensor" of the following types:int32, int16, float16, float32. * @par Attributes: * @li begin: A required list of ints. @@ -1029,9 +1029,9 @@ REG_OP(StridedSliceAssign) * @see StridedSlice() */ REG_OP(StridedSliceAssignD) - .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) - .INPUT(input_value, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32})) - .OUTPUT(var, TensorType(BasicType)) + .INPUT(var, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT16})) + .INPUT(input_value, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT16})) + .OUTPUT(var, TensorType({DT_FLOAT16, DT_FLOAT, DT_INT32, DT_INT16})) .REQUIRED_ATTR(begin, ListInt) .REQUIRED_ATTR(end, ListInt) .REQUIRED_ATTR(strides, ListInt) @@ -1396,24 +1396,23 @@ REG_OP(UnsortedSegmentMin) * @brief Computes the minimum along segments of a tensor. * @par Inputs: -* Three inputs, including: -* @li x: A Tensor of type RealNumberType. -* @li segment_ids: A 1D Tensor of type IndexNumberType, whose shape is a prefix +* Two inputs, including: +* @li x: A Tensor of the following types:int32, int16, float16, float32. +* @li segment_ids: A 1D Tensor of type int32, whose shape is a prefix * of "x.shape". -* @li k: A Tensor. * @par Attributes: * num_segments: A required int32, specifying the number of distinct segment IDs. * @par Outputs: -* y: A Tensor of type RealNumberType. +* y: A Tensor.Must have the same type as input "x". * @see UnsortedSegmentProdD(), */ REG_OP(UnsortedSegmentMinD) - .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT16})) .INPUT(segment_ids, TensorType({DT_INT32})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT16})) .REQUIRED_ATTR(num_segments, Int) .OP_END_FACTORY_REG(UnsortedSegmentMinD) @@ -1446,24 +1445,23 @@ REG_OP(UnsortedSegmentProd) * @brief Computes the product along segments of a tensor. * @par Inputs: -* Three inputs, including: -* @li x: A Tensor of type RealNumberType. -* @li segment_ids: A 1D Tensor of type IndexNumberType, whose shape is a prefix +* Two inputs, including: +* @li x: A Tensor of the following types:int32, int16, float16, float32. +* @li segment_ids: A 1D Tensor of type int32, whose shape is a prefix * of "x.shape". -* @li k: A Tensor. * @par Attributes: * num_segments: An int32, specifying the number of distinct segment IDs. * @par Outputs: -* y: A Tensor of type RealNumberType. +* y: A Tensor.Must have the same type as input "x". * @see UnsortedSegmentMinD() */ REG_OP(UnsortedSegmentProdD) - .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .INPUT(x, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT16})) .INPUT(segment_ids, TensorType({DT_INT32})) - .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32})) + .OUTPUT(y, TensorType({DT_FLOAT, DT_FLOAT16, DT_INT32, DT_INT16})) .REQUIRED_ATTR(num_segments, Int) .OP_END_FACTORY_REG(UnsortedSegmentProdD) @@ -1559,14 +1557,14 @@ REG_OP(ProposalD) * If reverse=false: (N, H, W, C)->(N, H/stride, W/stride, C*(stride*stride)) *@par Inputs: -*x: An (N, H, W, C) tensor. All types except double are supported. +*x: An (N, H, W, C) tensor. Type is float16, float32, int8, uint8, int16, uint16, int32, uint32, int64 or uint64.. *@par Attributes: *@li stride: An optional int32, specifying the plane or channel scaling factor. Defaults to "2". *@li reverse: An optional bool, specifying the conversion mode. If "true", depth to space conversion is performed. If "false", space to depth conversion is performed. Defaults to "false". *@par Outputs: -*y: An (N, H, W, C) tensor. All types except double are supported. +*y: An (N, H, W, C) tensor. Has same type as "x". *@attention Constraints: *@li If reverse=true: C/(stride*stride) yields an integer result. If reverse=false: W/stride and H/stride yield integer results. @@ -1593,7 +1591,7 @@ REG_OP(PassThrough) * @li x: A required Tensor. Must be one of the following types: float16, float32, int8, uint8, int16, uint16, int32, uint32,int64, uint64. * @li size: A required Tensor. Must be one of the following types: float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64. *@par Attributes: -*@li axis: A required int32, specifying the first dimension to crop. +*@li axis: A required int32, specifying the first dimension to crop. Defaults to "2". *@li offset: A required array, specifying the shift for all/each dimension to align the cropped bottom with the reference bottom. Must be one of the following types: float16, float32, int8, uint8, int16, uint16, int32, uint32, int64, uint64. *@par Outputs: *y: A required Tensor. Has the same type and shape as "size". @@ -1774,6 +1772,6 @@ REG_OP(CumulativeLogsumexpD) .ATTR(exclusive, Bool, false) .ATTR(reverse, Bool, false) .OP_END_FACTORY_REG(CumulativeLogsumexpD) - } // namespace ge + #endif // GE_OP_SELECTION_OPS_H diff --git a/third_party/fwkacllib/inc/ops/split_combination_ops.h b/third_party/fwkacllib/inc/ops/split_combination_ops.h index 700d34b7..7e4428d0 100644 --- a/third_party/fwkacllib/inc/ops/split_combination_ops.h +++ b/third_party/fwkacllib/inc/ops/split_combination_ops.h @@ -25,11 +25,11 @@ namespace ge { *@par Inputs: * Two inputs, including: *@li x: An ND Tensor. -*Must be one of the following types: float16, float32, int32, int8, int16, int64, uint8, uint16, uint32, uint64 +*Must be one of the types:float16, float32, double, int64, int32, uint8, uint16, uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32. *@li split_dim: Must be the following type:int32. Specifies the dimension along which to split. *@par Attributes: -*num_split: A required int8, int16, int32, or int64. Specifies the number of output tensors. No default value. +*num_split: A required int32. Specifies the number of output tensors. No default value. *@par Outputs: *y: Dynamic output.A list of output tensors. Has the same type and format as "x". @@ -186,6 +186,7 @@ REG_OP(ParallelConcat) *@par Attributes: *concat_dim: A required int8, int16, int32, or int64. Specifies the dimension along which to concatenate. No default value. +*N: An attribute int8, int16, int32, or int64. Specifies the number of elements in "x". Defaults to "1". *@par Outputs: *y: A Tensor. Has the same type and format as "x". @@ -267,7 +268,9 @@ REG_OP(ConcatD) *@par Inputs: * Two inputs, including: *@li x: Dynamic input.An NC1HWC0 or ND Tensor. -*Must be one of the following types: float16, float32, int32, int8, int16, int64, uint8, uint16, uint32, uint64 +*Must be one of the following types: float16, float32, double, int32, +* uint8, int16, int8, complex64, int64, qint8, quint8, qint32, uint16, +* complex128, uint32, uint64, qint16, quint16. *@li concat_dim: An int32, or int64. Specifies the dimension along which to concatenate. *@par Attributes: diff --git a/third_party/fwkacllib/inc/ops/transformation_ops.h b/third_party/fwkacllib/inc/ops/transformation_ops.h index 69951da9..5bbf1e78 100644 --- a/third_party/fwkacllib/inc/ops/transformation_ops.h +++ b/third_party/fwkacllib/inc/ops/transformation_ops.h @@ -20,6 +20,35 @@ #include "graph/operator_reg.h" namespace ge { +/** +*@brief This operation convert output dataType and shape + +*@par Inputs: +*The input handle must have the resource type. Inputs include: \n +*@li x:A list of Tensor objects. One or more tensors from which \n +the enqueued tensors should be taken. + +*@par Outputs: +*@li y:A list of Tensor objects. One or more tensors from which \n +the enqueued tensors should be taken. + +*@par Attributes: +*@li type: An optional ge::DataType. It refers to the target data type of outputs. + +*@par Third-party framework compatibility +*Compatible with tensorflow QueueIsClosed operator. +*/ + +REG_OP(Bitcast) + .INPUT(x, TensorType({DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT32, DT_UINT8, + DT_INT64, DT_UINT64, DT_INT16, DT_UINT16, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32})) + .OUTPUT(y, TensorType({DT_BOOL, DT_FLOAT16, DT_FLOAT, DT_INT8, DT_INT32, DT_UINT32, DT_UINT8, + DT_INT64, DT_UINT64, DT_INT16, DT_UINT16, DT_DOUBLE, DT_COMPLEX64, + DT_COMPLEX128, DT_QINT8, DT_QUINT8, DT_QINT16, DT_QUINT16, DT_QINT32})) + .REQUIRED_ATTR(type, Type) + .OP_END_FACTORY_REG(Bitcast) + /** *@brief Convert tensor format from HWCN to C1HWNCoC0. @@ -94,6 +123,13 @@ REG_OP(Transpose) .OUTPUT(y, TensorType::BasicType()) .OP_END_FACTORY_REG(Transpose) +REG_OP(TransData) + .INPUT(src, TensorType::BasicType()) + .OUTPUT(dst, TensorType::BasicType()) + .REQUIRED_ATTR(src_format, String) + .REQUIRED_ATTR(dst_format, String) + .OP_END_FACTORY_REG(TransData) + /** *@brief Permutes the dimensions according to order.\n The returned tensor's dimension i will correspond to the input dimension order[i]. @@ -102,7 +138,7 @@ REG_OP(Transpose) *x: A Tensor. Must be one of the following types: float16, float32. *@par Attributes: -*order: A permutation of the dimensions of "x".support any axis transformation +*order: A permutation of the dimensions of "x".Type is int32.support any axis transformation.Defaults to "{0}" *@par Outputs: *y: A Tensor. Has the same type as "x". @@ -291,7 +327,7 @@ REG_OP(DepthToSpace) *@brief Permutes data into spatial data blocks and then prunes them. *@par Inputs: -*@li x: A 4D Tensor with format NC1HWC0. +*@li x: A 4D Tensor with format NHWC. *@li crops: A 1D list or tuple of int32 or int64. *Must be one of the following types: float16, float32 @@ -300,7 +336,7 @@ REG_OP(DepthToSpace) *block_size: A required int8, int16, int32, or int64. No default value. *@par Outputs: -*y: A 4D Tensor with format NC1HWC0, +*y: A 4D Tensor with format NHWC, * of type float16 or float32. @@ -365,7 +401,7 @@ REG_OP(BatchToSpaceD) *@par Inputs: * Two inputs, including: -*@li x: An NC1HWC0 Tensor. Must be one of the following types: +*@li x: An NHWC Tensor. Must be one of the following types: * float16, float32, double, int64, int32, uint8, uint16, uint32, uint64, int8, * int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32. *@li paddings: A 2D tensor of type int, specifying the input. @@ -389,7 +425,7 @@ REG_OP(SpaceToBatch) *@brief Outputs a copy of the input tensor where values from the "height" and "width" dimensions are padded and rearranged to the "batch" dimension. *@par Inputs: -*x: An NC1HWC0 Tensor. Must be one of the following types: float16, float32, double, int64, int32, uint8, uint16, uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32. +*x: An NHWC Tensor. Must be one of the following types: float16, float32, double, int64, int32, uint8, uint16, uint32, uint64, int8, int16, complex64, complex128, qint8, quint8, qint16, quint16, qint32. *@par Attributes: @@ -598,6 +634,13 @@ REG_OP(Compress) .OUTPUT(compress_index, TensorType({DT_INT8})) .REQUIRED_ATTR(compress_parameters, ListInt) .OP_END_FACTORY_REG(Compress) + +REG_OP(CompressFcOp) + .INPUT(weight, TensorType({DT_INT8})) + .OUTPUT(weight_compress, TensorType({DT_INT8})) + .OUTPUT(compress_index, TensorType({DT_INT8})) + .REQUIRED_ATTR(compress_parameters, ListInt) + .OP_END_FACTORY_REG(CompressFcOp) } // namespace ge #endif // GE_OP_TRANSFORMATION_OPS_H diff --git a/third_party/fwkacllib/inc/register/op_kernel_registry.h b/third_party/fwkacllib/inc/register/op_kernel_registry.h index 2c479e92..5fed8960 100644 --- a/third_party/fwkacllib/inc/register/op_kernel_registry.h +++ b/third_party/fwkacllib/inc/register/op_kernel_registry.h @@ -41,6 +41,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpKernelRegistry { private: OpKernelRegistry(); class OpKernelRegistryImpl; + /*lint -e148*/ std::unique_ptr impl_; }; } // namespace ge diff --git a/third_party/fwkacllib/inc/register/op_registry.h b/third_party/fwkacllib/inc/register/op_registry.h index 1fcdf9de..1dc14b8b 100644 --- a/third_party/fwkacllib/inc/register/op_registry.h +++ b/third_party/fwkacllib/inc/register/op_registry.h @@ -35,6 +35,7 @@ enum RemoveInputType { OMG_MOVE_TYPE_SCALAR_VALUE, OMG_REMOVE_TYPE_WITH_COND = 1000, OMG_REMOVE_INPUT_WITH_ORIGINAL_TYPE, + OMG_INPUT_REORDER, }; struct RemoveInputConfigure { @@ -43,6 +44,7 @@ struct RemoveInputConfigure { RemoveInputType moveType; bool attrValue = false; std::string originalType; + std::vector input_order; }; class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistry { @@ -57,11 +59,11 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistry { void GetOpTypeByImplyType(std::vector &vec_op_type, const domi::ImplyType &imply_type); - domi::ParseParamFunc GetParseParamFunc(const std::string &op_type); + domi::ParseParamFunc GetParseParamFunc(const std::string &op_type, const std::string &ori_type); - domi::ParseParamByOpFunc GetParseParamByOperatorFunc(const std::string &op_type); + domi::ParseParamByOpFunc GetParseParamByOperatorFunc(const std::string &ori_type); - domi::FusionParseParamFunc GetFusionParseParamFunc(const std::string &op_type); + domi::FusionParseParamFunc GetFusionParseParamFunc(const std::string &op_type, const std::string &ori_type); domi::ParseSubgraphFunc GetParseSubgraphPostFunc(const std::string &op_type); @@ -72,14 +74,13 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistry { bool GetOmTypeByOriOpType(const std::string &ori_optype, std::string &om_type); private: - std::unordered_map> op_ori_optype_map_; std::unordered_map op_run_mode_map_; - std::unordered_map opParseParamsFnMap_; + std::unordered_map op_parse_params_fn_map_; std::unordered_map parse_params_by_op_func_map_; - std::unordered_map fusionOpParseParamsFnMap_; + std::unordered_map fusion_op_parse_params_fn_map_; std::unordered_map op_types_to_parse_subgraph_post_func_; std::unordered_map> remove_input_configure_map_; - std::unordered_map originOpType2OmOpType_; + std::unordered_map origin_type_to_om_type_; }; } // namespace domi #endif // INC_REGISTER_OP_REGISTRY_H_ diff --git a/third_party/fwkacllib/inc/register/op_tiling.h b/third_party/fwkacllib/inc/register/op_tiling.h new file mode 100644 index 00000000..e9d19f94 --- /dev/null +++ b/third_party/fwkacllib/inc/register/op_tiling.h @@ -0,0 +1,133 @@ +/** + * Copyright 2019-2020 Huawei Technologies Co., Ltd + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#ifndef INC_OP_TILING_H_ +#define INC_OP_TILING_H_ + +#include "external/register/register_types.h" +#include "graph/debug/ge_attr_define.h" +#include "graph/node.h" + +#include +#include +#include +#include +#include +#include +#include +#include "graph/node.h" + +#define REGISTER_OP_TILING_FUNC(optype, opfunc) \ + REGISTER_OP_TILING_FUNC_UNIQ_HELPER(optype, opfunc, __COUNTER__) + +#define REGISTER_OP_TILING_FUNC_UNIQ_HELPER(optype, opfunc, counter) \ + REGISTER_OP_TILING_FUNC_UNIQ(optype, opfunc, counter) + +#define REGISTER_OP_TILING_FUNC_UNIQ(optype, opfunc, counter) \ + static OpTilingInterf g_##optype##TilingInterf##counter(#optype, opfunc) + +namespace optiling { + +enum TensorArgType { + TA_NONE, + TA_SINGLE, + TA_LIST, +}; + + +using ByteBuffer = std::stringstream; + +struct TeOpTensor { + std::vector shape; + std::vector ori_shape; + std::string format; + std::string ori_format; + std::string dtype; + std::map attrs; +}; + + +struct TeOpTensorArg { + TensorArgType arg_type; + std::vector tensor; +}; + +struct OpRunInfo { + uint32_t block_dim; + std::vector workspaces; + ByteBuffer tiling_data; +}; + + +using TeOpAttrArgs = std::vector; +using TeConstTensorData = std::tuple; + +struct TeOpParas { + std::vector inputs; + std::vector outputs; + std::map const_inputs; + TeOpAttrArgs attrs; +}; + + +using OpTilingFunc = std::function; + +using OpTilingFuncPtr = bool(*)(const std::string&, const TeOpParas&, const nlohmann::json& , OpRunInfo&); + +class FMK_FUNC_HOST_VISIBILITY OpTilingInterf +{ +public: + OpTilingInterf(std::string op_type, OpTilingFunc func); + ~OpTilingInterf() = default; + static std::map &RegisteredOpInterf(); +}; + + +template +ByteBuffer& ByteBufferPut(ByteBuffer &buf, const T &value) +{ + buf.write(reinterpret_cast(&value), sizeof(value)); + buf.flush(); + return buf; +} + +template +ByteBuffer& ByteBufferGet(ByteBuffer &buf, T &value) +{ + buf.read(reinterpret_cast(&value), sizeof(value)); + return buf; +} + +inline size_t ByteBufferGetAll(ByteBuffer &buf, char *dest, size_t dest_len) +{ + size_t nread = 0; + size_t rn = 0; + do { + rn = buf.readsome(dest + nread, dest_len - nread); + nread += rn; + } while (rn > 0 && dest_len > nread); + + return nread; +} + + +extern "C" ge::graphStatus OpParaCalculate(const ge::Node &node, OpRunInfo &run_info); +extern "C" ge::graphStatus OpAtomicCalculate(const ge::Node &node, OpRunInfo &run_info); + +} + +#endif // INC_OP_TILING_H_ diff --git a/third_party/fwkacllib/inc/runtime/base.h b/third_party/fwkacllib/inc/runtime/base.h index 49c9de6a..2d6503f9 100644 --- a/third_party/fwkacllib/inc/runtime/base.h +++ b/third_party/fwkacllib/inc/runtime/base.h @@ -62,12 +62,20 @@ typedef enum tagRtError { RT_ERROR_DEVICE_POWER_DOWN_FAIL = 0x16, RT_ERROR_FEATURE_NOT_SUPPROT = 0x17, RT_ERROR_KERNEL_DUPLICATE = 0x18, // register same kernel repeatly + RT_ERROR_STREAM_DUPLICATE = 0x19, // streamId Map is repeatly + RT_ERROR_STREAM_NOT_EXIST = 0x1a, // streamId is not exist + RT_ERROR_SQ_NO_EXIST_SQ_TO_REUSE = 0x1b, // no exist sq to reuse + RT_ERROR_SQID_FULL = 0x3C, RT_ERROR_MODEL_STREAM_EXE_FAILED = 0x91, // the model stream failed RT_ERROR_MODEL_LOAD_FAILED = 0x94, // the model stream failed RT_ERROR_END_OF_SEQUENCE = 0x95, // end of sequence RT_ERROR_NO_STREAM_CB_REG = 0x96, // no callback register info for stream RT_ERROR_DATA_DUMP_LOAD_FAILED = 0x97, // data dump load info fail RT_ERROR_CALLBACK_THREAD_UNSUBSTRIBE = 0x98, // callback thread unsubstribe + RT_ERROR_DEBUG_REGISTER_FAILED = 0x99, // debug register fail + RT_ERROR_DEBUG_UNREGISTER_FAILED = 0x9A, // debug unregister fail + RT_ERROR_GROUP_NOT_SET = 0x9B, + RT_ERROR_GROUP_NOT_CREATE = 0x9C, RT_ERROR_RESERVED } rtError_t; @@ -154,6 +162,12 @@ RTS_API rtError_t rtSetProfDirEx(const char *profDir, const char *address, const */ RTS_API rtError_t rtProfilerInit(const char *profdir, const char *address, const char *job_ctx); +/** + * @ingroup profiling_base + * @brief config rts profiler. + */ +RTS_API rtError_t rtProfilerConfig(uint16_t type); + /** * @ingroup profiling_base * @brief start rts profiler. @@ -184,14 +198,6 @@ RTS_API rtError_t rtGetLastError(); */ RTS_API rtError_t rtPeekAtLastError(); -/** - * @ingroup dvrt_base - * @brief set polling receive mode for task report - * @param [out] NA - * @return RT_ERROR_NONE for ok - */ -RTS_API rtError_t rtSetPollingMode(); - /** * @ingroup dvrt_base * @brief register callback for error code diff --git a/third_party/fwkacllib/inc/runtime/config.h b/third_party/fwkacllib/inc/runtime/config.h index 2e48cc57..c64ed16f 100644 --- a/third_party/fwkacllib/inc/runtime/config.h +++ b/third_party/fwkacllib/inc/runtime/config.h @@ -41,8 +41,7 @@ typedef enum tagRtChipType { CHIP_CLOUD, CHIP_MDC, CHIP_LHISI, - CHIP_OTHER_PHN, - CHIP_OTHER_OLD, + CHIP_DC, CHIP_END, } rtChipType_t; @@ -62,6 +61,7 @@ typedef enum tagRtPlatformType { PLATFORM_MINI_V2, PLATFORM_LHISI_ES, PLATFORM_LHISI_CS, + PLATFORM_DC, PLATFORM_END, } rtPlatformType_t; diff --git a/third_party/fwkacllib/inc/runtime/context.h b/third_party/fwkacllib/inc/runtime/context.h index ed1f13c2..70437b74 100644 --- a/third_party/fwkacllib/inc/runtime/context.h +++ b/third_party/fwkacllib/inc/runtime/context.h @@ -39,6 +39,17 @@ typedef enum tagCtxMode { RT_CTX_GEN_MODE = 1, } rtCtxMode_t; +typedef struct tagRtGroupInfo { + int32_t groupId; + int32_t flag; + uint32_t aicoreNum; + uint32_t aicpuNum; + uint32_t aivectorNum; + uint32_t sdmaNum; + uint32_t activeStreamNum; + void* extrPtr; +} rtGroupInfo_t; + /** * @ingroup rt_context * @brief create context and associates it with the calling thread @@ -115,17 +126,32 @@ RTS_API rtError_t rtGetPriCtxByDeviceId(int32_t device, rtContext_t *ctx); RTS_API rtError_t rtCtxGetDevice(int32_t *device); /** - * @ingroup rt_context - * @brief set ctx run mode: normal or dryrun - * @param [in] ctx: context - * @param [in] enable: set true means enable dryrun mode - * @param [in] flag: reserved - * @return RT_ERROR_NONE for ok + * @ingroup + * @brief set group id + * @param [in] groupid + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtSetGroup(int32_t groupId); + +/** + * @ingroup + * @brief get group info + * @param [in] groupid count + * @return RT_ERROR_NONE for ok, errno for failed */ -RTS_API rtError_t rtCtxSetDryRun(rtContext_t ctx, rtDryRunFlag_t enable, uint32_t flag); +RTS_API rtError_t rtGetGroupInfo(int32_t groupId, rtGroupInfo_t* groupInfo, uint32_t count); + +/** + * @ingroup + * @brief get group count + * @param [in] groupid count + * @return RT_ERROR_NONE for ok, errno for failed + */ +RTS_API rtError_t rtGetGroupCount(uint32_t *count); #ifdef __cplusplus } #endif -#endif // __CCE_RUNTIME_CONTEXT_H__ \ No newline at end of file + +#endif // __CCE_RUNTIME_CONTEXT_H__ diff --git a/third_party/fwkacllib/inc/runtime/dev.h b/third_party/fwkacllib/inc/runtime/dev.h index 928f2822..f79f060c 100644 --- a/third_party/fwkacllib/inc/runtime/dev.h +++ b/third_party/fwkacllib/inc/runtime/dev.h @@ -23,6 +23,9 @@ extern "C" { #endif +#define RT_CAPABILITY_SUPPORT (0x1) +#define RT_CAPABILITY_NOT_SUPPORT (0x0) + typedef struct tagRTDeviceInfo { uint8_t env_type; // 0: FPGA 1: EMU 2: ESL uint32_t ctrl_cpu_ip; @@ -32,6 +35,7 @@ typedef struct tagRTDeviceInfo { uint32_t ts_cpu_core_num; uint32_t ai_cpu_core_num; uint32_t ai_core_num; + uint32_t ai_core_freq; uint32_t ai_cpu_core_id; uint32_t ai_core_id; uint32_t aicpu_occupy_bitmap; @@ -46,6 +50,23 @@ typedef enum tagRtRunMode { RT_RUN_MODE_RESERVED } rtRunMode; +typedef enum tagRtAicpuDeployType { + AICPU_DEPLOY_CROSS_OS = 0x0, + AICPU_DEPLOY_CROSS_PROCESS = 0x1, + AICPU_DEPLOY_CROSS_THREAD = 0x2, + AICPU_DEPLOY_RESERVED +} rtAicpuDeployType_t; + +typedef enum tagRtFeatureType { + FEATURE_TYPE_MEMCPY = 0, + FEATURE_TYPE_RSV +} rtFeatureType_t; + +typedef enum tagMemcpyInfo { + MEMCPY_INFO_SUPPORT_ZEROCOPY = 0, + MEMCPY_INFO_RSV +} rtMemcpyInfo_t; + /** * @ingroup dvrt_dev * @brief get total device number. @@ -62,15 +83,40 @@ RTS_API rtError_t rtGetDeviceCount(int32_t *count); * @return RT_ERROR_DRV_ERR for error */ RTS_API rtError_t rtGetDeviceIDs(uint32_t *devices, uint32_t len); + /** * @ingroup dvrt_dev - * @brief get total device infomation. + * @brief get device infomation. * @param [in] device the device id - * @param [out] info the device info + * @param [in] moduleType module type + typedef enum { + MODULE_TYPE_SYSTEM = 0, system info + MODULE_TYPE_AICPU, aicpu info + MODULE_TYPE_CCPU, ccpu_info + MODULE_TYPE_DCPU, dcpu info + MODULE_TYPE_AICORE, AI CORE info + MODULE_TYPE_TSCPU, tscpu info + MODULE_TYPE_PCIE, PCIE info + } DEV_MODULE_TYPE; + * @param [in] infoType info type + typedef enum { + INFO_TYPE_ENV = 0, + INFO_TYPE_VERSION, + INFO_TYPE_MASTERID, + INFO_TYPE_CORE_NUM, + INFO_TYPE_OS_SCHED, + INFO_TYPE_IN_USED, + INFO_TYPE_ERROR_MAP, + INFO_TYPE_OCCUPY, + INFO_TYPE_ID, + INFO_TYPE_IP, + INFO_TYPE_ENDIAN, + } DEV_INFO_TYPE; + * @param [out] value the device info * @return RT_ERROR_NONE for ok * @return RT_ERROR_NO_DEVICE for can not find any device */ -RTS_API rtError_t rtGetDeviceInfo(int32_t device, rtDeviceInfo_t *info); +RTS_API rtError_t rtGetDeviceInfo(uint32_t deviceId, int32_t moduleType, int32_t infoType, int64_t *value); /** * @ingroup dvrt_dev @@ -130,6 +176,25 @@ RTS_API rtError_t rtEnableP2P(uint32_t devIdDes, uint32_t phyIdSrc); */ RTS_API rtError_t rtDisableP2P(uint32_t devIdDes, uint32_t phyIdSrc); +/** + * @ingroup dvrt_dev + * @brief get status + * @param [in] devIdDes the logical device id + * @param [in] phyIdSrc the physical device id + * @param [in|out] status status value + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_NO_DEVICE for can not find any device + */ +RTS_API rtError_t rtGetP2PStatus(uint32_t devIdDes, uint32_t phyIdSrc, uint32_t *status); + +/** + * @ingroup dvrt_dev + * @brief get value of current thread + * @param [in|out] pid value of pid + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtDeviceGetBareTgid(uint32_t *pid); + /** * @ingroup dvrt_dev * @brief get target device of current thread @@ -212,6 +277,15 @@ RTS_API rtError_t rtSetTSDevice(uint32_t tsId); */ RTS_API rtError_t rtGetRunMode(rtRunMode *mode); +/** + * @ingroup dvrt_dev + * @brief get aicpu deploy + * @param [out] aicpu deploy + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_DRV_ERR for can not get aicpu deploy + */ +RTS_API rtError_t rtGetAicpuDeploy(rtAicpuDeployType_t *deplyType); + /** * @ingroup dvrt_dev * @brief set chipType @@ -225,6 +299,35 @@ RTS_API rtError_t rtSetSocVersion(const char *version); * @return RT_ERROR_NONE for ok */ rtError_t rtGetSocVersion(char *version, const uint32_t maxLen); + +/** + * @ingroup dvrt_dev + * @brief get status + * @param [in] devId the logical device id + * @param [in] otherDevId the other logical device id + * @param [in] infoType info type + * @param [in|out] value pair info + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetPairDevicesInfo(uint32_t devId, uint32_t otherDevId, int32_t infoType, int64_t *value); + +/** + * @ingroup dvrt_dev + * @brief get capability infomation. + * @param [in] featureType feature type + typedef enum tagRtFeatureType { + FEATURE_TYPE_MEMCPY = 0, + FEATURE_TYPE_RSV, + } rtFeatureType_t; + * @param [in] infoType info type + typedef enum tagMemcpyInfo { + MEMCPY_INFO_SUPPORT_ZEROCOPY = 0, + MEMCPY_INFO _RSV, + } rtMemcpyInfo_t; + * @param [out] value the capability info + * @return RT_ERROR_NONE for ok + */ +RTS_API rtError_t rtGetRtCapability(rtFeatureType_t featureType, int32_t featureInfo, int64_t *value); #ifdef __cplusplus } #endif diff --git a/third_party/fwkacllib/inc/runtime/mem.h b/third_party/fwkacllib/inc/runtime/mem.h index 7c2a0728..e70ebd38 100644 --- a/third_party/fwkacllib/inc/runtime/mem.h +++ b/third_party/fwkacllib/inc/runtime/mem.h @@ -17,7 +17,9 @@ #ifndef __CCE_RUNTIME_MEM_H__ #define __CCE_RUNTIME_MEM_H__ +/*lint -e7*/ #include +/*lint +e7*/ #include "base.h" #include "config.h" #include "stream.h" @@ -222,17 +224,15 @@ RTS_API rtError_t rtMemAllocManaged(void **ptr, uint64_t size, uint32_t flag); * @return RT_ERROR_INVALID_DEVICE_POINTER for error device memory pointer */ RTS_API rtError_t rtMemFreeManaged(void *ptr); - /** * @ingroup dvrt_mem - * @brief advise memory - * @param [in] ptr memory pointer - * @param [in] size memory size - * @param [in] advise memory advise + * @brief alloc cached device memory + * @param [in| devPtr memory pointer + * @param [in] size memory size + * @param [in] type memory type * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_DEVICE_POINTER for error device memory pointer */ -RTS_API rtError_t rtMemAdvise(void *ptr, uint64_t size, uint32_t advise); +RTS_API rtError_t rtMallocCached(void **devPtr, uint64_t size, rtMemType_t type); /** * @ingroup dvrt_mem @@ -241,7 +241,7 @@ RTS_API rtError_t rtMemAdvise(void *ptr, uint64_t size, uint32_t advise); * @param [in] len memory size * @return RT_ERROR_NONE for ok, errno for failed */ -RTS_API rtError_t rtFlushCache(uint64_t base, uint32_t len); +RTS_API rtError_t rtFlushCache(void *base, size_t len); /** * @ingroup dvrt_mem @@ -250,7 +250,7 @@ RTS_API rtError_t rtFlushCache(uint64_t base, uint32_t len); * @param [in] len memory size * @return RT_ERROR_NONE for ok, errno for failed */ -RTS_API rtError_t rtInvalidCache(uint64_t base, uint32_t len); +RTS_API rtError_t rtInvalidCache(void *base, size_t len); /** * @ingroup dvrt_mem @@ -426,19 +426,6 @@ RTS_API rtError_t rtIpcCloseMemory(const void *ptr); */ RTS_API rtError_t rtRDMASend(uint32_t index, uint32_t wqe_index, rtStream_t stream); -/** - * @ingroup dvrt_mem - * @brief Set the memory readCount value - * @param [in] devPtr memory pointer - * @param [in] size memory size - * @param [in] readCount readCount value - * @return RT_ERROR_NONE for ok - * @return RT_ERROR_INVALID_VALUE for error input - * @return RT_ERROR_INVALID_RESOURCE_HANDLE for invalid resource handle - * @return RT_ERROR_DRV_ERR for driver error - */ -RTS_API rtError_t rtMemSetRC(const void *devPtr, uint64_t size, uint32_t readCount); - /** * @ingroup dvrt_mem * @brief Ipc set mem pid diff --git a/third_party/fwkacllib/inc/runtime/rt_model.h b/third_party/fwkacllib/inc/runtime/rt_model.h index 790492fc..5c85a3d7 100644 --- a/third_party/fwkacllib/inc/runtime/rt_model.h +++ b/third_party/fwkacllib/inc/runtime/rt_model.h @@ -65,6 +65,13 @@ typedef enum tagModelQueueFlag { #define EXECUTOR_TS ((uint32_t)0x01) #define EXECUTOR_AICPU ((uint32_t)0x02) +/* + * @ingroup rt_model + * @brief debug flag for kernel exception dump + */ +#define RT_DEBUG_FLAG_AICORE_OVERFLOW (0x1 << 0) +#define RT_DEBUG_FLAG_ATOMIC_ADD_OVERFLOW (0x1 << 1) + /** * @ingroup * @brief the type defination of aicpu model task command @@ -403,6 +410,26 @@ RTS_API rtError_t rtModelBindQueue(rtModel_t model, uint32_t queueId, rtModelQue */ RTS_API rtError_t rtModelGetId(rtModel_t model, uint32_t *modelId); +/* + * @ingroup rt_model + * @brief enable debug for dump overflow exception + * @param [in] addr: ddr address of kernel exception dumpped + * @param [in] model: model handle + * @param [in] flag: debug flag + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input handle + */ +rtError_t rtDebugRegister(rtModel_t model, uint32_t flag, const void *addr, uint32_t *streamId, uint32_t *taskId); + +/* + * @ingroup rt_model + * @brief disable debug for dump overflow exception + * @param [in] model: model handle + * @return RT_ERROR_NONE for ok + * @return RT_ERROR_INVALID_VALUE for error input handle + */ +RTS_API rtError_t rtDebugUnRegister(rtModel_t model); + #ifdef __cplusplus } #endif diff --git a/third_party/fwkacllib/inc/toolchain/slog.h b/third_party/fwkacllib/inc/toolchain/slog.h index f77df225..261fe866 100644 --- a/third_party/fwkacllib/inc/toolchain/slog.h +++ b/third_party/fwkacllib/inc/toolchain/slog.h @@ -91,6 +91,10 @@ extern "C" { * max log length */ #define MSG_LENGTH 1024 +#define DEBUG_LOG_MASK (0x00010000) +#define SECURITY_LOG_MASK (0x00100000) +#define RUN_LOG_MASK (0x01000000) +#define OPERATION_LOG_MASK (0x10000000) typedef struct tagDCODE { const char *cName; @@ -169,83 +173,11 @@ enum { PROCMGR, // Process Manager, Base Platform BBOX, AIVECTOR, + TBE, + FV, INVLID_MOUDLE_ID }; -#ifdef MODULE_ID_NAME - -/** - * @ingroup slog - * - * set module id to map - */ -#define SET_MOUDLE_ID_MAP_NAME(x) \ - { #x, x } - -static DCODE g_moduleIdName[] = {SET_MOUDLE_ID_MAP_NAME(SLOG), - SET_MOUDLE_ID_MAP_NAME(IDEDD), - SET_MOUDLE_ID_MAP_NAME(IDEDH), - SET_MOUDLE_ID_MAP_NAME(HCCL), - SET_MOUDLE_ID_MAP_NAME(FMK), - SET_MOUDLE_ID_MAP_NAME(HIAIENGINE), - SET_MOUDLE_ID_MAP_NAME(DVPP), - SET_MOUDLE_ID_MAP_NAME(RUNTIME), - SET_MOUDLE_ID_MAP_NAME(CCE), -#if (OS_TYPE == LINUX) - SET_MOUDLE_ID_MAP_NAME(HDC), -#else - SET_MOUDLE_ID_MAP_NAME(HDCL), -#endif // OS_TYPE - SET_MOUDLE_ID_MAP_NAME(DRV), - SET_MOUDLE_ID_MAP_NAME(MDCFUSION), - SET_MOUDLE_ID_MAP_NAME(MDCLOCATION), - SET_MOUDLE_ID_MAP_NAME(MDCPERCEPTION), - SET_MOUDLE_ID_MAP_NAME(MDCFSM), - SET_MOUDLE_ID_MAP_NAME(MDCCOMMON), - SET_MOUDLE_ID_MAP_NAME(MDCMONITOR), - SET_MOUDLE_ID_MAP_NAME(MDCBSWP), - SET_MOUDLE_ID_MAP_NAME(MDCDEFAULT), - SET_MOUDLE_ID_MAP_NAME(MDCSC), - SET_MOUDLE_ID_MAP_NAME(MDCPNC), - SET_MOUDLE_ID_MAP_NAME(MLL), - SET_MOUDLE_ID_MAP_NAME(DEVMM), - SET_MOUDLE_ID_MAP_NAME(KERNEL), - SET_MOUDLE_ID_MAP_NAME(LIBMEDIA), - SET_MOUDLE_ID_MAP_NAME(CCECPU), - SET_MOUDLE_ID_MAP_NAME(ASCENDDK), - SET_MOUDLE_ID_MAP_NAME(ROS), - SET_MOUDLE_ID_MAP_NAME(HCCP), - SET_MOUDLE_ID_MAP_NAME(ROCE), - SET_MOUDLE_ID_MAP_NAME(TEFUSION), - SET_MOUDLE_ID_MAP_NAME(PROFILING), - SET_MOUDLE_ID_MAP_NAME(DP), - SET_MOUDLE_ID_MAP_NAME(APP), - SET_MOUDLE_ID_MAP_NAME(TS), - SET_MOUDLE_ID_MAP_NAME(TSDUMP), - SET_MOUDLE_ID_MAP_NAME(AICPU), - SET_MOUDLE_ID_MAP_NAME(LP), - SET_MOUDLE_ID_MAP_NAME(TDT), - SET_MOUDLE_ID_MAP_NAME(FE), - SET_MOUDLE_ID_MAP_NAME(MD), - SET_MOUDLE_ID_MAP_NAME(MB), - SET_MOUDLE_ID_MAP_NAME(ME), - SET_MOUDLE_ID_MAP_NAME(IMU), - SET_MOUDLE_ID_MAP_NAME(IMP), - SET_MOUDLE_ID_MAP_NAME(GE), - SET_MOUDLE_ID_MAP_NAME(MDCFUSA), - SET_MOUDLE_ID_MAP_NAME(CAMERA), - SET_MOUDLE_ID_MAP_NAME(ASCENDCL), - SET_MOUDLE_ID_MAP_NAME(TEEOS), - SET_MOUDLE_ID_MAP_NAME(ISP), - SET_MOUDLE_ID_MAP_NAME(SIS), - SET_MOUDLE_ID_MAP_NAME(HSM), - SET_MOUDLE_ID_MAP_NAME(DSS), - SET_MOUDLE_ID_MAP_NAME(PROCMGR), - SET_MOUDLE_ID_MAP_NAME(BBOX), - SET_MOUDLE_ID_MAP_NAME(AIVECTOR), - { NULL, -1 }}; -#endif // MODULE_ID_NAME - #if (OS_TYPE == LINUX) /** * @ingroup slog @@ -386,6 +318,11 @@ extern int CheckLogLevel(int moduleId, int logLevel); DlogWithKVInner(moduleId, level, pstKVArray, kvNum, "[%s:%d]" fmt, __FILE__, __LINE__, ##__VA_ARGS__); \ } while (0) +/** + * @ingroup slog + * @brief DlogFlush: flush log buffer to file + */ +void DlogFlush(void); /** * @ingroup slog diff --git a/third_party/prebuild/aarch64/liberror_manager.so b/third_party/prebuild/aarch64/liberror_manager.so new file mode 100755 index 00000000..759d8e30 Binary files /dev/null and b/third_party/prebuild/aarch64/liberror_manager.so differ diff --git a/third_party/prebuild/aarch64/libslog.so b/third_party/prebuild/aarch64/libslog.so new file mode 100755 index 00000000..700fc118 Binary files /dev/null and b/third_party/prebuild/aarch64/libslog.so differ diff --git a/third_party/prebuild/x86_64/liberror_manager.so b/third_party/prebuild/x86_64/liberror_manager.so new file mode 100755 index 00000000..cd9ad8bc Binary files /dev/null and b/third_party/prebuild/x86_64/liberror_manager.so differ diff --git a/third_party/prebuild/x86_64/libslog.so b/third_party/prebuild/x86_64/libslog.so index b476618d..01b75e40 100755 Binary files a/third_party/prebuild/x86_64/libslog.so and b/third_party/prebuild/x86_64/libslog.so differ