| @@ -62,6 +62,9 @@ class GraphOptimizer { | |||||
| // optimize streamed Graph | // optimize streamed Graph | ||||
| virtual Status OptimizeStreamGraph(ComputeGraph &graph, const RunContext &context) { return SUCCESS; } | virtual Status OptimizeStreamGraph(ComputeGraph &graph, const RunContext &context) { return SUCCESS; } | ||||
| // op compile | |||||
| virtual Status OptimizeFusedGraphAfterGraphSlice(ComputeGraph &graph) { return SUCCESS; } | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| /*lint +e148*/ | /*lint +e148*/ | ||||
| @@ -35,5 +35,7 @@ static const std::string ATTR_NAME_L2_FUSION_EXTEND_PTR = "l2_fusion_extend_cont | |||||
| static const std::string L1_OPTIMIZED = "l1_optimized"; | static const std::string L1_OPTIMIZED = "l1_optimized"; | ||||
| static const std::string L2_OPTIMIZED = "l2_optimized"; | static const std::string L2_OPTIMIZED = "l2_optimized"; | ||||
| static const std::string OP_SLICE_INFO = "_op_slice_info"; | |||||
| } // namespace fe | } // namespace fe | ||||
| #endif | #endif | ||||
| @@ -34,6 +34,7 @@ class ScopeAllocator { | |||||
| bool HasScopeAttr(ge::ConstOpDescPtr opdef); | bool HasScopeAttr(ge::ConstOpDescPtr opdef); | ||||
| bool GetScopeAttr(ge::ConstOpDescPtr opdef, int64_t& scopeId); | bool GetScopeAttr(ge::ConstOpDescPtr opdef, int64_t& scopeId); | ||||
| bool SetScopeAttr(ge::OpDescPtr opdef, int64_t scopeId); | bool SetScopeAttr(ge::OpDescPtr opdef, int64_t scopeId); | ||||
| bool ResetScopeId(int64_t scopeId); | |||||
| private: | private: | ||||
| int64_t scopeId; | int64_t scopeId; | ||||
| @@ -40,6 +40,7 @@ const char *const OPTION_EXEC_DEPLOY_MODE = "ge.exec.deployMode"; | |||||
| const char *const OPTION_EXEC_RANK_TABLE_FILE = "ge.exec.rankTableFile"; | const char *const OPTION_EXEC_RANK_TABLE_FILE = "ge.exec.rankTableFile"; | ||||
| const char *const GE_AICPU_FLAG = "ge.aicpuFlag"; | const char *const GE_AICPU_FLAG = "ge.aicpuFlag"; | ||||
| const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; | const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; | ||||
| // Dump flag and para | |||||
| const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | ||||
| const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | ||||
| const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | ||||
| @@ -48,7 +49,10 @@ const char *const OPTION_EXEC_ENABLE_DUMP_DEBUG = "ge.exec.enableDumpDebug"; | |||||
| const char *const OPTION_EXEC_DUMP_DEBUG_MODE = "ge.exec.dumpDebugMode"; | const char *const OPTION_EXEC_DUMP_DEBUG_MODE = "ge.exec.dumpDebugMode"; | ||||
| const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; | const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; | ||||
| const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; | const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; | ||||
| const char *const OPTION_EXEC_ENABLE_EXCEPTION_DUMP = "ge.exec.enable_exception_dump"; | |||||
| const char *const OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES = "ge.exec.enableScopeFusionPasses"; | const char *const OPTION_EXEC_ENABLE_SCOPE_FUSION_PASSES = "ge.exec.enableScopeFusionPasses"; | ||||
| const char *const OPTION_EXEC_PROFILING_FPPONIT_OPTIONS = "ge.exec.profilingFpPointOptions"; | |||||
| const char *const OPTION_EXEC_PROFILING_BPPONIT_OPTIONS = "ge.exec.profilingBpPointOptions"; | |||||
| // profiling flag | // profiling flag | ||||
| const char *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode"; | const char *const OPTION_EXEC_PROFILING_MODE = "ge.exec.profilingMode"; | ||||
| const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; | const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; | ||||
| @@ -223,6 +223,7 @@ class OpReg { | |||||
| \ | \ | ||||
| private: \ | private: \ | ||||
| void __dy_input_##x() { \ | void __dy_input_##x() { \ | ||||
| Operator::DynamicInputRegister(#x, 0, true); \ | |||||
| (void)OpReg() | (void)OpReg() | ||||
| #define DYNAMIC_OUTPUT(x, t) \ | #define DYNAMIC_OUTPUT(x, t) \ | ||||
| @@ -242,6 +243,7 @@ class OpReg { | |||||
| \ | \ | ||||
| private: \ | private: \ | ||||
| void __dy_output_##x() { \ | void __dy_output_##x() { \ | ||||
| Operator::DynamicOutputRegister(#x, 0, true); \ | |||||
| (void)OpReg() | (void)OpReg() | ||||
| #define GRAPH(x) \ | #define GRAPH(x) \ | ||||
| @@ -55,6 +55,28 @@ class Message; | |||||
| } // namespace google | } // namespace google | ||||
| namespace domi { | namespace domi { | ||||
| const int64_t kMaxNameLength = 1048576; // 1M | |||||
| enum DynamicType { kInvalid = 0, kInput = 1, kOutput = 2 }; | |||||
| struct DynamicInputOutputInfo { | |||||
| DynamicType type; // input/output | |||||
| const char *port_name; | |||||
| int64_t port_name_len; | |||||
| const char *attr_name; | |||||
| int64_t attr_name_len; | |||||
| DynamicInputOutputInfo() | |||||
| : type(kInvalid), port_name(nullptr), port_name_len(0), attr_name(nullptr), attr_name_len(0) {} | |||||
| DynamicInputOutputInfo(DynamicType type, const char *port_name, int64_t port_name_len, const char *attr_name, | |||||
| int64_t attr_name_len) | |||||
| : type(type), | |||||
| port_name(port_name), | |||||
| port_name_len(port_name_len), | |||||
| attr_name(attr_name), | |||||
| attr_name_len(attr_name_len) {} | |||||
| }; | |||||
| Status AutoMappingByOpFn(const ge::Operator &op_src, ge::Operator &op); | |||||
| Status AutoMappingByOpFnDynamic(const ge::Operator &op_src, ge::Operator &op, | |||||
| const vector<DynamicInputOutputInfo> &dynamic_name_attr_value); | |||||
| Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); | Status AutoMappingFn(const google::protobuf::Message *op_src, ge::Operator &op); | ||||
| Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | Status AutoMappingFnDynamic(const google::protobuf::Message *op_src, ge::Operator &op, | ||||
| std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value, | std::map<std::string, std::pair<std::string, std::string>> dynamic_name_attr_value, | ||||
| @@ -71,6 +93,7 @@ using ParseParamFunc = std::function<domi::Status(const google::protobuf::Messag | |||||
| using ParseParamByOpFunc = std::function<domi::Status(const ge::Operator &, ge::Operator &)>; | using ParseParamByOpFunc = std::function<domi::Status(const ge::Operator &, ge::Operator &)>; | ||||
| using FusionParseParamFunc = | using FusionParseParamFunc = | ||||
| std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>; | std::function<domi::Status(const std::vector<const google::protobuf::Message *>, ge::Operator &)>; | ||||
| using FusionParseParamByOpFunc = std::function<domi::Status(const std::vector<ge::Operator> &, ge::Operator &)>; | |||||
| using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>; | using ParseSubgraphFunc = std::function<Status(const std::string &subgraph_name, const ge::Graph &graph)>; | ||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | ||||
| @@ -91,6 +114,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||||
| OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); | OpRegistrationData &FusionParseParamsFn(const FusionParseParamFunc &fusionParseParamFn); | ||||
| OpRegistrationData &FusionParseParamsFn(const FusionParseParamByOpFunc &fusion_parse_param_fn); | |||||
| OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn); | OpRegistrationData &ParseSubgraphPostFn(const ParseSubgraphFunc &subgraph_post_fn); | ||||
| OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | OpRegistrationData &ImplyType(const domi::ImplyType &imply_type); | ||||
| @@ -108,6 +133,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpRegistrationData { | |||||
| ParseParamFunc GetParseParamFn() const; | ParseParamFunc GetParseParamFn() const; | ||||
| ParseParamByOpFunc GetParseParamByOperatorFn() const; | ParseParamByOpFunc GetParseParamByOperatorFn() const; | ||||
| FusionParseParamFunc GetFusionParseParamFn() const; | FusionParseParamFunc GetFusionParseParamFn() const; | ||||
| FusionParseParamByOpFunc GetFusionParseParamByOpFn() const; | |||||
| ParseSubgraphFunc GetParseSubgraphPostFn() const; | ParseSubgraphFunc GetParseSubgraphPostFn() const; | ||||
| private: | private: | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <map> | #include <map> | ||||
| #include <unordered_map> | |||||
| #include "ge/ge_api_error_codes.h" | #include "ge/ge_api_error_codes.h" | ||||
| #include "register/register_error_codes.h" | #include "register/register_error_codes.h" | ||||
| #include "register/register_types.h" | #include "register/register_types.h" | ||||
| @@ -52,15 +53,16 @@ class ScopePassManager; | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope { | class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope { | ||||
| public: | public: | ||||
| explicit Scope(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr); | |||||
| Scope(); | |||||
| Status Init(const std::string &name, const std::string &sub_type = "", Scope *father_scope = nullptr); | |||||
| ~Scope(); | ~Scope(); | ||||
| std::string Name() const; | |||||
| std::string SubType() const; | |||||
| std::map<std::string, ge::OperatorPtr> AllNodesMap() const; | |||||
| const std::string &Name() const; | |||||
| const std::string &SubType() const; | |||||
| const std::unordered_map<std::string, ge::OperatorPtr> &AllNodesMap() const; | |||||
| Scope *GetSubScope(const std::string &scope_name) const; | Scope *GetSubScope(const std::string &scope_name) const; | ||||
| std::string LastName() const; | |||||
| std::vector<Scope *> GetAllSubScopes() const; | |||||
| const std::string LastName() const; | |||||
| const std::vector<Scope *> &GetAllSubScopes() const; | |||||
| const Scope *GetFatherScope() const; | const Scope *GetFatherScope() const; | ||||
| private: | private: | ||||
| @@ -76,12 +78,13 @@ class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY Scope { | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY FusionScopesResult { | class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY FusionScopesResult { | ||||
| public: | public: | ||||
| FusionScopesResult(); | FusionScopesResult(); | ||||
| Status Init(); | |||||
| ~FusionScopesResult(); | ~FusionScopesResult(); | ||||
| void SetName(const std::string &name); | void SetName(const std::string &name); | ||||
| void SetType(const std::string &type); | void SetType(const std::string &type); | ||||
| void SetDescription(const std::string &description); | void SetDescription(const std::string &description); | ||||
| std::string Name() const; | |||||
| std::vector<ge::OperatorPtr> Nodes() const; | |||||
| const std::string &Name() const; | |||||
| const std::vector<ge::OperatorPtr> &Nodes() const; | |||||
| void InsertInputs(const std::string &inner_op_name, const std::vector<int32_t> &index_map); | void InsertInputs(const std::string &inner_op_name, const std::vector<int32_t> &index_map); | ||||
| void InsertOutputs(const std::string &inner_op_name, const std::vector<int32_t> &index_map); | void InsertOutputs(const std::string &inner_op_name, const std::vector<int32_t> &index_map); | ||||
| @@ -136,7 +139,7 @@ class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeTree { | |||||
| ScopeTree &operator=(const ScopeTree &scopetree) = delete; | ScopeTree &operator=(const ScopeTree &scopetree) = delete; | ||||
| ~ScopeTree(); | ~ScopeTree(); | ||||
| std::vector<Scope *> GetAllScopes() const; | |||||
| const std::vector<Scope *> &GetAllScopes() const; | |||||
| private: | private: | ||||
| class ScopeTreeImpl; | class ScopeTreeImpl; | ||||
| @@ -154,7 +157,7 @@ class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeGraph { | |||||
| ~ScopeGraph(); | ~ScopeGraph(); | ||||
| const ScopeTree *GetScopeTree() const; | const ScopeTree *GetScopeTree() const; | ||||
| std::map<std::string, ge::OperatorPtr> GetNodesMap() const; | |||||
| const std::unordered_map<std::string, ge::OperatorPtr> &GetNodesMap() const; | |||||
| private: | private: | ||||
| class ScopeGraphImpl; | class ScopeGraphImpl; | ||||
| @@ -203,7 +206,7 @@ class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeOpTypeFeature : ScopeBa | |||||
| class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeAttrFeature : ScopeBaseFeature { | class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY NodeAttrFeature : ScopeBaseFeature { | ||||
| public: | public: | ||||
| NodeAttrFeature(std::string nodeType, std::string attr_name, ge::DataType datatype, ScopeAttrValue attr_value); | |||||
| NodeAttrFeature(std::string nodeType, std::string attr_name, ge::DataType datatype, ScopeAttrValue &attr_value); | |||||
| NodeAttrFeature(NodeAttrFeature const &feature); | NodeAttrFeature(NodeAttrFeature const &feature); | ||||
| NodeAttrFeature &operator=(NodeAttrFeature const &feature); | NodeAttrFeature &operator=(NodeAttrFeature const &feature); | ||||
| ~NodeAttrFeature(); | ~NodeAttrFeature(); | ||||
| @@ -258,16 +258,19 @@ struct ComputeGraphDescInfo { | |||||
| struct OpDescInfo { | struct OpDescInfo { | ||||
| std::string op_name; | std::string op_name; | ||||
| std::string op_type; | |||||
| uint32_t task_id; | uint32_t task_id; | ||||
| uint32_t stream_id; | uint32_t stream_id; | ||||
| std::vector<Format> input_format; | std::vector<Format> input_format; | ||||
| std::vector<std::vector<int64_t>> input_shape; | std::vector<std::vector<int64_t>> input_shape; | ||||
| std::vector<DataType> input_data_type; | std::vector<DataType> input_data_type; | ||||
| std::vector<void *> input_addrs; | std::vector<void *> input_addrs; | ||||
| std::vector<int64_t> input_size; | |||||
| std::vector<Format> output_format; | std::vector<Format> output_format; | ||||
| std::vector<std::vector<int64_t>> output_shape; | std::vector<std::vector<int64_t>> output_shape; | ||||
| std::vector<DataType> output_data_type; | std::vector<DataType> output_data_type; | ||||
| std::vector<void *> output_addrs; | std::vector<void *> output_addrs; | ||||
| std::vector<int64_t> output_size; | |||||
| }; | }; | ||||
| struct ModelDumpConfig { | struct ModelDumpConfig { | ||||
| std::string model_name; | std::string model_name; | ||||
| @@ -64,6 +64,7 @@ class ModelHelper { | |||||
| Status LoadWeights(OmFileLoadHelper& om_load_helper); | Status LoadWeights(OmFileLoadHelper& om_load_helper); | ||||
| Status LoadTask(OmFileLoadHelper& om_load_helper); | Status LoadTask(OmFileLoadHelper& om_load_helper); | ||||
| Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); | Status LoadTBEKernelStore(OmFileLoadHelper& om_load_helper); | ||||
| Status LoadCustAICPUKernelStore(OmFileLoadHelper& om_load_helper); | |||||
| Status ReleaseLocalModelData() noexcept; | Status ReleaseLocalModelData() noexcept; | ||||
| Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, ModelPartitionType type, | Status SaveModelPartition(std::shared_ptr<OmFileSaveHelper>& om_file_save_helper, ModelPartitionType type, | ||||
| const uint8_t* data, size_t size); | const uint8_t* data, size_t size); | ||||
| @@ -851,9 +851,9 @@ static constexpr int32_t PARTITION_TYPE_WEIGHTS = 1; | |||||
| static constexpr int32_t PARTITION_TYPE_TASK_INFO = 2; | static constexpr int32_t PARTITION_TYPE_TASK_INFO = 2; | ||||
| // number of partitions in the current model | // number of partitions in the current model | ||||
| static constexpr uint32_t PARTITION_SIZE = 4; | |||||
| static constexpr uint32_t PARTITION_SIZE = 5; | |||||
| enum ModelPartitionType { MODEL_DEF = 0, WEIGHTS_DATA, TASK_INFO, TBE_KERNELS }; | |||||
| enum ModelPartitionType { MODEL_DEF = 0, WEIGHTS_DATA, TASK_INFO, TBE_KERNELS, CUST_AICPU_KERNELS }; | |||||
| struct ModelPartitionMemInfo { | struct ModelPartitionMemInfo { | ||||
| ModelPartitionType type; | ModelPartitionType type; | ||||
| @@ -108,11 +108,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeExecutor { | |||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief Get current dynamic dims info by combined dims | /// @brief Get current dynamic dims info by combined dims | ||||
| /// @param [in] model_id: model id allocate from manager | /// @param [in] model_id: model id allocate from manager | ||||
| /// @param [in] combined_dims: array of combined dimensions | |||||
| /// @param [in] dynamic_dims: cur gear dynamic dims value | |||||
| /// @param [out] cur_dynamic_dims: current dynamic dims | /// @param [out] cur_dynamic_dims: current dynamic dims | ||||
| /// @return execute result | /// @return execute result | ||||
| /// | /// | ||||
| ge::Status GetCurDynamicDims(uint32_t model_id, const std::vector<uint64_t> &combined_dims, | |||||
| ge::Status GetCurDynamicDims(uint32_t model_id, const std::vector<uint64_t> &dynamic_dims, | |||||
| std::vector<uint64_t> &cur_dynamic_dims); | std::vector<uint64_t> &cur_dynamic_dims); | ||||
| /// | /// | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include "graph/graph.h" | #include "graph/graph.h" | ||||
| #include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
| #include "graph/detail/attributes_holder.h" | #include "graph/detail/attributes_holder.h" | ||||
| #include "omg/omg_inner_types.h" | |||||
| namespace ge { | namespace ge { | ||||
| class GeGenerator { | class GeGenerator { | ||||
| @@ -45,6 +46,7 @@ class GeGenerator { | |||||
| GeGenerator &operator=(const GeGenerator &) = delete; | GeGenerator &operator=(const GeGenerator &) = delete; | ||||
| Status Initialize(const std::map<std::string, std::string> &options); | Status Initialize(const std::map<std::string, std::string> &options); | ||||
| Status Initialize(const std::map<std::string, std::string> &options, OmgContext &context); | |||||
| Status Finalize(); | Status Finalize(); | ||||
| @@ -98,24 +98,14 @@ struct OmgContext { | |||||
| std::vector<std::string> out_top_names; | std::vector<std::string> out_top_names; | ||||
| // path for the aicpu custom operator so_file | // path for the aicpu custom operator so_file | ||||
| std::vector<std::string> aicpu_op_run_paths; | std::vector<std::string> aicpu_op_run_paths; | ||||
| // ddk version | |||||
| std::string ddk_version; | |||||
| // preferential format used by the entire network | // preferential format used by the entire network | ||||
| domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | ||||
| domi::FrameworkType type = domi::FRAMEWORK_RESERVED; | domi::FrameworkType type = domi::FRAMEWORK_RESERVED; | ||||
| RunMode run_mode = ONLY_PRE_CHECK; | RunMode run_mode = ONLY_PRE_CHECK; | ||||
| bool train_flag = false; | bool train_flag = false; | ||||
| // whether to use FP16 high precision | |||||
| int32_t fp16_high_precision = HIGH_PRECISION_DEFAULT; | |||||
| std::string output_type; | std::string output_type; | ||||
| // Save the name of the entire network: Some special operators are used to determine a network. Some operators in the | |||||
| // network require special processing based on the specific network. e.g:faster-rcnn, the FirstStageProcessor module | |||||
| // is determined as the Faster-R-CNN network based on the scope fusion. Then, the conv+reshape operators in the | |||||
| // FirstStageBoxPredictor/BoxEncodingPredictor scope are combined. The convolution kernel rearrangement reshape | |||||
| // operator needs to be deleted for the convolution kernel. | |||||
| std::string net_name; | |||||
| // Whether to use dynamic batch size or dynamic image size | // Whether to use dynamic batch size or dynamic image size | ||||
| bool is_dynamic_input = false; | bool is_dynamic_input = false; | ||||
| std::string dynamic_batch_size; | std::string dynamic_batch_size; | ||||
| @@ -93,6 +93,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
| NodePtr AddNodeFront(const OpDescPtr &op); | NodePtr AddNodeFront(const OpDescPtr &op); | ||||
| NodePtr AddInputNode(NodePtr node); | NodePtr AddInputNode(NodePtr node); | ||||
| NodePtr AddOutputNode(NodePtr node); | NodePtr AddOutputNode(NodePtr node); | ||||
| NodePtr AddOutputNodeByIndex(NodePtr node, int32_t index); | |||||
| // insert node with specific pre_node | // insert node with specific pre_node | ||||
| NodePtr AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node); | NodePtr AddNodeAfter(OpDescPtr &op, const NodePtr &pre_node); | ||||
| NodePtr AddNodeAfter(NodePtr node, const NodePtr &pre_node); | NodePtr AddNodeAfter(NodePtr node, const NodePtr &pre_node); | ||||
| @@ -138,6 +139,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
| graphStatus TopologicalSorting(); | graphStatus TopologicalSorting(); | ||||
| bool IsValid() const; | bool IsValid() const; | ||||
| void InValid() { is_valid_flag_ = false; } | |||||
| void Dump() const; | void Dump() const; | ||||
| void Swap(ComputeGraph &graph); | void Swap(ComputeGraph &graph); | ||||
| @@ -268,6 +270,7 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
| friend class ModelSerializeImp; | friend class ModelSerializeImp; | ||||
| friend class GraphDebugImp; | friend class GraphDebugImp; | ||||
| friend class OnnxUtils; | friend class OnnxUtils; | ||||
| friend class TuningUtils; | |||||
| std::string name_; | std::string name_; | ||||
| uint32_t graph_id_ = 0; | uint32_t graph_id_ = 0; | ||||
| @@ -1031,6 +1031,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_FLAG; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_FLAG; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_ADDR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_ADDR; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ENGINE_NAME_FOR_LX; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_NEED_LX_FUSION; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OPTIMIZE_GROUP; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OP_COMPILE_STRATEGY; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_NAME; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_BUFFER; | |||||
| // for unregistered op | // for unregistered op | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_OPPATH; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_OPPATH; | ||||
| @@ -174,6 +174,9 @@ class Node : public std::enable_shared_from_this<Node> { | |||||
| fusion_output_dataflow_list_ = fusion_output_list; | fusion_output_dataflow_list_ = fusion_output_list; | ||||
| } | } | ||||
| bool GetHostNode() const { return host_node_; } | |||||
| void SetHostNode(bool is_host) { host_node_ = is_host; } | |||||
| void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } | void SetOrigNode(const NodePtr &orignode) { orig_node_ = orignode; } | ||||
| NodePtr GetOrigNode() { return orig_node_; } | NodePtr GetOrigNode() { return orig_node_; } | ||||
| @@ -192,6 +195,7 @@ class Node : public std::enable_shared_from_this<Node> { | |||||
| OutControlAnchorPtr out_control_anchor_; | OutControlAnchorPtr out_control_anchor_; | ||||
| map<string, GeAttrValue> attrs_; // lint !e1073 | map<string, GeAttrValue> attrs_; // lint !e1073 | ||||
| bool has_init_{false}; | bool has_init_{false}; | ||||
| bool host_node_{false}; | |||||
| bool anchor_status_updated_{false}; | bool anchor_status_updated_{false}; | ||||
| std::vector<uint32_t> send_event_id_list_; | std::vector<uint32_t> send_event_id_list_; | ||||
| std::vector<uint32_t> recv_event_id_list_; | std::vector<uint32_t> recv_event_id_list_; | ||||
| @@ -202,6 +206,7 @@ class Node : public std::enable_shared_from_this<Node> { | |||||
| NodePtr orig_node_; | NodePtr orig_node_; | ||||
| friend class NodeUtils; | friend class NodeUtils; | ||||
| friend class OnnxUtils; | friend class OnnxUtils; | ||||
| friend class TuningUtils; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -18,6 +18,7 @@ | |||||
| #define INC_GRAPH_OP_DESC_H_ | #define INC_GRAPH_OP_DESC_H_ | ||||
| #include <functional> | #include <functional> | ||||
| #include <algorithm> | |||||
| #include <map> | #include <map> | ||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| @@ -87,6 +88,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
| graphStatus AddInputDescMiddle(const string &name, const unsigned int num, size_t index); | graphStatus AddInputDescMiddle(const string &name, const unsigned int num, size_t index); | ||||
| graphStatus AddOutputDescMiddle(const string &name, const unsigned int num, size_t index); | |||||
| graphStatus AddOutputDescForward(const string &name, const unsigned int num); | graphStatus AddOutputDescForward(const string &name, const unsigned int num); | ||||
| graphStatus AddOptionalInputDesc(const string &name, const GeTensorDesc &input_desc); | graphStatus AddOptionalInputDesc(const string &name, const GeTensorDesc &input_desc); | ||||
| @@ -187,6 +190,14 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
| graphStatus CommonVerify() const; | graphStatus CommonVerify() const; | ||||
| graphStatus AddRegisterInputName(const string &name); | |||||
| graphStatus AddRegisterOutputName(const string &name); | |||||
| vector<string> GetRegisterInputName() const; | |||||
| vector<string> GetRegisterOutputName() const; | |||||
| using AttrHolder::AddRequiredAttr; | using AttrHolder::AddRequiredAttr; | ||||
| using AttrHolder::DelAttr; | using AttrHolder::DelAttr; | ||||
| using AttrHolder::GetAllAttrNames; | using AttrHolder::GetAllAttrNames; | ||||
| @@ -297,9 +308,11 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
| vector<GeTensorDescPtr> inputs_desc_{}; | vector<GeTensorDescPtr> inputs_desc_{}; | ||||
| map<string, uint32_t> input_name_idx_{}; | map<string, uint32_t> input_name_idx_{}; | ||||
| vector<string> register_input_name_{}; | |||||
| std::unordered_set<string> optional_input_names_{}; | std::unordered_set<string> optional_input_names_{}; | ||||
| vector<GeTensorDescPtr> outputs_desc_{}; | vector<GeTensorDescPtr> outputs_desc_{}; | ||||
| map<string, uint32_t> output_name_idx_{}; | map<string, uint32_t> output_name_idx_{}; | ||||
| vector<string> register_output_name_{}; | |||||
| std::function<graphStatus(Operator &)> infer_func_ = nullptr; | std::function<graphStatus(Operator &)> infer_func_ = nullptr; | ||||
| std::function<graphStatus(Operator &)> infer_format_func_ = nullptr; | std::function<graphStatus(Operator &)> infer_format_func_ = nullptr; | ||||
| std::function<graphStatus(Operator &)> verifier_func_ = nullptr; | std::function<graphStatus(Operator &)> verifier_func_ = nullptr; | ||||
| @@ -42,6 +42,7 @@ class OpKernelBin { | |||||
| using OpKernelBinPtr = std::shared_ptr<OpKernelBin>; | using OpKernelBinPtr = std::shared_ptr<OpKernelBin>; | ||||
| const char *const OP_EXTATTR_NAME_TBE_KERNEL = "tbeKernel"; | const char *const OP_EXTATTR_NAME_TBE_KERNEL = "tbeKernel"; | ||||
| const char *const OP_EXTATTR_CUSTAICPU_KERNEL = "cust_aicpu_kernel"; | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_GRAPH_OP_KERNEL_BIN_H_ | #endif // INC_GRAPH_OP_KERNEL_BIN_H_ | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <mutex> | |||||
| namespace ge { | namespace ge { | ||||
| class OpsProtoManager { | class OpsProtoManager { | ||||
| @@ -30,14 +31,15 @@ class OpsProtoManager { | |||||
| static OpsProtoManager *Instance(); | static OpsProtoManager *Instance(); | ||||
| bool Initialize(const std::map<std::string, std::string> &options); | bool Initialize(const std::map<std::string, std::string> &options); | ||||
| void Finalize(); | void Finalize(); | ||||
| private: | |||||
| void LoadOpsProtoPluginSo(std::string &path); | void LoadOpsProtoPluginSo(std::string &path); | ||||
| private: | |||||
| std::string pluginPath_; | std::string pluginPath_; | ||||
| std::vector<void *> handles_; | std::vector<void *> handles_; | ||||
| bool is_init_ = false; | |||||
| std::mutex mutex_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -0,0 +1,130 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MAIN_TUNING_UTILS_H | |||||
| #define MAIN_TUNING_UTILS_H | |||||
| #include <fcntl.h> | |||||
| #include <sys/stat.h> | |||||
| #include <sys/types.h> | |||||
| #include <unistd.h> | |||||
| #include <algorithm> | |||||
| #include <cstring> | |||||
| #include <fstream> | |||||
| #include <iomanip> | |||||
| #include <queue> | |||||
| #include <mutex> | |||||
| #include <graph/anchor.h> | |||||
| #include <graph/detail/attributes_holder.h> | |||||
| #include <graph/ge_tensor.h> | |||||
| #include <graph/graph.h> | |||||
| #include <graph/model.h> | |||||
| #include <graph/node.h> | |||||
| #include <graph/utils/graph_utils.h> | |||||
| #include <graph/utils/type_utils.h> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "utils/attr_utils.h" | |||||
| #include "utils/node_utils.h" | |||||
| #include "external/ge/ge_api_types.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| namespace ge { | |||||
| // Configure build mode, default value is "normal" | |||||
| const char *const BUILD_MODE = "ge.buildMode"; | |||||
| const char *const BUILD_STEP = "ge.buildStep"; | |||||
| // Configure tuning path | |||||
| const char *const TUNING_PATH = "ge.tuningPath"; | |||||
| // for interface: aclgrphBuildModel | |||||
| const std::set<std::string> ir_builder_supported_options_for_lx_fusion = {BUILD_MODE, BUILD_STEP, TUNING_PATH}; | |||||
| // Build model | |||||
| const char *const BUILD_MODE_NORMAL = "normal"; | |||||
| const char *const BUILD_MODE_TUNING = "tuning"; | |||||
| const char *const BUILD_MODE_BASELINE = "baseline"; | |||||
| const std::set<std::string> build_mode_options = {BUILD_MODE_NORMAL, BUILD_MODE_TUNING, BUILD_MODE_BASELINE}; | |||||
| // Build step | |||||
| const char *const BUILD_STEP_BEFORE_UB_MATCH = "before_ub_match"; | |||||
| const char *const BUILD_STEP_AFTER_UB_MATCH = "after_ub_match"; | |||||
| const char *const BUILD_STEP_AFTER_BUILDER = "after_builder"; | |||||
| const char *const BUILD_STEP_AFTER_BUILDER_SUB = "after_builder_sub"; | |||||
| const char *const BUILD_STEP_AFTER_MERGE = "after_merge"; | |||||
| const std::set<std::string> build_step_options = {BUILD_STEP_BEFORE_UB_MATCH, BUILD_STEP_AFTER_UB_MATCH, | |||||
| BUILD_STEP_AFTER_BUILDER, BUILD_STEP_AFTER_BUILDER_SUB, | |||||
| BUILD_STEP_AFTER_MERGE}; | |||||
| using SubgraphCreateOutNode = std::unordered_map<ComputeGraphPtr, NodePtr>; | |||||
| using NodetoNodeMap = std::unordered_map<NodePtr, NodePtr>; | |||||
| using NodeSet = std::set<NodePtr>; | |||||
| using NodeNametoNodeNameMap = std::unordered_map<std::string, std::string>; | |||||
| using NodetoNodeNameMap = std::unordered_map<NodePtr, std::string>; | |||||
| class TuningUtils { | |||||
| public: | |||||
| TuningUtils() = default; | |||||
| ~TuningUtils() = default; | |||||
| // Dump all the subgraphs and modify | |||||
| // the subgraphs in them to be executable subgraphs if exe_flag is true | |||||
| // `tuning_path` means path to save the graphs | |||||
| static graphStatus ConvertGraphToFile(std::vector<ComputeGraphPtr> tuning_subgraphs, | |||||
| std::vector<ComputeGraphPtr> non_tuning_subgraphs = {}, bool exe_flag = false, | |||||
| const std::string &path = "", const std::string &user_path = ""); | |||||
| // Recovery `graph` from graph dump files configured in options | |||||
| static graphStatus ConvertFileToGraph(const map<int64_t, string> &options, ge::Graph &graph); | |||||
| private: | |||||
| // part 1 | |||||
| struct HelpInfo { | |||||
| int64_t index; | |||||
| bool exe_flag; | |||||
| bool is_tuning_graph; | |||||
| const std::string &path; | |||||
| const std::string &user_path; | |||||
| }; | |||||
| static graphStatus MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info); | |||||
| static graphStatus HandlePld(NodePtr &node); | |||||
| static graphStatus HandleEnd(NodePtr &node); | |||||
| static graphStatus ChangePld2Data(NodePtr &node, NodePtr &data_node); | |||||
| static graphStatus ChangeEnd2NetOutput(NodePtr &node, NodePtr &out_node); | |||||
| static graphStatus LinkEnd2NetOutput(NodePtr &node, NodePtr &out_node); | |||||
| static graphStatus CreateDataNode(NodePtr &node, NodePtr &data_node); | |||||
| static graphStatus CreateNetOutput(NodePtr &node, NodePtr &out_node); | |||||
| static graphStatus AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node); | |||||
| static graphStatus AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node); | |||||
| static void DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path); | |||||
| static SubgraphCreateOutNode create_output_; | |||||
| // part 2 | |||||
| static graphStatus MergeAllSubGraph(std::vector<ComputeGraphPtr> &graphs, ComputeGraphPtr &graph); | |||||
| static graphStatus MergeSubGraph(ComputeGraphPtr &graph); | |||||
| // Deletes new data and output nodes added by call `MakeExeGraph()` func in part 1 | |||||
| static graphStatus RemoveDataNetoutputEdge(ComputeGraphPtr &graph); | |||||
| static graphStatus GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor, | |||||
| AnchorPtr &src_out_anchor); | |||||
| static NodeNametoNodeNameMap data_2_netoutput_; | |||||
| static NodetoNodeNameMap data_node_2_netoutput_; | |||||
| static NodetoNodeMap data_node_2_netoutput_node_; | |||||
| static NodeSet netoutput_nodes_; | |||||
| static NodeSet merged_graph_nodes_; | |||||
| static std::mutex mutex_; | |||||
| // for debug | |||||
| static std::string PrintCheckLog(); | |||||
| static std::string GetNodeNameByAnchor(const Anchor *anchor); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // MAIN_TUNING_UTILS_H | |||||
| @@ -36,8 +36,8 @@ | |||||
| do { \ | do { \ | ||||
| GraphUtils::DumpGEGraph(compute_graph, name); \ | GraphUtils::DumpGEGraph(compute_graph, name); \ | ||||
| GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); \ | GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); \ | ||||
| uint64_t i = 0; \ | |||||
| for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { \ | for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { \ | ||||
| static int8_t i = 0; \ | |||||
| auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); \ | auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); \ | ||||
| GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); \ | GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); \ | ||||
| GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); \ | GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); \ | ||||
| @@ -203,10 +203,13 @@ class GraphUtils { | |||||
| static bool MatchDumpStr(const std::string &suffix); | static bool MatchDumpStr(const std::string &suffix); | ||||
| static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false); | |||||
| static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false, | |||||
| const std::string &user_graph_name = ""); | |||||
| static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); | static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); | ||||
| static bool LoadGEGraph(const char *file, ge::ComputeGraphPtr &compute_graph); | |||||
| static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos); | static void BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_nodes_infos); | ||||
| static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); | static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); | ||||
| @@ -24,6 +24,7 @@ file(GLOB_RECURSE PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "../../proto/task.proto" | "../../proto/task.proto" | ||||
| "../../proto/fwk_adaper.proto" | "../../proto/fwk_adaper.proto" | ||||
| "../../proto/op_mapping_info.proto" | "../../proto/op_mapping_info.proto" | ||||
| "../proto/dump_task.proto" | |||||
| ) | ) | ||||
| file(GLOB_RECURSE ONNX_PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | file(GLOB_RECURSE ONNX_PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | ||||
| @@ -36,6 +36,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const size_t OUTPUT_PARAM_SIZE = 2; | const size_t OUTPUT_PARAM_SIZE = 2; | ||||
| const std::string alias_name_attr = "_aliasName"; | |||||
| bool IsUseBFS() { | bool IsUseBFS() { | ||||
| string run_mode; | string run_mode; | ||||
| const int base = 10; | const int base = 10; | ||||
| @@ -133,6 +134,14 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::FindNode(co | |||||
| if (node->GetName() == name) { | if (node->GetName() == name) { | ||||
| return node; | return node; | ||||
| } | } | ||||
| std::vector<string> out_alias_name; | |||||
| if (AttrUtils::GetListStr(node->GetOpDesc(), alias_name_attr, out_alias_name)) { | |||||
| for (const auto &alias_name : out_alias_name) { | |||||
| if (alias_name == name) { | |||||
| return node; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | } | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -258,6 +267,7 @@ NodePtr ComputeGraph::AddNodeFront(NodePtr node) { | |||||
| GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null."); | GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null."); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| node->SetHostNode(is_valid_flag_); | |||||
| node->GetOpDesc()->SetId(nodes_.size()); | node->GetOpDesc()->SetId(nodes_.size()); | ||||
| if (nodes_.size() > 0 && nodes_[0]->GetType() == DATA) { | if (nodes_.size() > 0 && nodes_[0]->GetType() == DATA) { | ||||
| (void)nodes_.insert(nodes_.begin() + 1, node); | (void)nodes_.insert(nodes_.begin() + 1, node); | ||||
| @@ -284,6 +294,7 @@ NodePtr ComputeGraph::AddNodeAfter(NodePtr node, const NodePtr &pre_node) { | |||||
| GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null."); | GELOGE(GRAPH_FAILED, "The node ptr or op desc should not be null."); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| node->SetHostNode(is_valid_flag_); | |||||
| node->GetOpDesc()->SetId(nodes_.size()); | node->GetOpDesc()->SetId(nodes_.size()); | ||||
| auto node_iter = std::find(nodes_.begin(), nodes_.end(), pre_node); | auto node_iter = std::find(nodes_.begin(), nodes_.end(), pre_node); | ||||
| if (node_iter != nodes_.end()) { | if (node_iter != nodes_.end()) { | ||||
| @@ -313,6 +324,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr ComputeGraph::AddNode(Nod | |||||
| GELOGE(GRAPH_FAILED, "The node ptr should not be null."); | GELOGE(GRAPH_FAILED, "The node ptr should not be null."); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| node->SetHostNode(is_valid_flag_); | |||||
| node->GetOpDesc()->SetId((int64_t)GetDirectNodesSize()); | node->GetOpDesc()->SetId((int64_t)GetDirectNodesSize()); | ||||
| nodes_.push_back(node); | nodes_.push_back(node); | ||||
| return node; | return node; | ||||
| @@ -339,6 +351,7 @@ NodePtr ComputeGraph::AddNode(OpDescPtr op, int64_t id) { // for unserialize. | |||||
| NodePtr node = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this())); | NodePtr node = shared_ptr<Node>(new (std::nothrow) Node(op, shared_from_this())); | ||||
| GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); | GE_IF_BOOL_EXEC(node == nullptr, GELOGE(GRAPH_FAILED, "node_ptr is NULL!!!"); return nullptr); | ||||
| GE_IF_BOOL_EXEC(node->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); | GE_IF_BOOL_EXEC(node->Init() != GRAPH_SUCCESS, GELOGE(GRAPH_FAILED, "node init fail."); return nullptr); | ||||
| node->SetHostNode(is_valid_flag_); | |||||
| nodes_.push_back(node); | nodes_.push_back(node); | ||||
| return node; | return node; | ||||
| } | } | ||||
| @@ -355,7 +368,9 @@ NodePtr ComputeGraph::AddInputNode(NodePtr node) { | |||||
| return node; | return node; | ||||
| } | } | ||||
| NodePtr ComputeGraph::AddOutputNode(NodePtr node) { | |||||
| NodePtr ComputeGraph::AddOutputNode(NodePtr node) { return AddOutputNodeByIndex(node, 0); } | |||||
| NodePtr ComputeGraph::AddOutputNodeByIndex(NodePtr node, int32_t index) { | |||||
| if (node == nullptr || node->GetOpDesc() == nullptr) { | if (node == nullptr || node->GetOpDesc() == nullptr) { | ||||
| GELOGE(GRAPH_FAILED, "The node ptr or opdesc should not be null."); | GELOGE(GRAPH_FAILED, "The node ptr or opdesc should not be null."); | ||||
| return nullptr; | return nullptr; | ||||
| @@ -365,7 +380,7 @@ NodePtr ComputeGraph::AddOutputNode(NodePtr node) { | |||||
| NodePtr result = node; | NodePtr result = node; | ||||
| // [output_nodes_info_ : should not be null] | // [output_nodes_info_ : should not be null] | ||||
| for (const auto &item : output_nodes_info_) { | for (const auto &item : output_nodes_info_) { | ||||
| if (item.first->GetName() == node->GetName()) { | |||||
| if (item.first->GetName() == node->GetName() && item.second == index) { | |||||
| already_have = true; | already_have = true; | ||||
| result = item.first; | result = item.first; | ||||
| break; | break; | ||||
| @@ -373,7 +388,8 @@ NodePtr ComputeGraph::AddOutputNode(NodePtr node) { | |||||
| } | } | ||||
| if (!already_have) { | if (!already_have) { | ||||
| output_nodes_info_.emplace_back(std::make_pair(node, 0)); | |||||
| output_nodes_info_.emplace_back(std::make_pair(node, index)); | |||||
| GELOGI("Push back node name:%s, index:%ld, into output_nodes_info_.", node->GetName().c_str(), index); | |||||
| } | } | ||||
| if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) { | if (std::find(nodes_.begin(), nodes_.end(), node) == nodes_.end()) { | ||||
| @@ -32,6 +32,8 @@ GE_REGISTER_OPTYPE(STATELESSWHILE, "StatelessWhile"); | |||||
| GE_REGISTER_OPTYPE(SQUEEZE, "Squeeze"); | GE_REGISTER_OPTYPE(SQUEEZE, "Squeeze"); | ||||
| GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); | GE_REGISTER_OPTYPE(EXPANDDIMS, "ExpandDims"); | ||||
| GE_REGISTER_OPTYPE(SWITCH, "Switch"); | GE_REGISTER_OPTYPE(SWITCH, "Switch"); | ||||
| GE_REGISTER_OPTYPE(REFSWITCH, "RefSwitch"); | |||||
| GE_REGISTER_OPTYPE(SWITCHN, "SwitchN"); | |||||
| GE_REGISTER_OPTYPE(MERGE, "Merge"); | GE_REGISTER_OPTYPE(MERGE, "Merge"); | ||||
| GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); | GE_REGISTER_OPTYPE(STREAMMERGE, "StreamMerge"); | ||||
| GE_REGISTER_OPTYPE(ENTER, "Enter"); | GE_REGISTER_OPTYPE(ENTER, "Enter"); | ||||
| @@ -40,6 +42,7 @@ GE_REGISTER_OPTYPE(NEXTITERATION, "NextIteration"); | |||||
| GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); | GE_REGISTER_OPTYPE(REFNEXTITERATION, "RefNextIteration"); | ||||
| GE_REGISTER_OPTYPE(CONSTANT, "Const"); | GE_REGISTER_OPTYPE(CONSTANT, "Const"); | ||||
| GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); | GE_REGISTER_OPTYPE(PLACEHOLDER, "PlaceHolder"); | ||||
| GE_REGISTER_OPTYPE(END, "End"); | |||||
| GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); | GE_REGISTER_OPTYPE(FRAMEWORKOP, "FrameworkOp"); | ||||
| GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); | GE_REGISTER_OPTYPE(GETNEXT, "GetNext"); | ||||
| GE_REGISTER_OPTYPE(INITDATA, "InitData"); | GE_REGISTER_OPTYPE(INITDATA, "InitData"); | ||||
| @@ -43,7 +43,7 @@ namespace ge { | |||||
| namespace { | namespace { | ||||
| const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; | const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; | ||||
| const string kIsGraphInferred = "_is_graph_inferred"; | const string kIsGraphInferred = "_is_graph_inferred"; | ||||
| RefRelations reflection_builder; | |||||
| thread_local RefRelations reflection_builder; | |||||
| } // namespace | } // namespace | ||||
| graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &reflection, | graphStatus ReflectionProcess(const std::unordered_set<RefCell, RefCellHash> &reflection, | ||||
| @@ -967,6 +967,13 @@ const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION = "_enable_l2_fusion"; | |||||
| const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag"; | const std::string ATTR_NAME_OP_INPUT_L1_FLAG = "_op_input_l1_flag"; | ||||
| const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr"; | const std::string ATTR_NAME_OP_INPUT_L1_ADDR = "_op_input_l1_addr"; | ||||
| const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size"; | const std::string ATTR_NAME_OP_INPUT_L1_VALID_SIZE = "_op_input_l1_valid_size"; | ||||
| const std::string ATTR_NAME_ENGINE_NAME_FOR_LX = "_lxfusion_engine_name"; | |||||
| const std::string ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX = "_lxfusion_op_kernel_lib_name"; | |||||
| const std::string ATTR_NAME_NEED_LX_FUSION = "_lx_fusion"; | |||||
| const std::string ATTR_NAME_OPTIMIZE_GROUP = "_optimize_group"; | |||||
| const std::string ATTR_NAME_OP_COMPILE_STRATEGY = "_op_compile_strategy"; | |||||
| const std::string ATTR_NAME_TBE_KERNEL_NAME = "_tbe_kernel_name"; | |||||
| const std::string ATTR_NAME_TBE_KERNEL_BUFFER = "_tbe_kernel_buffer"; | |||||
| // Op debug attrs | // Op debug attrs | ||||
| const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; | const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; | ||||
| @@ -8,6 +8,7 @@ COMMON_LOCAL_SRC_FILES := \ | |||||
| ./proto/task.proto \ | ./proto/task.proto \ | ||||
| ./proto/fwk_adapter.proto \ | ./proto/fwk_adapter.proto \ | ||||
| ./proto/op_mapping_info.proto \ | ./proto/op_mapping_info.proto \ | ||||
| ./proto/dump_task.proto \ | |||||
| ./anchor.cc \ | ./anchor.cc \ | ||||
| ./ge_attr_value.cc \ | ./ge_attr_value.cc \ | ||||
| ./attr_value.cc \ | ./attr_value.cc \ | ||||
| @@ -29,6 +30,7 @@ COMMON_LOCAL_SRC_FILES := \ | |||||
| ./ge_tensor.cc \ | ./ge_tensor.cc \ | ||||
| ./detail/attributes_holder.cc \ | ./detail/attributes_holder.cc \ | ||||
| ./utils/anchor_utils.cc \ | ./utils/anchor_utils.cc \ | ||||
| ./utils/tuning_utils.cc \ | |||||
| ./utils/graph_utils.cc \ | ./utils/graph_utils.cc \ | ||||
| ./utils/ge_ir_utils.cc \ | ./utils/ge_ir_utils.cc \ | ||||
| ./utils/node_utils.cc \ | ./utils/node_utils.cc \ | ||||
| @@ -51,6 +53,7 @@ COMMON_LOCAL_C_INCLUDES := \ | |||||
| proto/task.proto \ | proto/task.proto \ | ||||
| proto/fwk_adapter.proto \ | proto/fwk_adapter.proto \ | ||||
| proto/op_mapping_info.proto \ | proto/op_mapping_info.proto \ | ||||
| proto/dump_task.proto \ | |||||
| inc \ | inc \ | ||||
| inc/external \ | inc/external \ | ||||
| inc/external/graph \ | inc/external/graph \ | ||||
| @@ -195,9 +195,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool ModelSerializeImp::Serialize | |||||
| } | } | ||||
| } | } | ||||
| // Outputs | // Outputs | ||||
| for (const auto &output : graph->GetOutputNodes()) { | |||||
| if (output != nullptr) { | |||||
| graph_proto->add_output(output->GetName() + ":0"); | |||||
| for (const auto &output : graph->GetGraphOutNodesInfo()) { | |||||
| if (output.first != nullptr) { | |||||
| graph_proto->add_output(output.first->GetName() + ":" + std::to_string(output.second)); | |||||
| GELOGI("Add output to graph proto, node name:%s, index:%ld", output.first->GetName().c_str(), output.second); | |||||
| } | } | ||||
| } | } | ||||
| if (graph->attrs_.GetProtoMsg() != nullptr) { | if (graph->attrs_.GetProtoMsg() != nullptr) { | ||||
| @@ -440,7 +441,8 @@ bool ModelSerializeImp::HandleNodeNameRef() { | |||||
| } | } | ||||
| GE_IF_BOOL_EXEC(item.graph == nullptr, continue); | GE_IF_BOOL_EXEC(item.graph == nullptr, continue); | ||||
| auto ret = item.graph->AddOutputNode(node_it->second); | |||||
| auto ret = item.graph->AddOutputNodeByIndex(node_it->second, item.index); | |||||
| GELOGI("node name:%s, item.index:%ld", node_it->second->GetName().c_str(), item.index); | |||||
| if (ret == nullptr) { | if (ret == nullptr) { | ||||
| GELOGE(GRAPH_FAILED, "AddOutputNode failed."); | GELOGE(GRAPH_FAILED, "AddOutputNode failed."); | ||||
| return false; | return false; | ||||
| @@ -219,6 +219,10 @@ graphStatus OpDesc::AddInputDesc(const string &name, const ge::GeTensorDesc &inp | |||||
| } | } | ||||
| inputs_desc_.push_back(in_desc); | inputs_desc_.push_back(in_desc); | ||||
| (void)input_name_idx_.insert(make_pair(name, index)); | (void)input_name_idx_.insert(make_pair(name, index)); | ||||
| if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) { | |||||
| register_input_name_.push_back(name); | |||||
| } | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| } | } | ||||
| @@ -255,6 +259,38 @@ graphStatus OpDesc::AddInputDescMiddle(const string &name, const unsigned int nu | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| graphStatus OpDesc::AddOutputDescMiddle(const string &name, const unsigned int num, size_t index) { | |||||
| for (unsigned int i = 0; i < num; i++) { | |||||
| string output_name = name + std::to_string(i); | |||||
| GE_CHK_BOOL_RET_STATUS((output_name_idx_.find(output_name) == output_name_idx_.end()), GRAPH_FAILED, | |||||
| "Add input tensor_desc is existed. name[%s]", output_name.c_str()); | |||||
| std::shared_ptr<GeTensorDesc> out_desc = ComGraphMakeShared<GeTensorDesc>(GeTensorDesc()); | |||||
| if (out_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, malloc shared_ptr failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (index > outputs_desc_.size()) { | |||||
| GELOGE(GRAPH_FAILED, "AddInputDescMiddle failed, insert index should not more than inputs size."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| (void)outputs_desc_.insert(outputs_desc_.begin() + index + i, out_desc); | |||||
| // Update index in input_name_idx | |||||
| for (auto it = output_name_idx_.begin(); it != output_name_idx_.end(); ++it) { | |||||
| if (it->second >= (index + i)) { | |||||
| it->second += 1; | |||||
| } | |||||
| } | |||||
| (void)output_name_idx_.insert(make_pair(output_name, i + index)); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { | graphStatus OpDesc::AddInputDescForward(const string &name, const unsigned int num) { | ||||
| for (unsigned int i = 0; i < num; i++) { | for (unsigned int i = 0; i < num; i++) { | ||||
| string input_name = name + std::to_string(i); | string input_name = name + std::to_string(i); | ||||
| @@ -550,6 +586,9 @@ graphStatus OpDesc::AddOutputDesc(const string &name, const ge::GeTensorDesc &ou | |||||
| } | } | ||||
| outputs_desc_.push_back(tensor); | outputs_desc_.push_back(tensor); | ||||
| (void)output_name_idx_.insert(make_pair(name, index)); | (void)output_name_idx_.insert(make_pair(name, index)); | ||||
| if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) { | |||||
| register_output_name_.push_back(name); | |||||
| } | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| @@ -655,6 +694,16 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ConstGeTensorDescPtr OpDesc::GetI | |||||
| return inputs_desc_[it->second]; | return inputs_desc_[it->second]; | ||||
| } | } | ||||
| graphStatus OpDesc::AddRegisterInputName(const std::string &name) { | |||||
| if (find(register_input_name_.begin(), register_input_name_.end(), name) == register_input_name_.end()) { | |||||
| register_input_name_.push_back(name); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| vector<string> OpDesc::GetRegisterInputName() const { return register_input_name_; } | |||||
| graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) { | graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int num, bool is_push_back) { | ||||
| if (is_push_back) { | if (is_push_back) { | ||||
| for (unsigned int i = 0; i < num; i++) { | for (unsigned int i = 0; i < num; i++) { | ||||
| @@ -663,6 +712,10 @@ graphStatus OpDesc::AddDynamicInputDesc(const string &name, const unsigned int n | |||||
| } else { | } else { | ||||
| if (AddInputDescForward(name, num) != GRAPH_SUCCESS) return GRAPH_FAILED; | if (AddInputDescForward(name, num) != GRAPH_SUCCESS) return GRAPH_FAILED; | ||||
| } | } | ||||
| if (AddRegisterInputName(name) != GRAPH_SUCCESS) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| @@ -673,6 +726,16 @@ graphStatus OpDesc::AddDynamicInputDescByIndex(const string &name, const unsigne | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| graphStatus OpDesc::AddRegisterOutputName(const string &name) { | |||||
| if (find(register_output_name_.begin(), register_output_name_.end(), name) == register_output_name_.end()) { | |||||
| register_output_name_.push_back(name); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| vector<string> OpDesc::GetRegisterOutputName() const { return register_output_name_; } | |||||
| graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int num, bool is_push_back) { | graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int num, bool is_push_back) { | ||||
| if (is_push_back) { | if (is_push_back) { | ||||
| for (unsigned int i = 0; i < num; i++) { | for (unsigned int i = 0; i < num; i++) { | ||||
| @@ -681,6 +744,10 @@ graphStatus OpDesc::AddDynamicOutputDesc(const string &name, const unsigned int | |||||
| } else { | } else { | ||||
| if (AddOutputDescForward(name, num) != GRAPH_SUCCESS) return GRAPH_FAILED; | if (AddOutputDescForward(name, num) != GRAPH_SUCCESS) return GRAPH_FAILED; | ||||
| } | } | ||||
| if (AddRegisterOutputName(name) != GRAPH_SUCCESS) { | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| @@ -31,6 +31,13 @@ OpsProtoManager *OpsProtoManager::Instance() { | |||||
| } | } | ||||
| bool OpsProtoManager::Initialize(const std::map<std::string, std::string> &options) { | bool OpsProtoManager::Initialize(const std::map<std::string, std::string> &options) { | ||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| if (is_init_) { | |||||
| GELOGI("OpsProtoManager is already initialized."); | |||||
| return true; | |||||
| } | |||||
| /*lint -e1561*/ | /*lint -e1561*/ | ||||
| auto proto_iter = options.find("ge.opsProtoLibPath"); | auto proto_iter = options.find("ge.opsProtoLibPath"); | ||||
| /*lint +e1561*/ | /*lint +e1561*/ | ||||
| @@ -42,10 +49,19 @@ bool OpsProtoManager::Initialize(const std::map<std::string, std::string> &optio | |||||
| pluginPath_ = proto_iter->second; | pluginPath_ = proto_iter->second; | ||||
| LoadOpsProtoPluginSo(pluginPath_); | LoadOpsProtoPluginSo(pluginPath_); | ||||
| is_init_ = true; | |||||
| return true; | return true; | ||||
| } | } | ||||
| void OpsProtoManager::Finalize() { | void OpsProtoManager::Finalize() { | ||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| if (!is_init_) { | |||||
| GELOGI("OpsProtoManager is not initialized."); | |||||
| return; | |||||
| } | |||||
| for (auto handle : handles_) { | for (auto handle : handles_) { | ||||
| if (handle != nullptr) { | if (handle != nullptr) { | ||||
| if (dlclose(handle) != 0) { | if (dlclose(handle) != 0) { | ||||
| @@ -57,6 +73,8 @@ void OpsProtoManager::Finalize() { | |||||
| GELOGW("close opsprotomanager handler failure, handler is nullptr"); | GELOGW("close opsprotomanager handler failure, handler is nullptr"); | ||||
| } | } | ||||
| } | } | ||||
| is_init_ = false; | |||||
| } | } | ||||
| static std::vector<std::string> Split(const std::string &str, char delim) { | static std::vector<std::string> Split(const std::string &str, char delim) { | ||||
| @@ -601,7 +601,7 @@ InferenceContextPtr CreateInferenceContext(const std::unordered_map<NodePtr, Inf | |||||
| } | } | ||||
| namespace { | namespace { | ||||
| std::unordered_map<NodePtr, InferenceContextPtr> context_map; | |||||
| thread_local std::unordered_map<NodePtr, InferenceContextPtr> context_map; | |||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ShapeRefiner::ClearContextMap() { context_map.clear(); } | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ShapeRefiner::ClearContextMap() { context_map.clear(); } | ||||
| @@ -645,6 +645,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ShapeRefiner::InferSh | |||||
| graphStatus status = InferShapeAndType(node, op, before_subgraph); | graphStatus status = InferShapeAndType(node, op, before_subgraph); | ||||
| if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { | if (status == GRAPH_PARAM_INVALID || status == GRAPH_SUCCESS) { | ||||
| if (is_unknown_graph) { | if (is_unknown_graph) { | ||||
| PrintInOutTensorShape(node, "after_infershape when running"); | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include <fstream> | #include <fstream> | ||||
| #include <iomanip> | #include <iomanip> | ||||
| #include <queue> | #include <queue> | ||||
| #include <atomic> | |||||
| #include "./ge_context.h" | #include "./ge_context.h" | ||||
| #include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
| @@ -57,6 +58,7 @@ namespace { | |||||
| const int32_t kBaseOfIntegerValue = 10; | const int32_t kBaseOfIntegerValue = 10; | ||||
| #ifdef FMK_SUPPORT_DUMP | #ifdef FMK_SUPPORT_DUMP | ||||
| const char *const kDumpGeGraph = "DUMP_GE_GRAPH"; | const char *const kDumpGeGraph = "DUMP_GE_GRAPH"; | ||||
| const int kDumpGraphIndexWidth = 5; | |||||
| #endif | #endif | ||||
| const char *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; | const char *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; | ||||
| const char *const kDumpStrBuild = "Build"; | const char *const kDumpStrBuild = "Build"; | ||||
| @@ -431,10 +433,15 @@ GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDat | |||||
| OutControlAnchorPtr src_out_ctrl_anchor = src_node->GetOutControlAnchor(); | OutControlAnchorPtr src_out_ctrl_anchor = src_node->GetOutControlAnchor(); | ||||
| GE_CHECK_NOTNULL(src_out_ctrl_anchor); | GE_CHECK_NOTNULL(src_out_ctrl_anchor); | ||||
| bool ctrl_edge_flag = true; | |||||
| std::string type = NodeUtils::GetNodeType(src->GetOwnerNode()); | |||||
| if ((type == SWITCH) || (type == REFSWITCH) || (type == SWITCHN)) { | |||||
| ctrl_edge_flag = false; | |||||
| } | |||||
| for (auto &dst : dsts) { | for (auto &dst : dsts) { | ||||
| GE_CHECK_NOTNULL(dst); | GE_CHECK_NOTNULL(dst); | ||||
| NodePtr dst_node = dst->GetOwnerNode(); | NodePtr dst_node = dst->GetOwnerNode(); | ||||
| GE_CHECK_NOTNULL(dst_node); | |||||
| GELOGI("Insert node %s between %s->%s.", insert_node->GetName().c_str(), src_node->GetName().c_str(), | GELOGI("Insert node %s between %s->%s.", insert_node->GetName().c_str(), src_node->GetName().c_str(), | ||||
| dst_node->GetName().c_str()); | dst_node->GetName().c_str()); | ||||
| if (src_node->GetOwnerComputeGraph() != dst_node->GetOwnerComputeGraph()) { | if (src_node->GetOwnerComputeGraph() != dst_node->GetOwnerComputeGraph()) { | ||||
| @@ -450,11 +457,12 @@ GraphUtils::InsertNodeAfter(const OutDataAnchorPtr &src, const std::vector<InDat | |||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| OutControlAnchorPtr new_out_ctrl_anchor = insert_node->GetOutControlAnchor(); | |||||
| GE_CHECK_NOTNULL(new_out_ctrl_anchor); | |||||
| if (!ctrl_edge_flag) { | |||||
| continue; | |||||
| } | |||||
| for (const InControlAnchorPtr &peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { | for (const InControlAnchorPtr &peer_in_ctrl_anchor : src_out_ctrl_anchor->GetPeerInControlAnchors()) { | ||||
| if ((RemoveEdge(src_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS) || | if ((RemoveEdge(src_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS) || | ||||
| (AddEdge(new_out_ctrl_anchor, peer_in_ctrl_anchor) != GRAPH_SUCCESS)) { | |||||
| (AddEdge(insert_node->GetOutControlAnchor(), peer_in_ctrl_anchor) != GRAPH_SUCCESS)) { | |||||
| GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), | GELOGE(GRAPH_FAILED, "ReplaceEdge from %s->%s to %s->%s failed.", src_node->GetName().c_str(), | ||||
| peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str(), insert_node->GetName().c_str(), | peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str(), insert_node->GetName().c_str(), | ||||
| peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); | peer_in_ctrl_anchor->GetOwnerNode()->GetName().c_str()); | ||||
| @@ -552,7 +560,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::MatchDumpStr(con | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(const ge::ComputeGraphPtr &graph, | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(const ge::ComputeGraphPtr &graph, | ||||
| const std::string &suffix, | const std::string &suffix, | ||||
| bool is_always_dump) { | |||||
| bool is_always_dump, | |||||
| const std::string &user_graph_name) { | |||||
| #ifdef FMK_SUPPORT_DUMP | #ifdef FMK_SUPPORT_DUMP | ||||
| char *dump_ge_graph = std::getenv(kDumpGeGraph); | char *dump_ge_graph = std::getenv(kDumpGeGraph); | ||||
| GE_IF_BOOL_EXEC(dump_ge_graph == nullptr && !is_always_dump, return;); | GE_IF_BOOL_EXEC(dump_ge_graph == nullptr && !is_always_dump, return;); | ||||
| @@ -563,32 +572,33 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons | |||||
| } | } | ||||
| // file name | // file name | ||||
| static int file_idx = 0; | |||||
| const int dump_graph_index_width = 5; | |||||
| file_idx++; | |||||
| GELOGD("Start to dump om txt: %d", file_idx); | |||||
| static std::atomic_long atomic_file_index(0); | |||||
| auto file_index = atomic_file_index.fetch_add(1); | |||||
| GELOGD("Start to dump om txt: %ld", file_index); | |||||
| static int max_dumpfile_num = 0; | |||||
| if (max_dumpfile_num == 0) { | |||||
| thread_local long max_dump_file_num = 0; | |||||
| if (max_dump_file_num == 0) { | |||||
| string opt = "0"; | string opt = "0"; | ||||
| (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); | (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); | ||||
| max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); | |||||
| max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); | |||||
| } | } | ||||
| if (max_dumpfile_num != 0 && file_idx > max_dumpfile_num) { | |||||
| GELOGW("dump graph file cnt > maxDumpFileNum, maxDumpFileCnt=%d.", max_dumpfile_num); | |||||
| if (max_dump_file_num != 0 && file_index > max_dump_file_num) { | |||||
| GELOGW("dump graph file cnt > maxDumpFileNum, maxDumpFileCnt=%ld.", max_dump_file_num); | |||||
| return; | return; | ||||
| } | } | ||||
| std::stringstream stream_file_name; | std::stringstream stream_file_name; | ||||
| stream_file_name << "ge_proto_" << std::setw(dump_graph_index_width) << std::setfill('0') << file_idx; | |||||
| stream_file_name << "ge_proto_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; | |||||
| stream_file_name << "_" << suffix << ".txt"; | stream_file_name << "_" << suffix << ".txt"; | ||||
| std::string proto_file = stream_file_name.str(); | |||||
| std::string proto_file = user_graph_name.empty() ? stream_file_name.str() : user_graph_name; | |||||
| // Create buffer | // Create buffer | ||||
| ge::Model model("", ""); | ge::Model model("", ""); | ||||
| model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast<ComputeGraph>(graph))); | model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast<ComputeGraph>(graph))); | ||||
| Buffer buffer; | Buffer buffer; | ||||
| model.Save(buffer, true); | |||||
| const int64_t kDumpLevel = | |||||
| (dump_ge_graph != nullptr) ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) : ge::OnnxUtils::NO_DUMP; | |||||
| model.Save(buffer, kDumpLevel != ge::OnnxUtils::DUMP_ALL); | |||||
| // Write file | // Write file | ||||
| ge::proto::ModelDef ge_proto; | ge::proto::ModelDef ge_proto; | ||||
| @@ -631,6 +641,35 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(cons | |||||
| } | } | ||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(const char *file, | |||||
| ge::ComputeGraphPtr &compute_graph) { | |||||
| ge::proto::ModelDef model_def; | |||||
| // Get ModelDef object from file generated by DumpGEGraph() | |||||
| if (!ReadProtoFromTextFile(file, &model_def)) { | |||||
| GELOGE(GRAPH_FAILED, "Get ModelDef failed from file"); | |||||
| return false; | |||||
| } | |||||
| ge::Model model; | |||||
| // Get Model object from ModelDef by deserialize ModelDef | |||||
| if (model.Load(model_def) == GRAPH_SUCCESS) { | |||||
| GE_CHK_BOOL_EXEC(GraphUtils::GetComputeGraph(model.GetGraph()) != nullptr, return false, | |||||
| "Get computer graph is nullptr"); | |||||
| compute_graph = GraphUtils::GetComputeGraph(model.GetGraph()); | |||||
| for (const auto &node : compute_graph->GetDirectNode()) { | |||||
| GELOGI("Node %s set owner graph", node->GetName().c_str()); | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if (node->SetOwnerComputeGraph(compute_graph) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Node %s set owner graph failed", node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } else { | |||||
| GELOGE(GRAPH_FAILED, "Get Model failed from ModelDef"); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| // Printing protocol messages in text format is useful for debugging and human editing of messages. | // Printing protocol messages in text format is useful for debugging and human editing of messages. | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToTextFile( | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToTextFile( | ||||
| const google::protobuf::Message &proto, const char *real_path) { | const google::protobuf::Message &proto, const char *real_path) { | ||||
| @@ -666,16 +705,16 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToText | |||||
| return; | return; | ||||
| } | } | ||||
| if (fseek(file, 0L, SEEK_END) == 0) { | if (fseek(file, 0L, SEEK_END) == 0) { | ||||
| int64_t fileSize = ftell(file); | |||||
| static int64_t maxDumpFileSize = 0; | |||||
| if (maxDumpFileSize == 0) { | |||||
| long fileSize = ftell(file); | |||||
| thread_local long max_dump_file_size = 0; | |||||
| if (max_dump_file_size == 0) { | |||||
| string opt = "0"; | string opt = "0"; | ||||
| // Can not check return value | // Can not check return value | ||||
| (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt); | (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_SIZE, opt); | ||||
| maxDumpFileSize = atol(opt.c_str()); | |||||
| max_dump_file_size = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); | |||||
| } | } | ||||
| if (maxDumpFileSize != 0 && fileSize != -1 && fileSize > maxDumpFileSize) { | |||||
| GELOGW("dump graph file size > maxDumpFileSize, maxDumpFileSize=%ld.", maxDumpFileSize); | |||||
| if (max_dump_file_size != 0 && fileSize != -1 && fileSize > max_dump_file_size) { | |||||
| GELOGW("dump graph file size > maxDumpFileSize, maxDumpFileSize=%ld.", max_dump_file_size); | |||||
| GE_IF_BOOL_EXEC(std::remove(real_path) != 0, GELOGW("remove %s failed", real_path)); | GE_IF_BOOL_EXEC(std::remove(real_path) != 0, GELOGW("remove %s failed", real_path)); | ||||
| GE_CHK_BOOL_EXEC(fclose(file) == 0, return, "Fclose %s failed", real_path); | GE_CHK_BOOL_EXEC(fclose(file) == 0, return, "Fclose %s failed", real_path); | ||||
| return; | return; | ||||
| @@ -734,25 +773,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn | |||||
| } | } | ||||
| // 2.Set file name | // 2.Set file name | ||||
| static int file_index = 0; | |||||
| file_index++; | |||||
| GELOGD("Start to dump ge onnx file: %d", file_index); | |||||
| static std::atomic_long atomic_file_index(0); | |||||
| auto file_index = atomic_file_index.fetch_add(1); | |||||
| GELOGD("Start to dump ge onnx file: %ld", file_index); | |||||
| static int max_dumpfile_num = 0; | |||||
| if (max_dumpfile_num == 0) { | |||||
| thread_local long max_dump_file_num = 0; | |||||
| if (max_dump_file_num == 0) { | |||||
| string opt = "0"; | string opt = "0"; | ||||
| (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); | (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); | ||||
| max_dumpfile_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); | |||||
| max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); | |||||
| } | } | ||||
| if (max_dumpfile_num != 0 && file_index > max_dumpfile_num) { | |||||
| GELOGW("dump graph file cnt > maxDumpFileNum, maxDumpFileNum=%d.", max_dumpfile_num); | |||||
| if (max_dump_file_num != 0 && file_index > max_dump_file_num) { | |||||
| GELOGW("dump graph file cnt > maxDumpFileNum, maxDumpFileNum=%ld.", max_dump_file_num); | |||||
| return; | return; | ||||
| } | } | ||||
| /// 99999 graphs can be dumped at most at one time | |||||
| /// setw(5) is for formatted sort | |||||
| std::stringstream stream_file_name; | std::stringstream stream_file_name; | ||||
| stream_file_name << "ge_onnx_" << std::setw(5) << std::setfill('0') << file_index; | |||||
| stream_file_name << "ge_onnx_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; | |||||
| stream_file_name << "_graph_" << compute_graph.GetGraphID(); | stream_file_name << "_graph_" << compute_graph.GetGraphID(); | ||||
| stream_file_name << "_" << suffix << ".pbtxt"; | stream_file_name << "_" << suffix << ".pbtxt"; | ||||
| std::string proto_file = stream_file_name.str(); | std::string proto_file = stream_file_name.str(); | ||||
| @@ -1363,6 +1400,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::FindR | |||||
| /// Make a copy of ComputeGraph. | /// Make a copy of ComputeGraph. | ||||
| /// @param graph: original graph. | /// @param graph: original graph. | ||||
| /// @param prefix: node name prefix of new graph. | /// @param prefix: node name prefix of new graph. | ||||
| /// @param output_nodes: output nodes of new graph. | |||||
| /// @return ComputeGraphPtr | /// @return ComputeGraphPtr | ||||
| /// | /// | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr | ||||
| @@ -1399,6 +1437,14 @@ GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std::string &prefix, | |||||
| } | } | ||||
| } | } | ||||
| std::string session_graph_id; | |||||
| if (AttrUtils::GetStr(*graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id)) { | |||||
| bool ret = AttrUtils::SetStr(*new_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id); | |||||
| if (!ret) { | |||||
| GELOGE(GRAPH_FAILED, "Set attr ATTR_NAME_SESSION_GRAPH_ID failed."); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| return new_graph; | return new_graph; | ||||
| } | } | ||||
| @@ -479,6 +479,19 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<GeTensorPtr> OpDescUtils:: | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (node.GetType() == DATA) { | |||||
| auto parent = NodeUtils::GetParentInput(node); | |||||
| if ((parent != nullptr) && NodeUtils::IsConst(*parent)) { | |||||
| auto weight = MutableWeights(parent->GetOpDesc()); | |||||
| if (weight == nullptr) { | |||||
| GELOGI("const op has no weight, op name:%s", parent->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| ret.push_back(weight); | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| // Other operators, get weights from connected constop | // Other operators, get weights from connected constop | ||||
| auto input_nodes = GetConstInputs(node); | auto input_nodes = GetConstInputs(node); | ||||
| for (const auto &input_node : input_nodes) { | for (const auto &input_node : input_nodes) { | ||||
| @@ -560,11 +573,9 @@ OpDescPtr OpDescUtils::CreateConstOp(const GeTensorPtr &tensor_ptr) { | |||||
| const_opdesc->SetType(CONSTANT); | const_opdesc->SetType(CONSTANT); | ||||
| static int const_count = 0; | |||||
| const_opdesc->SetName("dynamic_const_" + std::to_string(const_count)); | |||||
| thread_local int64_t const_count = 0; | |||||
| const_opdesc->SetName("dynamic_const_" + std::to_string(GetTid()) + "_" + std::to_string(const_count)); | |||||
| GELOGI("add const op: %s", const_opdesc->GetName().c_str()); | GELOGI("add const op: %s", const_opdesc->GetName().c_str()); | ||||
| ++const_count; | ++const_count; | ||||
| (void)const_opdesc->AddOutputDesc(tensor_ptr->GetTensorDesc()); | (void)const_opdesc->AddOutputDesc(tensor_ptr->GetTensorDesc()); | ||||
| @@ -0,0 +1,684 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/tuning_utils.h" | |||||
| #include "../debug/ge_util.h" | |||||
| #include "../debug/ge_op_types.h" | |||||
| namespace ge { | |||||
| const std::string peer_node_name_attr = "_peerNodeName"; | |||||
| const std::string parent_node_name_attr = "_parentNodeName"; | |||||
| const std::string alias_name_attr = "_aliasName"; | |||||
| const std::string parent_node_attr = "parentNode"; | |||||
| const std::string parent_node_anchor_index_attr = "_parentNodeAnchorIndex"; | |||||
| const std::string tuning_subgraph_prefix = "/aicore_subgraph_"; | |||||
| const std::string non_tuning_subgraph_prefix = "/subgraph_"; | |||||
| const std::set<std::string> kPartitionOpTypes = {PLACEHOLDER, END}; | |||||
| const std::set<std::string> kExeTypes = {DATA, NETOUTPUT}; | |||||
| NodeNametoNodeNameMap TuningUtils::data_2_netoutput_; | |||||
| NodetoNodeNameMap TuningUtils::data_node_2_netoutput_; | |||||
| NodetoNodeMap TuningUtils::data_node_2_netoutput_node_; | |||||
| NodeSet TuningUtils::netoutput_nodes_; | |||||
| NodeSet TuningUtils::merged_graph_nodes_; | |||||
| SubgraphCreateOutNode TuningUtils::create_output_; | |||||
| std::mutex TuningUtils::mutex_; | |||||
| std::string TuningUtils::PrintCheckLog() { | |||||
| std::stringstream ss; | |||||
| ss << "d2n:{"; | |||||
| for (const auto &pair : data_2_netoutput_) { | |||||
| ss << "data:" << pair.first << "-" | |||||
| << "netoutput:" << pair.second; | |||||
| ss << " | "; | |||||
| } | |||||
| ss << "}"; | |||||
| ss << "netoutputs:{"; | |||||
| for (const auto &node : netoutput_nodes_) { | |||||
| ss << "netoutput:" << node->GetName(); | |||||
| ss << " | "; | |||||
| } | |||||
| ss << "}"; | |||||
| return ss.str(); | |||||
| } | |||||
| std::string TuningUtils::GetNodeNameByAnchor(const Anchor *anchor) { | |||||
| if (anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Anchor is nullptr"); | |||||
| return "Null"; | |||||
| } | |||||
| auto node = anchor->GetOwnerNode(); | |||||
| return node == nullptr ? "Null" : node->GetName(); | |||||
| } | |||||
| // part 1 | |||||
| graphStatus TuningUtils::ConvertGraphToFile(std::vector<ComputeGraphPtr> tuning_subgraphs, | |||||
| std::vector<ComputeGraphPtr> non_tuning_subgraphs, bool exe_flag, | |||||
| const std::string &path, const std::string &user_path) { | |||||
| int64_t i = 0; | |||||
| int64_t j = 0; | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| for (auto &subgraph : tuning_subgraphs) { | |||||
| create_output_.emplace(subgraph, nullptr); | |||||
| auto help_info = HelpInfo{i, exe_flag, true, path, user_path}; | |||||
| if (MakeExeGraph(subgraph, help_info) != SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "TUU:subgraph %zu generate exe graph failed", i); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| i++; | |||||
| } | |||||
| for (auto &subgraph : non_tuning_subgraphs) { | |||||
| create_output_.emplace(subgraph, nullptr); | |||||
| auto help_info = HelpInfo{j, true, false, path, user_path}; | |||||
| if (MakeExeGraph(subgraph, help_info) != SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "TUU:non tuning_subgraph %zu generate exe graph failed", j); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| j++; | |||||
| } | |||||
| create_output_.clear(); | |||||
| return SUCCESS; | |||||
| } | |||||
| // +---------------+ | |||||
| // | pld pld | | |||||
| // | \ / | | |||||
| // | relu relu | | |||||
| // | \ / | | |||||
| // | add | | |||||
| // | | | | |||||
| // | end | | |||||
| // +---------------+ | |||||
| // | | |||||
| // | | |||||
| // V | |||||
| // +---------------+ | |||||
| // | data data | | |||||
| // | \ / | | |||||
| // | relu relu | | |||||
| // | \ / | | |||||
| // | add | | |||||
| // | | | | |||||
| // | netoutput | | |||||
| // +---------------+ | |||||
| graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info) { | |||||
| GE_CHECK_NOTNULL(exe_graph); | |||||
| // if not make exe, just dump and return | |||||
| if (!help_info.exe_flag) { | |||||
| DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path); | |||||
| GELOGI("TUU:just return, dump original sub_graph[%s]index[%d]", exe_graph->GetName().c_str(), help_info.index); | |||||
| return SUCCESS; | |||||
| } | |||||
| // modify sub graph | |||||
| for (NodePtr &node : exe_graph->GetDirectNode()) { | |||||
| // 1.handle pld | |||||
| if (node->GetType() == PLACEHOLDER) { | |||||
| if (HandlePld(node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), | |||||
| exe_graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| // 2.handle end | |||||
| if (node->GetType() == END) { | |||||
| if (HandleEnd(node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), | |||||
| exe_graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| graphStatus ret = exe_graph->TopologicalSorting(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", exe_graph->GetName().c_str(), ret); | |||||
| return ret; | |||||
| } | |||||
| // dump subgraphs which modified by us | |||||
| if (help_info.user_path.empty()) { | |||||
| DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path); | |||||
| } else { | |||||
| GraphUtils::DumpGEGraph(exe_graph, "", true, help_info.user_path); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void TuningUtils::DumpGraphToPath(ComputeGraphPtr &exe_graph, int64_t index, bool is_tuning_graph, std::string path) { | |||||
| if (!path.empty()) { | |||||
| if (is_tuning_graph) { | |||||
| GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt"); | |||||
| } else { | |||||
| GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt"); | |||||
| } | |||||
| } else { | |||||
| path = "./"; | |||||
| if (is_tuning_graph) { | |||||
| GraphUtils::DumpGEGraph(exe_graph, "", true, path + tuning_subgraph_prefix + std::to_string(index) + ".txt"); | |||||
| } else { | |||||
| GraphUtils::DumpGEGraph(exe_graph, "", true, path + non_tuning_subgraph_prefix + std::to_string(index) + ".txt"); | |||||
| } | |||||
| } | |||||
| } | |||||
| graphStatus TuningUtils::CreateDataNode(NodePtr &node, NodePtr &data_node) { | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| auto data_op_desc = ComGraphMakeShared<OpDesc>(node->GetName(), DATA); | |||||
| GE_CHECK_NOTNULL(data_op_desc); | |||||
| auto pld_op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(pld_op_desc); | |||||
| auto output_desc = pld_op_desc->GetOutputDesc(0); // only one output for pld and data | |||||
| // data inputdesc & outputdesc set as same | |||||
| if (data_op_desc->AddInputDesc(output_desc) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (data_op_desc->AddOutputDesc(output_desc) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:data node %s AddOutputDesc failed", data_op_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| data_node = graph->AddNode(data_op_desc); | |||||
| GE_CHECK_NOTNULL(data_node); | |||||
| if (data_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed"); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::AddAttrToDataNodeForMergeGraph(const NodePtr &pld, NodePtr &data_node) { | |||||
| auto op_desc = data_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| auto pld_desc = pld->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(pld_desc); | |||||
| // inherit | |||||
| // a. set `end's input node type` as attr | |||||
| std::string parent_op_type; | |||||
| if (!AttrUtils::GetStr(pld_desc, "parentOpType", parent_op_type)) { | |||||
| GELOGE(FAILED, "TUU:pld %s get parentOpType failed", pld_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| (void)AttrUtils::SetStr(op_desc, "parentOpType", parent_op_type); | |||||
| // b. set `end's input node name` as attr | |||||
| std::string parent_op_name; | |||||
| if (!AttrUtils::GetStr(pld_desc, parent_node_name_attr, parent_op_name)) { | |||||
| GELOGE(FAILED, "TUU:pld %s get _parentNodeName failed", pld_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| (void)AttrUtils::SetStr(op_desc, parent_node_name_attr, parent_op_name); | |||||
| // c. set `end's input node's out anchor index` as attr | |||||
| int parent_node_anchor_index; | |||||
| if (!AttrUtils::GetInt(pld_desc, "anchorIndex", parent_node_anchor_index)) { | |||||
| GELOGE(FAILED, "TUU:pld %s get anchorIndex failed", pld_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| (void)AttrUtils::SetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index); | |||||
| GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(), | |||||
| data_node->GetName().c_str(), data_node->GetType().c_str()); | |||||
| // d. set `end node name` as attr | |||||
| std::string peer_end_name; | |||||
| if (!AttrUtils::GetStr(pld_desc, peer_node_name_attr, peer_end_name)) { | |||||
| GELOGE(FAILED, "TUU:pld %s get _peerNodeName failed", pld_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| (void)AttrUtils::SetStr(op_desc, peer_node_name_attr, peer_end_name); | |||||
| GELOGD("TUU:from node %s(%s) to add attr to node %s(%s) success", pld->GetName().c_str(), pld->GetType().c_str(), | |||||
| data_node->GetName().c_str(), data_node->GetType().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::ChangePld2Data(NodePtr &node, NodePtr &data_node) { | |||||
| auto type_pld = node->GetType(); | |||||
| auto type_data = data_node->GetType(); | |||||
| if (type_pld != PLACEHOLDER || type_data != DATA) { | |||||
| GELOGE(FAILED, "TUU:Failed to change node %s from type %s to type %s", node->GetName().c_str(), type_pld.c_str(), | |||||
| type_data.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| std::vector<int> output_map(node->GetAllOutDataAnchorsSize()); | |||||
| for (size_t i = 0; i < node->GetAllOutDataAnchorsSize(); ++i) { | |||||
| output_map[i] = static_cast<int>(i); | |||||
| } | |||||
| auto ret = GraphUtils::ReplaceNodeAnchors(data_node, node, {}, output_map); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to replace node %s by node %s error node %u", node->GetName().c_str(), | |||||
| data_node->GetName().c_str(), ret); | |||||
| return FAILED; | |||||
| } | |||||
| NodeUtils::UnlinkAll(*node); | |||||
| ret = GraphUtils::RemoveNodeWithoutRelink(graph, node); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("TUU:Remove node %s(%s) by the ChangePld2Data process, replace it with node %s(%s)", node->GetName().c_str(), | |||||
| node->GetType().c_str(), data_node->GetName().c_str(), data_node->GetType().c_str()); | |||||
| return ret; | |||||
| } | |||||
| graphStatus TuningUtils::HandlePld(NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| NodePtr data_node = nullptr; | |||||
| // 1. create data node | |||||
| if (CreateDataNode(node, data_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // 2. add necessary info to data_node for recovery whole graph | |||||
| if (AddAttrToDataNodeForMergeGraph(node, data_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // 3. replace pld node by data node created before | |||||
| if (ChangePld2Data(node, data_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("TUU:pld[%s] handle success", node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::CreateNetOutput(NodePtr &node, NodePtr &out_node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| auto search = create_output_.find(graph); | |||||
| if (search == create_output_.end()) { | |||||
| GELOGE(FAILED, "TUU:node %s's owner sub graph %s not exist in create_output map", node->GetName().c_str(), | |||||
| graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (search->second != nullptr) { | |||||
| out_node = search->second; | |||||
| GELOGD("TUU:sub graph %s has created output node, just return", graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| auto out_op_desc = ComGraphMakeShared<OpDesc>(node->GetName(), NETOUTPUT); | |||||
| GE_CHECK_NOTNULL(out_op_desc); | |||||
| out_node = graph->AddNode(out_op_desc); | |||||
| GE_CHECK_NOTNULL(out_node); | |||||
| if (out_node->SetOwnerComputeGraph(graph) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:SetOwnerComputeGraph failed"); | |||||
| return FAILED; | |||||
| } | |||||
| create_output_[graph] = out_node; | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::AddAttrToNetOutputForMergeGraph(const NodePtr &end, NodePtr &out_node) { | |||||
| GE_CHECK_NOTNULL(end); | |||||
| GE_CHECK_NOTNULL(out_node); | |||||
| auto op_desc = out_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| std::vector<std::string> alias_names = {}; | |||||
| (void)AttrUtils::GetListStr(op_desc, alias_name_attr, alias_names); | |||||
| alias_names.push_back(end->GetName()); | |||||
| (void)AttrUtils::SetListStr(op_desc, alias_name_attr, alias_names); | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) { | |||||
| GE_CHECK_NOTNULL(end_node); | |||||
| GE_CHECK_NOTNULL(out_node); | |||||
| // get end in node is control node or normal node | |||||
| AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr) | |||||
| ? Anchor::DynamicAnchorCast<Anchor>(end_node->GetInControlAnchor()) | |||||
| : Anchor::DynamicAnchorCast<Anchor>(end_node->GetInDataAnchor(0)); | |||||
| auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1 | |||||
| if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:remove end input edge from from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", | |||||
| GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), | |||||
| GetNodeNameByAnchor(end_in_anchor.get()).c_str(), end_in_anchor->GetIdx(), end_node->GetName().c_str(), | |||||
| end_node->GetOwnerComputeGraph()->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // add edge between `end in node` and `out_node` | |||||
| if (src_anchor->IsTypeOf<OutDataAnchor>()) { | |||||
| std::shared_ptr<InDataAnchor> anchor = | |||||
| ComGraphMakeShared<InDataAnchor>(out_node, out_node->GetAllInDataAnchors().size()); | |||||
| GE_CHECK_NOTNULL(anchor); | |||||
| out_node->in_data_anchors_.push_back(anchor); | |||||
| if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", | |||||
| GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), | |||||
| GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(), | |||||
| end_node->GetOwnerComputeGraph()->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| auto end_op_desc = end_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(end_op_desc); | |||||
| auto out_node_op_desc = out_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(out_node_op_desc); | |||||
| // end node always has one input | |||||
| if (out_node_op_desc->AddInputDesc(end_op_desc->GetInputDesc(0)) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:node %s add input desc failed.", out_node_op_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } else if (src_anchor->IsTypeOf<OutControlAnchor>()) { | |||||
| auto anchor = out_node->GetInControlAnchor(); | |||||
| if (GraphUtils::AddEdge(src_anchor, anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", | |||||
| GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), | |||||
| GetNodeNameByAnchor(anchor.get()).c_str(), anchor->GetIdx(), end_node->GetName().c_str(), | |||||
| end_node->GetOwnerComputeGraph()->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } else { | |||||
| GELOGE(FAILED, "TUU: node_name:%s, graph_name:%s handled failed", end_node->GetName().c_str(), | |||||
| end_node->GetOwnerComputeGraph()->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::ChangeEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) { | |||||
| GE_CHECK_NOTNULL(end_node); | |||||
| GE_CHECK_NOTNULL(out_node); | |||||
| auto type_end = end_node->GetType(); | |||||
| auto type_out = out_node->GetType(); | |||||
| if (type_end != END || type_out != NETOUTPUT) { | |||||
| GELOGE(FAILED, "TUU:Failed to change end_node %s from type %s to type %s", end_node->GetName().c_str(), | |||||
| type_end.c_str(), type_out.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // link all `end nodes's in node` to this out_node | |||||
| if (LinkEnd2NetOutput(end_node, out_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:end_node [%s] LinkEnd2NetOutput failed.", end_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // remove `end node` | |||||
| NodeUtils::UnlinkAll(*end_node); | |||||
| auto graph = end_node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| if (GraphUtils::RemoveNodeWithoutRelink(graph, end_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:end node [%s] RemoveNodeWithoutRelink failed.", end_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::HandleEnd(NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto graph = node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| NodePtr out_node = nullptr; | |||||
| // 1. create net_output node , add only one NetOutput node to one subgraph | |||||
| if (CreateNetOutput(node, out_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // 2. add necessary info to out_node for recovery whole graph | |||||
| if (AddAttrToNetOutputForMergeGraph(node, out_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // 3. replace all end nodes by one output node created before | |||||
| if (ChangeEnd2NetOutput(node, out_node) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to handle node %s from graph %s", node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("TUU:end[%s] handle success", node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| // part 2 | |||||
| graphStatus TuningUtils::ConvertFileToGraph(const map<int64_t, string> &options, ge::Graph &graph) { | |||||
| // 1. get all subgraph object | |||||
| std::vector<ComputeGraphPtr> graphs; | |||||
| // options format like {index:"subgraph_path"} | |||||
| for (const auto &pair : options) { | |||||
| ComputeGraphPtr compute_graph = ComGraphMakeShared<ComputeGraph>(std::to_string(pair.first)); | |||||
| if (!ge::GraphUtils::LoadGEGraph(pair.second.c_str(), *compute_graph)) { | |||||
| GELOGE(FAILED, "TUU:load graph from file failed"); | |||||
| } | |||||
| graphs.push_back(compute_graph); | |||||
| } | |||||
| // 2. merge graph | |||||
| ComputeGraphPtr merged_graph = ComGraphMakeShared<ComputeGraph>("whole_graph_after_tune"); | |||||
| GE_CHECK_NOTNULL(merged_graph); | |||||
| if (MergeAllSubGraph(graphs, merged_graph) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:MergeGraph failed"); | |||||
| return FAILED; | |||||
| } | |||||
| // 3. set parent graph | |||||
| for (const auto &node : merged_graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if (node->SetOwnerComputeGraph(merged_graph) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:node %s set owner graph failed", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| graph = GraphUtils::CreateGraphFromComputeGraph(merged_graph); | |||||
| return SUCCESS; | |||||
| } | |||||
| // +----------------------------------+ | |||||
| // | const const | | |||||
| // | \ / | | |||||
| // | netoutput(end,end) | | |||||
| // +----------------------------------+ | |||||
| // + | |||||
| // +----------------------------------+ | |||||
| // | data(pld) data(pld) | | |||||
| // | \ / | | |||||
| // | relu relu | | |||||
| // | \ / | | |||||
| // | \ / | | |||||
| // | add | | |||||
| // | | | | |||||
| // | netoutput(end) | | |||||
| // +----------------------------------+ | |||||
| // + | |||||
| // +----------------------------------+ | |||||
| // | data(pld) | | |||||
| // | / | | |||||
| // | netoutput | | |||||
| // +----------------------------------+ | |||||
| // | | |||||
| // | | |||||
| // V | |||||
| // +----------------------------------+ | |||||
| // | const const | | |||||
| // | \ / | | |||||
| // | relu relu | | |||||
| // | \ / | | |||||
| // | \ / | | |||||
| // | add | | |||||
| // | | | | |||||
| // | netoutput | | |||||
| // +----------------------------------+ | |||||
| graphStatus TuningUtils::MergeAllSubGraph(std::vector<ComputeGraphPtr> &subgraphs, | |||||
| ComputeGraphPtr &output_merged_compute_graph) { | |||||
| GE_CHECK_NOTNULL(output_merged_compute_graph); | |||||
| // 1. handle all subgraphs | |||||
| for (auto &subgraph : subgraphs) { | |||||
| Status ret_status = MergeSubGraph(subgraph); | |||||
| if (ret_status != SUCCESS) { | |||||
| GELOGE(ret_status, "TUU:subgraph %s merge failed", subgraph->GetName().c_str()); | |||||
| return ret_status; | |||||
| } | |||||
| } | |||||
| for (const auto &node : merged_graph_nodes_) { | |||||
| (void)output_merged_compute_graph->AddNode(node); | |||||
| GELOGD("TUU:graph %s add node %s success", output_merged_compute_graph->GetName().c_str(), node->GetName().c_str()); | |||||
| } | |||||
| // 2. remove data and output node added by us | |||||
| if (RemoveDataNetoutputEdge(output_merged_compute_graph) != SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to merge graph %s", output_merged_compute_graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| graphStatus ret = output_merged_compute_graph->TopologicalSorting(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Graph[%s] topological sort failed, ret:%d.", output_merged_compute_graph->GetName().c_str(), ret); | |||||
| return ret; | |||||
| } | |||||
| GELOGD("TUU:Print-%s", PrintCheckLog().c_str()); | |||||
| GELOGI("TUU:output_merged_compute_graph %s success", output_merged_compute_graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::MergeSubGraph(ComputeGraphPtr &subgraph) { | |||||
| for (auto &node : subgraph->GetDirectNode()) { | |||||
| if (kPartitionOpTypes.count(node->GetType()) > 0) { | |||||
| GELOGE(FAILED, "TUU:subgraph passed in should not contain nodes of end or pld type"); | |||||
| return FAILED; | |||||
| } | |||||
| // handle data converted from pld node | |||||
| if (node->GetType() == DATA) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| std::string peer_out_name; | |||||
| bool has_valid_str = (AttrUtils::GetStr(op_desc, peer_node_name_attr, peer_out_name)) && (!peer_out_name.empty()); | |||||
| if (has_valid_str) { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| data_2_netoutput_.emplace(op_desc->GetName(), peer_out_name); | |||||
| data_node_2_netoutput_.emplace(node, peer_out_name); | |||||
| continue; | |||||
| } | |||||
| } | |||||
| // handle netoutput converted from end node | |||||
| if (node->GetType() == NETOUTPUT) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| std::vector<string> out_alias_name; | |||||
| bool has_valid_str = | |||||
| (AttrUtils::GetListStr(op_desc, alias_name_attr, out_alias_name)) && (!out_alias_name.empty()); | |||||
| if (has_valid_str) { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| netoutput_nodes_.insert(node); | |||||
| } | |||||
| } | |||||
| { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| merged_graph_nodes_.emplace(node); | |||||
| } | |||||
| GELOGD("TUU:subgraph %s add node %s success", subgraph->GetName().c_str(), node->GetName().c_str()); | |||||
| } | |||||
| GELOGI("TUU:merge subgraph %s success", subgraph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::RemoveDataNetoutputEdge(ComputeGraphPtr &graph) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| // 1. traverse | |||||
| for (auto &pair : data_node_2_netoutput_) { | |||||
| auto data_node = pair.first; | |||||
| GE_CHECK_NOTNULL(data_node); | |||||
| auto netoutput_name = pair.second; | |||||
| auto netoutput_node = graph->FindNode(netoutput_name); | |||||
| GE_CHECK_NOTNULL(netoutput_node); | |||||
| data_node_2_netoutput_node_.emplace(data_node, netoutput_node); | |||||
| // 2. get `data out anchor` and `net output in anchor` and `net output in node's out anchor` | |||||
| AnchorPtr data_out_anchor = (data_node->GetOutDataAnchor(0)->GetFirstPeerAnchor() == nullptr) | |||||
| ? Anchor::DynamicAnchorCast<Anchor>(data_node->GetOutControlAnchor()) | |||||
| : Anchor::DynamicAnchorCast<Anchor>(data_node->GetOutDataAnchor(0)); | |||||
| AnchorPtr net_output_in_anchor = nullptr; | |||||
| AnchorPtr src_out_anchor = nullptr; | |||||
| if (GetInAndOutAnchorPair(data_node, netoutput_node, net_output_in_anchor, src_out_anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:get out node:%s 's in anchor related with data node:%s failed", | |||||
| netoutput_node->GetName().c_str(), data_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // 3. relink | |||||
| if (GraphUtils::RemoveEdge(src_out_anchor, net_output_in_anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", | |||||
| GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(), | |||||
| GetNodeNameByAnchor(net_output_in_anchor.get()).c_str(), net_output_in_anchor->GetIdx(), | |||||
| data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GE_CHECK_NOTNULL(data_out_anchor); | |||||
| for (const auto &peer_in_anchor : data_out_anchor->GetPeerAnchors()) { | |||||
| if (GraphUtils::RemoveEdge(data_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:remove edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", | |||||
| GetNodeNameByAnchor(data_out_anchor.get()).c_str(), data_out_anchor->GetIdx(), | |||||
| GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), | |||||
| data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (GraphUtils::AddEdge(src_out_anchor, peer_in_anchor) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:add edge from %s(%d) to %s(%d) failed. node_name:(data:%s;netoutput:%s), graph_name:%s", | |||||
| GetNodeNameByAnchor(src_out_anchor.get()).c_str(), src_out_anchor->GetIdx(), | |||||
| GetNodeNameByAnchor(peer_in_anchor.get()).c_str(), peer_in_anchor->GetIdx(), | |||||
| data_node->GetName().c_str(), netoutput_node->GetName().c_str(), graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| // 4. remove out nodes added by us | |||||
| for (auto &node : netoutput_nodes_) { | |||||
| NodeUtils::UnlinkAll(*node); | |||||
| if (GraphUtils::RemoveNodeWithoutRelink(graph, node) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "TUU:Failed to remove node %s from graph", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGD("TUU:Remove node %s by the RemoveDataNetoutputEdge process success", node->GetName().c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| graphStatus TuningUtils::GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_node, AnchorPtr &dest_in_anchor, | |||||
| AnchorPtr &src_out_anchor) { | |||||
| // 1. get `data parent node name`, i.e. `netoutput input node name` | |||||
| std::string netoutput_input_name; | |||||
| auto op_desc = data_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if (!AttrUtils::GetStr(op_desc, parent_node_name_attr, netoutput_input_name)) { | |||||
| GELOGE(FAILED, "TUU:Failed to get parent node attr from node %s", op_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // 2. find index | |||||
| int parent_node_anchor_index; | |||||
| if (!AttrUtils::GetInt(op_desc, parent_node_anchor_index_attr, parent_node_anchor_index)) { | |||||
| GELOGE(FAILED, "TUU:Failed to get parent node anchor index attr from node %s", op_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| // 3.find in data or ctrl anchor by 1&2 step | |||||
| for (auto &in_anchor : out_node->GetAllInAnchors()) { | |||||
| GE_CHECK_NOTNULL(in_anchor); | |||||
| for (auto &src_anchor : in_anchor->GetPeerAnchors()) { // get all peer anchors for ctrl | |||||
| GE_CHECK_NOTNULL(src_anchor); | |||||
| auto src_node = src_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(src_node); | |||||
| if (src_node->GetName() == netoutput_input_name && src_anchor->GetIdx() == parent_node_anchor_index) { | |||||
| dest_in_anchor = in_anchor; | |||||
| src_out_anchor = src_anchor; | |||||
| GELOGD("TUU:get out node:%s 's in anchor(%d) src_node:%s 's out anchor(%d) related with data node:%s", | |||||
| out_node->GetName().c_str(), dest_in_anchor->GetIdx(), netoutput_input_name.c_str(), | |||||
| parent_node_anchor_index, data_node->GetName().c_str()); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| GE_CHECK_NOTNULL(dest_in_anchor); | |||||
| GE_CHECK_NOTNULL(src_out_anchor); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -31,6 +31,7 @@ file(GLOB PROTO_HEADER_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "../proto/ge_ir.proto" | "../proto/ge_ir.proto" | ||||
| "../proto/fwk_adapter.proto" | "../proto/fwk_adapter.proto" | ||||
| "../proto/op_mapping_info.proto" | "../proto/op_mapping_info.proto" | ||||
| "../proto/dump_task.proto" | |||||
| ) | ) | ||||
| ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | ge_protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | ||||
| ge_protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) | ge_protobuf_generate(ge PROTO_CLIENT_SRCS PROTO_CLIENT_HDRS ${PROTO_CLIENT_LIST}) | ||||
| @@ -39,6 +40,7 @@ ge_protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST} | |||||
| include_directories(${CMAKE_CURRENT_LIST_DIR}) | include_directories(${CMAKE_CURRENT_LIST_DIR}) | ||||
| include_directories(${GE_SOURCE_DIR}) | include_directories(${GE_SOURCE_DIR}) | ||||
| include_directories(${GE_SOURCE_DIR}/src) | include_directories(${GE_SOURCE_DIR}/src) | ||||
| include_directories(${GE_SOURCE_DIR}/src/ge/analyzer) | |||||
| include_directories(${GE_SOURCE_DIR}/inc) | include_directories(${GE_SOURCE_DIR}/inc) | ||||
| include_directories(${GE_SOURCE_DIR}/inc/common/util) | include_directories(${GE_SOURCE_DIR}/inc/common/util) | ||||
| include_directories(${GE_SOURCE_DIR}/inc/external) | include_directories(${GE_SOURCE_DIR}/inc/external) | ||||
| @@ -55,6 +57,7 @@ include_directories(${CMAKE_BINARY_DIR}/proto/ge) | |||||
| ######### libge_runner.so ############# | ######### libge_runner.so ############# | ||||
| # need to remove dependencies on pb files later | # need to remove dependencies on pb files later | ||||
| file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | ||||
| "analyzer/analyzer.cc" | |||||
| "client/ge_api.cc" | "client/ge_api.cc" | ||||
| "common/dump/dump_manager.cc" | "common/dump/dump_manager.cc" | ||||
| "common/dump/dump_properties.cc" | "common/dump/dump_properties.cc" | ||||
| @@ -105,12 +108,12 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/manager/graph_caching_allocator.cc" | "graph/manager/graph_caching_allocator.cc" | ||||
| "graph/manager/graph_var_manager.cc" | "graph/manager/graph_var_manager.cc" | ||||
| "graph/manager/model_manager/event_manager.cc" | "graph/manager/model_manager/event_manager.cc" | ||||
| "graph/manager/rdma_pool_allocator.cc" | |||||
| "graph/manager/trans_var_data_utils.cc" | "graph/manager/trans_var_data_utils.cc" | ||||
| "graph/manager/util/debug.cc" | "graph/manager/util/debug.cc" | ||||
| "graph/manager/util/hcom_util.cc" | "graph/manager/util/hcom_util.cc" | ||||
| "graph/manager/util/rt_context_util.cc" | "graph/manager/util/rt_context_util.cc" | ||||
| "graph/manager/util/variable_accelerate_ctrl.cc" | "graph/manager/util/variable_accelerate_ctrl.cc" | ||||
| "graph/manager/model_manager/event_manager.cc" | |||||
| "graph/manager/util/debug.cc" | "graph/manager/util/debug.cc" | ||||
| "graph/manager/util/hcom_util.cc" | "graph/manager/util/hcom_util.cc" | ||||
| "graph/manager/util/rt_context_util.cc" | "graph/manager/util/rt_context_util.cc" | ||||
| @@ -228,6 +231,7 @@ target_link_libraries(ge_runner | |||||
| ######### libge_compiler.so ############# | ######### libge_compiler.so ############# | ||||
| # need to remove dependencies on pb files later | # need to remove dependencies on pb files later | ||||
| file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | ||||
| "analyzer/analyzer.cc" | |||||
| "common/dump/dump_properties.cc" | "common/dump/dump_properties.cc" | ||||
| "common/dump/dump_manager.cc" | "common/dump/dump_manager.cc" | ||||
| "common/dump/dump_op.cc" | "common/dump/dump_op.cc" | ||||
| @@ -276,6 +280,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/manager/trans_var_data_utils.cc" | "graph/manager/trans_var_data_utils.cc" | ||||
| "graph/manager/graph_var_manager.cc" | "graph/manager/graph_var_manager.cc" | ||||
| "graph/manager/model_manager/event_manager.cc" | "graph/manager/model_manager/event_manager.cc" | ||||
| "graph/manager/rdma_pool_allocator.cc" | |||||
| "graph/manager/util/debug.cc" | "graph/manager/util/debug.cc" | ||||
| "graph/manager/util/rt_context_util.cc" | "graph/manager/util/rt_context_util.cc" | ||||
| "graph/manager/util/variable_accelerate_ctrl.cc" | "graph/manager/util/variable_accelerate_ctrl.cc" | ||||
| @@ -0,0 +1,304 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "analyzer.h" | |||||
| #include <cstdlib> | |||||
| #include <cstdio> | |||||
| #include <iostream> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/util.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| namespace ge { | |||||
| using json = nlohmann::json; | |||||
| using Status = ge::Status; | |||||
| using ComputeGraph = ge::ComputeGraph; | |||||
| using namespace analyzer; | |||||
| namespace { | |||||
| constexpr int kFileAuthority = 0640; | |||||
| constexpr int kJsonDumpLevel = 4; | |||||
| const std::string kFilePath = "./"; | |||||
| const std::string kAnalyzeFile = "ge_check_op.json"; | |||||
| const std::string kUnknownShape = "unknownshape"; | |||||
| const std::string kUnsupport = "unsupport"; | |||||
| const std::string kSessionId = "session_id"; | |||||
| const std::string kGraphId = "graph_id"; | |||||
| const std::string kOpInfo = "op_info"; | |||||
| const std::string kErrorType = "error_type"; | |||||
| const std::string kOpName = "name"; | |||||
| const std::string kOpType = "type"; | |||||
| const std::string kReason = "reason"; | |||||
| const std::string kInput = "input"; | |||||
| const std::string kOutput = "output"; | |||||
| const std::string kShape = "shape"; | |||||
| const std::string kDataType = "data_type"; | |||||
| const std::string kLayout = "layout"; | |||||
| const std::string kResult = "result"; | |||||
| const std::string kOp = "op"; | |||||
| std::map<analyzer::AnalyzeType, std::string> errors_map{{PARSER, "paser_error"}, | |||||
| {INFER_SHAPE, "infer_shape_error"}, | |||||
| {CHECKSUPPORT, "check_support_error"}, | |||||
| {GRAPH_OPTIMIZE, "graph_optimize_error"}, | |||||
| {GRAPH_PARTION, "graph_partion_error"}, | |||||
| {GRAPH_BUILDER, "graph_builder_error"}}; | |||||
| } // namespace | |||||
| Analyzer *Analyzer::GetInstance() { | |||||
| static Analyzer instance; | |||||
| return &instance; | |||||
| } | |||||
| Status Analyzer::BuildJsonObject(uint64_t session_id, uint64_t graph_id) { | |||||
| GELOGD("Start to build map. SessionId:%lu GraphId:%lu", session_id, graph_id); | |||||
| std::lock_guard<std::recursive_mutex> lg(mutex_); | |||||
| auto iter = graph_infos_.find(session_id); | |||||
| if (iter == graph_infos_.end()) { | |||||
| auto p = new (std::nothrow) GraphInfo(); | |||||
| GE_CHECK_NOTNULL(p); | |||||
| std::shared_ptr<GraphInfo> graph_info(p); | |||||
| std::map<uint64_t, std::shared_ptr<GraphInfo>> graph_map; | |||||
| graph_map[graph_id] = graph_info; | |||||
| graph_info->session_id = session_id; | |||||
| graph_info->graph_id = graph_id; | |||||
| graph_infos_.insert({session_id, graph_map}); | |||||
| } else { | |||||
| auto iter1 = (iter->second).find(graph_id); | |||||
| if (iter1 == (iter->second).end()) { | |||||
| auto p = new (std::nothrow) GraphInfo(); | |||||
| GE_CHECK_NOTNULL(p); | |||||
| std::shared_ptr<GraphInfo> graph_info(p); | |||||
| graph_info->session_id = session_id; | |||||
| graph_info->graph_id = graph_id; | |||||
| (iter->second).insert({graph_id, graph_info}); | |||||
| } else { | |||||
| GELOGI("session_id:%lu graph_id:%lu already existed json object", session_id, graph_id); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| ge::Status Analyzer::Initialize() { | |||||
| ClearHistoryFile(); | |||||
| return CreateAnalyzerFile(); | |||||
| } | |||||
| void Analyzer::Finalize() { | |||||
| GELOGD("Analyzer start to finalize!"); | |||||
| std::lock_guard<std::recursive_mutex> lg(mutex_); | |||||
| for (auto &session_resource : graph_infos_) { | |||||
| session_resource.second.clear(); | |||||
| } | |||||
| graph_infos_.clear(); | |||||
| std::lock_guard<std::mutex> lk(file_mutex_); | |||||
| if (json_file_.is_open()) { | |||||
| json_file_.close(); | |||||
| } | |||||
| } | |||||
| void Analyzer::DestroySessionJsonObject(uint64_t session_id) { | |||||
| std::lock_guard<std::recursive_mutex> lg(mutex_); | |||||
| auto iter = graph_infos_.find(session_id); | |||||
| if (iter == graph_infos_.end()) { | |||||
| GELOGW("can not find the stored object by session_id[%lu].Do nothing", session_id); | |||||
| } else { | |||||
| graph_infos_.erase(iter); | |||||
| } | |||||
| } | |||||
| void Analyzer::DestroyGraphJsonObject(uint64_t session_id, uint64_t graph_id) { | |||||
| std::lock_guard<std::recursive_mutex> lg(mutex_); | |||||
| auto iter = graph_infos_.find(session_id); | |||||
| if (iter == graph_infos_.end()) { | |||||
| GELOGW("can not find the stored object by session_id[%lu].Do nothing", session_id); | |||||
| } else { | |||||
| auto iter1 = (iter->second).find(graph_id); | |||||
| if (iter1 == (iter->second).end()) { | |||||
| GELOGW("can not find the graph json object by session_id[%lu] and graph_id[%lu].Do nothing", session_id, | |||||
| graph_id); | |||||
| } | |||||
| (iter->second).erase(iter1); | |||||
| } | |||||
| } | |||||
| std::shared_ptr<GraphInfo> Analyzer::GetJsonObject(uint64_t session_id, uint64_t graph_id) { | |||||
| std::lock_guard<std::recursive_mutex> lg(mutex_); | |||||
| auto iter = graph_infos_.find(session_id); | |||||
| if (iter == graph_infos_.end()) { | |||||
| GELOGE(PARAM_INVALID, "session_id:%lu does not exist!", session_id); | |||||
| return nullptr; | |||||
| } else { | |||||
| auto iter1 = (iter->second).find(graph_id); | |||||
| if (iter1 == (iter->second).end()) { | |||||
| GELOGE(PARAM_INVALID, "graph_id:%lu does not exist!", graph_id); | |||||
| return nullptr; | |||||
| } | |||||
| GELOGI("GetJsonObject Success!session_id:%lu graph_id:%lu", session_id, graph_id); | |||||
| return iter1->second; | |||||
| } | |||||
| } | |||||
| void Analyzer::ClearHistoryFile() { | |||||
| GELOGD("Analyzer start to clear history file!"); | |||||
| // Remove history files | |||||
| int res = remove(json_file_name_.c_str()); | |||||
| GELOGD("remove file %s, result:%d", json_file_name_.c_str(), res); | |||||
| } | |||||
| ge::Status Analyzer::CreateAnalyzerFile() { | |||||
| GELOGD("start to create analyzer file!"); | |||||
| // Check whether the manifest exists, if not, create it. | |||||
| string real_path = RealPath(kFilePath.c_str()); | |||||
| if (real_path.empty()) { | |||||
| GELOGE(FAILED, "File path is invalid."); | |||||
| return FAILED; | |||||
| } | |||||
| string file = real_path + "/" + kAnalyzeFile; | |||||
| GELOGD("Created analyzer file:[%s]", file.c_str()); | |||||
| int fd = open(file.c_str(), O_WRONLY | O_CREAT | O_TRUNC, kFileAuthority); | |||||
| if (fd < 0) { | |||||
| GELOGE(INTERNAL_ERROR, "Fail to open the file: %s.", file.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (close(fd) != 0) { | |||||
| GELOGE(INTERNAL_ERROR, "Fail to close the file: %s.", file.c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| json_file_name_ = file; | |||||
| GELOGD("success to create analyzer file[%s]!", json_file_name_.c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| ge::Status Analyzer::SaveAnalyzerDataToFile() { | |||||
| GELOGD("start to save analyze file!"); | |||||
| std::lock_guard<std::mutex> lg(file_mutex_); | |||||
| json_file_.open(json_file_name_, std::ios::out); | |||||
| if (!json_file_.is_open()) { | |||||
| GELOGE(FAILED, "analyzer file does not exist[%s]", json_file_name_.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| std::lock_guard<std::recursive_mutex> lk(mutex_); | |||||
| for (auto &ele : graph_infos_) { | |||||
| for (auto &ele2 : ele.second) { | |||||
| json jsn; | |||||
| GraphInfoToJson(jsn, *(ele2.second)); | |||||
| json_file_ << jsn.dump(kJsonDumpLevel) << std::endl; | |||||
| } | |||||
| } | |||||
| json_file_.close(); | |||||
| return SUCCESS; | |||||
| } | |||||
| ge::Status Analyzer::DoAnalyze(DataInfo &data_info) { | |||||
| GELOGD("start to do analyzer!"); | |||||
| auto pnode = data_info.node_ptr; | |||||
| GE_CHECK_NOTNULL(pnode); | |||||
| auto desc = pnode->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(desc); | |||||
| // buff analyze data | |||||
| std::lock_guard<std::recursive_mutex> lg(mutex_); | |||||
| auto graph_info = GetJsonObject(data_info.session_id, data_info.graph_id); | |||||
| GE_CHECK_NOTNULL(graph_info); | |||||
| auto status = SaveOpInfo(desc, data_info, graph_info); | |||||
| if (status != SUCCESS) { | |||||
| GELOGE(status, "save op info failed!"); | |||||
| return FAILED; | |||||
| } | |||||
| // save data to file | |||||
| return SaveAnalyzerDataToFile(); | |||||
| } | |||||
| ge::Status Analyzer::SaveOpInfo(ge::OpDescPtr desc, DataInfo &data_info, | |||||
| std::shared_ptr<analyzer::GraphInfo> graph_info) { | |||||
| auto iter = errors_map.find(data_info.analyze_type); | |||||
| if (iter == errors_map.end()) { | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| OpInfo op_info; | |||||
| op_info.error_type = iter->second; | |||||
| op_info.op_name = desc->GetName(); | |||||
| op_info.op_type = desc->GetType(); | |||||
| op_info.reason = data_info.reason; | |||||
| for (const auto &ptr : desc->GetAllInputsDescPtr()) { | |||||
| TensorInfo tensor_info; | |||||
| tensor_info.shape = ptr->GetShape().GetDims(); | |||||
| tensor_info.d_type = ge::TypeUtils::DataTypeToSerialString(ptr->GetDataType()); | |||||
| tensor_info.layout = ge::TypeUtils::FormatToSerialString(ptr->GetFormat()); | |||||
| op_info.input_info.emplace_back(tensor_info); | |||||
| } | |||||
| for (const auto &ptr : desc->GetAllOutputsDescPtr()) { | |||||
| TensorInfo tensor_info; | |||||
| tensor_info.shape = ptr->GetShape().GetDims(); | |||||
| tensor_info.d_type = ge::TypeUtils::DataTypeToSerialString(ptr->GetDataType()); | |||||
| tensor_info.layout = ge::TypeUtils::FormatToSerialString(ptr->GetFormat()); | |||||
| op_info.output_info.emplace_back(tensor_info); | |||||
| } | |||||
| graph_info->op_info.emplace_back(op_info); | |||||
| return SUCCESS; | |||||
| } | |||||
| void Analyzer::TensorInfoToJson(json &j, const TensorInfo &tensor_info) { | |||||
| j[kShape] = tensor_info.shape; | |||||
| j[kDataType] = tensor_info.d_type; | |||||
| j[kLayout] = tensor_info.layout; | |||||
| } | |||||
| void Analyzer::OpInfoToJson(json &j, const OpInfo &op_info) { | |||||
| j[kErrorType] = op_info.error_type; | |||||
| j[kOpName] = op_info.op_name; | |||||
| j[kOpType] = op_info.op_type; | |||||
| j[kReason] = op_info.reason; | |||||
| for (size_t i = 0; i < op_info.input_info.size(); i++) { | |||||
| json json_tensor_info; | |||||
| TensorInfoToJson(json_tensor_info, op_info.input_info.at(i)); | |||||
| j[kInput + std::to_string(i)] = json_tensor_info; | |||||
| } | |||||
| for (size_t i = 0; i < op_info.output_info.size(); i++) { | |||||
| json json_tensor_info; | |||||
| TensorInfoToJson(json_tensor_info, op_info.output_info.at(i)); | |||||
| j[kOutput + std::to_string(i)] = json_tensor_info; | |||||
| } | |||||
| } | |||||
| void Analyzer::GraphInfoToJson(json &j, const GraphInfo &graph_info) { | |||||
| GELOGD("start to buff graph info!"); | |||||
| j[kSessionId] = graph_info.session_id; | |||||
| j[kGraphId] = graph_info.graph_id; | |||||
| std::vector<json> json_op_infos; | |||||
| for (size_t i = 0; i < graph_info.op_info.size(); i++) { | |||||
| json json_op_info; | |||||
| OpInfoToJson(json_op_info, graph_info.op_info.at(i)); | |||||
| json_op_infos.emplace_back(json_op_info); | |||||
| } | |||||
| j[kOp] = json_op_infos; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,186 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef DOMI_ANALYZER_ANANLYZER_H_ | |||||
| #define DOMI_ANALYZER_ANANLYZER_H_ | |||||
| #include "nlohmann/json.hpp" | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <mutex> | |||||
| #include <memory> | |||||
| #include <fstream> | |||||
| #include "external/ge/ge_api_types.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/node.h" | |||||
| namespace ge { | |||||
| namespace analyzer { | |||||
| enum AnalyzeType { | |||||
| PARSER = 0, | |||||
| INFER_SHAPE = 1, | |||||
| CHECKSUPPORT = 2, | |||||
| GRAPH_OPTIMIZE = 3, | |||||
| GRAPH_PARTION = 4, | |||||
| GRAPH_BUILDER = 5, | |||||
| }; | |||||
| struct TensorInfo { | |||||
| vector<int64_t> shape; | |||||
| string d_type; | |||||
| string layout; | |||||
| }; | |||||
| struct OpInfo { | |||||
| string error_type; | |||||
| string op_name; | |||||
| string op_type; | |||||
| std::vector<TensorInfo> input_info; | |||||
| std::vector<TensorInfo> output_info; | |||||
| string reason; | |||||
| }; | |||||
| struct GraphInfo { | |||||
| uint64_t session_id = 0; | |||||
| uint64_t graph_id = 0; | |||||
| std::vector<OpInfo> op_info; | |||||
| }; | |||||
| struct DataInfo { | |||||
| DataInfo() = default; | |||||
| ~DataInfo() = default; | |||||
| DataInfo(uint64_t sess, uint64_t graph, AnalyzeType type, ge::NodePtr node, std::string error_info) { | |||||
| session_id = sess; | |||||
| graph_id = graph; | |||||
| analyze_type = type; | |||||
| node_ptr = node; | |||||
| reason = error_info; | |||||
| } | |||||
| uint64_t session_id; | |||||
| uint64_t graph_id; | |||||
| AnalyzeType analyze_type; | |||||
| ge::NodePtr node_ptr{nullptr}; | |||||
| std::string reason; | |||||
| }; | |||||
| } // namespace analyzer | |||||
| class Analyzer { | |||||
| public: | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: get analyzer instance. | |||||
| * @param [in]: None | |||||
| * @return: Analyzer instance ptr | |||||
| */ | |||||
| static Analyzer *GetInstance(); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: check whether env var ENABLE_NETWORK_ANALYSIS_DEBUG is enabled. | |||||
| * When enable env, it will keep adaptor sink geop graph even though fail. | |||||
| * @param [in]: None | |||||
| * @return: true: enable env false : disable env | |||||
| */ | |||||
| bool IsEnableNetAnalyzeDebug() { return std::getenv("ENABLE_NETWORK_ANALYSIS_DEBUG") != nullptr; } | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: build buff object by sess id and graph id . | |||||
| * @param [in]: session id & graph id | |||||
| * @return: 0: success other: failed | |||||
| */ | |||||
| ge::Status BuildJsonObject(uint64_t session_id, uint64_t graph_id); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: get buff object by sess id and graph id . | |||||
| * @param [in]: session id & graph id | |||||
| * @return: nullptr if failed | |||||
| */ | |||||
| std::shared_ptr<analyzer::GraphInfo> GetJsonObject(uint64_t session_id, uint64_t graph_id); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: analyzer globle init method. | |||||
| * @param [in]: None | |||||
| * @return: None | |||||
| */ | |||||
| ge::Status Initialize(); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: DeConstruct method. Release all used resource of analyzer. | |||||
| * @param [in]: None | |||||
| * @return: None | |||||
| */ | |||||
| void Finalize(); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: DeConstruct method. Only release resource about session id. | |||||
| * @param [in]: None | |||||
| * @return: None | |||||
| */ | |||||
| void DestroySessionJsonObject(uint64_t session_id); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: DeConstruct method. Only release resource about session id and graph id. | |||||
| * @param [in]: None | |||||
| * @return: None | |||||
| */ | |||||
| void DestroyGraphJsonObject(uint64_t session_id, uint64_t graph_id); | |||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: main process method. Buff analyzed data and output to json file | |||||
| * @param [in]: DataInfo Object | |||||
| * @return: 0: SUCCESS other: FAILED | |||||
| */ | |||||
| ge::Status DoAnalyze(analyzer::DataInfo &data_info); | |||||
| Analyzer(const Analyzer &) = delete; | |||||
| Analyzer &operator=(const Analyzer &) = delete; | |||||
| Analyzer(Analyzer &&) = delete; | |||||
| Analyzer &operator=(Analyzer &&) = delete; | |||||
| private: | |||||
| void TensorInfoToJson(nlohmann::json &j, const analyzer::TensorInfo &tensor_info); | |||||
| void OpInfoToJson(nlohmann::json &j, const analyzer::OpInfo &op_info); | |||||
| void GraphInfoToJson(nlohmann::json &j, const analyzer::GraphInfo &graph_info); | |||||
| ge::Status SaveAnalyzerDataToFile(); | |||||
| ge::Status SaveOpInfo(ge::OpDescPtr desc, analyzer::DataInfo &data_info, | |||||
| std::shared_ptr<analyzer::GraphInfo> graph_info); | |||||
| void ClearHistoryFile(); | |||||
| ge::Status CreateAnalyzerFile(); | |||||
| explicit Analyzer(){}; | |||||
| ~Analyzer() = default; | |||||
| private: | |||||
| std::map<uint64_t, std::map<uint64_t, std::shared_ptr<analyzer::GraphInfo>>> graph_infos_; | |||||
| std::recursive_mutex mutex_; // protect graph_infos_ | |||||
| std::mutex file_mutex_; // protect json_file_ | |||||
| std::ofstream json_file_; | |||||
| std::string json_file_name_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // DOMI_ANALYZER_ANANLYZER_H_ | |||||
| @@ -32,7 +32,6 @@ | |||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| #include "common/ge/tbe_plugin_manager.h" | #include "common/ge/tbe_plugin_manager.h" | ||||
| using domi::GetContext; | |||||
| using domi::OpRegistry; | using domi::OpRegistry; | ||||
| using std::map; | using std::map; | ||||
| using std::string; | using std::string; | ||||
| @@ -25,6 +25,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "../model/ge_model.cc" | "../model/ge_model.cc" | ||||
| "auth/file_saver.cc" | "auth/file_saver.cc" | ||||
| "context/ctx.cc" | "context/ctx.cc" | ||||
| "cust_aicpu_kernel_store.cc" | |||||
| "debug/memory_dumper.cc" | "debug/memory_dumper.cc" | ||||
| "fmk_error_codes.cc" | "fmk_error_codes.cc" | ||||
| "formats/format_transfers/datatype_transfer.cc" | "formats/format_transfers/datatype_transfer.cc" | ||||
| @@ -52,6 +53,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "ge_format_util.cc" | "ge_format_util.cc" | ||||
| "helper/model_helper.cc" | "helper/model_helper.cc" | ||||
| "helper/om_file_helper.cc" | "helper/om_file_helper.cc" | ||||
| "kernel_store.cc" | |||||
| "math/fp16_math.cc" | "math/fp16_math.cc" | ||||
| "model_parser/base.cc" | "model_parser/base.cc" | ||||
| "model_saver.cc" | "model_saver.cc" | ||||
| @@ -0,0 +1,119 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_BASE64_H_ | |||||
| #define GE_COMMON_BASE64_H_ | |||||
| #include <algorithm> | |||||
| #include <string> | |||||
| #include "debug/ge_log.h" | |||||
| #include "ge_error_codes.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| const char *kBase64Chars = | |||||
| "ABCDEFGHIJKLMNOPQRSTUVWXYZ" | |||||
| "abcdefghijklmnopqrstuvwxyz" | |||||
| "0123456789+/"; | |||||
| const char kEqualSymbol = '='; | |||||
| const size_t kBase64CharsNum = 64; | |||||
| const size_t kThreeByteOneGroup = 3; | |||||
| const size_t kFourByteOneGroup = 4; | |||||
| } // namespace | |||||
| namespace base64 { | |||||
| static inline bool IsBase64Char(const char &c) { return (isalnum(c) || (c == '+') || (c == '/')); } | |||||
| static std::string EncodeToBase64(const std::string &raw_data) { | |||||
| size_t encode_length = raw_data.size() / kThreeByteOneGroup * kFourByteOneGroup; | |||||
| encode_length += raw_data.size() % kThreeByteOneGroup == 0 ? 0 : kFourByteOneGroup; | |||||
| size_t raw_data_index = 0; | |||||
| size_t encode_data_index = 0; | |||||
| std::string encode_data; | |||||
| encode_data.resize(encode_length); | |||||
| for (; raw_data_index + kThreeByteOneGroup <= raw_data.size(); raw_data_index += kThreeByteOneGroup) { | |||||
| auto char_1 = static_cast<uint8_t>(raw_data[raw_data_index]); | |||||
| auto char_2 = static_cast<uint8_t>(raw_data[raw_data_index + 1]); | |||||
| auto char_3 = static_cast<uint8_t>(raw_data[raw_data_index + 2]); | |||||
| encode_data[encode_data_index++] = kBase64Chars[char_1 >> 2u]; | |||||
| encode_data[encode_data_index++] = kBase64Chars[((char_1 << 4u) & 0x30) | (char_2 >> 4u)]; | |||||
| encode_data[encode_data_index++] = kBase64Chars[((char_2 << 2u) & 0x3c) | (char_3 >> 6u)]; | |||||
| encode_data[encode_data_index++] = kBase64Chars[char_3 & 0x3f]; | |||||
| } | |||||
| if (raw_data_index < raw_data.size()) { | |||||
| auto tail = raw_data.size() - raw_data_index; | |||||
| auto char_1 = static_cast<uint8_t>(raw_data[raw_data_index]); | |||||
| if (tail == 1) { | |||||
| encode_data[encode_data_index++] = kBase64Chars[char_1 >> 2u]; | |||||
| encode_data[encode_data_index++] = kBase64Chars[(char_1 << 4u) & 0x30]; | |||||
| encode_data[encode_data_index++] = kEqualSymbol; | |||||
| encode_data[encode_data_index++] = kEqualSymbol; | |||||
| } else { | |||||
| auto char_2 = static_cast<uint8_t>(raw_data[raw_data_index + 1]); | |||||
| encode_data[encode_data_index++] = kBase64Chars[char_1 >> 2u]; | |||||
| encode_data[encode_data_index++] = kBase64Chars[((char_1 << 4u) & 0x30) | (char_2 >> 4u)]; | |||||
| encode_data[encode_data_index++] = kBase64Chars[(char_2 << 2u) & 0x3c]; | |||||
| encode_data[encode_data_index++] = kEqualSymbol; | |||||
| } | |||||
| } | |||||
| return encode_data; | |||||
| } | |||||
| #pragma GCC diagnostic push | |||||
| #pragma GCC diagnostic ignored "-Wunused-function" | |||||
| static Status DecodeFromBase64(const std::string &base64_data, std::string &decode_data) { | |||||
| if (base64_data.size() % kFourByteOneGroup != 0) { | |||||
| GELOGE(PARAM_INVALID, "base64 data size must can be divided by 4, but given data size is %zu", base64_data.size()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| decode_data.clear(); | |||||
| size_t base64_data_len = base64_data.size(); | |||||
| uint8_t byte_4[kFourByteOneGroup]; | |||||
| auto FindCharInBase64Chars = [&](const char &raw_char) -> uint8_t { | |||||
| auto char_pos = std::find(kBase64Chars, kBase64Chars + kBase64CharsNum, raw_char); | |||||
| return static_cast<uint8_t>(std::distance(kBase64Chars, char_pos)) & 0xff; | |||||
| }; | |||||
| for (std::size_t input_data_index = 0; input_data_index < base64_data_len; input_data_index += 4) { | |||||
| for (size_t i = 0; i < kFourByteOneGroup; ++i) { | |||||
| if (base64_data[input_data_index + i] == kEqualSymbol && input_data_index >= base64_data_len - 4 && i > 1) { | |||||
| byte_4[i] = kBase64CharsNum; | |||||
| } else if (IsBase64Char(base64_data[input_data_index + i])) { | |||||
| byte_4[i] = FindCharInBase64Chars(base64_data[input_data_index + i]); | |||||
| } else { | |||||
| GELOGE(PARAM_INVALID, "given base64 data is illegal"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| } | |||||
| decode_data += static_cast<char>((byte_4[0] << 2u) + ((byte_4[1] & 0x30) >> 4u)); | |||||
| if (byte_4[2] >= kBase64CharsNum) { | |||||
| break; | |||||
| } else if (byte_4[3] >= kBase64CharsNum) { | |||||
| decode_data += static_cast<char>(((byte_4[1] & 0x0f) << 4u) + ((byte_4[2] & 0x3c) >> 2u)); | |||||
| break; | |||||
| } | |||||
| decode_data += static_cast<char>(((byte_4[1] & 0x0f) << 4u) + ((byte_4[2] & 0x3c) >> 2u)); | |||||
| decode_data += static_cast<char>(((byte_4[2] & 0x03) << 6u) + byte_4[3]); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| #pragma GCC diagnostic pop | |||||
| } // namespace base64 | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_BASE64_H_ | |||||
| @@ -0,0 +1,37 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/cust_aicpu_kernel_store.h" | |||||
| namespace ge { | |||||
| CustAICPUKernelStore::CustAICPUKernelStore() {} | |||||
| void CustAICPUKernelStore::AddCustAICPUKernel(const CustAICPUKernelPtr &kernel) { AddKernel(kernel); } | |||||
| void CustAICPUKernelStore::LoadCustAICPUKernelBinToOpDesc(const std::shared_ptr<ge::OpDesc> &op_desc) const { | |||||
| GELOGI("LoadCustAICPUKernelBinToOpDesc in"); | |||||
| if (op_desc != nullptr) { | |||||
| auto kernel_bin = FindKernel(op_desc->GetName()); | |||||
| if (kernel_bin != nullptr) { | |||||
| GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(ge::OP_EXTATTR_CUSTAICPU_KERNEL, kernel_bin), | |||||
| GELOGW("LoadKernelCustAICPUBinToOpDesc: SetExtAttr for kernel_bin failed");) | |||||
| GELOGI("Load cust aicpu kernel:%s, %zu", kernel_bin->GetName().c_str(), kernel_bin->GetBinDataSize()); | |||||
| } | |||||
| } | |||||
| GELOGI("LoadCustAICPUKernelBinToOpDesc success"); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,35 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_CUST_AICPU_KERNEL_STORE_H_ | |||||
| #define GE_COMMON_CUST_AICPU_KERNEL_STORE_H_ | |||||
| #include "common/kernel_store.h" | |||||
| namespace ge { | |||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY CustAICPUKernelStore : public KernelStore { | |||||
| public: | |||||
| CustAICPUKernelStore(); | |||||
| ~CustAICPUKernelStore() {} | |||||
| void AddCustAICPUKernel(const CustAICPUKernelPtr &kernel); | |||||
| void LoadCustAICPUKernelBinToOpDesc(const std::shared_ptr<ge::OpDesc> &op_desc) const; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_CUST_AICPU_KERNEL_STORE_H_ | |||||
| @@ -157,7 +157,7 @@ int MemoryDumper::OpenFile(const char *filename) { | |||||
| // Using the O_EXCL, if the file already exists,return failed to avoid privilege escalation vulnerability. | // Using the O_EXCL, if the file already exists,return failed to avoid privilege escalation vulnerability. | ||||
| mode_t mode = S_IRUSR | S_IWUSR; | mode_t mode = S_IRUSR | S_IWUSR; | ||||
| int32_t fd = mmOpen2(real_path.c_str(), O_WRONLY | O_CREAT | O_TRUNC, mode); | |||||
| int32_t fd = mmOpen2(real_path.c_str(), O_RDWR | O_CREAT | O_APPEND, mode); | |||||
| if (fd == EN_ERROR || fd == EN_INVALID_PARAM) { | if (fd == EN_ERROR || fd == EN_INVALID_PARAM) { | ||||
| GELOGE(kInvalidFd, "open file failed. errno = %d, %s", fd, strerror(errno)); | GELOGE(kInvalidFd, "open file failed. errno = %d, %s", fd, strerror(errno)); | ||||
| return kInvalidFd; | return kInvalidFd; | ||||
| @@ -44,6 +44,9 @@ enum DataTypeTransMode { | |||||
| kTransferWithDatatypeInt8ToFloat, | kTransferWithDatatypeInt8ToFloat, | ||||
| kTransferWithDatatypeInt8ToInt32, | kTransferWithDatatypeInt8ToInt32, | ||||
| kTransferWithDatatypeInt64ToInt32, | kTransferWithDatatypeInt64ToInt32, | ||||
| kTransferWithDatatypeInt32ToInt64, | |||||
| kTransferWithDatatypeInt32ToDouble, | |||||
| kTransferWithDatatypeDoubleToInt32, | |||||
| }; | }; | ||||
| std::map<std::pair<DataType, DataType>, DataTypeTransMode> trans_mode_map{ | std::map<std::pair<DataType, DataType>, DataTypeTransMode> trans_mode_map{ | ||||
| @@ -59,7 +62,11 @@ std::map<std::pair<DataType, DataType>, DataTypeTransMode> trans_mode_map{ | |||||
| {std::pair<DataType, DataType>(DT_UINT8, DT_INT32), kTransferWithDatatypeUint8ToInt32}, | {std::pair<DataType, DataType>(DT_UINT8, DT_INT32), kTransferWithDatatypeUint8ToInt32}, | ||||
| {std::pair<DataType, DataType>(DT_INT8, DT_FLOAT), kTransferWithDatatypeInt8ToFloat}, | {std::pair<DataType, DataType>(DT_INT8, DT_FLOAT), kTransferWithDatatypeInt8ToFloat}, | ||||
| {std::pair<DataType, DataType>(DT_INT8, DT_INT32), kTransferWithDatatypeInt8ToInt32}, | {std::pair<DataType, DataType>(DT_INT8, DT_INT32), kTransferWithDatatypeInt8ToInt32}, | ||||
| {std::pair<DataType, DataType>(DT_INT64, DT_INT32), kTransferWithDatatypeInt64ToInt32}}; | |||||
| {std::pair<DataType, DataType>(DT_INT64, DT_INT32), kTransferWithDatatypeInt64ToInt32}, | |||||
| {std::pair<DataType, DataType>(DT_INT32, DT_INT64), kTransferWithDatatypeInt32ToInt64}, | |||||
| {std::pair<DataType, DataType>(DT_INT32, DT_DOUBLE), kTransferWithDatatypeInt32ToDouble}, | |||||
| {std::pair<DataType, DataType>(DT_DOUBLE, DT_INT32), kTransferWithDatatypeDoubleToInt32}, | |||||
| }; | |||||
| template <typename SrcT, typename DstT> | template <typename SrcT, typename DstT> | ||||
| Status TransDataSrc2Dst(const CastArgs &args, uint8_t *dst, const size_t data_size) { | Status TransDataSrc2Dst(const CastArgs &args, uint8_t *dst, const size_t data_size) { | ||||
| @@ -82,38 +89,30 @@ Status TransDataSrc2Fp16(const CastArgs &args, uint8_t *dst, const size_t data_s | |||||
| } | } | ||||
| Status CastKernel(const CastArgs &args, uint8_t *dst, const size_t data_size, const DataTypeTransMode trans_mode) { | Status CastKernel(const CastArgs &args, uint8_t *dst, const size_t data_size, const DataTypeTransMode trans_mode) { | ||||
| switch (trans_mode) { | |||||
| case kTransferWithDatatypeFloatToFloat16: | |||||
| return TransDataSrc2Fp16<float>(args, dst, data_size); | |||||
| case kTransferWithDatatypeFloatToInt32: | |||||
| return TransDataSrc2Dst<float, int32_t>(args, dst, data_size); | |||||
| case kTransferWithDatatypeFloat16ToFloat: | |||||
| return TransDataSrc2Dst<fp16_t, float>(args, dst, data_size); | |||||
| case kTransferWithDatatypeFloat16ToInt32: | |||||
| return TransDataSrc2Dst<fp16_t, int32_t>(args, dst, data_size); | |||||
| case kTransferWithDatatypeInt32ToFloat: | |||||
| return TransDataSrc2Dst<int32_t, float>(args, dst, data_size); | |||||
| case kTransferWithDatatypeInt32ToFloat16: | |||||
| return TransDataSrc2Fp16<int32_t>(args, dst, data_size); | |||||
| case kTransferWithDatatypeInt32ToUint8: | |||||
| return TransDataSrc2Dst<int32_t, uint8_t>(args, dst, data_size); | |||||
| case kTransferWithDatatypeInt32ToInt8: | |||||
| return TransDataSrc2Dst<int32_t, int8_t>(args, dst, data_size); | |||||
| case kTransferWithDatatypeUint8ToFloat: | |||||
| return TransDataSrc2Dst<uint8_t, float>(args, dst, data_size); | |||||
| case kTransferWithDatatypeUint8ToInt32: | |||||
| return TransDataSrc2Dst<uint8_t, int32_t>(args, dst, data_size); | |||||
| case kTransferWithDatatypeInt8ToFloat: | |||||
| return TransDataSrc2Dst<int8_t, float>(args, dst, data_size); | |||||
| case kTransferWithDatatypeInt8ToInt32: | |||||
| return TransDataSrc2Dst<int8_t, int32_t>(args, dst, data_size); | |||||
| case kTransferWithDatatypeInt64ToInt32: | |||||
| return TransDataSrc2Dst<int64_t, int32_t>(args, dst, data_size); | |||||
| default: | |||||
| GELOGE(PARAM_INVALID, "Trans data type from %s to %s is not supported.", | |||||
| TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | |||||
| return UNSUPPORTED; | |||||
| static std::map<DataTypeTransMode, std::function<Status(const CastArgs &, uint8_t *, const size_t)>> transfer_handle = | |||||
| { | |||||
| {kTransferWithDatatypeFloatToFloat16, TransDataSrc2Fp16<float>}, | |||||
| {kTransferWithDatatypeFloatToInt32, TransDataSrc2Dst<float, int32_t>}, | |||||
| {kTransferWithDatatypeFloat16ToFloat, TransDataSrc2Dst<fp16_t, float>}, | |||||
| {kTransferWithDatatypeFloat16ToInt32, TransDataSrc2Dst<fp16_t, int32_t>}, | |||||
| {kTransferWithDatatypeInt32ToFloat, TransDataSrc2Dst<int32_t, float>}, | |||||
| {kTransferWithDatatypeInt32ToFloat16, TransDataSrc2Fp16<int32_t>}, | |||||
| {kTransferWithDatatypeInt32ToUint8, TransDataSrc2Dst<int32_t, uint8_t>}, | |||||
| {kTransferWithDatatypeInt32ToInt8, TransDataSrc2Dst<int32_t, int8_t>}, | |||||
| {kTransferWithDatatypeUint8ToFloat, TransDataSrc2Dst<uint8_t, float>}, | |||||
| {kTransferWithDatatypeUint8ToInt32, TransDataSrc2Dst<uint8_t, int32_t>}, | |||||
| {kTransferWithDatatypeInt8ToFloat, TransDataSrc2Dst<int8_t, float>}, | |||||
| {kTransferWithDatatypeInt8ToInt32, TransDataSrc2Dst<int8_t, int32_t>}, | |||||
| {kTransferWithDatatypeInt64ToInt32, TransDataSrc2Dst<int64_t, int32_t>}, | |||||
| {kTransferWithDatatypeInt32ToInt64, TransDataSrc2Dst<int32_t, int64_t>}, | |||||
| {kTransferWithDatatypeInt32ToDouble, TransDataSrc2Dst<int32_t, double>}, | |||||
| {kTransferWithDatatypeDoubleToInt32, TransDataSrc2Dst<double, int32_t>}, | |||||
| }; | |||||
| auto it = transfer_handle.find(trans_mode); | |||||
| if (it == transfer_handle.end()) { | |||||
| return UNSUPPORTED; | |||||
| } else { | |||||
| return (it->second)(args, dst, data_size); | |||||
| } | } | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| @@ -36,7 +36,9 @@ GE_COMMON_LOCAL_SRC_FILES := \ | |||||
| properties_manager.cc \ | properties_manager.cc \ | ||||
| types.cc\ | types.cc\ | ||||
| model_parser/base.cc \ | model_parser/base.cc \ | ||||
| kernel_store.cc \ | |||||
| tbe_kernel_store.cc \ | tbe_kernel_store.cc \ | ||||
| cust_aicpu_kernel_store.cc \ | |||||
| op/attr_value_util.cc \ | op/attr_value_util.cc \ | ||||
| op/ge_op_utils.cc \ | op/ge_op_utils.cc \ | ||||
| thread_pool.cc \ | thread_pool.cc \ | ||||
| @@ -310,7 +310,7 @@ Status ModelCacheHelper::GetNodesNeedRecompile(ComputeGraphPtr &graph, vector<No | |||||
| string kernel_lib_name = op_desc->GetOpKernelLibName(); | string kernel_lib_name = op_desc->GetOpKernelLibName(); | ||||
| if (kernel_lib_name.empty()) { | if (kernel_lib_name.empty()) { | ||||
| // reset op kernel lib | // reset op kernel lib | ||||
| (void)instance->DNNEngineManagerObj().GetDNNEngineName(op_desc); | |||||
| (void)instance->DNNEngineManagerObj().GetDNNEngineName(node); | |||||
| kernel_lib_name = op_desc->GetOpKernelLibName(); | kernel_lib_name = op_desc->GetOpKernelLibName(); | ||||
| if (kernel_lib_name.empty()) { | if (kernel_lib_name.empty()) { | ||||
| GELOGW("Get node:%s, type:%s kernel lib failed.", node->GetName().c_str(), op_desc->GetType().c_str()); | GELOGW("Get node:%s, type:%s kernel lib failed.", node->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| @@ -41,6 +41,7 @@ Status ModelHelper::SaveModelPartition(std::shared_ptr<OmFileSaveHelper> &om_fil | |||||
| const uint8_t *data, size_t size) { | const uint8_t *data, size_t size) { | ||||
| if (size < 1 || size > UINT32_MAX) { | if (size < 1 || size > UINT32_MAX) { | ||||
| GELOGE(PARAM_INVALID, "Add model partition failed, partition size %zu invalid", size); | GELOGE(PARAM_INVALID, "Add model partition failed, partition size %zu invalid", size); | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19022"); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| if (data == nullptr) { | if (data == nullptr) { | ||||
| @@ -101,16 +102,22 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::SaveToOmMod | |||||
| TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore(); | TBEKernelStore tbe_kernel_store = ge_model->GetTBEKernelStore(); | ||||
| GELOGI("TBE_KERNELS size is %zu", tbe_kernel_store.DataSize()); | GELOGI("TBE_KERNELS size is %zu", tbe_kernel_store.DataSize()); | ||||
| if (tbe_kernel_store.DataSize() > 0) { | if (tbe_kernel_store.DataSize() > 0) { | ||||
| if (SaveModelPartition(om_file_save_helper, ModelPartitionType::TBE_KERNELS, tbe_kernel_store.Data(), | |||||
| tbe_kernel_store.DataSize()) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "Add tbe kernel partition failed"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, ModelPartitionType::TBE_KERNELS, tbe_kernel_store.Data(), | |||||
| tbe_kernel_store.DataSize()), | |||||
| "Add tbe kernel partition failed"); | |||||
| } | } | ||||
| // no need to check value, DATA->NetOutput | // no need to check value, DATA->NetOutput | ||||
| (void)tbe_kernel_store.Load(tbe_kernel_store.Data(), tbe_kernel_store.DataSize()); | (void)tbe_kernel_store.Load(tbe_kernel_store.Data(), tbe_kernel_store.DataSize()); | ||||
| CustAICPUKernelStore cust_aicpu_kernel_store = ge_model->GetCustAICPUKernelStore(); | |||||
| GELOGI("cust aicpu kernels size is %zu", cust_aicpu_kernel_store.DataSize()); | |||||
| if (cust_aicpu_kernel_store.DataSize() > 0) { | |||||
| GE_CHK_STATUS_RET(SaveModelPartition(om_file_save_helper, ModelPartitionType::CUST_AICPU_KERNELS, | |||||
| cust_aicpu_kernel_store.Data(), cust_aicpu_kernel_store.DataSize()), | |||||
| "Add cust aicpu kernel partition failed"); | |||||
| } | |||||
| std::shared_ptr<ModelTaskDef> model_task_def = ge_model->GetModelTaskDefPtr(); | std::shared_ptr<ModelTaskDef> model_task_def = ge_model->GetModelTaskDefPtr(); | ||||
| if (model_task_def == nullptr) { | if (model_task_def == nullptr) { | ||||
| GELOGE(MEMALLOC_FAILED, "Create model task def ptr failed"); | GELOGE(MEMALLOC_FAILED, "Create model task def ptr failed"); | ||||
| @@ -308,6 +315,10 @@ Status ModelHelper::GenerateGeModel(OmFileLoadHelper &om_load_helper) { | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| return GE_EXEC_LOAD_KERNEL_PARTITION_FAILED; | return GE_EXEC_LOAD_KERNEL_PARTITION_FAILED; | ||||
| } | } | ||||
| ret = LoadCustAICPUKernelStore(om_load_helper); | |||||
| if (ret != SUCCESS) { | |||||
| return GE_EXEC_LOAD_KERNEL_PARTITION_FAILED; | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -384,6 +395,22 @@ Status ModelHelper::LoadTBEKernelStore(OmFileLoadHelper &om_load_helper) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ModelHelper::LoadCustAICPUKernelStore(OmFileLoadHelper &om_load_helper) { | |||||
| // Load cust aicpu kernels | |||||
| ModelPartition partition_kernel_def; | |||||
| CustAICPUKernelStore kernel_store; | |||||
| if (om_load_helper.GetModelPartition(ModelPartitionType::CUST_AICPU_KERNELS, partition_kernel_def) == SUCCESS) { | |||||
| GELOGI("Kernels partition size:%u", partition_kernel_def.size); | |||||
| if (kernel_store.Load(partition_kernel_def.data, partition_kernel_def.size)) { | |||||
| GELOGI("Load cust aicpu kernels success"); | |||||
| } else { | |||||
| GELOGW("Load cust aicpu kernels failed"); | |||||
| } | |||||
| } | |||||
| model_->SetCustAICPUKernelStore(kernel_store); | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeModel() { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY GeModelPtr ModelHelper::GetGeModel() { | ||||
| if (model_ != nullptr) { | if (model_ != nullptr) { | ||||
| return model_; | return model_; | ||||
| @@ -27,6 +27,9 @@ | |||||
| using std::string; | using std::string; | ||||
| namespace { | |||||
| const int32_t kOptionalNum = 2; | |||||
| } | |||||
| namespace ge { | namespace ge { | ||||
| // For Load | // For Load | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(const ge::ModelData &model) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::Init(const ge::ModelData &model) { | ||||
| @@ -67,7 +70,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status OmFileLoadHelper::GetMod | |||||
| } | } | ||||
| if (!found) { | if (!found) { | ||||
| if (type != ModelPartitionType::TBE_KERNELS && type != ModelPartitionType::WEIGHTS_DATA) { | |||||
| if (type != ModelPartitionType::TBE_KERNELS && type != ModelPartitionType::WEIGHTS_DATA && | |||||
| type != ModelPartitionType::CUST_AICPU_KERNELS) { | |||||
| GELOGE(FAILED, "GetModelPartition:type:%d is not in partition_datas!", static_cast<int>(type)); | GELOGE(FAILED, "GetModelPartition:type:%d is not in partition_datas!", static_cast<int>(type)); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -114,7 +118,7 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, const uint | |||||
| // Davinici model partition include graph-info weight-info task-info tbe-kernel : | // Davinici model partition include graph-info weight-info task-info tbe-kernel : | ||||
| // Original model partition include graph-info | // Original model partition include graph-info | ||||
| if ((partition_table->num != PARTITION_SIZE) && (partition_table->num != (PARTITION_SIZE - 1)) && | if ((partition_table->num != PARTITION_SIZE) && (partition_table->num != (PARTITION_SIZE - 1)) && | ||||
| (partition_table->num != 1)) { | |||||
| (partition_table->num != (PARTITION_SIZE - kOptionalNum)) && (partition_table->num != 1)) { | |||||
| GELOGE(GE_EXEC_MODEL_PARTITION_NUM_INVALID, "Invalid partition_table->num:%u", partition_table->num); | GELOGE(GE_EXEC_MODEL_PARTITION_NUM_INVALID, "Invalid partition_table->num:%u", partition_table->num); | ||||
| return GE_EXEC_MODEL_PARTITION_NUM_INVALID; | return GE_EXEC_MODEL_PARTITION_NUM_INVALID; | ||||
| } | } | ||||
| @@ -0,0 +1,118 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "common/kernel_store.h" | |||||
| namespace ge { | |||||
| void KernelStore::AddKernel(const KernelBinPtr &kernel) { | |||||
| if (kernel != nullptr) { | |||||
| kernels_[kernel->GetName()] = kernel; | |||||
| } | |||||
| } | |||||
| bool KernelStore::Build() { | |||||
| buffer_.clear(); | |||||
| size_t total_len = 0; | |||||
| for (const auto &item : kernels_) { | |||||
| auto kernel = item.second; | |||||
| total_len += sizeof(KernelStoreItemHead); | |||||
| total_len += kernel->GetName().length(); | |||||
| total_len += kernel->GetBinDataSize(); | |||||
| } | |||||
| try { | |||||
| buffer_.resize(total_len); | |||||
| } catch (std::bad_alloc &e) { | |||||
| GELOGE(ge::MEMALLOC_FAILED, "All build memory failed, memory size %zu", total_len); | |||||
| return false; | |||||
| } | |||||
| uint8_t *next_buffer = buffer_.data(); | |||||
| size_t remain_len = total_len; | |||||
| errno_t mem_ret; | |||||
| for (const auto &item : kernels_) { | |||||
| auto kernel = item.second; | |||||
| KernelStoreItemHead kernel_head{}; | |||||
| kernel_head.magic = kKernelItemMagic; | |||||
| kernel_head.name_len = static_cast<uint32_t>(kernel->GetName().length()); | |||||
| kernel_head.bin_len = static_cast<uint32_t>(kernel->GetBinDataSize()); | |||||
| GELOGI("get kernel bin name %s, addr %p, size %u", kernel->GetName().c_str(), kernel->GetBinData(), | |||||
| kernel->GetBinDataSize()); | |||||
| mem_ret = memcpy_s(next_buffer, remain_len, &kernel_head, sizeof(kernel_head)); | |||||
| GE_CHK_BOOL_EXEC_NOLOG(mem_ret == EOK, return false); | |||||
| next_buffer += sizeof(kernel_head); | |||||
| mem_ret = memcpy_s(next_buffer, remain_len - sizeof(kernel_head), kernel->GetName().data(), kernel_head.name_len); | |||||
| GE_CHK_BOOL_EXEC_NOLOG(mem_ret == EOK, return false); | |||||
| next_buffer += kernel_head.name_len; | |||||
| mem_ret = memcpy_s(next_buffer, remain_len - sizeof(kernel_head) - kernel_head.name_len, kernel->GetBinData(), | |||||
| kernel_head.bin_len); | |||||
| GE_CHK_BOOL_EXEC_NOLOG(mem_ret == EOK, return false); | |||||
| next_buffer += kernel_head.bin_len; | |||||
| remain_len = remain_len - sizeof(kernel_head) - kernel_head.name_len - kernel_head.bin_len; | |||||
| } | |||||
| kernels_.clear(); | |||||
| return true; | |||||
| } | |||||
| const uint8_t *KernelStore::Data() const { return buffer_.data(); } | |||||
| size_t KernelStore::DataSize() const { return buffer_.size(); } | |||||
| bool KernelStore::Load(const uint8_t *data, const size_t &len) { | |||||
| if (data == nullptr || len == 0) { | |||||
| return false; | |||||
| } | |||||
| size_t buffer_len = len; | |||||
| while (buffer_len > sizeof(KernelStoreItemHead)) { | |||||
| const char *next_buffer = reinterpret_cast<const char *>(data) + (len - buffer_len); | |||||
| const auto *kernel_head = reinterpret_cast<const KernelStoreItemHead *>(next_buffer); | |||||
| if (buffer_len < kernel_head->name_len + kernel_head->bin_len + sizeof(KernelStoreItemHead)) { | |||||
| GELOGW("Invalid kernel block remain buffer len %zu, name len %u, bin len %u", buffer_len, kernel_head->name_len, | |||||
| kernel_head->bin_len); | |||||
| break; | |||||
| } | |||||
| next_buffer += sizeof(KernelStoreItemHead); | |||||
| std::string name(next_buffer, kernel_head->name_len); | |||||
| next_buffer += kernel_head->name_len; | |||||
| GELOGI("Load kernel from om:%s,%u,%u", name.c_str(), kernel_head->name_len, kernel_head->bin_len); | |||||
| std::vector<char> kernel_bin(next_buffer, next_buffer + kernel_head->bin_len); | |||||
| KernelBinPtr teb_kernel_ptr = ge::MakeShared<KernelBin>(name, std::move(kernel_bin)); | |||||
| if (teb_kernel_ptr != nullptr) { | |||||
| kernels_.emplace(name, teb_kernel_ptr); | |||||
| } | |||||
| buffer_len -= sizeof(KernelStoreItemHead) + kernel_head->name_len + kernel_head->bin_len; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| KernelBinPtr KernelStore::FindKernel(const std::string &name) const { | |||||
| auto it = kernels_.find(name); | |||||
| if (it != kernels_.end()) { | |||||
| return it->second; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,70 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_KERNEL_STORE_H_ | |||||
| #define GE_COMMON_KERNEL_STORE_H_ | |||||
| #include <cstdint> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include <securec.h> | |||||
| #include <utility> | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| #include "framework/common/fmk_types.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/op_kernel_bin.h" | |||||
| namespace ge { | |||||
| using KernelBin = ge::OpKernelBin; | |||||
| using KernelBinPtr = std::shared_ptr<ge::OpKernelBin>; | |||||
| using CustAICPUKernel = ge::OpKernelBin; | |||||
| using CustAICPUKernelPtr = std::shared_ptr<ge::OpKernelBin>; | |||||
| using TBEKernel = ge::OpKernelBin; | |||||
| using TBEKernelPtr = std::shared_ptr<ge::OpKernelBin>; | |||||
| const uint32_t kKernelItemMagic = 0x5d776efd; | |||||
| struct KernelStoreItemHead { | |||||
| uint32_t magic; | |||||
| uint32_t name_len; | |||||
| uint32_t bin_len; | |||||
| }; | |||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY KernelStore { | |||||
| public: | |||||
| KernelStore() = default; | |||||
| virtual ~KernelStore() = default; | |||||
| virtual bool Build(); | |||||
| virtual bool Load(const uint8_t *data, const size_t &len); | |||||
| virtual const uint8_t *Data() const; | |||||
| virtual size_t DataSize() const; | |||||
| virtual void AddKernel(const KernelBinPtr &kernel); | |||||
| virtual KernelBinPtr FindKernel(const std::string &name) const; | |||||
| private: | |||||
| std::unordered_map<std::string, KernelBinPtr> kernels_; | |||||
| std::vector<uint8_t> buffer_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_KERNEL_STORE_H_ | |||||
| @@ -612,295 +612,268 @@ inline Status CheckInt32DivOverflow(int32_t a, int32_t b) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| #define FMK_INT_ADDCHECK(a, b) \ | |||||
| if (ge::CheckIntAddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "Int %d and %d addition can result in overflow!", static_cast<int>(a), \ | |||||
| static_cast<int>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | |||||
| #define FMK_INT8_ADDCHECK(a, b) \ | |||||
| if (ge::CheckInt8AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "Int8 %d and %d addition can result in overflow!", static_cast<int8_t>(a), \ | |||||
| static_cast<int8_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | |||||
| #define FMK_INT16_ADDCHECK(a, b) \ | |||||
| if (ge::CheckInt16AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "Int16 %d and %d addition can result in overflow!", static_cast<int16_t>(a), \ | |||||
| static_cast<int16_t>(b)); \ | |||||
| #define FMK_INT_ADDCHECK(a, b) \ | |||||
| if (ge::CheckIntAddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int %d and %d addition can result in overflow!", static_cast<int>(a), static_cast<int>(b)); \ | |||||
| return INTERNAL_ERROR; \ | return INTERNAL_ERROR; \ | ||||
| } | } | ||||
| #define FMK_INT32_ADDCHECK(a, b) \ | |||||
| if (ge::CheckInt32AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "Int32 %d and %d addition can result in overflow!", static_cast<int32_t>(a), \ | |||||
| static_cast<int32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT8_ADDCHECK(a, b) \ | |||||
| if (ge::CheckInt8AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int8 %d and %d addition can result in overflow!", static_cast<int8_t>(a), static_cast<int8_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_INT64_ADDCHECK(a, b) \ | |||||
| if (ge::CheckInt64AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "Int64 %ld and %ld addition can result in overflow!", static_cast<int64_t>(a), \ | |||||
| static_cast<int64_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT16_ADDCHECK(a, b) \ | |||||
| if (ge::CheckInt16AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int16 %d and %d addition can result in overflow!", static_cast<int16_t>(a), static_cast<int16_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_UINT8_ADDCHECK(a, b) \ | |||||
| if (ge::CheckUint8AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "UINT8 %u and %u addition can result in overflow!", static_cast<uint8_t>(a), \ | |||||
| static_cast<uint8_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT32_ADDCHECK(a, b) \ | |||||
| if (ge::CheckInt32AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int32 %d and %d addition can result in overflow!", static_cast<int32_t>(a), static_cast<int32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_UINT16_ADDCHECK(a, b) \ | |||||
| if (ge::CheckUint16AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "UINT16 %u and %u addition can result in overflow!", static_cast<uint16_t>(a), \ | |||||
| static_cast<uint16_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT64_ADDCHECK(a, b) \ | |||||
| if (ge::CheckInt64AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int64 %ld and %ld addition can result in overflow!", static_cast<int64_t>(a), static_cast<int64_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_UINT32_ADDCHECK(a, b) \ | |||||
| if (ge::CheckUint32AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "UINT32 %u and %u addition can result in overflow!", static_cast<uint32_t>(a), \ | |||||
| static_cast<uint32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_UINT8_ADDCHECK(a, b) \ | |||||
| if (ge::CheckUint8AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Uint8 %u and %u addition can result in overflow!", static_cast<uint8_t>(a), static_cast<uint8_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_UINT64_ADDCHECK(a, b) \ | |||||
| if (ge::CheckUint64AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "UINT64 %lu and %lu addition can result in overflow!", static_cast<uint64_t>(a), \ | |||||
| static_cast<uint64_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_UINT16_ADDCHECK(a, b) \ | |||||
| if (ge::CheckUint16AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("UINT16 %u and %u addition can result in overflow!", static_cast<uint16_t>(a), static_cast<uint16_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_FP16_ADDCHECK(a, b) \ | |||||
| if (ge::CheckFp16AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "fp16 %f and %f addition can result in overflow!", static_cast<float>(a), \ | |||||
| static_cast<float>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_UINT32_ADDCHECK(a, b) \ | |||||
| if (ge::CheckUint32AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Uint32 %u and %u addition can result in overflow!", static_cast<uint32_t>(a), static_cast<uint32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_FLOAT_ADDCHECK(a, b) \ | |||||
| if (ge::CheckFloatAddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "float %f and %f addition can result in overflow!", static_cast<float>(a), \ | |||||
| static_cast<float>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_UINT64_ADDCHECK(a, b) \ | |||||
| if (ge::CheckUint64AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Uint64 %lu and %lu addition can result in overflow!", static_cast<uint64_t>(a), static_cast<uint64_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_DOUBLE_ADDCHECK(a, b) \ | |||||
| if (ge::CheckDoubleAddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "double %lf and %lf addition can result in overflow!", static_cast<double>(a), \ | |||||
| static_cast<double>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_FP16_ADDCHECK(a, b) \ | |||||
| if (ge::CheckFp16AddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Fp16 %f and %f addition can result in overflow!", static_cast<float>(a), static_cast<float>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_INT_SUBCHECK(a, b) \ | |||||
| if (ge::CheckIntSubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "INT %d and %d subtraction can result in overflow!", static_cast<int>(a), \ | |||||
| static_cast<int>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_FLOAT_ADDCHECK(a, b) \ | |||||
| if (ge::CheckFloatAddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Float %f and %f addition can result in overflow!", static_cast<float>(a), static_cast<float>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_INT8_SUBCHECK(a, b) \ | |||||
| if (ge::CheckInt8SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "INT8 %d and %d subtraction can result in overflow!", static_cast<int8_t>(a), \ | |||||
| static_cast<int8_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_DOUBLE_ADDCHECK(a, b) \ | |||||
| if (ge::CheckDoubleAddOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Double %lf and %lf addition can result in overflow!", static_cast<double>(a), static_cast<double>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_INT16_SUBCHECK(a, b) \ | |||||
| if (ge::CheckInt16SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "INT16 %d and %d subtraction can result in overflow!", static_cast<int16_t>(a), \ | |||||
| static_cast<int16_t>(b)); \ | |||||
| #define FMK_INT_SUBCHECK(a, b) \ | |||||
| if (ge::CheckIntSubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int %d and %d subtraction can result in overflow!", static_cast<int>(a), static_cast<int>(b)); \ | |||||
| return INTERNAL_ERROR; \ | return INTERNAL_ERROR; \ | ||||
| } | } | ||||
| #define FMK_INT32_SUBCHECK(a, b) \ | |||||
| if (ge::CheckInt32SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "INT32 %d and %d subtraction can result in overflow!", static_cast<int32_t>(a), \ | |||||
| static_cast<int32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT8_SUBCHECK(a, b) \ | |||||
| if (ge::CheckInt8SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int8 %d and %d subtraction can result in overflow!", static_cast<int8_t>(a), static_cast<int8_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_INT64_SUBCHECK(a, b) \ | |||||
| if (ge::CheckInt64SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "INT64 %ld and %ld subtraction can result in overflow!", static_cast<int64_t>(a), \ | |||||
| static_cast<int64_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT16_SUBCHECK(a, b) \ | |||||
| if (ge::CheckInt16SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int16 %d and %d subtraction can result in overflow!", static_cast<int16_t>(a), static_cast<int16_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_UINT8_SUBCHECK(a, b) \ | |||||
| if (ge::CheckUint8SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "UINT8 %u and %u subtraction can result in overflow!", static_cast<uint8_t>(a), \ | |||||
| static_cast<uint8_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT32_SUBCHECK(a, b) \ | |||||
| if (ge::CheckInt32SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int32 %d and %d subtraction can result in overflow!", static_cast<int32_t>(a), static_cast<int32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_UINT16_SUBCHECK(a, b) \ | |||||
| if (ge::CheckUint16SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "UINT16 %u and %u subtraction can result in overflow!", static_cast<uint16_t>(a), \ | |||||
| static_cast<uint16_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT64_SUBCHECK(a, b) \ | |||||
| if (ge::CheckInt64SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int64 %ld and %ld subtraction can result in overflow!", static_cast<int64_t>(a), static_cast<int64_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_UINT32_SUBCHECK(a, b) \ | |||||
| if (ge::CheckUint32SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "UINT32 %u and %u subtraction can result in overflow!", static_cast<uint32_t>(a), \ | |||||
| static_cast<uint32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_UINT8_SUBCHECK(a, b) \ | |||||
| if (ge::CheckUint8SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Uint8 %u and %u subtraction can result in overflow!", static_cast<uint8_t>(a), static_cast<uint8_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_UINT64_SUBCHECK(a, b) \ | |||||
| if (ge::CheckUint64SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "UINT64 %lu and %lu subtraction can result in overflow!", static_cast<uint64_t>(a), \ | |||||
| static_cast<uint64_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_UINT16_SUBCHECK(a, b) \ | |||||
| if (ge::CheckUint16SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Uint16 %u and %u subtraction can result in overflow!", static_cast<uint16_t>(a), \ | |||||
| static_cast<uint16_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_FP16_SUBCHECK(a, b) \ | |||||
| if (ge::CheckFp16SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "fp16 %f and %f subtraction can result in overflow!", static_cast<float>(a), \ | |||||
| static_cast<float>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_UINT32_SUBCHECK(a, b) \ | |||||
| if (ge::CheckUint32SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Uint32 %u and %u subtraction can result in overflow!", static_cast<uint32_t>(a), \ | |||||
| static_cast<uint32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_FLOAT_SUBCHECK(a, b) \ | |||||
| if (ge::CheckFloatSubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "float %f and %f subtraction can result in overflow!", static_cast<float>(a), \ | |||||
| static_cast<float>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_UINT64_SUBCHECK(a, b) \ | |||||
| if (ge::CheckUint64SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Uint64 %lu and %lu subtraction can result in overflow!", static_cast<uint64_t>(a), \ | |||||
| static_cast<uint64_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_DOUBLE_SUBCHECK(a, b) \ | |||||
| if (ge::CheckDoubleSubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "double %lf and %lf subtraction can result in overflow!", static_cast<double>(a), \ | |||||
| static_cast<double>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_FP16_SUBCHECK(a, b) \ | |||||
| if (ge::CheckFp16SubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Fp16 %f and %f subtraction can result in overflow!", static_cast<float>(a), static_cast<float>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_INT_MULCHECK(a, b) \ | |||||
| if (ge::CheckIntMulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "INT %d and %d multiplication can result in overflow!", static_cast<int>(a), \ | |||||
| static_cast<int>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_FLOAT_SUBCHECK(a, b) \ | |||||
| if (ge::CheckFloatSubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Float %f and %f subtraction can result in overflow!", static_cast<float>(a), static_cast<float>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_INT8_MULCHECK(a, b) \ | |||||
| if (ge::CheckInt8MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "INT8 %d and %d multiplication can result in overflow!", static_cast<int8_t>(a), \ | |||||
| static_cast<int8_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_DOUBLE_SUBCHECK(a, b) \ | |||||
| if (ge::CheckDoubleSubOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Double %lf and %lf subtraction can result in overflow!", static_cast<double>(a), static_cast<double>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_INT16_MULCHECK(a, b) \ | |||||
| if (ge::CheckInt16MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "INT16 %d and %d multiplication can result in overflow!", static_cast<int16_t>(a), \ | |||||
| static_cast<int16_t>(b)); \ | |||||
| #define FMK_INT_MULCHECK(a, b) \ | |||||
| if (ge::CheckIntMulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int %d and %d multiplication can result in overflow!", static_cast<int>(a), static_cast<int>(b)); \ | |||||
| return INTERNAL_ERROR; \ | return INTERNAL_ERROR; \ | ||||
| } | } | ||||
| #define FMK_INT32_MULCHECK(a, b) \ | |||||
| if (ge::CheckInt32MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "INT32 %d and %d multiplication can result in overflow!", static_cast<int32_t>(a), \ | |||||
| static_cast<int32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT8_MULCHECK(a, b) \ | |||||
| if (ge::CheckInt8MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int8 %d and %d multiplication can result in overflow!", static_cast<int8_t>(a), static_cast<int8_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_INT64_MULCHECK(a, b) \ | |||||
| if (ge::Int64MulCheckOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "INT64 %ld and %ld multiplication can result in overflow!", static_cast<int64_t>(a), \ | |||||
| static_cast<int64_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT16_MULCHECK(a, b) \ | |||||
| if (ge::CheckInt16MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int16 %d and %d multiplication can result in overflow!", static_cast<int16_t>(a), \ | |||||
| static_cast<int16_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_UINT8_MULCHECK(a, b) \ | |||||
| if (ge::CheckUint8MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "UINT8 %u and %u multiplication can result in overflow!", static_cast<uint8_t>(a), \ | |||||
| static_cast<uint8_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT32_MULCHECK(a, b) \ | |||||
| if (ge::CheckInt32MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int32 %d and %d multiplication can result in overflow!", static_cast<int32_t>(a), \ | |||||
| static_cast<int32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_UINT16_MULCHECK(a, b) \ | |||||
| if (ge::CheckUint16MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "UINT16 %u and %u multiplication can result in overflow!", static_cast<uint16_t>(a), \ | |||||
| static_cast<uint16_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT64_MULCHECK(a, b) \ | |||||
| if (ge::Int64MulCheckOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int64 %ld and %ld multiplication can result in overflow!", static_cast<int64_t>(a), \ | |||||
| static_cast<int64_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_UINT32_MULCHECK(a, b) \ | |||||
| if (ge::CheckUint32MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "UINT32 %u and %u multiplication can result in overflow!", static_cast<uint32_t>(a), \ | |||||
| static_cast<uint32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_UINT8_MULCHECK(a, b) \ | |||||
| if (ge::CheckUint8MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Uint8 %u and %u multiplication can result in overflow!", static_cast<uint8_t>(a), \ | |||||
| static_cast<uint8_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_UINT64_MULCHECK(a, b) \ | |||||
| if (ge::CheckUint64MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "UINT64 %lu and %lu multiplication can result in overflow!", static_cast<uint64_t>(a), \ | |||||
| static_cast<uint64_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_UINT16_MULCHECK(a, b) \ | |||||
| if (ge::CheckUint16MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Uint16 %u and %u multiplication can result in overflow!", static_cast<uint16_t>(a), \ | |||||
| static_cast<uint16_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_FP16_MULCHECK(a, b) \ | |||||
| if (ge::CheckFp16MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "fp16 %f and %f multiplication can result in overflow!", static_cast<float>(a), \ | |||||
| static_cast<float>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_UINT32_MULCHECK(a, b) \ | |||||
| if (ge::CheckUint32MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Uint32 %u and %u multiplication can result in overflow!", static_cast<uint32_t>(a), \ | |||||
| static_cast<uint32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_FLOAT_MULCHECK(a, b) \ | |||||
| if (ge::CheckFloatMulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "float %f and %f multiplication can result in overflow!", static_cast<float>(a), \ | |||||
| static_cast<float>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_UINT64_MULCHECK(a, b) \ | |||||
| if (ge::CheckUint64MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Uint64 %lu and %lu multiplication can result in overflow!", static_cast<uint64_t>(a), \ | |||||
| static_cast<uint64_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_DOUBLE_MULCHECK(a, b) \ | |||||
| if (ge::CheckDoubleMulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "double %lf and %lf multiplication can result in overflow!", static_cast<double>(a), \ | |||||
| static_cast<double>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_FP16_MULCHECK(a, b) \ | |||||
| if (ge::CheckFp16MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Fp16 %f and %f multiplication can result in overflow!", static_cast<float>(a), static_cast<float>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_INT_DIVCHECK(a, b) \ | |||||
| if (CheckIntDivOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "INT %d and %d division can result in overflow!", static_cast<int>(a), \ | |||||
| static_cast<int>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_FLOAT_MULCHECK(a, b) \ | |||||
| if (ge::CheckFloatMulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Float %f and %f multiplication can result in overflow!", static_cast<float>(a), static_cast<float>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_INT32_DIVCHECK(a, b) \ | |||||
| if (CheckInt32DivOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "INT32 %d and %d division can result in overflow!", static_cast<int32_t>(a), \ | |||||
| static_cast<int32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_DOUBLE_MULCHECK(a, b) \ | |||||
| if (ge::CheckDoubleMulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Double %lf and %lf multiplication can result in overflow!", static_cast<double>(a), \ | |||||
| static_cast<double>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_INT64_UINT32_MULCHECK(a, b) \ | |||||
| if (ge::CheckInt64Uint32MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGE(INTERNAL_ERROR, "INT64 %ld and UINT32 %u multiplication can result in overflow!", static_cast<uint32_t>(a), \ | |||||
| static_cast<uint32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT_DIVCHECK(a, b) \ | |||||
| if (CheckIntDivOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int %d and %d division can result in overflow!", static_cast<int>(a), static_cast<int>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_FP16_ZEROCHECK(a) \ | |||||
| if (fabs(a) < DBL_EPSILON) { \ | |||||
| GELOGE(INTERNAL_ERROR, "fp16 %f can not be zero !", a); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT32_DIVCHECK(a, b) \ | |||||
| if (CheckInt32DivOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int32 %d and %d division can result in overflow!", static_cast<int32_t>(a), static_cast<int32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_FLOAT_ZEROCHECK(a) \ | |||||
| if (fabs(a) < FLT_EPSILON) { \ | |||||
| GELOGE(INTERNAL_ERROR, "float %f can not be zero !", a); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| #define FMK_INT64_UINT32_MULCHECK(a, b) \ | |||||
| if (ge::CheckInt64Uint32MulOverflow((a), (b)) != SUCCESS) { \ | |||||
| GELOGW("Int64 %ld and UINT32 %u multiplication can result in overflow!", static_cast<uint32_t>(a), \ | |||||
| static_cast<uint32_t>(b)); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | } | ||||
| #define FMK_DOUBLE_ZEROCHECK(a) \ | |||||
| if (fabs(a) < DBL_EPSILON) { \ | |||||
| GELOGE(INTERNAL_ERROR, "double %lf can not be zero !", a); \ | |||||
| #define FMK_FP16_ZEROCHECK(a) \ | |||||
| if (fabs(a) < DBL_EPSILON || a < 0) { \ | |||||
| GELOGW("Fp16 %f can not less than or equal to zero! ", a); \ | |||||
| return INTERNAL_ERROR; \ | return INTERNAL_ERROR; \ | ||||
| } | } | ||||
| #define FMK_FLOAT_ZEROCHECK(a) \ | |||||
| if (fabs(a) < FLT_EPSILON || a < 0) { \ | |||||
| GELOGW("Float %f can not less than or equal to zero! ", a); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | |||||
| #define FMK_DOUBLE_ZEROCHECK(a) \ | |||||
| if (fabs(a) < DBL_EPSILON || a < 0) { \ | |||||
| GELOGW("Double %lf can not less than or equal to zero! ", a); \ | |||||
| return INTERNAL_ERROR; \ | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_COMMON_MATH_MATH_UTIL_H_ | #endif // GE_COMMON_MATH_MATH_UTIL_H_ | ||||
| @@ -16,126 +16,19 @@ | |||||
| #include "common/tbe_kernel_store.h" | #include "common/tbe_kernel_store.h" | ||||
| #include <securec.h> | |||||
| #include <utility> | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/debug/log.h" | |||||
| namespace ge { | namespace ge { | ||||
| const uint32_t kKernelItemMagic = 0x5d776efd; | |||||
| struct KernelStoreItemHead { | |||||
| uint32_t magic; | |||||
| uint32_t name_len; | |||||
| uint32_t bin_len; | |||||
| }; | |||||
| TBEKernelStore::TBEKernelStore() {} | TBEKernelStore::TBEKernelStore() {} | ||||
| void TBEKernelStore::AddTBEKernel(const TBEKernelPtr &kernel) { | |||||
| if (kernel != nullptr) { | |||||
| kernels_[kernel->GetName()] = kernel; | |||||
| } | |||||
| } | |||||
| bool TBEKernelStore::Build() { | |||||
| buffer_.clear(); | |||||
| size_t total_len = 0; | |||||
| for (const auto &item : kernels_) { | |||||
| auto kernel = item.second; | |||||
| total_len += sizeof(KernelStoreItemHead); | |||||
| total_len += kernel->GetName().length(); | |||||
| total_len += kernel->GetBinDataSize(); | |||||
| } | |||||
| try { | |||||
| buffer_.resize(total_len); | |||||
| } catch (std::bad_alloc &e) { | |||||
| GELOGE(ge::MEMALLOC_FAILED, "All build memory failed, memory size %zu", total_len); | |||||
| return false; | |||||
| } | |||||
| uint8_t *next_buffer = buffer_.data(); | |||||
| size_t remain_len = total_len; | |||||
| errno_t mem_ret; | |||||
| for (const auto &item : kernels_) { | |||||
| auto kernel = item.second; | |||||
| KernelStoreItemHead kernel_head{}; | |||||
| kernel_head.magic = kKernelItemMagic; | |||||
| kernel_head.name_len = static_cast<uint32_t>(kernel->GetName().length()); | |||||
| kernel_head.bin_len = static_cast<uint32_t>(kernel->GetBinDataSize()); | |||||
| mem_ret = memcpy_s(next_buffer, remain_len, &kernel_head, sizeof(kernel_head)); | |||||
| GE_CHK_BOOL_EXEC_NOLOG(mem_ret == EOK, return false); | |||||
| next_buffer += sizeof(kernel_head); | |||||
| mem_ret = memcpy_s(next_buffer, remain_len - sizeof(kernel_head), kernel->GetName().data(), kernel_head.name_len); | |||||
| GE_CHK_BOOL_EXEC_NOLOG(mem_ret == EOK, return false); | |||||
| next_buffer += kernel_head.name_len; | |||||
| mem_ret = memcpy_s(next_buffer, remain_len - sizeof(kernel_head) - kernel_head.name_len, kernel->GetBinData(), | |||||
| kernel_head.bin_len); | |||||
| GE_CHK_BOOL_EXEC_NOLOG(mem_ret == EOK, return false); | |||||
| next_buffer += kernel_head.bin_len; | |||||
| remain_len = remain_len - sizeof(kernel_head) - kernel_head.name_len - kernel_head.bin_len; | |||||
| } | |||||
| kernels_.clear(); | |||||
| return true; | |||||
| } | |||||
| const uint8_t *TBEKernelStore::Data() const { return buffer_.data(); } | |||||
| size_t TBEKernelStore::DataSize() const { return buffer_.size(); } | |||||
| bool TBEKernelStore::Load(const uint8_t *data, const size_t &len) { | |||||
| if (data == nullptr || len == 0) { | |||||
| return false; | |||||
| } | |||||
| size_t buffer_len = len; | |||||
| while (buffer_len > sizeof(KernelStoreItemHead)) { | |||||
| const char *next_buffer = reinterpret_cast<const char *>(data) + (len - buffer_len); | |||||
| const auto *kernel_head = reinterpret_cast<const KernelStoreItemHead *>(next_buffer); | |||||
| if (buffer_len < kernel_head->name_len + kernel_head->bin_len + sizeof(KernelStoreItemHead)) { | |||||
| GELOGW("Invalid kernel block remain buffer len %zu, name len %u, bin len %u", buffer_len, kernel_head->name_len, | |||||
| kernel_head->bin_len); | |||||
| break; | |||||
| } | |||||
| next_buffer += sizeof(KernelStoreItemHead); | |||||
| std::string name(next_buffer, kernel_head->name_len); | |||||
| next_buffer += kernel_head->name_len; | |||||
| GELOGI("Load kernel from om:%s,%u,%u", name.c_str(), kernel_head->name_len, kernel_head->bin_len); | |||||
| std::vector<char> kernel_bin(next_buffer, next_buffer + kernel_head->bin_len); | |||||
| TBEKernelPtr teb_kernel_ptr = ge::MakeShared<TBEKernel>(name, std::move(kernel_bin)); | |||||
| if (teb_kernel_ptr != nullptr) { | |||||
| kernels_.emplace(name, teb_kernel_ptr); | |||||
| } | |||||
| buffer_len -= sizeof(KernelStoreItemHead) + kernel_head->name_len + kernel_head->bin_len; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| TBEKernelPtr TBEKernelStore::FindTBEKernel(const std::string &name) const { | |||||
| auto it = kernels_.find(name); | |||||
| if (it != kernels_.end()) { | |||||
| return it->second; | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| void TBEKernelStore::AddTBEKernel(const TBEKernelPtr &kernel) { AddKernel(kernel); } | |||||
| void TBEKernelStore::LoadTBEKernelBinToOpDesc(const std::shared_ptr<ge::OpDesc> &op_desc) const { | void TBEKernelStore::LoadTBEKernelBinToOpDesc(const std::shared_ptr<ge::OpDesc> &op_desc) const { | ||||
| if (op_desc != nullptr) { | if (op_desc != nullptr) { | ||||
| auto tbe_kernel = FindTBEKernel(op_desc->GetName()); | |||||
| if (tbe_kernel != nullptr) { | |||||
| GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel), | |||||
| GELOGW("LoadTBEKernelBinToOpDesc: SetExtAttr for tbe_kernel failed");) | |||||
| GELOGI("Load tbe kernel:%s, %zu", tbe_kernel->GetName().c_str(), tbe_kernel->GetBinDataSize()); | |||||
| auto kernel_bin = FindKernel(op_desc->GetName()); | |||||
| if (kernel_bin != nullptr) { | |||||
| GE_IF_BOOL_EXEC(!op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, kernel_bin), | |||||
| GELOGW("LoadKernelTBEBinToOpDesc: SetExtAttr for kernel_bin failed");) | |||||
| GELOGI("Load tbe kernel:%s, %zu", kernel_bin->GetName().c_str(), kernel_bin->GetBinDataSize()); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -17,38 +17,17 @@ | |||||
| #ifndef GE_COMMON_TBE_KERNEL_STORE_H_ | #ifndef GE_COMMON_TBE_KERNEL_STORE_H_ | ||||
| #define GE_COMMON_TBE_KERNEL_STORE_H_ | #define GE_COMMON_TBE_KERNEL_STORE_H_ | ||||
| #include <cstdint> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include <unordered_map> | |||||
| #include <vector> | |||||
| #include "framework/common/fmk_types.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/op_kernel_bin.h" | |||||
| #include "common/kernel_store.h" | |||||
| namespace ge { | namespace ge { | ||||
| using TBEKernel = ge::OpKernelBin; | |||||
| using TBEKernelPtr = std::shared_ptr<ge::OpKernelBin>; | |||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY TBEKernelStore { | |||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY TBEKernelStore : public KernelStore { | |||||
| public: | public: | ||||
| TBEKernelStore(); | TBEKernelStore(); | ||||
| ~TBEKernelStore() = default; | |||||
| ~TBEKernelStore() {} | |||||
| void AddTBEKernel(const TBEKernelPtr &kernel); | void AddTBEKernel(const TBEKernelPtr &kernel); | ||||
| bool Build(); | |||||
| bool Load(const uint8_t *data, const size_t &len); | |||||
| TBEKernelPtr FindTBEKernel(const std::string &name) const; | |||||
| void LoadTBEKernelBinToOpDesc(const std::shared_ptr<ge::OpDesc> &op_desc) const; | void LoadTBEKernelBinToOpDesc(const std::shared_ptr<ge::OpDesc> &op_desc) const; | ||||
| const uint8_t *Data() const; | |||||
| size_t DataSize() const; | |||||
| private: | |||||
| std::unordered_map<std::string, TBEKernelPtr> kernels_; | |||||
| std::vector<uint8_t> buffer_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -26,7 +26,10 @@ | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "analyzer/analyzer.h" | |||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "init/gelib.h" | #include "init/gelib.h" | ||||
| namespace { | namespace { | ||||
| @@ -164,11 +167,22 @@ bool DNNEngineManager::IsEngineRegistered(const std::string &name) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| void DNNEngineManager::InitPerformanceStaistic() { checksupport_cost_.clear(); } | |||||
| void DNNEngineManager::InitPerformanceStaistic() { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| checksupport_cost_.clear(); | |||||
| } | |||||
| const map<string, uint64_t> &DNNEngineManager::GetCheckSupportCost() const { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| return checksupport_cost_; | |||||
| } | |||||
| const map<string, uint64_t> &DNNEngineManager::GetCheckSupportCost() const { return checksupport_cost_; } | |||||
| std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | |||||
| std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { | |||||
| GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GE_CLI_GE_NOT_INITIALIZED, "DNNEngineManager: node_ptr is nullptr"); | |||||
| return ""); | |||||
| auto op_desc = node_ptr->GetOpDesc(); | |||||
| GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GE_CLI_GE_NOT_INITIALIZED, "DNNEngineManager: op_desc is nullptr"); | GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(GE_CLI_GE_NOT_INITIALIZED, "DNNEngineManager: op_desc is nullptr"); | ||||
| return ""); | return ""); | ||||
| // Use the OpsKernelManager in GELib to get the opInfos for this opCode | // Use the OpsKernelManager in GELib to get the opInfos for this opCode | ||||
| @@ -190,6 +204,7 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { | |||||
| std::string exclude_core_Type = (ge_core_type == kVectorCore) ? kAIcoreEngine : kVectorEngine; | std::string exclude_core_Type = (ge_core_type == kVectorCore) ? kAIcoreEngine : kVectorEngine; | ||||
| GELOGD("engine type will exclude: %s", exclude_core_Type.c_str()); | GELOGD("engine type will exclude: %s", exclude_core_Type.c_str()); | ||||
| auto root_graph = ge::GraphUtils::FindRootGraph(node_ptr->GetOwnerComputeGraph()); | |||||
| std::map<std::string, std::string> unsupported_reasons; | std::map<std::string, std::string> unsupported_reasons; | ||||
| for (const auto &it : op_infos) { | for (const auto &it : op_infos) { | ||||
| if (it.engine == exclude_core_Type) { | if (it.engine == exclude_core_Type) { | ||||
| @@ -206,6 +221,9 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { | |||||
| checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time; | checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time; | ||||
| op_desc->SetOpEngineName(it.engine); | op_desc->SetOpEngineName(it.engine); | ||||
| op_desc->SetOpKernelLibName(kernel_name); | op_desc->SetOpKernelLibName(kernel_name); | ||||
| // set attrs for taking information when load txt to graph object | |||||
| (void)AttrUtils::SetStr(op_desc, ATTR_NAME_ENGINE_NAME_FOR_LX, it.engine); | |||||
| (void)AttrUtils::SetStr(op_desc, ATTR_NAME_KKERNEL_LIB_NAME_FOR_LX, kernel_name); | |||||
| GELOGD("DNNEngineManager:Set OpKernelLibName %s and engine name %s to op_desc %s", kernel_name.c_str(), | GELOGD("DNNEngineManager:Set OpKernelLibName %s and engine name %s to op_desc %s", kernel_name.c_str(), | ||||
| it.engine.c_str(), op_desc->GetName().c_str()); | it.engine.c_str(), op_desc->GetName().c_str()); | ||||
| return it.engine; | return it.engine; | ||||
| @@ -219,6 +237,9 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { | |||||
| "The custom operator registered by the user does not support the logic function delivered by this " | "The custom operator registered by the user does not support the logic function delivered by this " | ||||
| "network. Check support failed, kernel_name is %s, op type is %s, op name is %s", | "network. Check support failed, kernel_name is %s, op type is %s, op name is %s", | ||||
| kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str()); | kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str()); | ||||
| std::string error_info = | |||||
| "The custom operator registered by the user does not support the logic function" | |||||
| "delivered by this network"; | |||||
| return ""; | return ""; | ||||
| } | } | ||||
| unsupported_reasons.emplace(kernel_name, unsupported_reason); | unsupported_reasons.emplace(kernel_name, unsupported_reason); | ||||
| @@ -235,12 +256,22 @@ std::string DNNEngineManager::GetDNNEngineName(const OpDescPtr &op_desc) { | |||||
| kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str()); | kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str()); | ||||
| } | } | ||||
| } | } | ||||
| // concat unsupported reasons analyzed data selection | |||||
| string reason; | |||||
| for (const auto &it : unsupported_reasons) { | for (const auto &it : unsupported_reasons) { | ||||
| reason += it.first + ":" + it.second + ";"; | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E13002", {"optype", "opskernel", "reason"}, | ErrorManager::GetInstance().ATCReportErrMessage("E13002", {"optype", "opskernel", "reason"}, | ||||
| {op_desc->GetType(), it.first, it.second}); | {op_desc->GetType(), it.first, it.second}); | ||||
| GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "GetDNNEngineName:Op type %s of ops kernel %s is unsupported, reason:%s", | GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "GetDNNEngineName:Op type %s of ops kernel %s is unsupported, reason:%s", | ||||
| op_desc->GetType().c_str(), it.first.c_str(), it.second.c_str()); | op_desc->GetType().c_str(), it.first.c_str(), it.second.c_str()); | ||||
| } | } | ||||
| analyzer::DataInfo analyze_info{root_graph->GetSessionID(), root_graph->GetGraphID(), analyzer::CHECKSUPPORT, | |||||
| node_ptr, reason}; | |||||
| // do not change original process | |||||
| (void)Analyzer::GetInstance()->DoAnalyze(analyze_info); | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E13003", {"opname", "optype"}, | ErrorManager::GetInstance().ATCReportErrMessage("E13003", {"opname", "optype"}, | ||||
| {op_desc->GetName(), op_desc->GetType()}); | {op_desc->GetName(), op_desc->GetType()}); | ||||
| GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "Can't find any supported ops kernel and engine of %s, type is %s", | GELOGE(GE_GRAPH_ASSIGN_ENGINE_FAILED, "Can't find any supported ops kernel and engine of %s, type is %s", | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <mutex> | |||||
| #include "nlohmann/json.hpp" | #include "nlohmann/json.hpp" | ||||
| @@ -29,6 +30,7 @@ | |||||
| #include "common/opskernel/ops_kernel_info_types.h" | #include "common/opskernel/ops_kernel_info_types.h" | ||||
| #include "engine/dnnengine.h" | #include "engine/dnnengine.h" | ||||
| #include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
| #include "graph/node.h" | |||||
| using JsonHandle = void *; | using JsonHandle = void *; | ||||
| namespace ge { | namespace ge { | ||||
| @@ -61,7 +63,7 @@ class DNNEngineManager { | |||||
| std::shared_ptr<ge::DNNEngine> GetEngine(const std::string &name) const; | std::shared_ptr<ge::DNNEngine> GetEngine(const std::string &name) const; | ||||
| bool IsEngineRegistered(const std::string &name); | bool IsEngineRegistered(const std::string &name); | ||||
| // If can't find appropriate engine name, return "", report error | // If can't find appropriate engine name, return "", report error | ||||
| string GetDNNEngineName(const OpDescPtr &op_desc); | |||||
| string GetDNNEngineName(const ge::NodePtr &node_ptr); | |||||
| const map<string, SchedulerConf> &GetSchedulers() const; | const map<string, SchedulerConf> &GetSchedulers() const; | ||||
| const map<string, uint64_t> &GetCheckSupportCost() const; | const map<string, uint64_t> &GetCheckSupportCost() const; | ||||
| void InitPerformanceStaistic(); | void InitPerformanceStaistic(); | ||||
| @@ -83,6 +85,7 @@ class DNNEngineManager { | |||||
| std::map<string, SchedulerConf> schedulers_; | std::map<string, SchedulerConf> schedulers_; | ||||
| std::map<string, uint64_t> checksupport_cost_; | std::map<string, uint64_t> checksupport_cost_; | ||||
| bool init_flag_; | bool init_flag_; | ||||
| mutable std::mutex mutex_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -22,6 +22,7 @@ file(GLOB PROTO_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "../../proto/insert_op.proto" | "../../proto/insert_op.proto" | ||||
| "../../proto/op_mapping_info.proto" | "../../proto/op_mapping_info.proto" | ||||
| "../../proto/ge_ir.proto" | "../../proto/ge_ir.proto" | ||||
| "../proto/dump_task.proto" | |||||
| ) | ) | ||||
| file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | ||||
| @@ -68,6 +69,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "../graph/manager/graph_manager_utils.cc" | "../graph/manager/graph_manager_utils.cc" | ||||
| "../graph/manager/graph_mem_allocator.cc" | "../graph/manager/graph_mem_allocator.cc" | ||||
| "../graph/manager/graph_var_manager.cc" | "../graph/manager/graph_var_manager.cc" | ||||
| "../graph/manager/rdma_pool_allocator.cc" | |||||
| "../graph/manager/trans_var_data_utils.cc" | "../graph/manager/trans_var_data_utils.cc" | ||||
| "../graph/manager/util/debug.cc" | "../graph/manager/util/debug.cc" | ||||
| "../hybrid/hybrid_davinci_model_stub.cc" | "../hybrid/hybrid_davinci_model_stub.cc" | ||||
| @@ -344,47 +344,19 @@ Status GeExecutor::SetDynamicDims(uint32_t model_id, void *dynamic_input_addr, u | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| Status ret = GraphExecutor::SetDynamicSize(model_id, dynamic_dims, static_cast<int32_t>(DYNAMIC_DIMS)); | |||||
| vector<uint64_t> cur_dynamic_dims; | |||||
| Status ret = GetCurDynamicDims(model_id, dynamic_dims, cur_dynamic_dims); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(FAILED, "Set dynamic size failed"); | |||||
| GELOGE(FAILED, "Set cur gear dynmaic dims failed"); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| vector<uint64_t> cur_dynamic_dims; | |||||
| std::vector<ge::TensorDesc> input_desc; | |||||
| std::vector<ge::TensorDesc> output_desc; | |||||
| ret = GetModelDescInfo(model_id, input_desc, output_desc); | |||||
| if (ret != ge::SUCCESS) { | |||||
| GELOGE(FAILED, "GetModelDescInfo failed."); | |||||
| return FAILED; | |||||
| } | |||||
| vector<string> user_designate_shape_order; | |||||
| vector<int64_t> all_data_dims; | |||||
| ret = GetUserDesignateShapeOrder(model_id, user_designate_shape_order); | |||||
| if (ret != ge::SUCCESS) { | |||||
| GELOGE(FAILED, "GetUserDesignateShapeOrder failed."); | |||||
| return FAILED; | |||||
| } | |||||
| for (auto &data_name : user_designate_shape_order) { | |||||
| for (size_t j = 0; j < input_desc.size(); ++j) { | |||||
| if (input_desc.at(j).GetName() == data_name) { | |||||
| for (auto dim : input_desc.at(j).GetShape().GetDims()) { | |||||
| all_data_dims.push_back(dim); | |||||
| } | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (dynamic_dims.size() != all_data_dims.size()) { | |||||
| GELOGE(FAILED, "Dynamic input size [%lu] is not equal with all data dims size [%lu]!", dynamic_dims.size(), | |||||
| all_data_dims.size()); | |||||
| ret = GraphExecutor::SetDynamicSize(model_id, cur_dynamic_dims, static_cast<int32_t>(DYNAMIC_DIMS)); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "Set dynamic size failed"); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| for (std::size_t i = 0; i < all_data_dims.size(); ++i) { | |||||
| if (all_data_dims[i] < 0) { | |||||
| cur_dynamic_dims.push_back(dynamic_dims[i]); | |||||
| } | |||||
| } | |||||
| size_t dynamic_dim_num = cur_dynamic_dims.size(); | size_t dynamic_dim_num = cur_dynamic_dims.size(); | ||||
| uint64_t dynamic_input_size = static_cast<uint64_t>(dynamic_dim_num * sizeof(uint64_t)); | uint64_t dynamic_input_size = static_cast<uint64_t>(dynamic_dim_num * sizeof(uint64_t)); | ||||
| if (length < dynamic_input_size) { | if (length < dynamic_input_size) { | ||||
| @@ -403,58 +375,43 @@ Status GeExecutor::SetDynamicDims(uint32_t model_id, void *dynamic_input_addr, u | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GeExecutor::GetCurDynamicDims(uint32_t model_id, const vector<uint64_t> &combined_dims, | |||||
| Status GeExecutor::GetCurDynamicDims(uint32_t model_id, const vector<uint64_t> &dynamic_dims, | |||||
| vector<uint64_t> &cur_dynamic_dims) { | vector<uint64_t> &cur_dynamic_dims) { | ||||
| vector<vector<int64_t>> combined_batch; | |||||
| if (GraphExecutor::GetCombinedDynamicDims(model_id, combined_batch) != SUCCESS) { | |||||
| GELOGE(FAILED, "Get combined dynamic dims info failed."); | |||||
| return FAILED; | |||||
| } | |||||
| if (combined_batch.empty()) { | |||||
| GELOGE(FAILED, "Combined dynamic dims is empty."); | |||||
| cur_dynamic_dims.clear(); | |||||
| vector<ge::TensorDesc> input_desc; | |||||
| vector<ge::TensorDesc> output_desc; | |||||
| auto ret = GetModelDescInfo(model_id, input_desc, output_desc); | |||||
| if (ret != ge::SUCCESS) { | |||||
| GELOGE(FAILED, "GetModelDescInfo failed."); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (combined_dims.size() != combined_batch[0].size()) { | |||||
| GELOGE(FAILED, "Input dynamic dims's dimension size[%zu] is different from model[%zu].", combined_dims.size(), | |||||
| combined_batch[0].size()); | |||||
| vector<string> user_designate_shape_order; | |||||
| vector<int64_t> all_data_dims; | |||||
| ret = GetUserDesignateShapeOrder(model_id, user_designate_shape_order); | |||||
| if (ret != ge::SUCCESS) { | |||||
| GELOGE(FAILED, "GetUserDesignateShapeOrder failed."); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| bool matched = false; | |||||
| size_t idx = 0; | |||||
| for (size_t i = 0; i < combined_batch.size(); i++) { | |||||
| bool is_match = true; | |||||
| for (size_t j = 0; j < combined_dims.size(); j++) { | |||||
| if (combined_dims[j] != static_cast<uint64_t>(combined_batch[i][j])) { | |||||
| is_match = false; | |||||
| for (auto &data_name : user_designate_shape_order) { | |||||
| for (auto &desc : input_desc) { | |||||
| if (desc.GetName() == data_name) { | |||||
| for (auto dim : desc.GetShape().GetDims()) { | |||||
| all_data_dims.push_back(dim); | |||||
| } | |||||
| break; | break; | ||||
| } | } | ||||
| } | } | ||||
| if (is_match) { | |||||
| idx = i; | |||||
| matched = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!matched) { | |||||
| GELOGE(FAILED, "Input dynamic dims can not match model."); | |||||
| return FAILED; | |||||
| } | } | ||||
| // batch_info save the dynamic info of combined_dims | |||||
| vector<vector<int64_t>> batch_info; | |||||
| int32_t dynamic_type = static_cast<int32_t>(FIXED); | |||||
| if (GraphExecutor::GetDynamicBatchInfo(model_id, batch_info, dynamic_type) != SUCCESS) { | |||||
| GELOGE(FAILED, "Get dynamic input info failed."); | |||||
| if (dynamic_dims.size() != all_data_dims.size()) { | |||||
| GELOGE(FAILED, "Dynamic input size [%lu] is not equal with all data dims size [%lu]!", dynamic_dims.size(), | |||||
| all_data_dims.size()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| cur_dynamic_dims.clear(); | |||||
| for (size_t i = 0; i < batch_info[idx].size(); i++) { | |||||
| cur_dynamic_dims.emplace_back(static_cast<uint64_t>(batch_info[idx][i])); | |||||
| for (std::size_t i = 0; i < all_data_dims.size(); ++i) { | |||||
| if (all_data_dims[i] < 0) { | |||||
| cur_dynamic_dims.push_back(dynamic_dims[i]); | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -924,13 +881,6 @@ Status GeExecutor::ExecModel(uint32_t model_id, void *stream, const ge::RunModel | |||||
| GELOGE(ret, "Get dynamic input info failed."); | GELOGE(ret, "Get dynamic input info failed."); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (dynamic_type == static_cast<int32_t>(DYNAMIC_DIMS)) { | |||||
| ret = GraphExecutor::GetCombinedDynamicDims(model_id, batch_info); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(FAILED, "Get dynamic input info failed."); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| if (!batch_info.empty()) { | if (!batch_info.empty()) { | ||||
| SetDynamicInputDataFlag(run_input_data, batch_info, input_data); | SetDynamicInputDataFlag(run_input_data, batch_info, input_data); | ||||
| } | } | ||||
| @@ -13,6 +13,7 @@ local_ge_executor_src_files := \ | |||||
| ../omm/csa_interact.cc \ | ../omm/csa_interact.cc \ | ||||
| ../graph/manager/graph_manager_utils.cc \ | ../graph/manager/graph_manager_utils.cc \ | ||||
| ../graph/manager/graph_var_manager.cc \ | ../graph/manager/graph_var_manager.cc \ | ||||
| ../graph/manager/rdma_pool_allocator.cc \ | |||||
| ../graph/manager/graph_mem_allocator.cc \ | ../graph/manager/graph_mem_allocator.cc \ | ||||
| ../graph/manager/graph_caching_allocator.cc \ | ../graph/manager/graph_caching_allocator.cc \ | ||||
| ../graph/manager/trans_var_data_utils.cc \ | ../graph/manager/trans_var_data_utils.cc \ | ||||
| @@ -63,6 +64,7 @@ local_ge_executor_src_files := \ | |||||
| local_ge_executor_c_include := \ | local_ge_executor_c_include := \ | ||||
| proto/insert_op.proto \ | proto/insert_op.proto \ | ||||
| proto/op_mapping_info.proto \ | proto/op_mapping_info.proto \ | ||||
| proto/dump_task.proto \ | |||||
| proto/ge_ir.proto \ | proto/ge_ir.proto \ | ||||
| proto/task.proto \ | proto/task.proto \ | ||||
| proto/om.proto \ | proto/om.proto \ | ||||
| @@ -59,6 +59,7 @@ GRAPH_MANAGER_LOCAL_SRC_FILES := \ | |||||
| generator/ge_generator.cc \ | generator/ge_generator.cc \ | ||||
| generator/generator_api.cc \ | generator/generator_api.cc \ | ||||
| graph/manager/graph_var_manager.cc \ | graph/manager/graph_var_manager.cc \ | ||||
| graph/manager/rdma_pool_allocator.cc \ | |||||
| graph/manager/graph_mem_allocator.cc \ | graph/manager/graph_mem_allocator.cc \ | ||||
| graph/manager/graph_caching_allocator.cc \ | graph/manager/graph_caching_allocator.cc \ | ||||
| @@ -66,6 +67,9 @@ BUILER_SRC_FILES := \ | |||||
| ir_build/ge_ir_build.cc \ | ir_build/ge_ir_build.cc \ | ||||
| ir_build/atc_ir_common.cc \ | ir_build/atc_ir_common.cc \ | ||||
| ANALYZER_SRC_FILES:= \ | |||||
| analyzer/analyzer.cc \ | |||||
| OMG_HOST_SRC_FILES := \ | OMG_HOST_SRC_FILES := \ | ||||
| model/ge_model.cc \ | model/ge_model.cc \ | ||||
| model/ge_root_model.cc \ | model/ge_root_model.cc \ | ||||
| @@ -103,6 +107,7 @@ OMG_HOST_SRC_FILES := \ | |||||
| graph/passes/mark_graph_unknown_status_pass.cc \ | graph/passes/mark_graph_unknown_status_pass.cc \ | ||||
| graph/common/omg_util.cc \ | graph/common/omg_util.cc \ | ||||
| graph/common/bcast.cc \ | graph/common/bcast.cc \ | ||||
| graph/common/local_context.cc \ | |||||
| graph/passes/dimension_compute_pass.cc \ | graph/passes/dimension_compute_pass.cc \ | ||||
| graph/passes/dimension_adjust_pass.cc \ | graph/passes/dimension_adjust_pass.cc \ | ||||
| graph/passes/get_original_format_pass.cc \ | graph/passes/get_original_format_pass.cc \ | ||||
| @@ -260,6 +265,7 @@ COMMON_LOCAL_C_INCLUDES := \ | |||||
| proto/ge_ir.proto \ | proto/ge_ir.proto \ | ||||
| proto/fwk_adapter.proto \ | proto/fwk_adapter.proto \ | ||||
| proto/op_mapping_info.proto \ | proto/op_mapping_info.proto \ | ||||
| proto/dump_task.proto \ | |||||
| proto/tensorflow/attr_value.proto \ | proto/tensorflow/attr_value.proto \ | ||||
| proto/tensorflow/function.proto \ | proto/tensorflow/function.proto \ | ||||
| proto/tensorflow/graph.proto \ | proto/tensorflow/graph.proto \ | ||||
| @@ -284,6 +290,9 @@ COMMON_LOCAL_C_INCLUDES := \ | |||||
| third_party/protobuf/include \ | third_party/protobuf/include \ | ||||
| third_party/opencv/include \ | third_party/opencv/include \ | ||||
| ANALYZER_LOCAL_INCLUDES := \ | |||||
| $(TOPDIR)framework/domi/analyzer \ | |||||
| NEW_OMG_HOST_SRC_FILES := \ | NEW_OMG_HOST_SRC_FILES := \ | ||||
| graph/preprocess/insert_op/util_insert_aipp_op.cc \ | graph/preprocess/insert_op/util_insert_aipp_op.cc \ | ||||
| graph/preprocess/insert_op/ge_aipp_op.cc \ | graph/preprocess/insert_op/ge_aipp_op.cc \ | ||||
| @@ -348,6 +357,7 @@ LOCAL_CFLAGS += -g -O0 | |||||
| endif | endif | ||||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | ||||
| LOCAL_C_INCLUDES += $(ANALYZER_LOCAL_INCLUDES) | |||||
| LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) | LOCAL_SRC_FILES := $(COMMON_LOCAL_SRC_FILES) | ||||
| LOCAL_SRC_FILES += $(GRAPH_MANAGER_LOCAL_SRC_FILES) | LOCAL_SRC_FILES += $(GRAPH_MANAGER_LOCAL_SRC_FILES) | ||||
| @@ -355,6 +365,7 @@ LOCAL_SRC_FILES += $(OMG_HOST_SRC_FILES) | |||||
| LOCAL_SRC_FILES += $(OME_HOST_SRC_FILES) | LOCAL_SRC_FILES += $(OME_HOST_SRC_FILES) | ||||
| LOCAL_SRC_FILES += $(NEW_OME_DEVICE_SRC_FILES) | LOCAL_SRC_FILES += $(NEW_OME_DEVICE_SRC_FILES) | ||||
| LOCAL_SRC_FILES += $(BUILER_SRC_FILES) | LOCAL_SRC_FILES += $(BUILER_SRC_FILES) | ||||
| LOCAL_SRC_FILES += $(ANALYZER_SRC_FILES) | |||||
| LOCAL_STATIC_LIBRARIES := libge_memory \ | LOCAL_STATIC_LIBRARIES := libge_memory \ | ||||
| @@ -414,9 +425,11 @@ LOCAL_SRC_FILES += $(GRAPH_MANAGER_LOCAL_SRC_FILES) | |||||
| LOCAL_SRC_FILES += $(OMG_DEVICE_SRC_FILES) | LOCAL_SRC_FILES += $(OMG_DEVICE_SRC_FILES) | ||||
| LOCAL_SRC_FILES += $(OME_DEVICE_SRC_FILES) | LOCAL_SRC_FILES += $(OME_DEVICE_SRC_FILES) | ||||
| LOCAL_SRC_FILES += $(BUILER_SRC_FILES) | LOCAL_SRC_FILES += $(BUILER_SRC_FILES) | ||||
| LOCAL_SRC_FILES += $(ANALYZER_SRC_FILES) | |||||
| LOCAL_C_INCLUDES := $(DEVICE_LOCAL_C_INCLUDES) | LOCAL_C_INCLUDES := $(DEVICE_LOCAL_C_INCLUDES) | ||||
| LOCAL_C_INCLUDES += $(ANALYZER_LOCAL_INCLUDES) | |||||
| LOCAL_STATIC_LIBRARIES := libge_memory \ | LOCAL_STATIC_LIBRARIES := libge_memory \ | ||||
| @@ -19,10 +19,48 @@ | |||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
| #include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
| #include "mmpa/mmpa_api.h" | |||||
| #include "register/op_kernel_registry.h" | #include "register/op_kernel_registry.h" | ||||
| #include "register/host_cpu_context.h" | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "common/ge/plugin_manager.h" | #include "common/ge/plugin_manager.h" | ||||
| #include "graph/utils/type_utils.h" | |||||
| #include "common/fp16_t.h" | |||||
| namespace { | |||||
| #define CREATE_OUTPUT_CASE(DTYPE, TYPE) \ | |||||
| case (DTYPE): { \ | |||||
| GeTensorPtr ge_tensor = nullptr; \ | |||||
| if (need_create_flag) { \ | |||||
| int64_t data_num = out_desc.GetShape().IsScalar() ? 1 : out_desc.GetShape().GetShapeSize(); \ | |||||
| std::unique_ptr<TYPE[]> buf(new (std::nothrow) TYPE[data_num]()); \ | |||||
| if (buf == nullptr) { \ | |||||
| GELOGE(MEMALLOC_FAILED, "New sizeof(T) * data_num(%zu) memory failed", \ | |||||
| static_cast<size_t>(sizeof(TYPE) * data_num)); \ | |||||
| return MEMALLOC_FAILED; \ | |||||
| } \ | |||||
| ge_tensor = MakeShared<GeTensor>(out_desc); \ | |||||
| GE_CHECK_NOTNULL(ge_tensor); \ | |||||
| GELOGI("node:%s allocate output %zu, size=%lld", op_desc->GetName().c_str(), i, data_num * sizeof(TYPE)); \ | |||||
| ge_tensor->SetData(reinterpret_cast<uint8_t *>(buf.get()), data_num * sizeof(TYPE)); \ | |||||
| ge_tensor->MutableTensorDesc().SetDataType(out_desc.GetDataType()); \ | |||||
| ge_tensor->MutableTensorDesc().SetShape(out_desc.GetShape()); \ | |||||
| outputs.emplace_back(ge_tensor); \ | |||||
| } else { \ | |||||
| ge_tensor = outputs[i]; \ | |||||
| GE_CHECK_NOTNULL(ge_tensor); \ | |||||
| GELOGI("node:%s existed output %zu, addr=%p, size=%lld", op_desc->GetName().c_str(), i, \ | |||||
| reinterpret_cast<const uint8_t *>(ge_tensor->GetData().data()), ge_tensor->GetData().size()); \ | |||||
| } \ | |||||
| auto tensor = TensorAdapter::AsTensor(*ge_tensor); \ | |||||
| auto tensor_name = op_desc->GetOutputNameByIndex(i); \ | |||||
| GE_RETURN_WITH_LOG_IF_TRUE(tensor_name.empty(), "Failed to get output name. node = %s, index = %zu", \ | |||||
| op_desc->GetName().c_str(), i); \ | |||||
| GELOGD("Successfully inserted output tensor. node = %s, index = %zu, output name = %s, addr = %p, size = %zu", \ | |||||
| op_desc->GetName().c_str(), i, tensor_name.c_str(), tensor.GetData(), tensor.GetSize()); \ | |||||
| named_outputs.emplace(tensor_name, tensor); \ | |||||
| break; \ | |||||
| } | |||||
| } // namespace | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| @@ -105,17 +143,32 @@ Status HostCpuEngine::PrepareInputs(const ge::ConstOpDescPtr &op_desc, const vec | |||||
| Status HostCpuEngine::PrepareOutputs(const ge::ConstOpDescPtr &op_desc, vector<GeTensorPtr> &outputs, | Status HostCpuEngine::PrepareOutputs(const ge::ConstOpDescPtr &op_desc, vector<GeTensorPtr> &outputs, | ||||
| map<std::string, Tensor> &named_outputs) { | map<std::string, Tensor> &named_outputs) { | ||||
| if (!outputs.empty() && (outputs.size() != op_desc->GetOutputsSize())) { | |||||
| GELOGW("size of ouputs not match, size of outputs = %zu, exactly output_num=%zu.", outputs.size(), | |||||
| op_desc->GetOutputsSize()); | |||||
| outputs.clear(); | |||||
| } | |||||
| bool need_create_flag = (outputs.size() != op_desc->GetOutputsSize()); | |||||
| for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { | for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { | ||||
| auto ge_tensor = MakeShared<GeTensor>(op_desc->GetOutputDesc(i)); | |||||
| GE_CHECK_NOTNULL(ge_tensor); | |||||
| outputs.emplace_back(ge_tensor); | |||||
| auto tensor = TensorAdapter::AsTensor(*ge_tensor); | |||||
| auto tensor_name = op_desc->GetOutputNameByIndex(i); | |||||
| GE_RETURN_WITH_LOG_IF_TRUE(tensor_name.empty(), "Failed to get output name. node = %s, index = %zu", | |||||
| op_desc->GetName().c_str(), i); | |||||
| GELOGD("Successfully inserted output tensor. node = %s, index = %zu, output name = %s", op_desc->GetName().c_str(), | |||||
| i, tensor_name.c_str()); | |||||
| named_outputs.emplace(tensor_name, tensor); | |||||
| const auto &out_desc = op_desc->GetOutputDesc(i); | |||||
| switch (out_desc.GetDataType()) { | |||||
| CREATE_OUTPUT_CASE(DT_BOOL, bool) | |||||
| CREATE_OUTPUT_CASE(DT_INT8, int8_t) | |||||
| CREATE_OUTPUT_CASE(DT_INT16, int16_t) | |||||
| CREATE_OUTPUT_CASE(DT_INT32, int32_t) | |||||
| CREATE_OUTPUT_CASE(DT_INT64, int64_t) | |||||
| CREATE_OUTPUT_CASE(DT_UINT8, uint8_t) | |||||
| CREATE_OUTPUT_CASE(DT_UINT16, uint16_t) | |||||
| CREATE_OUTPUT_CASE(DT_UINT32, uint32_t) | |||||
| CREATE_OUTPUT_CASE(DT_UINT64, uint64_t) | |||||
| CREATE_OUTPUT_CASE(DT_FLOAT16, fp16_t) | |||||
| CREATE_OUTPUT_CASE(DT_FLOAT, float) | |||||
| CREATE_OUTPUT_CASE(DT_DOUBLE, double) | |||||
| default: | |||||
| GELOGE(PARAM_INVALID, "data type %s not support.", | |||||
| TypeUtils::DataTypeToSerialString(out_desc.GetDataType()).c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -146,6 +199,7 @@ Status HostCpuEngine::Run(NodePtr &node, const vector<ConstGeTensorPtr> &inputs, | |||||
| std::map<std::string, const Tensor> named_inputs; | std::map<std::string, const Tensor> named_inputs; | ||||
| std::vector<GeTensorPtr> tmp_outputs; | std::vector<GeTensorPtr> tmp_outputs; | ||||
| tmp_outputs.swap(outputs); | |||||
| std::map<std::string, Tensor> named_outputs; | std::map<std::string, Tensor> named_outputs; | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| GE_CHK_STATUS_RET_NOLOG(PrepareInputs(op_desc, inputs, named_inputs)); | GE_CHK_STATUS_RET_NOLOG(PrepareInputs(op_desc, inputs, named_inputs)); | ||||
| @@ -233,6 +287,15 @@ Status HostCpuEngine::LoadLib(const std::string &lib_path) { | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| auto initialize = (Status(*)(const HostCpuContext &))dlsym(handle, "Initialize"); | |||||
| if (initialize != nullptr) { | |||||
| GELOGI("Invoke function Initialize in lib: %s", lib_path.c_str()); | |||||
| if (initialize(HostCpuContext()) != SUCCESS) { | |||||
| GELOGW("Failed to invoke function Initialize in lib: %s", lib_path.c_str()); | |||||
| } | |||||
| } | |||||
| GELOGI("Lib: %s has been opened", lib_path.c_str()); | |||||
| lib_handles_.emplace_back(handle); | lib_handles_.emplace_back(handle); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -247,4 +310,4 @@ Status HostCpuEngine::GetRealPath(std::string &path) { | |||||
| path = real_path; | path = real_path; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace ge | |||||
| } // namespace ge | |||||
| @@ -42,6 +42,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| graph/build/stream_graph_optimizer.cc \ | graph/build/stream_graph_optimizer.cc \ | ||||
| graph/build/task_generator.cc \ | graph/build/task_generator.cc \ | ||||
| graph/common/bcast.cc \ | graph/common/bcast.cc \ | ||||
| graph/common/local_context.cc \ | |||||
| graph/common/omg_util.cc \ | graph/common/omg_util.cc \ | ||||
| graph/common/transop_util.cc \ | graph/common/transop_util.cc \ | ||||
| graph/execute/graph_execute.cc \ | graph/execute/graph_execute.cc \ | ||||
| @@ -88,6 +89,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| graph/manager/graph_mem_allocator.cc \ | graph/manager/graph_mem_allocator.cc \ | ||||
| graph/manager/graph_caching_allocator.cc \ | graph/manager/graph_caching_allocator.cc \ | ||||
| graph/manager/graph_var_manager.cc \ | graph/manager/graph_var_manager.cc \ | ||||
| graph/manager/rdma_pool_allocator.cc \ | |||||
| graph/manager/model_manager/event_manager.cc \ | graph/manager/model_manager/event_manager.cc \ | ||||
| graph/manager/trans_var_data_utils.cc \ | graph/manager/trans_var_data_utils.cc \ | ||||
| graph/manager/util/debug.cc \ | graph/manager/util/debug.cc \ | ||||
| @@ -289,6 +291,7 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| hybrid/node_executor/task_context.cc \ | hybrid/node_executor/task_context.cc \ | ||||
| hybrid/hybrid_davinci_model.cc \ | hybrid/hybrid_davinci_model.cc \ | ||||
| executor/ge_executor.cc \ | executor/ge_executor.cc \ | ||||
| analyzer/analyzer.cc \ | |||||
| LIBCLIENT_LOCAL_SRC_FILES := \ | LIBCLIENT_LOCAL_SRC_FILES := \ | ||||
| proto/ge_api.proto \ | proto/ge_api.proto \ | ||||
| @@ -308,11 +311,13 @@ RUNNER_LOCAL_C_INCLUDES := \ | |||||
| $(TOPDIR)inc/runtime \ | $(TOPDIR)inc/runtime \ | ||||
| $(TOPDIR)libc_sec/include \ | $(TOPDIR)libc_sec/include \ | ||||
| $(TOPDIR)ops/built-in/op_proto/inc \ | $(TOPDIR)ops/built-in/op_proto/inc \ | ||||
| $(TOPDIR)framework/domi/analyzer \ | |||||
| proto/fwk_adapter.proto \ | proto/fwk_adapter.proto \ | ||||
| proto/ge_ir.proto \ | proto/ge_ir.proto \ | ||||
| proto/insert_op.proto \ | proto/insert_op.proto \ | ||||
| proto/om.proto \ | proto/om.proto \ | ||||
| proto/op_mapping_info.proto \ | proto/op_mapping_info.proto \ | ||||
| proto/dump_task.proto \ | |||||
| proto/task.proto \ | proto/task.proto \ | ||||
| proto/tensorflow/attr_value.proto \ | proto/tensorflow/attr_value.proto \ | ||||
| proto/tensorflow/function.proto \ | proto/tensorflow/function.proto \ | ||||
| @@ -75,7 +75,8 @@ bool AicpuTask::Distribute() { | |||||
| return false; | return false; | ||||
| } | } | ||||
| flag = rtMemcpy(ext_info_, ext_size, reinterpret_cast<void *>(ext_info.data()), ext_size, RT_MEMCPY_HOST_TO_DEVICE); | |||||
| flag = rtMemcpy(ext_info_, ext_size, const_cast<void *>(reinterpret_cast<const void *>(ext_info.data())), ext_size, | |||||
| RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (flag != RT_ERROR_NONE) { | if (flag != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Call rt api(rtMemCpy) failed, ret: 0x%X.", flag); | GELOGE(RT_FAILED, "Call rt api(rtMemCpy) failed, ret: 0x%X.", flag); | ||||
| return false; | return false; | ||||
| @@ -15,6 +15,9 @@ | |||||
| */ | */ | ||||
| #include "generator/ge_generator.h" | #include "generator/ge_generator.h" | ||||
| #include <atomic> | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "common/ge/plugin_manager.h" | #include "common/ge/plugin_manager.h" | ||||
| #include "common/helper/model_helper.h" | #include "common/helper/model_helper.h" | ||||
| @@ -212,6 +215,9 @@ static void GetOpsProtoPath(string &opsproto_path) { | |||||
| class GeGenerator::Impl { | class GeGenerator::Impl { | ||||
| public: | public: | ||||
| Impl(OmgContext &omg_context) : omg_context_(omg_context), graph_manager_(omg_context) {} | |||||
| ~Impl() = default; | |||||
| Status BuildModel(const Graph &graph, const vector<GeTensor> &inputs, GeRootModelPtr &ge_models); | Status BuildModel(const Graph &graph, const vector<GeTensor> &inputs, GeRootModelPtr &ge_models); | ||||
| Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model); | Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model); | ||||
| @@ -221,10 +227,14 @@ class GeGenerator::Impl { | |||||
| Status GenerateInfershapeGraph(const Graph &graph); | Status GenerateInfershapeGraph(const Graph &graph); | ||||
| OmgContext &omg_context_; | |||||
| GraphManager graph_manager_; | GraphManager graph_manager_; | ||||
| SaveParam save_param_; | SaveParam save_param_; | ||||
| bool is_offline_ = true; | bool is_offline_ = true; | ||||
| bool is_singleop_unregistered_ = false; | bool is_singleop_unregistered_ = false; | ||||
| std::string build_mode_; | |||||
| std::string build_step_; | |||||
| static std::mutex mutex_; | |||||
| private: | private: | ||||
| static std::string Trim(const std::string &str); | static std::string Trim(const std::string &str); | ||||
| @@ -234,8 +244,10 @@ class GeGenerator::Impl { | |||||
| bool SetOppVersionInfo(AttrHolder &obj); | bool SetOppVersionInfo(AttrHolder &obj); | ||||
| }; | }; | ||||
| Status GeGenerator::Initialize(const map<string, string> &options) { | |||||
| impl_ = ge::MakeShared<Impl>(); | |||||
| Status GeGenerator::Initialize(const map<string, string> &options) { return Initialize(options, domi::GetContext()); } | |||||
| Status GeGenerator::Initialize(const map<string, string> &options, OmgContext &omg_context) { | |||||
| impl_ = ge::MakeShared<Impl>(omg_context); | |||||
| if (impl_ == nullptr) { | if (impl_ == nullptr) { | ||||
| GELOGE(MEMALLOC_FAILED, "Make shared failed"); | GELOGE(MEMALLOC_FAILED, "Make shared failed"); | ||||
| return MEMALLOC_FAILED; | return MEMALLOC_FAILED; | ||||
| @@ -273,6 +285,17 @@ Status GeGenerator::Initialize(const map<string, string> &options) { | |||||
| if (iter != options.end()) { | if (iter != options.end()) { | ||||
| impl_->save_param_.pri_key_file = iter->second; | impl_->save_param_.pri_key_file = iter->second; | ||||
| } | } | ||||
| // get build mode | |||||
| iter = options.find(BUILD_MODE); | |||||
| if (iter != options.end()) { | |||||
| impl_->build_mode_ = iter->second; | |||||
| } | |||||
| // get build step | |||||
| iter = options.find(BUILD_STEP); | |||||
| if (iter != options.end()) { | |||||
| impl_->build_step_ = iter->second; | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -312,6 +335,8 @@ Status GeGenerator::GenerateInfershapeGraph(const Graph &graph) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| std::mutex GeGenerator::Impl::mutex_; | |||||
| // Remove the space and tab before and after the string | // Remove the space and tab before and after the string | ||||
| std::string GeGenerator::Impl::Trim(const std::string &str) { | std::string GeGenerator::Impl::Trim(const std::string &str) { | ||||
| if (str.empty()) { | if (str.empty()) { | ||||
| @@ -436,8 +461,7 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||||
| auto rt = rtCtxGetCurrent(&ctx); | auto rt = rtCtxGetCurrent(&ctx); | ||||
| if (rt != RT_ERROR_NONE) { | if (rt != RT_ERROR_NONE) { | ||||
| GELOGW("Current ctx is null."); | GELOGW("Current ctx is null."); | ||||
| } else { | |||||
| ge::RtContextUtil::GetInstance().SetNormalModeContext(ctx); | |||||
| ctx = nullptr; | |||||
| } | } | ||||
| GeRootModelPtr ge_root_model = nullptr; | GeRootModelPtr ge_root_model = nullptr; | ||||
| @@ -451,6 +475,17 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| /// BUILD_MODE_TUNING with BUILD_STEP_BEFORE_UB_MATCH no need save model; | |||||
| /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER no need save model; | |||||
| /// BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need save model. | |||||
| if ((impl_->build_mode_ == BUILD_MODE_TUNING) && | |||||
| (impl_->build_step_ == BUILD_STEP_BEFORE_UB_MATCH || impl_->build_step_ == BUILD_STEP_AFTER_BUILDER || | |||||
| impl_->build_step_ == BUILD_STEP_AFTER_BUILDER_SUB)) { | |||||
| GELOGI("Build mode:%s with step:%s no need SaveModel.", impl_->build_mode_.c_str(), impl_->build_step_.c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| GE_CHECK_NOTNULL(ge_root_model); | GE_CHECK_NOTNULL(ge_root_model); | ||||
| GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); | GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); | ||||
| ModelHelper model_helper; | ModelHelper model_helper; | ||||
| @@ -474,8 +509,8 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| if (RtContextUtil::GetInstance().GetNormalModeContext() != nullptr) { | |||||
| (void)rtCtxSetCurrent(RtContextUtil::GetInstance().GetNormalModeContext()); | |||||
| if (ctx != nullptr) { | |||||
| (void)rtCtxSetCurrent(ctx); | |||||
| } | } | ||||
| GELOGI("GenerateOfflineModel success."); | GELOGI("GenerateOfflineModel success."); | ||||
| @@ -495,7 +530,8 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| domi::GetContext().is_dynamic_input = ContainsDynamicInpus(*op_desc); | |||||
| OmgContext &omg_context = (impl_ == nullptr) ? domi::GetContext() : impl_->omg_context_; | |||||
| omg_context.is_dynamic_input = ContainsDynamicInpus(*op_desc); | |||||
| if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) { | if (op_desc->HasAttr(ATTR_NAME_UNREGST_OPPATH)) { | ||||
| impl_->is_singleop_unregistered_ = true; | impl_->is_singleop_unregistered_ = true; | ||||
| @@ -633,35 +669,32 @@ Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr & | |||||
| Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs, | Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs, | ||||
| GeRootModelPtr &ge_root_model) { | GeRootModelPtr &ge_root_model) { | ||||
| static GraphId id = 0; | |||||
| static std::atomic<GraphId> atomic_graph_id(0); | |||||
| auto graph_id = atomic_graph_id.fetch_add(1); | |||||
| const std::map<std::string, std::string> options; | const std::map<std::string, std::string> options; | ||||
| Status ret = graph_manager_.AddGraph(id, graph, options); | |||||
| Status ret = graph_manager_.AddGraph(graph_id, graph, options); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph fail, graph id: %u", id); | |||||
| GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph fail, graph id: %u", graph_id); | |||||
| (void)graph_manager_.Finalize(); | (void)graph_manager_.Finalize(); | ||||
| return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED; | return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED; | ||||
| } | } | ||||
| GELOGI("Model inputs size is %zu", inputs.size()); | GELOGI("Model inputs size is %zu", inputs.size()); | ||||
| graph_manager_.SetOptionsRunGraphFlag(false); | graph_manager_.SetOptionsRunGraphFlag(false); | ||||
| struct timeval tv; | |||||
| if (gettimeofday(&tv, nullptr) != 0) { | |||||
| GELOGE(INTERNAL_ERROR, "get the time of day failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| uint64_t session_id = static_cast<uint64_t>(tv.tv_sec * 1000000 + tv.tv_usec); // 1000000us | |||||
| static std::atomic<uint64_t> atomic_session_id(0); | |||||
| auto session_id = atomic_session_id.fetch_add(1); | |||||
| if (is_singleop_unregistered_) { | if (is_singleop_unregistered_) { | ||||
| ret = graph_manager_.BuildGraphForUnregisteredOp(id, inputs, ge_root_model, session_id); | |||||
| ret = graph_manager_.BuildGraphForUnregisteredOp(graph_id, inputs, ge_root_model, session_id); | |||||
| } else { | } else { | ||||
| ret = graph_manager_.BuildGraph(id, inputs, ge_root_model, session_id); | |||||
| ret = graph_manager_.BuildGraph(graph_id, inputs, ge_root_model, session_id); | |||||
| } | } | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager build graph fail, graph id: %u", id); | |||||
| GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager build graph fail, graph id: %u", graph_id); | |||||
| VarManagerPool::Instance().RemoveVarManager(session_id); | VarManagerPool::Instance().RemoveVarManager(session_id); | ||||
| return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; | return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; | ||||
| } | } | ||||
| id += 1; | |||||
| VarManagerPool::Instance().RemoveVarManager(session_id); | VarManagerPool::Instance().RemoveVarManager(session_id); | ||||
| @@ -669,21 +702,21 @@ Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> | |||||
| } | } | ||||
| Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph) { | Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph) { | ||||
| static GraphId id = 0; | |||||
| static std::atomic<GraphId> atomic_graph_id(0); | |||||
| auto graph_id = atomic_graph_id.fetch_add(1); | |||||
| const std::map<std::string, std::string> options; | const std::map<std::string, std::string> options; | ||||
| Status ret = graph_manager_.AddGraph(id, graph, options); | |||||
| Status ret = graph_manager_.AddGraph(graph_id, graph, options); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph failed, graph id: %u", id); | |||||
| GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph failed, graph id: %u", graph_id); | |||||
| (void)graph_manager_.Finalize(); | (void)graph_manager_.Finalize(); | ||||
| return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED; | return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED; | ||||
| } | } | ||||
| ret = graph_manager_.GenerateInfershapeGraph(id); | |||||
| ret = graph_manager_.GenerateInfershapeGraph(graph_id); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager generate graph failed"); | GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager generate graph failed"); | ||||
| return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; | return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; | ||||
| } | } | ||||
| id += 1; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -63,7 +63,7 @@ Status GraphBuilder::CalcOpParam(const ge::ComputeGraphPtr &graph) { | |||||
| std::string kernel_lib_name = node_ptr->GetOpDesc()->GetOpKernelLibName(); | std::string kernel_lib_name = node_ptr->GetOpDesc()->GetOpKernelLibName(); | ||||
| if (kernel_lib_name.empty()) { | if (kernel_lib_name.empty()) { | ||||
| // reset op kernel lib | // reset op kernel lib | ||||
| (void)instance_ptr->DNNEngineManagerObj().GetDNNEngineName(node_ptr->GetOpDesc()); | |||||
| (void)instance_ptr->DNNEngineManagerObj().GetDNNEngineName(node_ptr); | |||||
| kernel_lib_name = node_ptr->GetOpDesc()->GetOpKernelLibName(); | kernel_lib_name = node_ptr->GetOpDesc()->GetOpKernelLibName(); | ||||
| if (kernel_lib_name.empty()) { | if (kernel_lib_name.empty()) { | ||||
| GELOGE(INTERNAL_ERROR, "Get node:%s(%s) kernel lib failed.", node_ptr->GetName().c_str(), | GELOGE(INTERNAL_ERROR, "Get node:%s(%s) kernel lib failed.", node_ptr->GetName().c_str(), | ||||
| @@ -84,6 +84,7 @@ Status GraphBuilder::CalcOpParam(const ge::ComputeGraphPtr &graph) { | |||||
| GELOGE(ret, "Calculate op running param failed, node name is %s", node_ptr->GetName().c_str()); | GELOGE(ret, "Calculate op running param failed, node name is %s", node_ptr->GetName().c_str()); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| GE_CHK_STATUS_RET(AddOutputMemTypeForNode(node_ptr)); | |||||
| } else { | } else { | ||||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "Get op %s ops kernel info store failed", node_ptr->GetName().c_str()); | GELOGE(GE_GRAPH_PARAM_NULLPTR, "Get op %s ops kernel info store failed", node_ptr->GetName().c_str()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| @@ -497,4 +498,24 @@ Status GraphBuilder::SecondPartition(ge::ComputeGraphPtr &comp_graph, vector<ge: | |||||
| GE_TIMESTAMP_END(GraphPartition2, "GraphPartitioner::Partition2"); | GE_TIMESTAMP_END(GraphPartition2, "GraphPartitioner::Partition2"); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| Status GraphBuilder::AddOutputMemTypeForNode(const NodePtr &node) { | |||||
| int64_t mem_type; | |||||
| if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_INPUT_MEMORY_TYPE, mem_type)) { | |||||
| GELOGD("[%s] has attr input_memory_type %ld", node->GetName().c_str(), mem_type); | |||||
| for (const auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| const auto &peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | |||||
| GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | |||||
| const auto &src_node = peer_out_anchor->GetOwnerNode(); | |||||
| const auto &src_op = src_node->GetOpDesc(); | |||||
| GE_IF_BOOL_EXEC(src_op == nullptr, continue); | |||||
| if (!AttrUtils::SetInt(src_op, ATTR_OUTPUT_MEMORY_TYPE, mem_type)) { | |||||
| GELOGE(INTERNAL_ERROR, "Set out_memory_type attr failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -67,6 +67,7 @@ class GraphBuilder { | |||||
| GeModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); | GeModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); | ||||
| Status BuildForUnknownShapeGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, | Status BuildForUnknownShapeGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, | ||||
| uint64_t session_id = INVALID_SESSION_ID); | uint64_t session_id = INVALID_SESSION_ID); | ||||
| Status AddOutputMemTypeForNode(const NodePtr &node); | |||||
| Status BuildForHostCpuGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, | Status BuildForHostCpuGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, | ||||
| uint64_t session_id = INVALID_SESSION_ID); | uint64_t session_id = INVALID_SESSION_ID); | ||||
| int build_mode_; | int build_mode_; | ||||
| @@ -24,7 +24,9 @@ | |||||
| namespace ge { | namespace ge { | ||||
| class BinaryBlockMemAssigner : public BlockMemAssigner { | class BinaryBlockMemAssigner : public BlockMemAssigner { | ||||
| public: | public: | ||||
| explicit BinaryBlockMemAssigner(ge::ComputeGraphPtr compute_graph) : BlockMemAssigner(std::move(compute_graph)) {} | |||||
| BinaryBlockMemAssigner(ComputeGraphPtr compute_graph, const std::map<std::string, std::string> &anchor_to_symbol, | |||||
| const std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors) | |||||
| : BlockMemAssigner(std::move(compute_graph), anchor_to_symbol, symbol_to_anchors) {} | |||||
| BinaryBlockMemAssigner(const BinaryBlockMemAssigner &) = delete; | BinaryBlockMemAssigner(const BinaryBlockMemAssigner &) = delete; | ||||
| @@ -32,10 +32,12 @@ | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/common/local_context.h" | |||||
| #include "graph/optimize/common/params.h" | #include "graph/optimize/common/params.h" | ||||
| #include "omg/omg_inner_types.h" | #include "omg/omg_inner_types.h" | ||||
| #include "runtime/mem.h" | #include "runtime/mem.h" | ||||
| using std::list; | |||||
| using std::map; | using std::map; | ||||
| using std::pair; | using std::pair; | ||||
| using std::set; | using std::set; | ||||
| @@ -402,8 +404,13 @@ string MemoryBlock::String() { | |||||
| return ss.str(); | return ss.str(); | ||||
| } | } | ||||
| BlockMemAssigner::BlockMemAssigner(ge::ComputeGraphPtr compute_graph) | |||||
| : mem_offset_(0), compute_graph_(std::move(compute_graph)), life_time_(0) {} | |||||
| BlockMemAssigner::BlockMemAssigner(ComputeGraphPtr compute_graph, const map<string, string> &anchor_to_symbol, | |||||
| const map<string, list<NodeIndexIO>> &symbol_to_anchors) | |||||
| : mem_offset_(0), | |||||
| compute_graph_(std::move(compute_graph)), | |||||
| symbol_to_anchors_(symbol_to_anchors), | |||||
| anchor_to_symbol_(anchor_to_symbol), | |||||
| life_time_(0) {} | |||||
| BlockMemAssigner::~BlockMemAssigner() { | BlockMemAssigner::~BlockMemAssigner() { | ||||
| for (MemoryBlock *memory_block : memory_blocks_) { | for (MemoryBlock *memory_block : memory_blocks_) { | ||||
| @@ -412,11 +419,6 @@ BlockMemAssigner::~BlockMemAssigner() { | |||||
| } | } | ||||
| void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) { | void BlockMemAssigner::GetOutAndWorkSpaceMem(vector<int64_t> &all_memory_size) { | ||||
| if (GraphUtils::GetRefMapping(compute_graph_, symbol_to_anchors_, anchor_to_symbol_) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Get ref-mapping for graph %s failed.", compute_graph_->GetName().c_str()); | |||||
| return; | |||||
| } | |||||
| vector<int64_t> temp; | vector<int64_t> temp; | ||||
| for (const NodePtr &n : compute_graph_->GetAllNodes()) { | for (const NodePtr &n : compute_graph_->GetAllNodes()) { | ||||
| auto node_op_desc = n->GetOpDesc(); | auto node_op_desc = n->GetOpDesc(); | ||||
| @@ -692,13 +694,16 @@ bool BlockMemAssigner::IsPostReuse(const MemoryBlock *mem_block) const { | |||||
| /// @ingroup GE | /// @ingroup GE | ||||
| /// @brief check if symbol of cur node_index_io has block | /// @brief check if symbol of cur node_index_io has block | ||||
| /// @param [in] node_index_io | /// @param [in] node_index_io | ||||
| /// @param [out] symbol | |||||
| /// @return bool | /// @return bool | ||||
| /// | /// | ||||
| bool BlockMemAssigner::IsSymbolExist(const NodeIndexIO &node_index_io) { | |||||
| bool BlockMemAssigner::IsSymbolExist(const NodeIndexIO &node_index_io, string &symbol) { | |||||
| auto iter = anchor_to_symbol_.find(node_index_io.ToString()); | auto iter = anchor_to_symbol_.find(node_index_io.ToString()); | ||||
| if (iter == anchor_to_symbol_.end()) { | if (iter == anchor_to_symbol_.end()) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| symbol = iter->second; | |||||
| return symbol_blocks_.find(iter->second) != symbol_blocks_.end(); | return symbol_blocks_.find(iter->second) != symbol_blocks_.end(); | ||||
| } | } | ||||
| @@ -883,8 +888,8 @@ MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetNoAlignSize(*node_op_desc, index, no_align_size) != SUCCESS, return nullptr, | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetNoAlignSize(*node_op_desc, index, no_align_size) != SUCCESS, return nullptr, | ||||
| "Get no align size failed"); | "Get no align size failed"); | ||||
| if (IsSymbolExist(node_index_io)) { | |||||
| const std::string &symbol = anchor_to_symbol_[node_index_io.ToString()]; | |||||
| std::string symbol; | |||||
| if (IsSymbolExist(node_index_io, symbol)) { | |||||
| block = symbol_blocks_[symbol]; | block = symbol_blocks_[symbol]; | ||||
| block->AddNodeTypeIndex({n, kOutput, index, true}, size, no_align_size); | block->AddNodeTypeIndex({n, kOutput, index, true}, size, no_align_size); | ||||
| block->ref_count_++; | block->ref_count_++; | ||||
| @@ -949,8 +954,8 @@ bool IsOutputBlock(const ge::InDataAnchorPtr &in_data_anchor) { | |||||
| GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, GELOGE(FAILED, "Peer out anchor is nullptr."); return false); | GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, GELOGE(FAILED, "Peer out anchor is nullptr."); return false); | ||||
| auto src = peer_out_anchor->GetOwnerNode(); | auto src = peer_out_anchor->GetOwnerNode(); | ||||
| int32_t index = peer_out_anchor->GetIdx(); | int32_t index = peer_out_anchor->GetIdx(); | ||||
| auto iter = domi::GetContext().out_nodes_map.find(src->GetName()); | |||||
| if (iter != domi::GetContext().out_nodes_map.end()) { | |||||
| auto iter = GetLocalOmgContext().out_nodes_map.find(src->GetName()); | |||||
| if (iter != GetLocalOmgContext().out_nodes_map.end()) { | |||||
| for (auto id : iter->second) { | for (auto id : iter->second) { | ||||
| if (index == id) { | if (index == id) { | ||||
| return true; | return true; | ||||
| @@ -159,7 +159,8 @@ class MemoryBlock { | |||||
| class BlockMemAssigner : public MemAssigner { | class BlockMemAssigner : public MemAssigner { | ||||
| public: | public: | ||||
| explicit BlockMemAssigner(ge::ComputeGraphPtr compute_graph); | |||||
| BlockMemAssigner(ComputeGraphPtr compute_graph, const std::map<std::string, std::string> &anchor_to_symbol, | |||||
| const std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors); | |||||
| BlockMemAssigner(const BlockMemAssigner &) = delete; | BlockMemAssigner(const BlockMemAssigner &) = delete; | ||||
| @@ -241,9 +242,10 @@ class BlockMemAssigner : public MemAssigner { | |||||
| /// @ingroup GE | /// @ingroup GE | ||||
| /// @brief check if symbol of cur node_index_io has block | /// @brief check if symbol of cur node_index_io has block | ||||
| /// @param [in] node_index_io | /// @param [in] node_index_io | ||||
| /// @param [out] symbol | |||||
| /// @return bool | /// @return bool | ||||
| /// | /// | ||||
| bool IsSymbolExist(const NodeIndexIO &node_index_io); | |||||
| bool IsSymbolExist(const NodeIndexIO &node_index_io, std::string &symbol); | |||||
| /// | /// | ||||
| /// @ingroup GE | /// @ingroup GE | ||||
| @@ -261,8 +263,8 @@ class BlockMemAssigner : public MemAssigner { | |||||
| std::vector<NodeTypeIndex> zero_memory_list_; | std::vector<NodeTypeIndex> zero_memory_list_; | ||||
| // ref mapping | // ref mapping | ||||
| std::map<std::string, std::list<NodeIndexIO>> symbol_to_anchors_; | |||||
| std::map<std::string, std::string> anchor_to_symbol_; | |||||
| const std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors_; | |||||
| const std::map<std::string, std::string> &anchor_to_symbol_; | |||||
| std::map<std::string, bool> pre_reuse_flag_; | std::map<std::string, bool> pre_reuse_flag_; | ||||
| std::map<std::string, bool> post_reuse_flag_; | std::map<std::string, bool> post_reuse_flag_; | ||||
| std::map<std::string, size_t> symbol_size_; | std::map<std::string, size_t> symbol_size_; | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include <cstring> | #include <cstring> | ||||
| #include <set> | #include <set> | ||||
| #include "common/math/math_util.h" | #include "common/math/math_util.h" | ||||
| #include "common/util/error_manager/error_manager.h" | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/build/memory/hybrid_mem_assigner.h" | #include "graph/build/memory/hybrid_mem_assigner.h" | ||||
| #include "graph/build/memory/var_mem_assign_util.h" | #include "graph/build/memory/var_mem_assign_util.h" | ||||
| @@ -226,6 +227,7 @@ Status GraphMemoryAssigner::ReAssignMemory(bool is_loop_graph, size_t &mem_offse | |||||
| if (mem_offset > VarManager::Instance(session_id)->GetGraphMemoryMaxSize()) { | if (mem_offset > VarManager::Instance(session_id)->GetGraphMemoryMaxSize()) { | ||||
| GELOGE(ge::FAILED, "Current memoffset %zu is greater than memory manager malloc max size %zu", mem_offset, | GELOGE(ge::FAILED, "Current memoffset %zu is greater than memory manager malloc max size %zu", mem_offset, | ||||
| VarManager::Instance(session_id)->GetGraphMemoryMaxSize()); | VarManager::Instance(session_id)->GetGraphMemoryMaxSize()); | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19022"); | |||||
| return ge::FAILED; | return ge::FAILED; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -41,10 +41,17 @@ Status HybridMemAssigner::AssignMemory(std::unique_ptr<BlockMemAssigner> &block_ | |||||
| } | } | ||||
| Status HybridMemAssigner::Assign() { | Status HybridMemAssigner::Assign() { | ||||
| std::unique_ptr<BlockMemAssigner> binary_assigner(new (std::nothrow) BinaryBlockMemAssigner(compute_graph_)); | |||||
| if (GraphUtils::GetRefMapping(compute_graph_, symbol_to_anchors_, anchor_to_symbol_) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Get ref-mapping for graph %s failed.", compute_graph_->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| std::unique_ptr<BlockMemAssigner> binary_assigner( | |||||
| new (std::nothrow) BinaryBlockMemAssigner(compute_graph_, anchor_to_symbol_, symbol_to_anchors_)); | |||||
| GE_CHECK_NOTNULL(binary_assigner); | GE_CHECK_NOTNULL(binary_assigner); | ||||
| std::unique_ptr<BlockMemAssigner> max_assigner(new (std::nothrow) MaxBlockMemAssigner(compute_graph_)); | |||||
| std::unique_ptr<BlockMemAssigner> max_assigner( | |||||
| new (std::nothrow) MaxBlockMemAssigner(compute_graph_, anchor_to_symbol_, symbol_to_anchors_)); | |||||
| GE_CHECK_NOTNULL(max_assigner); | GE_CHECK_NOTNULL(max_assigner); | ||||
| size_t bin_mem_size = 0; | size_t bin_mem_size = 0; | ||||
| @@ -54,6 +54,9 @@ class HybridMemAssigner : public MemAssigner { | |||||
| ge::ComputeGraphPtr compute_graph_; | ge::ComputeGraphPtr compute_graph_; | ||||
| BlockMemAssignerPtr priority_assigner_; | BlockMemAssignerPtr priority_assigner_; | ||||
| std::map<std::string, std::string> anchor_to_symbol_; | |||||
| std::map<std::string, std::list<NodeIndexIO>> symbol_to_anchors_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_BUILD_MEMORY_HYBRID_MEM_ASSIGNER_H_ | #endif // GE_GRAPH_BUILD_MEMORY_HYBRID_MEM_ASSIGNER_H_ | ||||
| @@ -23,7 +23,9 @@ | |||||
| namespace ge { | namespace ge { | ||||
| class MaxBlockMemAssigner : public BlockMemAssigner { | class MaxBlockMemAssigner : public BlockMemAssigner { | ||||
| public: | public: | ||||
| explicit MaxBlockMemAssigner(ge::ComputeGraphPtr compute_graph) : BlockMemAssigner(std::move(compute_graph)) {} | |||||
| MaxBlockMemAssigner(ComputeGraphPtr compute_graph, const std::map<std::string, std::string> &anchor_to_symbol, | |||||
| const std::map<std::string, std::list<NodeIndexIO>> &symbol_to_anchors) | |||||
| : BlockMemAssigner(std::move(compute_graph), anchor_to_symbol, symbol_to_anchors) {} | |||||
| MaxBlockMemAssigner(const MaxBlockMemAssigner &) = delete; | MaxBlockMemAssigner(const MaxBlockMemAssigner &) = delete; | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include "graph/build/stream_allocator.h" | #include "graph/build/stream_allocator.h" | ||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include "graph/common/ge_call_wrapper.h" | #include "graph/common/ge_call_wrapper.h" | ||||
| #include "graph/common/local_context.h" | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| @@ -244,7 +245,7 @@ Status ModelBuilder::SetInputOutputDesc() { | |||||
| } | } | ||||
| // if user set input node format ND, the expected node for data and netoutput format is ND in | // if user set input node format ND, the expected node for data and netoutput format is ND in | ||||
| // final graph. | // final graph. | ||||
| if ((domi::GetContext().format == domi::DOMI_TENSOR_ND) && (!node_op_desc->HasAttr("_is_single_op")) && | |||||
| if ((GetLocalOmgContext().format == domi::DOMI_TENSOR_ND) && (!node_op_desc->HasAttr("_is_single_op")) && | |||||
| ((node_op_desc->GetType() == DATA_TYPE) || (node_op_desc->GetType() == NETOUTPUT))) { | ((node_op_desc->GetType() == DATA_TYPE) || (node_op_desc->GetType() == NETOUTPUT))) { | ||||
| GELOGI("The node [%s] format should be set ND.", node_op_desc->GetName().c_str()); | GELOGI("The node [%s] format should be set ND.", node_op_desc->GetName().c_str()); | ||||
| auto inputDescsPtr = node_op_desc->GetAllInputsDescPtr(); | auto inputDescsPtr = node_op_desc->GetAllInputsDescPtr(); | ||||
| @@ -406,7 +407,7 @@ Status ModelBuilder::BuildModelDef(ge::Model &model) { | |||||
| GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_ZERO_COPY_MEMORY_SIZE, zero_copy_mem_size_), | GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_ZERO_COPY_MEMORY_SIZE, zero_copy_mem_size_), | ||||
| GELOGE(FAILED, "SetInt of ATTR_MODEL_ZERO_COPY_MEMORY_SIZE failed."); | GELOGE(FAILED, "SetInt of ATTR_MODEL_ZERO_COPY_MEMORY_SIZE failed."); | ||||
| return FAILED); | return FAILED); | ||||
| GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(&model, ATTR_MODEL_OUT_NODES_NAME, domi::GetContext().net_out_nodes), | |||||
| GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(&model, ATTR_MODEL_OUT_NODES_NAME, GetLocalOmgContext().net_out_nodes), | |||||
| GELOGE(FAILED, "SetListStr of ATTR_MODEL_OUT_NODES_NAME failed."); | GELOGE(FAILED, "SetListStr of ATTR_MODEL_OUT_NODES_NAME failed."); | ||||
| return FAILED); | return FAILED); | ||||
| GELOGI("For model, max_mem_offset_: %zu, zero_copy_mem_size_: %zu", max_mem_offset_, zero_copy_mem_size_); | GELOGI("For model, max_mem_offset_: %zu, zero_copy_mem_size_: %zu", max_mem_offset_, zero_copy_mem_size_); | ||||
| @@ -571,26 +572,59 @@ Status ModelBuilder::SaveDataToModel(ge::Model &model, ge::GeModel &ge_model) { | |||||
| // Add weight | // Add weight | ||||
| ge_model.SetWeight(weight_buffer_); | ge_model.SetWeight(weight_buffer_); | ||||
| // Add TBE Kernels | |||||
| std::set<std::string> name_set; | |||||
| // Add TBE Kernels and custom aicpu op bin | |||||
| std::set<std::string> tbe_name_set; | |||||
| std::set<std::string> aicpu_name_set; | |||||
| for (const ge::NodePtr &n : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { | for (const ge::NodePtr &n : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { | ||||
| auto node_op_desc = n->GetOpDesc(); | auto node_op_desc = n->GetOpDesc(); | ||||
| GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | ||||
| TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); | TBEKernelPtr tbe_kernel = node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, TBEKernelPtr()); | ||||
| if (tbe_kernel == nullptr) { | |||||
| std::string kernel_name; | |||||
| GeAttrValue::BYTES kernel_buffer; | |||||
| (void)AttrUtils::GetStr(node_op_desc, ATTR_NAME_TBE_KERNEL_NAME, kernel_name); | |||||
| (void)AttrUtils::GetBytes(node_op_desc, ATTR_NAME_TBE_KERNEL_BUFFER, kernel_buffer); | |||||
| if (!kernel_name.empty() && (kernel_buffer.GetSize() > 0)) { | |||||
| GE_CHECK_NOTNULL(kernel_buffer.GetData()); | |||||
| std::vector<char> data(kernel_buffer.GetData(), kernel_buffer.GetData() + kernel_buffer.GetSize()); | |||||
| tbe_kernel = std::make_shared<OpKernelBin>(kernel_name, std::move(data)); | |||||
| } | |||||
| } | |||||
| GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); | GE_IF_BOOL_EXEC(tbe_kernel == nullptr, continue); | ||||
| if (name_set.count(tbe_kernel->GetName()) > 0) { | |||||
| if (tbe_name_set.count(tbe_kernel->GetName()) > 0) { | |||||
| GELOGE(FAILED, "tbe_kernel name %s can't be the same", tbe_kernel->GetName().c_str()); | GELOGE(FAILED, "tbe_kernel name %s can't be the same", tbe_kernel->GetName().c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| name_set.insert(tbe_kernel->GetName()); | |||||
| tbe_name_set.insert(tbe_kernel->GetName()); | |||||
| tbe_kernel_store_.AddTBEKernel(tbe_kernel); | tbe_kernel_store_.AddTBEKernel(tbe_kernel); | ||||
| GELOGD("Add tbe kernel bin %s", tbe_kernel->GetName().c_str()); | |||||
| GELOGI("Add tbe kernel bin %s", tbe_kernel->GetName().c_str()); | |||||
| } | |||||
| for (const ge::NodePtr &n : compute_graph_->GetNodes(compute_graph_->GetGraphUnknownFlag())) { | |||||
| auto node_op_desc = n->GetOpDesc(); | |||||
| GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | |||||
| CustAICPUKernelPtr cust_aicpu_kernel = | |||||
| node_op_desc->TryGetExtAttr(ge::OP_EXTATTR_CUSTAICPU_KERNEL, CustAICPUKernelPtr()); | |||||
| GE_IF_BOOL_EXEC(cust_aicpu_kernel == nullptr, continue); | |||||
| if (aicpu_name_set.count(cust_aicpu_kernel->GetName()) > 0) { | |||||
| GELOGE(FAILED, "aicpu_kernel name %s can't be the same", cust_aicpu_kernel->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| aicpu_name_set.insert(cust_aicpu_kernel->GetName()); | |||||
| cust_aicpu_kernel_store_.AddCustAICPUKernel(cust_aicpu_kernel); | |||||
| GELOGI("Add cust aicpu kernel bin %s", cust_aicpu_kernel->GetName().c_str()); | |||||
| } | } | ||||
| if (!tbe_kernel_store_.Build()) { | if (!tbe_kernel_store_.Build()) { | ||||
| GELOGE(FAILED, "TBE Kernels store build failed!"); | GELOGE(FAILED, "TBE Kernels store build failed!"); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (!cust_aicpu_kernel_store_.Build()) { | |||||
| GELOGE(FAILED, "custom AICPU kernels store build failed!"); | |||||
| return FAILED; | |||||
| } | |||||
| ge_model.SetTBEKernelStore(tbe_kernel_store_); | ge_model.SetTBEKernelStore(tbe_kernel_store_); | ||||
| ge_model.SetCustAICPUKernelStore(cust_aicpu_kernel_store_); | |||||
| // Add task | // Add task | ||||
| GeAttrValue::BYTES task_def_bytes; | GeAttrValue::BYTES task_def_bytes; | ||||
| @@ -744,7 +778,7 @@ Status ModelBuilder::CompileSingleOp() { | |||||
| string kernel_lib_name = op_desc->GetOpKernelLibName(); | string kernel_lib_name = op_desc->GetOpKernelLibName(); | ||||
| if (kernel_lib_name.empty()) { | if (kernel_lib_name.empty()) { | ||||
| // Reset op kernel lib | // Reset op kernel lib | ||||
| (void)instance->DNNEngineManagerObj().GetDNNEngineName(op_desc); | |||||
| (void)instance->DNNEngineManagerObj().GetDNNEngineName(node); | |||||
| kernel_lib_name = op_desc->GetOpKernelLibName(); | kernel_lib_name = op_desc->GetOpKernelLibName(); | ||||
| if (kernel_lib_name.empty()) { | if (kernel_lib_name.empty()) { | ||||
| GELOGE(ge::INTERNAL_ERROR, "Get node:%s(%s) kernel lib failed.", node->GetName().c_str(), | GELOGE(ge::INTERNAL_ERROR, "Get node:%s(%s) kernel lib failed.", node->GetName().c_str(), | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "common/op/ge_op_utils.h" | #include "common/op/ge_op_utils.h" | ||||
| #include "common/tbe_kernel_store.h" | #include "common/tbe_kernel_store.h" | ||||
| #include "common/cust_aicpu_kernel_store.h" | |||||
| #include "common/types.h" | #include "common/types.h" | ||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| @@ -108,6 +109,7 @@ class ModelBuilder { | |||||
| size_t zero_copy_mem_size_; | size_t zero_copy_mem_size_; | ||||
| TBEKernelStore tbe_kernel_store_; | TBEKernelStore tbe_kernel_store_; | ||||
| CustAICPUKernelStore cust_aicpu_kernel_store_; | |||||
| uint8_t platform_type_; | uint8_t platform_type_; | ||||
| bool is_loop_graph_; | bool is_loop_graph_; | ||||
| @@ -15,8 +15,8 @@ | |||||
| */ | */ | ||||
| #include "graph/build/stream_allocator.h" | #include "graph/build/stream_allocator.h" | ||||
| #include <memory> | |||||
| #include <algorithm> | #include <algorithm> | ||||
| #include <memory> | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/fmk_error_codes.h" | #include "framework/common/fmk_error_codes.h" | ||||
| @@ -1062,12 +1062,12 @@ Status StreamAllocator::SetActiveStreamsForLoop() { | |||||
| GELOGI("there are %zu next iterator target streams has streamswitch node.", streams_skip_iterator_event.size()); | GELOGI("there are %zu next iterator target streams has streamswitch node.", streams_skip_iterator_event.size()); | ||||
| for (auto iter : stream_id_to_last_node) { | for (auto iter : stream_id_to_last_node) { | ||||
| if (streams_skip_iterator_event.find(iter.first) != streams_skip_iterator_event.end()) { | if (streams_skip_iterator_event.find(iter.first) != streams_skip_iterator_event.end()) { | ||||
| GELOGI("skip stream %ld which has streamswitch node when add event to next iterator active node", | |||||
| GELOGI("Skip stream %ld which has streamswitch node when adding event to next iterator active node", | |||||
| iter.first); | iter.first); | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (iter.second->GetOwnerComputeGraph()->GetParentGraph() != nullptr) { | if (iter.second->GetOwnerComputeGraph()->GetParentGraph() != nullptr) { | ||||
| GELOGI("skip stream %ld which last node in subgraph when add event to next iterator active node", | |||||
| GELOGI("Skip stream %ld which is last node in subgraph when adding event to next iterator active node", | |||||
| iter.first); | iter.first); | ||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -1264,15 +1264,6 @@ void StreamAllocator::DumpEvents() { | |||||
| } | } | ||||
| Status StreamAllocator::GetMaxStreamAndTask(bool huge_stream, uint32_t &max_stream_count, uint32_t &max_task_count) { | Status StreamAllocator::GetMaxStreamAndTask(bool huge_stream, uint32_t &max_stream_count, uint32_t &max_task_count) { | ||||
| const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); | |||||
| if (buffer_optimize_on != nullptr) { | |||||
| rtError_t ret = rtSetPlatformType(PLATFORM_MINI_V1); | |||||
| if (ret != RT_ERROR_NONE) { | |||||
| GELOGE(FAILED, "Get max stream and task count by rts failed."); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| uint32_t stream_type = RT_NORMAL_STREAM; | uint32_t stream_type = RT_NORMAL_STREAM; | ||||
| if (huge_stream) { | if (huge_stream) { | ||||
| stream_type = RT_HUGE_STREAM; | stream_type = RT_HUGE_STREAM; | ||||
| @@ -102,12 +102,9 @@ Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &com | |||||
| continue; | continue; | ||||
| } | } | ||||
| const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); | |||||
| if (buffer_optimize_on == nullptr) { | |||||
| if (!IsSameStreamId(subgraph)) { | |||||
| GELOGI("There are more than one stream in subgraph %s", subgraph->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| if (!IsSameStreamId(subgraph)) { | |||||
| GELOGI("There are more than one stream in subgraph %s", subgraph->GetName().c_str()); | |||||
| continue; | |||||
| } | } | ||||
| OpDescPtr op_desc = nodes.at(0)->GetOpDesc(); | OpDescPtr op_desc = nodes.at(0)->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| @@ -31,6 +31,8 @@ | |||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "graph/common/ge_call_wrapper.h" | #include "graph/common/ge_call_wrapper.h" | ||||
| #include "init/gelib.h" | #include "init/gelib.h" | ||||
| #include "graph/ge_local_context.h" | |||||
| #include "ge/ge_api_types.h" | |||||
| using domi::LogTimeStampDef; | using domi::LogTimeStampDef; | ||||
| using domi::ModelTaskDef; | using domi::ModelTaskDef; | ||||
| @@ -527,7 +529,7 @@ Status TaskGenerator::MarkNodeAndSetIndex(ComputeGraphPtr &graph) { | |||||
| // Reset op kernel lib name | // Reset op kernel lib name | ||||
| if (op_desc->GetOpKernelLibName().empty()) { | if (op_desc->GetOpKernelLibName().empty()) { | ||||
| (void)ge_lib->DNNEngineManagerObj().GetDNNEngineName(op_desc); | |||||
| (void)ge_lib->DNNEngineManagerObj().GetDNNEngineName(node); | |||||
| } | } | ||||
| all_stream_ops[op_desc->GetStreamId()].emplace_back(op_desc); | all_stream_ops[op_desc->GetStreamId()].emplace_back(op_desc); | ||||
| @@ -762,24 +764,26 @@ Status TaskGenerator::FindBpOfEnv(const ComputeGraphPtr &graph, const std::strin | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status TaskGenerator::FindProfilingTaskIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | |||||
| vector<uint32_t> &all_reduce_nodes) const { | |||||
| GELOGI("Start FindProfilingTaskIndex."); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| const char *profiling_mode = std::getenv(kProfilingMode); | |||||
| bool is_profiling = (profiling_mode != nullptr) || ProfilingManager::Instance().ProfilingOn(); | |||||
| if (!is_profiling) { | |||||
| Status TaskGenerator::GetFpBpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | |||||
| vector<uint32_t> &all_reduce_nodes, std::string &fp_point_str, | |||||
| std::string &bp_point_str) const { | |||||
| if (ge::GetContext().GetOption(OPTION_EXEC_PROFILING_FPPONIT_OPTIONS, fp_point_str) == SUCCESS && | |||||
| ge::GetContext().GetOption(OPTION_EXEC_PROFILING_BPPONIT_OPTIONS, bp_point_str) == SUCCESS && | |||||
| !fp_point_str.empty() && !bp_point_str.empty()) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ret = SUCCESS; | |||||
| const char *fp_point = std::getenv(kProfilingFpPoint); | const char *fp_point = std::getenv(kProfilingFpPoint); | ||||
| Status ret; | |||||
| if (fp_point == nullptr) { | if (fp_point == nullptr) { | ||||
| ret = AutoFindFpOpIndex(graph, profiling_point); | ret = AutoFindFpOpIndex(graph, profiling_point); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGW("First forward profiling op_index not set and FindFpOpIndex failed."); | GELOGW("First forward profiling op_index not set and FindFpOpIndex failed."); | ||||
| return SUCCESS; | |||||
| return FAILED; | |||||
| } | } | ||||
| } else { | |||||
| fp_point_str = string(fp_point); | |||||
| GELOGI("Get fp_point_str from env %s", fp_point_str.c_str()); | |||||
| } | } | ||||
| const char *bp_point = std::getenv(kProfilingBpPoint); | const char *bp_point = std::getenv(kProfilingBpPoint); | ||||
| @@ -787,20 +791,47 @@ Status TaskGenerator::FindProfilingTaskIndex(const ComputeGraphPtr &graph, Profi | |||||
| ret = AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes); | ret = AutoFindBpOpIndex(graph, profiling_point, all_reduce_nodes); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGW("Last backward profiling op_index not set and FindBpOpIndex failed."); | GELOGW("Last backward profiling op_index not set and FindBpOpIndex failed."); | ||||
| return SUCCESS; | |||||
| return FAILED; | |||||
| } | } | ||||
| } else { | |||||
| bp_point_str = string(bp_point); | |||||
| GELOGI("Get bp_point_str from env %s", bp_point_str.c_str()); | |||||
| } | } | ||||
| if (fp_point != nullptr) { | |||||
| string fp_point_str = string(fp_point); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status TaskGenerator::FindProfilingTaskIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | |||||
| vector<uint32_t> &all_reduce_nodes) const { | |||||
| GELOGI("Start FindProfilingTaskIndex."); | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| const char *profiling_mode = std::getenv(kProfilingMode); | |||||
| bool is_profiling = (profiling_mode != nullptr) || ProfilingManager::Instance().ProfilingOn(); | |||||
| if (!is_profiling) { | |||||
| GELOGW("Profiling is not open."); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGI("Start get FP/BP index."); | |||||
| std::string fp_point_str; | |||||
| std::string bp_point_str; | |||||
| Status ret = GetFpBpIndex(graph, profiling_point, all_reduce_nodes, fp_point_str, bp_point_str); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGW("Get FP_POINT BP_POINT failed."); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGI("fp_point_str:%s, bp_point_str:%s.", fp_point_str.c_str(), bp_point_str.c_str()); | |||||
| if (!fp_point_str.empty()) { | |||||
| ret = FindFpOfEnv(graph, fp_point_str, profiling_point); | ret = FindFpOfEnv(graph, fp_point_str, profiling_point); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGW("First backward profiling op name set but FindFpOfEnv failed."); | GELOGW("First backward profiling op name set but FindFpOfEnv failed."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } | } | ||||
| if (bp_point != nullptr) { | |||||
| string bp_point_str = string(bp_point); | |||||
| if (!bp_point_str.empty()) { | |||||
| ret = FindBpOfEnv(graph, bp_point_str, profiling_point, all_reduce_nodes); | ret = FindBpOfEnv(graph, bp_point_str, profiling_point, all_reduce_nodes); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGW("Last backward profiling op name set but FindBpOfEnv failed."); | GELOGW("Last backward profiling op name set but FindBpOfEnv failed."); | ||||
| @@ -118,6 +118,9 @@ class TaskGenerator { | |||||
| Status FindBpOfEnv(const ComputeGraphPtr &graph, const std::string &bp_point_str, ProfilingPoint &profiling_point, | Status FindBpOfEnv(const ComputeGraphPtr &graph, const std::string &bp_point_str, ProfilingPoint &profiling_point, | ||||
| vector<uint32_t> &all_reduce_nodes) const; | vector<uint32_t> &all_reduce_nodes) const; | ||||
| Status GetFpBpIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, vector<uint32_t> &all_reduce_nodes, | |||||
| std::string &fp_point_str, std::string &bp_point_str) const; | |||||
| Status FindProfilingTaskIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | Status FindProfilingTaskIndex(const ComputeGraphPtr &graph, ProfilingPoint &profiling_point, | ||||
| std::vector<uint32_t> &all_reduce_nodes) const; | std::vector<uint32_t> &all_reduce_nodes) const; | ||||
| Status InsertProfilingTaskBefore(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point, | Status InsertProfilingTaskBefore(const OpDescPtr &op_desc, const ProfilingPoint &profiling_point, | ||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/common/local_context.h" | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "common/debug/ge_log.h" | |||||
| #include "omg/omg_inner_types.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| thread_local OmgContext *omg_context = nullptr; | |||||
| } | |||||
| void SetLocalOmgContext(OmgContext &context) { omg_context = &context; } | |||||
| OmgContext &GetLocalOmgContext() { | |||||
| if (omg_context != nullptr) { | |||||
| return *omg_context; | |||||
| } else { | |||||
| GELOGW("omg_context is nullptr."); | |||||
| return domi::GetContext(); | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,26 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_COMMON_LOCAL_CONTEXT_H_ | |||||
| #define GE_GRAPH_COMMON_LOCAL_CONTEXT_H_ | |||||
| #include "omg/omg_inner_types.h" | |||||
| namespace ge { | |||||
| void SetLocalOmgContext(OmgContext &context); | |||||
| OmgContext &GetLocalOmgContext(); | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_COMMON_LOCAL_CONTEXT_H_ | |||||
| @@ -121,70 +121,50 @@ Status GraphLoader::GetMaxUsedMemory(uint32_t model_id, uint64_t &max_size) { | |||||
| Status GraphLoader::LoadDataFromFile(const std::string &path, const std::string &key_path, int32_t priority, | Status GraphLoader::LoadDataFromFile(const std::string &path, const std::string &key_path, int32_t priority, | ||||
| ModelData &model_data) { | ModelData &model_data) { | ||||
| Status ret; | Status ret; | ||||
| try { | |||||
| if (!CheckInputPathValid(path)) { | |||||
| GELOGE(GE_EXEC_MODEL_PATH_INVALID, "model path is invalid: %s", path.c_str()); | |||||
| return GE_EXEC_MODEL_PATH_INVALID; | |||||
| } | |||||
| GELOGI("Load model begin, model path is: %s", path.c_str()); | |||||
| if (!key_path.empty() && !CheckInputPathValid(key_path)) { | |||||
| GELOGE(GE_EXEC_MODEL_KEY_PATH_INVALID, "decrypt_key path is invalid: %s", key_path.c_str()); | |||||
| return GE_EXEC_MODEL_KEY_PATH_INVALID; | |||||
| } | |||||
| ret = DavinciModelParser::LoadFromFile(path.c_str(), key_path.c_str(), priority, model_data); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "LoadModelFromFile: Load failed. ret = %u", ret); | |||||
| return ret; | |||||
| } | |||||
| if (!CheckInputPathValid(path)) { | |||||
| GELOGE(GE_EXEC_MODEL_PATH_INVALID, "model path is invalid: %s", path.c_str()); | |||||
| return GE_EXEC_MODEL_PATH_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } catch (std::bad_alloc &) { | |||||
| GELOGE(MEMALLOC_FAILED, "Load model from file failed, bad memory allocation"); | |||||
| ret = MEMALLOC_FAILED; | |||||
| } catch (...) { | |||||
| GELOGE(FAILED, "Load model from file failed with exception"); | |||||
| ret = FAILED; | |||||
| GELOGI("Load model begin, model path is: %s", path.c_str()); | |||||
| if (!key_path.empty() && !CheckInputPathValid(key_path)) { | |||||
| GELOGE(GE_EXEC_MODEL_KEY_PATH_INVALID, "decrypt_key path is invalid: %s", key_path.c_str()); | |||||
| return GE_EXEC_MODEL_KEY_PATH_INVALID; | |||||
| } | } | ||||
| if (model_data.model_data != nullptr) { | |||||
| delete[] static_cast<char *>(model_data.model_data); | |||||
| model_data.model_data = nullptr; | |||||
| ret = DavinciModelParser::LoadFromFile(path.c_str(), key_path.c_str(), priority, model_data); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "LoadModelFromFile: Load failed. ret = %u", ret); | |||||
| if (model_data.model_data != nullptr) { | |||||
| delete[] static_cast<char *>(model_data.model_data); | |||||
| model_data.model_data = nullptr; | |||||
| } | |||||
| return ret; | |||||
| } | } | ||||
| return ret; | |||||
| return SUCCESS; | |||||
| } | } | ||||
| Status GraphLoader::LoadModelFromFile(const std::string &path, const std::string &key_path, int32_t priority, | Status GraphLoader::LoadModelFromFile(const std::string &path, const std::string &key_path, int32_t priority, | ||||
| const std::shared_ptr<ModelListener> &listener, uint32_t &model_id) { | const std::shared_ptr<ModelListener> &listener, uint32_t &model_id) { | ||||
| Status ret; | Status ret; | ||||
| ModelData model_data; | ModelData model_data; | ||||
| try { | |||||
| ret = LoadDataFromFile(path, key_path, priority, model_data); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "LoadModelFromFile: Load failed. ret = %u", ret); | |||||
| if (model_data.model_data != nullptr) { | |||||
| delete[] static_cast<char *>(model_data.model_data); | |||||
| model_data.model_data = nullptr; | |||||
| } | |||||
| return ret; | |||||
| ret = LoadDataFromFile(path, key_path, priority, model_data); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "LoadModelFromFile: Load failed. ret = %u", ret); | |||||
| if (model_data.model_data != nullptr) { | |||||
| delete[] static_cast<char *>(model_data.model_data); | |||||
| model_data.model_data = nullptr; | |||||
| } | } | ||||
| return ret; | |||||
| } | |||||
| ret = LoadModel(model_data, listener, model_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "LoadModel: Load failed. ret = %u", ret); | |||||
| if (model_data.model_data != nullptr) { | |||||
| delete[] static_cast<char *>(model_data.model_data); | |||||
| model_data.model_data = nullptr; | |||||
| } | |||||
| ret = LoadModel(model_data, listener, model_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "LoadModel: Load failed. ret = %u", ret); | |||||
| if (model_data.model_data != nullptr) { | |||||
| delete[] static_cast<char *>(model_data.model_data); | |||||
| model_data.model_data = nullptr; | |||||
| } | } | ||||
| } catch (std::bad_alloc &) { | |||||
| GELOGE(MEMALLOC_FAILED, "Load model from file failed, bad memory allocation"); | |||||
| ret = MEMALLOC_FAILED; | |||||
| } catch (...) { | |||||
| GELOGE(FAILED, "Load model from file failed with exception"); | |||||
| ret = FAILED; | |||||
| } | } | ||||
| if (model_data.model_data != nullptr) { | if (model_data.model_data != nullptr) { | ||||
| @@ -197,36 +177,27 @@ Status GraphLoader::LoadModelFromFile(const std::string &path, const std::string | |||||
| Status GraphLoader::LoadModel(const ModelData &model_data, const std::shared_ptr<ModelListener> &listener, | Status GraphLoader::LoadModel(const ModelData &model_data, const std::shared_ptr<ModelListener> &listener, | ||||
| uint32_t &model_id) { | uint32_t &model_id) { | ||||
| try { | |||||
| GELOGI("Load model begin, model_id:%u.", model_id); | |||||
| GELOGI("Load model begin, model_id:%u.", model_id); | |||||
| // For GeOp, Open Device 0 here. | |||||
| GE_CHK_RT_RET(rtSetDevice(0)); | |||||
| auto model_manager = ModelManager::GetInstance(); | |||||
| GE_CHECK_NOTNULL(model_manager); | |||||
| Status ret = model_manager->LoadModelOffline(model_id, model_data, listener); | |||||
| if (ret != SUCCESS) { | |||||
| GE_CHK_RT(rtDeviceReset(0)); | |||||
| GELOGE(ret, "LoadModel: Load failed."); | |||||
| return ret; | |||||
| } | |||||
| ret = model_manager->Start(model_id); | |||||
| if (ret != SUCCESS) { | |||||
| if (model_manager->Unload(model_id) != SUCCESS) { | |||||
| GELOGE(FAILED, "LoadModel: Unload failed while trying to unload after a failed start."); | |||||
| } | |||||
| GELOGE(ret, "LoadModel: Start failed."); | |||||
| return ret; | |||||
| // For GeOp, Open Device 0 here. | |||||
| GE_CHK_RT_RET(rtSetDevice(0)); | |||||
| auto model_manager = ModelManager::GetInstance(); | |||||
| GE_CHECK_NOTNULL(model_manager); | |||||
| Status ret = model_manager->LoadModelOffline(model_id, model_data, listener); | |||||
| if (ret != SUCCESS) { | |||||
| GE_CHK_RT(rtDeviceReset(0)); | |||||
| GELOGE(ret, "LoadModel: Load failed."); | |||||
| return ret; | |||||
| } | |||||
| ret = model_manager->Start(model_id); | |||||
| if (ret != SUCCESS) { | |||||
| if (model_manager->Unload(model_id) != SUCCESS) { | |||||
| GELOGE(FAILED, "LoadModel: Unload failed while trying to unload after a failed start."); | |||||
| } | } | ||||
| GELOGI("LoadModel: Start model success, model_id:%u.", model_id); | |||||
| } catch (std::bad_alloc &) { | |||||
| GELOGE(MEMALLOC_FAILED, "Load model failed, bad memory allocation occur !"); | |||||
| return MEMALLOC_FAILED; | |||||
| } catch (...) { | |||||
| GELOGE(FAILED, "Load model failed, some exceptions occur !"); | |||||
| return FAILED; | |||||
| GELOGE(ret, "LoadModel: Start failed."); | |||||
| return ret; | |||||
| } | } | ||||
| GELOGI("LoadModel: Start model success, model_id:%u.", model_id); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -255,28 +226,16 @@ Status GraphLoader::CommandHandle(const Command &command) { | |||||
| Status GraphLoader::LoadModelFromData(uint32_t &model_id, const ModelData &model_data, void *dev_ptr, size_t memsize, | Status GraphLoader::LoadModelFromData(uint32_t &model_id, const ModelData &model_data, void *dev_ptr, size_t memsize, | ||||
| void *weight_ptr, size_t weightsize) { | void *weight_ptr, size_t weightsize) { | ||||
| try { | |||||
| GELOGI("Load model begin, model_id:%u.", model_id); | |||||
| // For ACL, Open Device from App. | |||||
| auto model_manager = ModelManager::GetInstance(); | |||||
| GE_CHECK_NOTNULL(model_manager); | |||||
| Status ret = | |||||
| model_manager->LoadModelOffline(model_id, model_data, nullptr, dev_ptr, memsize, weight_ptr, weightsize); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Load model failed, model_id:%u.", model_id); | |||||
| return ret; | |||||
| } | |||||
| GELOGI("Load model success, model_id:%u.", model_id); | |||||
| } catch (std::bad_alloc &) { | |||||
| GELOGE(MEMALLOC_FAILED, "Load model failed, bad memory allocation occur !"); | |||||
| return MEMALLOC_FAILED; | |||||
| } catch (...) { | |||||
| GELOGE(FAILED, "Load model failed, some exceptions occur !"); | |||||
| return FAILED; | |||||
| GELOGI("Load model begin, model_id:%u.", model_id); | |||||
| // For ACL, Open Device from App. | |||||
| auto model_manager = ModelManager::GetInstance(); | |||||
| GE_CHECK_NOTNULL(model_manager); | |||||
| Status ret = model_manager->LoadModelOffline(model_id, model_data, nullptr, dev_ptr, memsize, weight_ptr, weightsize); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Load model failed, model_id:%u.", model_id); | |||||
| return ret; | |||||
| } | } | ||||
| GELOGI("Load model success, model_id:%u.", model_id); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -16,21 +16,28 @@ | |||||
| #include "graph/load/new_model_manager/data_dumper.h" | #include "graph/load/new_model_manager/data_dumper.h" | ||||
| #include <sys/time.h> | |||||
| #include <cstdlib> | |||||
| #include <ctime> | #include <ctime> | ||||
| #include <map> | #include <map> | ||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "common/debug/memory_dumper.h" | |||||
| #include "common/properties_manager.h" | #include "common/properties_manager.h" | ||||
| #include "common/util.h" | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/util.h" | #include "framework/common/util.h" | ||||
| #include "graph/anchor.h" | #include "graph/anchor.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/load/new_model_manager/model_utils.h" | #include "graph/load/new_model_manager/model_utils.h" | ||||
| #include "graph/manager/util/debug.h" | |||||
| #include "graph/utils/attr_utils.h" | #include "graph/utils/attr_utils.h" | ||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "proto/dump_task.pb.h" | |||||
| #include "proto/ge_ir.pb.h" | #include "proto/ge_ir.pb.h" | ||||
| #include "proto/op_mapping_info.pb.h" | #include "proto/op_mapping_info.pb.h" | ||||
| #include "runtime/base.h" | |||||
| #include "runtime/mem.h" | #include "runtime/mem.h" | ||||
| namespace { | namespace { | ||||
| @@ -66,6 +73,16 @@ static bool ParseNameIndex(const std::string &node_name_index, std::string &node | |||||
| static bool IsTensorDescWithSkipDumpAddrType(bool has_mem_type_attr, vector<int64_t> v_memory_type, size_t i) { | static bool IsTensorDescWithSkipDumpAddrType(bool has_mem_type_attr, vector<int64_t> v_memory_type, size_t i) { | ||||
| return has_mem_type_attr && (v_memory_type[i] == RT_MEMORY_L1); | return has_mem_type_attr && (v_memory_type[i] == RT_MEMORY_L1); | ||||
| } | } | ||||
| static uint64_t GetNowTime() { | |||||
| uint64_t ret = 0; | |||||
| struct timeval tv; | |||||
| if (gettimeofday(&tv, NULL) == 0) { | |||||
| ret = tv.tv_sec * 1000000ULL + tv.tv_usec; | |||||
| } | |||||
| return ret; | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| static int32_t GetIrDataType(ge::DataType data_type) { | static int32_t GetIrDataType(ge::DataType data_type) { | ||||
| @@ -176,6 +193,7 @@ void DataDumper::SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr | |||||
| GELOGD("Start SaveDumpOpInfo of task_id: %u, stream_id: %u", task_id, stream_id); | GELOGD("Start SaveDumpOpInfo of task_id: %u, stream_id: %u", task_id, stream_id); | ||||
| OpDescInfo op_desc_info; | OpDescInfo op_desc_info; | ||||
| op_desc_info.op_name = op->GetName(); | op_desc_info.op_name = op->GetName(); | ||||
| op_desc_info.op_type = op->GetType(); | |||||
| op_desc_info.task_id = task_id; | op_desc_info.task_id = task_id; | ||||
| op_desc_info.stream_id = stream_id; | op_desc_info.stream_id = stream_id; | ||||
| for (size_t i = 0; i < op->GetInputsSize(); ++i) { | for (size_t i = 0; i < op->GetInputsSize(); ++i) { | ||||
| @@ -183,12 +201,28 @@ void DataDumper::SaveDumpOpInfo(const RuntimeParam &model_param, const OpDescPtr | |||||
| op_desc_info.input_format.emplace_back(input_desc.GetFormat()); | op_desc_info.input_format.emplace_back(input_desc.GetFormat()); | ||||
| op_desc_info.input_shape.emplace_back(input_desc.GetShape().GetDims()); | op_desc_info.input_shape.emplace_back(input_desc.GetShape().GetDims()); | ||||
| op_desc_info.input_data_type.emplace_back(input_desc.GetDataType()); | op_desc_info.input_data_type.emplace_back(input_desc.GetDataType()); | ||||
| int64_t input_size = 0; | |||||
| auto tensor_descs = op->GetAllInputsDesc(); | |||||
| if (TensorUtils::GetTensorSizeInBytes(tensor_descs.at(i), input_size) != SUCCESS) { | |||||
| GELOGW("Get input size failed"); | |||||
| return; | |||||
| } | |||||
| GELOGI("Save dump op info, the input size is %ld", input_size); | |||||
| op_desc_info.input_size.emplace_back(input_size); | |||||
| } | } | ||||
| for (size_t j = 0; j < op->GetOutputsSize(); ++j) { | for (size_t j = 0; j < op->GetOutputsSize(); ++j) { | ||||
| GeTensorDesc output_desc = op->GetOutputDesc(j); | GeTensorDesc output_desc = op->GetOutputDesc(j); | ||||
| op_desc_info.output_format.emplace_back(output_desc.GetFormat()); | op_desc_info.output_format.emplace_back(output_desc.GetFormat()); | ||||
| op_desc_info.output_shape.emplace_back(output_desc.GetShape().GetDims()); | op_desc_info.output_shape.emplace_back(output_desc.GetShape().GetDims()); | ||||
| op_desc_info.output_data_type.emplace_back(output_desc.GetDataType()); | op_desc_info.output_data_type.emplace_back(output_desc.GetDataType()); | ||||
| int64_t output_size = 0; | |||||
| auto tensor_descs = op->GetAllOutputsDesc(); | |||||
| if (TensorUtils::GetTensorSizeInBytes(tensor_descs.at(j), output_size) != SUCCESS) { | |||||
| GELOGW("Get input size failed"); | |||||
| return; | |||||
| } | |||||
| GELOGI("Save dump op info, the output size is %ld", output_size); | |||||
| op_desc_info.output_size.emplace_back(output_size); | |||||
| } | } | ||||
| op_desc_info.input_addrs = ModelUtils::GetInputDataAddrs(model_param, op); | op_desc_info.input_addrs = ModelUtils::GetInputDataAddrs(model_param, op); | ||||
| op_desc_info.output_addrs = ModelUtils::GetOutputDataAddrs(model_param, op); | op_desc_info.output_addrs = ModelUtils::GetOutputDataAddrs(model_param, op); | ||||
| @@ -810,4 +844,90 @@ void DataDumper::PrintCheckLog(string &dump_list_key) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| Status DataDumper::DumpExceptionInput(const OpDescInfo &op_desc_info, const string &dump_file) { | |||||
| GELOGI("Start to dump exception input"); | |||||
| for (size_t i = 0; i < op_desc_info.input_addrs.size(); i++) { | |||||
| if (Debug::DumpDevMem(dump_file.data(), op_desc_info.input_addrs.at(i), op_desc_info.input_size.at(i)) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "Dump the %zu input data failed", i); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DataDumper::DumpExceptionOutput(const OpDescInfo &op_desc_info, const string &dump_file) { | |||||
| GELOGI("Start to dump exception output"); | |||||
| for (size_t i = 0; i < op_desc_info.output_addrs.size(); i++) { | |||||
| if (Debug::DumpDevMem(dump_file.data(), op_desc_info.output_addrs.at(i), op_desc_info.output_size.at(i)) != | |||||
| SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "Dump the %zu input data failed", i); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DataDumper::DumpExceptionInfo(const std::vector<rtExceptionInfo> exception_infos) { | |||||
| GELOGI("Start to dump exception info"); | |||||
| for (const rtExceptionInfo &iter : exception_infos) { | |||||
| OpDescInfo op_desc_info; | |||||
| if (GetOpDescInfo(iter.streamid, iter.taskid, op_desc_info)) { | |||||
| toolkit::dumpdata::DumpData dump_data; | |||||
| dump_data.set_version("2.0"); | |||||
| dump_data.set_dump_time(GetNowTime()); | |||||
| for (size_t i = 0; i < op_desc_info.input_format.size(); ++i) { | |||||
| toolkit::dumpdata::OpInput input; | |||||
| input.set_data_type(toolkit::dumpdata::OutputDataType(GetIrDataType(op_desc_info.input_data_type[i]))); | |||||
| input.set_format(toolkit::dumpdata::OutputFormat(op_desc_info.input_format[i])); | |||||
| for (auto dim : op_desc_info.input_shape[i]) { | |||||
| input.mutable_shape()->add_dim(dim); | |||||
| } | |||||
| input.set_size(op_desc_info.input_size[i]); | |||||
| GELOGI("The input size int exception is %ld", op_desc_info.input_size[i]); | |||||
| dump_data.mutable_input()->Add(std::move(input)); | |||||
| } | |||||
| for (size_t j = 0; j < op_desc_info.output_format.size(); ++j) { | |||||
| toolkit::dumpdata::OpOutput output; | |||||
| output.set_data_type(toolkit::dumpdata::OutputDataType(GetIrDataType(op_desc_info.output_data_type[j]))); | |||||
| output.set_format(toolkit::dumpdata::OutputFormat(op_desc_info.output_format[j])); | |||||
| for (auto dim : op_desc_info.output_shape[j]) { | |||||
| output.mutable_shape()->add_dim(dim); | |||||
| } | |||||
| output.set_size(op_desc_info.output_size[j]); | |||||
| GELOGI("The output size int exception is %ld", op_desc_info.output_size[j]); | |||||
| dump_data.mutable_output()->Add(std::move(output)); | |||||
| } | |||||
| uint64_t now_time = GetNowTime(); | |||||
| string dump_file_path = "./" + op_desc_info.op_type + "." + op_desc_info.op_name + "." + | |||||
| to_string(op_desc_info.task_id) + "." + to_string(now_time); | |||||
| uint64_t proto_size = dump_data.ByteSizeLong(); | |||||
| unique_ptr<char[]> proto_msg(new (std::nothrow) char[proto_size]); | |||||
| bool ret = dump_data.SerializeToArray(proto_msg.get(), proto_size); | |||||
| if (!ret || proto_size == 0) { | |||||
| GELOGE(PARAM_INVALID, "Dump data proto serialize failed"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| GE_CHK_STATUS_RET(MemoryDumper::DumpToFile(dump_file_path.c_str(), &proto_size, sizeof(uint64_t)), | |||||
| "Failed to dump proto size"); | |||||
| GE_CHK_STATUS_RET(MemoryDumper::DumpToFile(dump_file_path.c_str(), proto_msg.get(), proto_size), | |||||
| "Failed to dump proto msg"); | |||||
| if (DumpExceptionInput(op_desc_info, dump_file_path) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "Dump exception input failed"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| if (DumpExceptionOutput(op_desc_info, dump_file_path) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "Dump exception output failed"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| GELOGI("Dump exception info SUCCESS"); | |||||
| } else { | |||||
| GELOGE(PARAM_INVALID, "Get op desc info failed,task id:%u,stream id:%u", iter.taskid, iter.streamid); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -31,6 +31,7 @@ | |||||
| #include "runtime/mem.h" | #include "runtime/mem.h" | ||||
| #include "task_info/task_info.h" | #include "task_info/task_info.h" | ||||
| #include "framework/common/ge_types.h" | #include "framework/common/ge_types.h" | ||||
| #include "runtime/base.h" | |||||
| namespace ge { | namespace ge { | ||||
| class DataDumper { | class DataDumper { | ||||
| @@ -88,6 +89,11 @@ class DataDumper { | |||||
| const DumpProperties &GetDumpProperties() const { return dump_properties_; } | const DumpProperties &GetDumpProperties() const { return dump_properties_; } | ||||
| bool GetOpDescInfo(uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info) const; | bool GetOpDescInfo(uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info) const; | ||||
| // Dump exception info | |||||
| Status DumpExceptionInput(const OpDescInfo &op_desc_info, const string &dump_file); | |||||
| Status DumpExceptionOutput(const OpDescInfo &op_desc_info, const string &dump_file); | |||||
| Status DumpExceptionInfo(const std::vector<rtExceptionInfo> exception_infos); | |||||
| private: | private: | ||||
| void ReleaseDevMem(void **ptr) noexcept; | void ReleaseDevMem(void **ptr) noexcept; | ||||
| @@ -43,6 +43,7 @@ | |||||
| #include "graph/graph.h" | #include "graph/graph.h" | ||||
| #include "graph/load/new_model_manager/cpu_queue_schedule.h" | #include "graph/load/new_model_manager/cpu_queue_schedule.h" | ||||
| #include "graph/load/new_model_manager/tbe_handle_store.h" | #include "graph/load/new_model_manager/tbe_handle_store.h" | ||||
| #include "graph/load/new_model_manager/model_manager.h" | |||||
| #include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| #include "graph/manager/trans_var_data_utils.h" | #include "graph/manager/trans_var_data_utils.h" | ||||
| @@ -253,13 +254,7 @@ Status DavinciModel::Assign(const GeModelPtr &ge_model) { | |||||
| /// | /// | ||||
| void DavinciModel::Shrink() { | void DavinciModel::Shrink() { | ||||
| ge_model_.reset(); // delete object. | ge_model_.reset(); // delete object. | ||||
| // Old dump need op list, clear when closed. | |||||
| char *ge_dump_env = std::getenv("DUMP_OP"); | |||||
| int dump_op_switch = (ge_dump_env != nullptr) ? std::strtol(ge_dump_env, nullptr, kDecimal) : 0; | |||||
| if (dump_op_switch == 0) { | |||||
| op_list_.clear(); | |||||
| } | |||||
| op_list_.clear(); | |||||
| } | } | ||||
| Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { | Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_ptr, size_t weight_size) { | ||||
| @@ -295,8 +290,8 @@ Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_p | |||||
| GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc feature map memory failed. size: %zu", data_size); | GELOGE(GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED, "Alloc feature map memory failed. size: %zu", data_size); | ||||
| return GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED; | return GE_EXEC_ALLOC_FEATURE_MAP_MEM_FAILED; | ||||
| } | } | ||||
| GELOGI("[IMAS]InitModelMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, | |||||
| mem_base_, data_size); | |||||
| GEEVENT("[IMAS]InitModelMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, | |||||
| mem_base_, data_size); | |||||
| weights_mem_base_ = mem_base_; | weights_mem_base_ = mem_base_; | ||||
| @@ -337,8 +332,8 @@ Status DavinciModel::InitVariableMem() { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM); | var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM); | ||||
| GELOGI("[IMAS]InitVariableMem graph_%u MallocMemory type[V] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, | |||||
| var_mem_base_, TotalVarMemSize()); | |||||
| GEEVENT("[IMAS]InitVariableMem graph_%u MallocMemory type[V] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, | |||||
| var_mem_base_, TotalVarMemSize()); | |||||
| } | } | ||||
| runtime_param_.var_base = var_mem_base_; | runtime_param_.var_base = var_mem_base_; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -774,6 +769,7 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { | |||||
| map<uint32_t, OpDescPtr> data_by_index; | map<uint32_t, OpDescPtr> data_by_index; | ||||
| auto nodes = compute_graph->GetAllNodes(); | auto nodes = compute_graph->GetAllNodes(); | ||||
| const TBEKernelStore &tbekernel_store = ge_model_->GetTBEKernelStore(); | const TBEKernelStore &tbekernel_store = ge_model_->GetTBEKernelStore(); | ||||
| const CustAICPUKernelStore &aicpu_kernel_store = ge_model_->GetCustAICPUKernelStore(); | |||||
| for (size_t i = 0; i < nodes.size(); i++) { | for (size_t i = 0; i < nodes.size(); i++) { | ||||
| auto node = nodes.at(i); | auto node = nodes.at(i); | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| @@ -786,6 +782,7 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { | |||||
| GE_TIMESTAMP_RESTART(LoadTBEKernelBinToOpDesc); | GE_TIMESTAMP_RESTART(LoadTBEKernelBinToOpDesc); | ||||
| tbekernel_store.LoadTBEKernelBinToOpDesc(op_desc); | tbekernel_store.LoadTBEKernelBinToOpDesc(op_desc); | ||||
| aicpu_kernel_store.LoadCustAICPUKernelBinToOpDesc(op_desc); | |||||
| GE_TIMESTAMP_ADD(LoadTBEKernelBinToOpDesc); | GE_TIMESTAMP_ADD(LoadTBEKernelBinToOpDesc); | ||||
| if (IsDataOp(op_desc->GetType())) { | if (IsDataOp(op_desc->GetType())) { | ||||
| @@ -1076,30 +1073,42 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief output zero copy node Initialize. | /// @brief output zero copy node Initialize. | ||||
| /// @param [in] NodePtr: netoutput Op or merge op. | |||||
| /// @param [in] NodePtr: netoutput Op. | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status DavinciModel::InitOutputZeroCopyNodes(const NodePtr &node) { | Status DavinciModel::InitOutputZeroCopyNodes(const NodePtr &node) { | ||||
| set<NodePtr> nodes_need_record; | |||||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
| auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
| if (peer_out_data_anchor == nullptr) { | if (peer_out_data_anchor == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto node = peer_out_data_anchor->GetOwnerNode(); | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(FAILED, "Op desc is nullptr"); | |||||
| return FAILED; | |||||
| } | |||||
| auto peer_node = peer_out_data_anchor->GetOwnerNode(); | |||||
| nodes_need_record.emplace(peer_node); | |||||
| // Merge node output multiplexed input, upstream nodes need to be considered in multiple batch scenarios | // Merge node output multiplexed input, upstream nodes need to be considered in multiple batch scenarios | ||||
| if (node->GetType() == MERGE) { | |||||
| if (InitOutputZeroCopyNodes(node) != SUCCESS) { | |||||
| GELOGE(PARAM_INVALID, "Output merge zero copy nodes init failed!"); | |||||
| return PARAM_INVALID; | |||||
| if (peer_node->GetType() == MERGE) { | |||||
| for (const auto &merge_peer_in_data_anchor : peer_node->GetAllInDataAnchors()) { | |||||
| auto merge_peer_out_data_anchor = merge_peer_in_data_anchor->GetPeerOutAnchor(); | |||||
| if (merge_peer_out_data_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto merge_peer_node = merge_peer_out_data_anchor->GetOwnerNode(); | |||||
| nodes_need_record.emplace(merge_peer_node); | |||||
| } | |||||
| } else { | |||||
| for (const auto &other_in_data_anchor : peer_out_data_anchor->GetPeerInDataAnchors()) { | |||||
| auto other_in_node = other_in_data_anchor->GetOwnerNode(); | |||||
| if (other_in_node->GetType() != NETOUTPUT) { | |||||
| nodes_need_record.emplace(other_in_node); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | |||||
| for (const auto &node_need_record : nodes_need_record) { | |||||
| auto op_desc = node_need_record->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| string batch_label; | string batch_label; | ||||
| (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | ||||
| if (batch_label.empty()) { | if (batch_label.empty()) { | ||||
| @@ -2152,7 +2161,6 @@ void DavinciModel::SetProfileTime(ModelProcStage stage, int64_t endTime) { | |||||
| Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data, rtMemcpyKind_t kind) { | Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data, rtMemcpyKind_t kind) { | ||||
| if (output_op_list_.empty()) { | if (output_op_list_.empty()) { | ||||
| Status ret = SyncVarData(); | Status ret = SyncVarData(); | ||||
| DumpOpInputOutput(); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -2198,8 +2206,6 @@ Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data, r | |||||
| runtime_param_.graph_id, output.first, output.second.GetBasicAddr(), data_size, buffer_length); | runtime_param_.graph_id, output.first, output.second.GetBasicAddr(), data_size, buffer_length); | ||||
| GE_CHK_RT_RET(rtMemcpy(buffer_addr, buffer_length, output.second.GetBasicAddr(), data_size, kind)); | GE_CHK_RT_RET(rtMemcpy(buffer_addr, buffer_length, output.second.GetBasicAddr(), data_size, kind)); | ||||
| } | } | ||||
| DumpOpInputOutput(); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -2264,6 +2270,14 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b | |||||
| // return result is not required | // return result is not required | ||||
| if (!rslt_flg && !seq_end_flag) { | if (!rslt_flg && !seq_end_flag) { | ||||
| GELOGW("Compute failed, model id: %u", model_id_); | GELOGW("Compute failed, model id: %u", model_id_); | ||||
| auto model_manager = ModelManager::GetInstance(); | |||||
| GE_CHECK_NOTNULL(model_manager); | |||||
| auto exception_infos = model_manager->GetExceptionInfos(); | |||||
| if (exception_infos.size() > 0) { | |||||
| GE_CHK_STATUS_RET(data_dumper_.DumpExceptionInfo(exception_infos), "Dump exception info failed"); | |||||
| } else { | |||||
| GELOGI("Exception info is null"); | |||||
| } | |||||
| GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "OnComputeDone failed."); | GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "OnComputeDone failed."); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| @@ -2302,7 +2316,6 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b | |||||
| GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, SUCCESS, outputs), "OnComputeDone failed"); | GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, SUCCESS, outputs), "OnComputeDone failed"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief return not output to upper layer for cloud case | /// @brief return not output to upper layer for cloud case | ||||
| @@ -2318,114 +2331,12 @@ Status DavinciModel::ReturnNoOutput(uint32_t data_id) { | |||||
| op_desc->GetName().c_str()); | op_desc->GetName().c_str()); | ||||
| } | } | ||||
| DumpOpInputOutput(); | |||||
| GE_CHK_BOOL_EXEC(listener_ != nullptr, return PARAM_INVALID, "listener_ is null!"); | GE_CHK_BOOL_EXEC(listener_ != nullptr, return PARAM_INVALID, "listener_ is null!"); | ||||
| std::vector<ge::OutputTensorInfo> outputs; | std::vector<ge::OutputTensorInfo> outputs; | ||||
| GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, SUCCESS, outputs), "OnComputeDone failed."); | GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, SUCCESS, outputs), "OnComputeDone failed."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief dump all op input and output information | |||||
| /// @return void | |||||
| /// | |||||
| void DavinciModel::DumpOpInputOutput() { | |||||
| char *ge_dump_env = std::getenv("DUMP_OP"); | |||||
| int dump_op_switch = (ge_dump_env != nullptr) ? std::strtol(ge_dump_env, nullptr, kDecimal) : 0; | |||||
| if (dump_op_switch == 0) { | |||||
| GELOGI("need to set DUMP_OP for dump op input and output"); | |||||
| return; | |||||
| } | |||||
| if (op_list_.empty()) { | |||||
| GELOGW("op list is empty"); | |||||
| return; | |||||
| } | |||||
| int64_t cnt = 1; | |||||
| for (auto it : op_list_) { | |||||
| if (maxDumpOpNum_ != 0 && cnt > maxDumpOpNum_) { | |||||
| GELOGW("dump op cnt > maxDumpOpNum, maxDumpOpNum: %ld", maxDumpOpNum_); | |||||
| return; | |||||
| } | |||||
| cnt++; | |||||
| if (DumpSingleOpInputOutput(it.second) != SUCCESS) { | |||||
| GELOGW("dump single op failed, model_id: %u", model_id_); | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief dump single op input and output information | |||||
| /// @param [in] op_def: the op_desc which will be dump | |||||
| /// @return Status result | |||||
| /// | |||||
| Status DavinciModel::DumpSingleOpInputOutput(const OpDescPtr &op_def) { | |||||
| GE_CHK_BOOL_EXEC(nullptr != op_def, return PARAM_INVALID, "op_def is null!"); | |||||
| string op_name = ge::StringUtils::ReplaceAll(op_def->GetName(), "/", "-"); | |||||
| GELOGI("dump op name:%s, type:%s, model_id: %u.", op_def->GetName().c_str(), op_def->GetType().c_str(), model_id_); | |||||
| string model_path = "./dump" + to_string(model_id_); | |||||
| if (mmAccess(model_path.c_str()) != EN_OK) { | |||||
| int32_t ret = mmMkdir(model_path.c_str(), S_IRUSR | S_IWUSR | S_IXUSR); | |||||
| if (ret != EN_OK) { | |||||
| GELOGE(FAILED, "make dir failed, model_id: %u", model_id_); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| const vector<int64_t> input_size_vec = ModelUtils::GetInputSize(op_def); | |||||
| const vector<void *> input_addr_vec = ModelUtils::GetInputDataAddrs(runtime_param_, op_def); | |||||
| vector<int64_t> v_memory_type; | |||||
| bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_def, ATTR_NAME_INPUT_MEM_TYPE_LIST, v_memory_type); | |||||
| GELOGD("DumpSingleOp[%s], input size[%zu], input memory type size[%zu]", op_def->GetName().c_str(), | |||||
| op_def->GetInputsSize(), v_memory_type.size()); | |||||
| for (size_t i = 0; i < input_addr_vec.size(); i++) { | |||||
| if (has_mem_type_attr && v_memory_type[i] == RT_MEMORY_L1) { | |||||
| continue; | |||||
| } | |||||
| int64_t input_size = input_size_vec.at(i); | |||||
| char input_file_name[PATH_MAX] = {0}; | |||||
| if ((sprintf_s(input_file_name, PATH_MAX, "%s/dump_%u_%s_%s_input_%zu.bin", model_path.c_str(), model_id_, | |||||
| op_def->GetType().c_str(), op_name.c_str(), i)) == -1) { | |||||
| GELOGE(FAILED, "construct input dump file path failed."); | |||||
| return FAILED; | |||||
| } | |||||
| if ((Debug::DumpDevMem(input_file_name, input_addr_vec.at(i), input_size)) != SUCCESS) { | |||||
| GELOGE(FAILED, "dump to input_file failed"); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| const vector<int64_t> output_size_vec = ModelUtils::GetOutputSize(op_def); | |||||
| const vector<void *> output_addr_vec = ModelUtils::GetOutputDataAddrs(runtime_param_, op_def); | |||||
| v_memory_type.clear(); | |||||
| has_mem_type_attr = ge::AttrUtils::GetListInt(op_def, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, v_memory_type); | |||||
| GELOGD("DumpSingleOp[%s], output size[%zu], output memory type size[%zu]", op_def->GetName().c_str(), | |||||
| op_def->GetOutputsSize(), v_memory_type.size()); | |||||
| if (!(op_def->GetType() == "Const")) { | |||||
| for (size_t i = 0; i < output_addr_vec.size(); i++) { | |||||
| if (has_mem_type_attr && v_memory_type[i] == RT_MEMORY_L1) { | |||||
| continue; | |||||
| } | |||||
| int64_t output_size = output_size_vec.at(i); | |||||
| char output_file_name[PATH_MAX] = {0}; | |||||
| if ((sprintf_s(output_file_name, PATH_MAX, "%s/dump_%u_%s_%s_output_%zu.bin", model_path.c_str(), model_id_, | |||||
| op_def->GetType().c_str(), op_name.c_str(), i)) == -1) { | |||||
| GELOGE(FAILED, "construct output dump file path failed."); | |||||
| return FAILED; | |||||
| } | |||||
| if ((Debug::DumpDevMem(output_file_name, output_addr_vec.at(i), output_size)) != SUCCESS) { | |||||
| GELOGE(FAILED, "dump to output_file failed"); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| void *DavinciModel::Run(DavinciModel *model) { | void *DavinciModel::Run(DavinciModel *model) { | ||||
| GE_CHK_BOOL_EXEC(model != nullptr, | GE_CHK_BOOL_EXEC(model != nullptr, | ||||
| CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); | CsaInteract::GetInstance().WriteErrorCode(FAILED, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); | ||||
| @@ -3127,8 +3038,8 @@ Status DavinciModel::UpdateIoTaskArgs(const std::map<uint32_t, ZeroCopyOffset> & | |||||
| void *addr = data.second.GetDataInfo().at(count).second; | void *addr = data.second.GetDataInfo().at(count).second; | ||||
| void *buffer_addr = | void *buffer_addr = | ||||
| reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(buffer.data) + data.second.GetRelativeOffset().at(count)); | reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(buffer.data) + data.second.GetRelativeOffset().at(count)); | ||||
| GELOGI("[ZCPY] Copy blobs_index %u, virtual_addr: %p, size: %ld, user_data_addr: %p", data.first, addr, size, | |||||
| buffer_addr); | |||||
| GELOGI("[ZCPY] Copy %s blobs_index %u, virtual_addr: %p, size: %ld, user_data_addr: %p", input_or_output.c_str(), | |||||
| data.first, addr, size, buffer_addr); | |||||
| // For input data, just copy for rts task. | // For input data, just copy for rts task. | ||||
| for (ZeroCopyTask &task : zero_copy_tasks_) { | for (ZeroCopyTask &task : zero_copy_tasks_) { | ||||
| uintptr_t addr_val = reinterpret_cast<uintptr_t>(addr); | uintptr_t addr_val = reinterpret_cast<uintptr_t>(addr); | ||||
| @@ -3486,7 +3397,6 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa | |||||
| is_async_mode_ = async_mode; | is_async_mode_ = async_mode; | ||||
| GELOGI("Model Run begin, model id:%u, data index:%u, flag:%d.", model_id_, input_data.index, is_async_mode_); | GELOGI("Model Run begin, model id:%u, data index:%u, flag:%d.", model_id_, input_data.index, is_async_mode_); | ||||
| GE_CHK_STATUS_RET(InitModelStream(stream), "Init model stream failed."); | GE_CHK_STATUS_RET(InitModelStream(stream), "Init model stream failed."); | ||||
| is_dynamic_ = input_data.is_dynamic_batch; | is_dynamic_ = input_data.is_dynamic_batch; | ||||
| if (!is_dynamic_) { | if (!is_dynamic_) { | ||||
| zero_copy_batch_label_addrs_.clear(); | zero_copy_batch_label_addrs_.clear(); | ||||
| @@ -345,21 +345,6 @@ class DavinciModel { | |||||
| Status ReturnNoOutput(uint32_t data_id); | Status ReturnNoOutput(uint32_t data_id); | ||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief dump all op input and output information | |||||
| /// @return void | |||||
| /// | |||||
| void DumpOpInputOutput(); | |||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief dump single op input and output information | |||||
| /// @param [in] dump_op model_id | |||||
| /// @return Status | |||||
| /// | |||||
| Status DumpSingleOpInputOutput(const OpDescPtr &dump_op); | |||||
| Status ModelRunStart(); | Status ModelRunStart(); | ||||
| /// | /// | ||||
| @@ -18,9 +18,9 @@ | |||||
| #include <string> | #include <string> | ||||
| #include "common/dump/dump_manager.h" | |||||
| #include "common/l2_cache_optimize.h" | #include "common/l2_cache_optimize.h" | ||||
| #include "common/profiling/profiling_manager.h" | #include "common/profiling/profiling_manager.h" | ||||
| #include "common/dump/dump_manager.h" | |||||
| #include "common/properties_manager.h" | #include "common/properties_manager.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/util.h" | #include "framework/common/util.h" | ||||
| @@ -38,6 +38,7 @@ const int kDumpCmdPairSize = 2; | |||||
| } // namespace | } // namespace | ||||
| DumpProperties ModelManager::dump_properties_; | DumpProperties ModelManager::dump_properties_; | ||||
| std::mutex ModelManager::exeception_infos_mutex_; | |||||
| std::shared_ptr<ModelManager> ModelManager::GetInstance() { | std::shared_ptr<ModelManager> ModelManager::GetInstance() { | ||||
| static const std::shared_ptr<ModelManager> instance_ptr = | static const std::shared_ptr<ModelManager> instance_ptr = | ||||
| @@ -154,6 +155,7 @@ void ModelManager::DestroyAicpuSession(uint64_t session_id) { | |||||
| GELOGI("The session: %lu not created.", session_id); | GELOGI("The session: %lu not created.", session_id); | ||||
| return; | return; | ||||
| } else { | } else { | ||||
| GE_CHK_RT(rtSetDevice(static_cast<int32_t>(GetContext().DeviceId()))); | |||||
| Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_SESSION_DESTROY, session_id, 0); | Status ret = KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType::FWK_ADPT_SESSION_DESTROY, session_id, 0); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGW("The session: %lu destroy failed.", session_id); | GELOGW("The session: %lu destroy failed.", session_id); | ||||
| @@ -161,6 +163,7 @@ void ModelManager::DestroyAicpuSession(uint64_t session_id) { | |||||
| (void)sess_ids_.erase(session_id); | (void)sess_ids_.erase(session_id); | ||||
| GELOGI("The session: %lu destroyed.", session_id); | GELOGI("The session: %lu destroyed.", session_id); | ||||
| } | } | ||||
| GE_CHK_RT(rtDeviceReset(static_cast<int32_t>(GetContext().DeviceId()))); | |||||
| } | } | ||||
| } | } | ||||
| @@ -369,7 +372,8 @@ Status ModelManager::Unload(uint32_t model_id) { | |||||
| } else { | } else { | ||||
| GELOGI("Unload model %u success.no need reset device,device_count: %u", model_id, device_count); | GELOGI("Unload model %u success.no need reset device,device_count: %u", model_id, device_count); | ||||
| } | } | ||||
| std::lock_guard<std::mutex> lock(exeception_infos_mutex_); | |||||
| exception_infos_.clear(); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1106,4 +1110,23 @@ Status ModelManager::GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint3 | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| Status ModelManager::EnableExceptionDump(const std::map<string, string> &options) { | |||||
| auto iter = options.find(OPTION_EXEC_ENABLE_EXCEPTION_DUMP); | |||||
| if (iter != options.end()) { | |||||
| GELOGI("Find option enable_exeception_dump is %s", iter->second.c_str()); | |||||
| if (iter->second == "1") { | |||||
| rtError_t rt_ret = rtSetTaskFailCallback(ExceptionCallback); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "rtSetTaskFailCallback failed"); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | |||||
| } else { | |||||
| GELOGI("Option enable exception dump is %s", iter->second.c_str()); | |||||
| } | |||||
| } else { | |||||
| GELOGI("Not find option enable exception dump"); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -274,6 +274,22 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
| bool IsDynamicShape(uint32_t model_id); | bool IsDynamicShape(uint32_t model_id); | ||||
| ge::Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); | ge::Status GetOpDescInfo(uint32_t device_id, uint32_t stream_id, uint32_t task_id, OpDescInfo &op_desc_info); | ||||
| ge::Status EnableExceptionDump(const std::map<string, string> &options); | |||||
| const std::vector<rtExceptionInfo> &GetExceptionInfos() { return exception_infos_; } | |||||
| void AddExceptionInfo(const rtExceptionInfo &exception_info) { exception_infos_.emplace_back(exception_info); } | |||||
| static void ExceptionCallback(rtExceptionInfo *exception_info) { | |||||
| std::lock_guard<std::mutex> lock(exeception_infos_mutex_); | |||||
| auto instance = ModelManager::GetInstance(); | |||||
| if (instance == nullptr) { | |||||
| GELOGE(FAILED, "Instance is nullptr"); | |||||
| return; | |||||
| } | |||||
| instance->AddExceptionInfo(*exception_info); | |||||
| } | |||||
| private: | private: | ||||
| /// | /// | ||||
| /// @ingroup domi_ome | /// @ingroup domi_ome | ||||
| @@ -309,8 +325,10 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
| std::mutex map_mutex_; | std::mutex map_mutex_; | ||||
| std::mutex sess_ids_mutex_; | std::mutex sess_ids_mutex_; | ||||
| std::mutex session_id_create_mutex_; | std::mutex session_id_create_mutex_; | ||||
| static ::std::mutex exeception_infos_mutex_; | |||||
| uint64_t session_id_bias_; | uint64_t session_id_bias_; | ||||
| std::set<uint64_t> sess_ids_; | std::set<uint64_t> sess_ids_; | ||||
| std::vector<rtExceptionInfo> exception_infos_; | |||||
| static DumpProperties dump_properties_; | static DumpProperties dump_properties_; | ||||
| }; | }; | ||||
| @@ -258,7 +258,7 @@ Status HcclTaskInfo::SetAddrs(const std::shared_ptr<OpDesc> &op_desc, | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| hcclRedOp_t op_type = HCCL_REP_OP_SUM; | |||||
| HcclReduceOp op_type = HCCL_REDUCE_SUM; | |||||
| GE_CHECK_NOTNULL(davinci_model_); | GE_CHECK_NOTNULL(davinci_model_); | ||||
| GELOGI("Calc opType[%s] input address before. Node name[%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); | GELOGI("Calc opType[%s] input address before. Node name[%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); | ||||
| if (!davinci_model_->IsKnownNode()) { | if (!davinci_model_->IsKnownNode()) { | ||||
| @@ -37,11 +37,17 @@ const uint8_t kL2NotLoadToDdr = 0; | |||||
| // for skt | // for skt | ||||
| constexpr int64_t kInvalidGroupKey = -1; | constexpr int64_t kInvalidGroupKey = -1; | ||||
| constexpr uint32_t kSKTSingleSize = 1; | constexpr uint32_t kSKTSingleSize = 1; | ||||
| constexpr uint32_t kSKTMaxSizeLimit = 20000; | |||||
| const char *kIsLastNode = "is_last_node"; | const char *kIsLastNode = "is_last_node"; | ||||
| const char *kIsFirstNode = "is_first_node"; | const char *kIsFirstNode = "is_first_node"; | ||||
| const int64_t kCloseSkt = 100; | const int64_t kCloseSkt = 100; | ||||
| const uint32_t kAddrLen = sizeof(void *); | const uint32_t kAddrLen = sizeof(void *); | ||||
| const char *const kLoadOpFromBuf = "loadOpFromBuf"; | |||||
| struct CustAicpuSoBuf { | |||||
| uint64_t kernelSoBuf; | |||||
| uint32_t kernelSoBufLen; | |||||
| uint64_t kernelSoName; | |||||
| uint32_t kernelSoNameLen; | |||||
| } __attribute__((packed)); | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| @@ -49,10 +55,7 @@ KernelTaskInfo::SuperKernelTaskInfo KernelTaskInfo::skt_info_ = { | |||||
| 0, 0, 0, 0, nullptr, nullptr, {}, {}, {}, {}, {}, RT_KERNEL_DEFAULT, kInvalidGroupKey, 0, nullptr}; | 0, 0, 0, 0, nullptr, nullptr, {}, {}, {}, {}, {}, RT_KERNEL_DEFAULT, kInvalidGroupKey, 0, nullptr}; | ||||
| Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | ||||
| if (davinci_model == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "davinci model is null!"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| GE_CHECK_NOTNULL(davinci_model); | |||||
| davinci_model_ = davinci_model; | davinci_model_ = davinci_model; | ||||
| is_l1_fusion_enable_ = davinci_model_->GetL1FusionEnableOption(); | is_l1_fusion_enable_ = davinci_model_->GetL1FusionEnableOption(); | ||||
| GELOGD("KernelTaskInfo init start, ge.enableL1Fusion in davinci model is %d.", is_l1_fusion_enable_); | GELOGD("KernelTaskInfo init start, ge.enableL1Fusion in davinci model is %d.", is_l1_fusion_enable_); | ||||
| @@ -71,16 +74,12 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci | |||||
| kernel_type_ = static_cast<cce::ccKernelType>(context.kernel_type()); | kernel_type_ = static_cast<cce::ccKernelType>(context.kernel_type()); | ||||
| // get opdesc | // get opdesc | ||||
| op_desc_ = davinci_model_->GetOpByIndex(context.op_index()); | op_desc_ = davinci_model_->GetOpByIndex(context.op_index()); | ||||
| if (op_desc_ == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "Get op desc failed, index is out of range!"); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| GE_CHECK_NOTNULL(op_desc_); | |||||
| (void)AttrUtils::GetBool(*op_desc_, ATTR_N_BATCH_SPILT, is_n_batch_spilt_); | (void)AttrUtils::GetBool(*op_desc_, ATTR_N_BATCH_SPILT, is_n_batch_spilt_); | ||||
| GELOGD("node[%s] is_n_batch_spilt %d", op_desc_->GetName().c_str(), is_n_batch_spilt_); | GELOGD("node[%s] is_n_batch_spilt %d", op_desc_->GetName().c_str(), is_n_batch_spilt_); | ||||
| (void)AttrUtils::GetInt(*op_desc_, ATTR_NAME_FUSION_GROUP_KEY, group_key_); | (void)AttrUtils::GetInt(*op_desc_, ATTR_NAME_FUSION_GROUP_KEY, group_key_); | ||||
| has_group_key_ = (group_key_ != kInvalidGroupKey); | has_group_key_ = (group_key_ != kInvalidGroupKey); | ||||
| GELOGD("node[%s] has_group_key_ %ld, group key is [%ld]", op_desc_->GetName().c_str(), has_group_key_, group_key_); | GELOGD("node[%s] has_group_key_ %ld, group key is [%ld]", op_desc_->GetName().c_str(), has_group_key_, group_key_); | ||||
| // fusion_op_info | // fusion_op_info | ||||
| vector<std::string> original_op_names; | vector<std::string> original_op_names; | ||||
| bool result = AttrUtils::GetListStr(op_desc_, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names); | bool result = AttrUtils::GetListStr(op_desc_, ge::ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names); | ||||
| @@ -99,7 +98,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. stub_func: %s", | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(RT_FAILED, "execute rtGetFunctionByName failed. stub_func: %s", | ||||
| kernel_def.stub_func().c_str()); | kernel_def.stub_func().c_str()); | ||||
| return RT_ERROR_TO_GE_STATUS(rt_ret);); | return RT_ERROR_TO_GE_STATUS(rt_ret);); | ||||
| } else if (kernel_type_ != cce::ccKernelType::AI_CPU) { | |||||
| } else if (kernel_type_ == cce::ccKernelType::TE) { | |||||
| rtError_t rt_ret; | rtError_t rt_ret; | ||||
| rt_ret = rtGetFunctionByName(bin_file_key, &stub_func_); | rt_ret = rtGetFunctionByName(bin_file_key, &stub_func_); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, | ||||
| @@ -127,7 +126,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci | |||||
| ret = InitTVMTask(args_offset_tmp[0], kernel_def); | ret = InitTVMTask(args_offset_tmp[0], kernel_def); | ||||
| } else if (kernel_type_ == cce::ccKernelType::CUSTOMIZED) { | } else if (kernel_type_ == cce::ccKernelType::CUSTOMIZED) { | ||||
| ret = InitAICPUCustomTask(context.op_index(), kernel_def); | ret = InitAICPUCustomTask(context.op_index(), kernel_def); | ||||
| } else if (kernel_type_ == cce::ccKernelType::AI_CPU) { | |||||
| } else if (kernel_type_ == cce::ccKernelType::AI_CPU || kernel_type_ == cce::ccKernelType::CUST_AI_CPU) { | |||||
| ret = InitAicpuTask(context.op_index(), kernel_def); | ret = InitAicpuTask(context.op_index(), kernel_def); | ||||
| } else { | } else { | ||||
| if (kernel_def.args().empty() || args_size_ == 0) { | if (kernel_def.args().empty() || args_size_ == 0) { | ||||
| @@ -332,10 +331,6 @@ bool KernelTaskInfo::DoubleCallSKTSaveCheck() { return (!is_n_batch_spilt_ && !h | |||||
| Status KernelTaskInfo::SuperKernelDistribute() { | Status KernelTaskInfo::SuperKernelDistribute() { | ||||
| Status ret; | Status ret; | ||||
| char *skt_task_num = getenv("SKT_TASK_NUM"); | |||||
| auto task_num = static_cast<uint64_t>((skt_task_num != nullptr) ? strtol(skt_task_num, nullptr, 10) | |||||
| : kSKTMaxSizeLimit); // 10 for decimal number | |||||
| GELOGI("SKT: SuperKernel Distribute Task num[skt_id:%lu]", task_num); | |||||
| if (FirstCallSKTLaunchCheck()) { | if (FirstCallSKTLaunchCheck()) { | ||||
| ret = SuperKernelLaunch(); | ret = SuperKernelLaunch(); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -381,7 +376,8 @@ Status KernelTaskInfo::Distribute() { | |||||
| char *skt_enable_env = getenv("SKT_ENABLE"); | char *skt_enable_env = getenv("SKT_ENABLE"); | ||||
| int64_t env_flag = (skt_enable_env != nullptr) ? strtol(skt_enable_env, nullptr, 10) : 0; | int64_t env_flag = (skt_enable_env != nullptr) ? strtol(skt_enable_env, nullptr, 10) : 0; | ||||
| bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); | bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); | ||||
| if (kernel_type_ == cce::ccKernelType::AI_CPU) { | |||||
| if (kernel_type_ == cce::ccKernelType::AI_CPU || kernel_type_ == cce::ccKernelType::CUST_AI_CPU) { | |||||
| GELOGI("distribute task info kernel_type %d, flag %d", kernel_type_, dump_flag_); | |||||
| // blockDim is reserved parameter, set to 1 | // blockDim is reserved parameter, set to 1 | ||||
| rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(so_name_.c_str()), | rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(so_name_.c_str()), | ||||
| reinterpret_cast<const void *>(kernel_name_.c_str()), 1, args_, args_size_, | reinterpret_cast<const void *>(kernel_name_.c_str()), 1, args_, args_size_, | ||||
| @@ -865,10 +861,98 @@ Status KernelTaskInfo::InitCceTask(const domi::KernelDef &kernel_def) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status KernelTaskInfo::LaunchCustAicpuSo(const OpDescPtr op_desc, const domi::KernelDef &kernel_def) { | |||||
| CustAICPUKernelPtr aicpu_kernel = op_desc->TryGetExtAttr(OP_EXTATTR_CUSTAICPU_KERNEL, CustAICPUKernelPtr()); | |||||
| if (aicpu_kernel == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "cust aicpu op %s can't find kernel!", op_desc->GetName().c_str()); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| const void *aicpu_data = aicpu_kernel->GetBinData(); | |||||
| uint32_t aicpu_data_length = aicpu_kernel->GetBinDataSize(); | |||||
| void *d_aicpu_data = nullptr; | |||||
| rtError_t status = rtMalloc(&d_aicpu_data, aicpu_data_length, RT_MEMORY_HBM); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| status = rtMemcpy(d_aicpu_data, aicpu_data_length, aicpu_data, aicpu_data_length, RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| void *d_so_name = nullptr; | |||||
| status = rtMalloc(&d_so_name, so_name_.size(), RT_MEMORY_HBM); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| status = rtMemcpy(d_so_name, so_name_.size(), reinterpret_cast<const void *>(so_name_.c_str()), so_name_.size(), | |||||
| RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| CustAicpuSoBuf cust_aicpu_so_buf; | |||||
| cust_aicpu_so_buf.kernelSoBuf = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_aicpu_data)); | |||||
| cust_aicpu_so_buf.kernelSoBufLen = aicpu_data_length; | |||||
| cust_aicpu_so_buf.kernelSoName = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(d_so_name)); | |||||
| cust_aicpu_so_buf.kernelSoNameLen = so_name_.size(); | |||||
| void *args = nullptr; | |||||
| uint32_t args_size = sizeof(CustAicpuSoBuf); | |||||
| status = rtMalloc(&args, args_size, RT_MEMORY_HBM); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| GELOGI("loadOpFromBuf kernelSoBuf %p, kernelSoBufLen %u, kernelSoName %p, kernelSoNameLen %u.", d_aicpu_data, | |||||
| aicpu_data_length, d_so_name, so_name_.size()); | |||||
| status = rtMemcpy(args, args_size, static_cast<void *>(&cust_aicpu_so_buf), args_size, RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| rtStream_t stream = nullptr; | |||||
| status = rtStreamCreate(&stream, 0); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt create stream failed, status: 0x%x", status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| status = rtCpuKernelLaunch(nullptr, kLoadOpFromBuf, 1, args, args_size, nullptr, stream); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt CpuKernelLaunch loadOpFromBuf failed, status: 0x%X", status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| GELOGI("Cpu kernel launch loadOpFromBuf."); | |||||
| status = rtStreamSynchronize(stream); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt stream sync failed, status: 0x%x", status); | |||||
| return RT_ERROR_TO_GE_STATUS(status); | |||||
| } | |||||
| GE_CHK_RT(rtFree(args)); | |||||
| GE_CHK_RT(rtFree(d_aicpu_data)); | |||||
| GE_CHK_RT(rtFree(d_so_name)); | |||||
| GELOGI("Cpu kernel launch loadOpFromBuf task success."); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &kernel_def) { | Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &kernel_def) { | ||||
| GELOGI("Do InitAicpuTask"); | GELOGI("Do InitAicpuTask"); | ||||
| so_name_ = kernel_def.so_name(); | so_name_ = kernel_def.so_name(); | ||||
| kernel_name_ = kernel_def.kernel_name(); | kernel_name_ = kernel_def.kernel_name(); | ||||
| GELOGI("node[%s] test so name %s, kernel name %s", op_desc_->GetName().c_str(), so_name_.c_str(), | |||||
| kernel_name_.c_str()); | |||||
| OpDescPtr op_desc = davinci_model_->GetOpByIndex(op_index); | OpDescPtr op_desc = davinci_model_->GetOpByIndex(op_index); | ||||
| if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
| @@ -876,6 +960,10 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| if (kernel_type_ == cce::ccKernelType::CUST_AI_CPU) { | |||||
| GE_CHK_STATUS_RET(LaunchCustAicpuSo(op_desc, kernel_def), "launch cust aicpu so failed"); | |||||
| } | |||||
| // copy args to new host memory | // copy args to new host memory | ||||
| std::unique_ptr<uint8_t[]> args_addr(new (std::nothrow) uint8_t[args_size_]); | std::unique_ptr<uint8_t[]> args_addr(new (std::nothrow) uint8_t[args_size_]); | ||||
| GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) | GE_PRINT_DYNAMIC_MEMORY(new, "cce task physical memory.", sizeof(uint8_t) * args_size_) | ||||
| @@ -940,6 +1028,9 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
| } | } | ||||
| dump_args_ = static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead); | dump_args_ = static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead); | ||||
| } | } | ||||
| if (kernel_type_ == cce::ccKernelType::CUST_AI_CPU) { | |||||
| dump_flag_ |= RT_KERNEL_CUSTOM_AICPU; | |||||
| } | |||||
| davinci_model_->SetZeroCopyAddr(op_desc, io_addrs, args_addr.get(), args_, args_size_, sizeof(aicpu::AicpuParamHead)); | davinci_model_->SetZeroCopyAddr(op_desc, io_addrs, args_addr.get(), args_, args_size_, sizeof(aicpu::AicpuParamHead)); | ||||
| @@ -1195,16 +1286,6 @@ uint8_t KernelTaskInfo::IsL2CpToDDR(uint8_t origain_L2_load_to_ddr) { | |||||
| if (dump_flag_ == RT_KERNEL_DUMPFLAG) { | if (dump_flag_ == RT_KERNEL_DUMPFLAG) { | ||||
| return kL2LoadToDdr; | return kL2LoadToDdr; | ||||
| } | } | ||||
| static char *ge_dump_env = std::getenv("DUMP_OP"); | |||||
| if (ge_dump_env != nullptr) { | |||||
| static std::string ge_dump_str(ge_dump_env); | |||||
| static std::string open_ge_dump("1"); | |||||
| if (ge_dump_str == open_ge_dump) { | |||||
| return kL2LoadToDdr; | |||||
| } | |||||
| } | |||||
| return kL2NotLoadToDdr; | return kL2NotLoadToDdr; | ||||
| } | } | ||||
| @@ -106,6 +106,8 @@ class KernelTaskInfo : public TaskInfo { | |||||
| Status InitAicpuTaskExtInfo(const std::string &ext_info); | Status InitAicpuTaskExtInfo(const std::string &ext_info); | ||||
| Status LaunchCustAicpuSo(const OpDescPtr op_desc, const domi::KernelDef &kernel_def); | |||||
| Status StoreInputOutputTensor(const std::vector<void *> &input_data_addrs, | Status StoreInputOutputTensor(const std::vector<void *> &input_data_addrs, | ||||
| const std::vector<void *> &output_data_addrs, | const std::vector<void *> &output_data_addrs, | ||||
| const std::vector<::tagCcAICPUTensor> &input_descs, | const std::vector<::tagCcAICPUTensor> &input_descs, | ||||
| @@ -130,8 +130,8 @@ Status ZeroCopyTask::UpdateTaskParam(uintptr_t addr, void *buffer_addr, const ma | |||||
| } | } | ||||
| auto dst_addr = static_cast<uint8_t *>(buffer_addr); | auto dst_addr = static_cast<uint8_t *>(buffer_addr); | ||||
| GELOGI("[ZCPY] %s update task, args_addr: %p, size: %zu, offset: %zu, virtual_addr: 0x%lx", name_.c_str(), | |||||
| args_addr_, args_size_, offset, addr); | |||||
| GELOGI("[ZCPY] %s update task, args_addr: %p, size: %zu, offset: %zu, virtual_addr: 0x%lx, user_data_addr: %p", | |||||
| name_.c_str(), args_addr_, args_size_, offset, addr, buffer_addr); | |||||
| *(uintptr_t *)(args_info + offset) = reinterpret_cast<uintptr_t>(dst_addr); | *(uintptr_t *)(args_info + offset) = reinterpret_cast<uintptr_t>(dst_addr); | ||||
| is_updated_ = true; | is_updated_ = true; | ||||
| } | } | ||||
| @@ -29,6 +29,7 @@ | |||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| #include "graph/manager/block_memory.h" | |||||
| #include "runtime/mem.h" | #include "runtime/mem.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -38,30 +39,8 @@ constexpr size_t kKByteSize = 1024; | |||||
| constexpr size_t kMByteSize = 1024 * 1024; | constexpr size_t kMByteSize = 1024 * 1024; | ||||
| constexpr size_t kGByteSize = 1024 * 1024 * 1024; | constexpr size_t kGByteSize = 1024 * 1024 * 1024; | ||||
| struct Block; | |||||
| typedef bool (*Comparison)(const Block *, const Block *); | |||||
| using BlockBin = std::set<Block *, Comparison>; | |||||
| static const uint32_t kNumBins = 8; | static const uint32_t kNumBins = 8; | ||||
| struct Block { | |||||
| uint32_t device_id; // npu device id | |||||
| size_t size; // block size in bytes | |||||
| BlockBin *bin; // owning block bin | |||||
| uint8_t *ptr; // memory address | |||||
| bool allocated; // in-use flag | |||||
| Block *prev; // prev block if split from a larger allocation | |||||
| Block *next; // next block if split from a larger allocation | |||||
| Block(uint32_t device, size_t size, BlockBin *bin, uint8_t *ptr) | |||||
| : device_id(device), size(size), bin(bin), ptr(ptr), allocated(0), prev(nullptr), next(nullptr) {} | |||||
| // constructor for search key | |||||
| Block(uint32_t device, size_t size, uint8_t *ptr) | |||||
| : device_id(device), size(size), bin(nullptr), ptr(ptr), allocated(0), prev(nullptr), next(nullptr) {} | |||||
| bool IsSplit() const { return (prev != nullptr) || (next != nullptr); } | |||||
| }; | |||||
| class MemoryAllocator; | class MemoryAllocator; | ||||
| class CachingAllocator { | class CachingAllocator { | ||||
| @@ -33,7 +33,9 @@ | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "framework/common/ge_types.h" | #include "framework/common/ge_types.h" | ||||
| #include "analyzer/analyzer.h" | |||||
| #include "graph/common/ge_call_wrapper.h" | #include "graph/common/ge_call_wrapper.h" | ||||
| #include "graph/common/local_context.h" | |||||
| #include "graph/common/transop_util.h" | #include "graph/common/transop_util.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| @@ -42,6 +44,7 @@ | |||||
| #include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
| #include "graph/manager/util/rt_context_util.h" | #include "graph/manager/util/rt_context_util.h" | ||||
| #include "graph/partition/dynamic_shape_partition.h" | #include "graph/partition/dynamic_shape_partition.h" | ||||
| #include "graph/passes/enter_pass.h" | |||||
| #include "graph/passes/addn_pass.h" | #include "graph/passes/addn_pass.h" | ||||
| #include "graph/passes/bitcast_pass.h" | #include "graph/passes/bitcast_pass.h" | ||||
| #include "graph/passes/atomic_addr_clean_pass.h" | #include "graph/passes/atomic_addr_clean_pass.h" | ||||
| @@ -110,6 +113,9 @@ const char *const kSend = "Send"; | |||||
| const char *const kRecv = "Recv"; | const char *const kRecv = "Recv"; | ||||
| const char *const kCheckPointForGetVar = "CheckPointGraphForGetVar"; | const char *const kCheckPointForGetVar = "CheckPointGraphForGetVar"; | ||||
| const char *const kCheckPointGraph = "checkpoint_graph"; | const char *const kCheckPointGraph = "checkpoint_graph"; | ||||
| const char *const kVectorEngine = "VectorEngine"; | |||||
| const char *const kAIcoreEngine = "AIcoreEngine"; | |||||
| const char *const kOffOptimize = "off_optimize"; | |||||
| bool IsTailingOptimization() { | bool IsTailingOptimization() { | ||||
| string is_tailing_optimization_option; | string is_tailing_optimization_option; | ||||
| @@ -125,7 +131,10 @@ bool IsTailingOptimization() { | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| GraphManager::GraphManager() : thread_run_flag_(false), graph_run_listener_(nullptr), init_flag_(false) {} | |||||
| GraphManager::GraphManager(OmgContext &omg_context) | |||||
| : thread_run_flag_(false), graph_run_listener_(nullptr), init_flag_(false), omg_context_(omg_context) { | |||||
| SetLocalOmgContext(omg_context); | |||||
| } | |||||
| Status GraphManager::Initialize(const std::map<string, string> &options) { | Status GraphManager::Initialize(const std::map<string, string> &options) { | ||||
| if (init_flag_) { | if (init_flag_) { | ||||
| @@ -321,14 +330,56 @@ Status GraphManager::MergeSubGraph(ComputeGraphPtr &compute_graph, const ge::Com | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph) { | |||||
| Status GraphManager::CopySubGraphAndMarkFusion(const ComputeGraphPtr &compute_graph, | |||||
| Graph2SubGraphInfoList &sub_graph_map, | |||||
| std::unordered_map<std::string, ComputeGraphPtr> ©_graphs) { | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| vector<ComputeGraphPtr> old_compute_graphs; | |||||
| const auto &root_subgraph_list = sub_graph_map[compute_graph]; | |||||
| for (const auto &subgraph : root_subgraph_list) { | |||||
| old_compute_graphs.emplace_back(subgraph->GetSubGraph()); | |||||
| } | |||||
| for (const auto &function_graph : compute_graph->GetAllSubgraphs()) { | |||||
| const auto &subgraph_list = sub_graph_map[function_graph]; | |||||
| for (const auto &subgraph : subgraph_list) { | |||||
| old_compute_graphs.emplace_back(subgraph->GetSubGraph()); | |||||
| } | |||||
| } | |||||
| for (const auto &old_compute_graph : old_compute_graphs) { | |||||
| std::vector<NodePtr> input_nodes; | |||||
| std::vector<NodePtr> output_nodes; | |||||
| ComputeGraphPtr new_compute_graph = GraphUtils::CloneGraph(old_compute_graph, "", input_nodes, output_nodes); | |||||
| if (new_compute_graph == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "Clone graph failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| copy_graphs.emplace(old_compute_graph->GetName(), new_compute_graph); | |||||
| if (!AttrUtils::SetBool(old_compute_graph, ATTR_NAME_NEED_LX_FUSION, true)) { | |||||
| GELOGE(INTERNAL_ERROR, "Set attr lx_fusion to graph failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| } | |||||
| GELOGI("Copy %zu graphs successfully.", copy_graphs.size()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphManager::OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_graph, | |||||
| Graph2SubGraphInfoList &sub_graph_map, uint64_t session_id) { | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| // use default 16 multi thread | // use default 16 multi thread | ||||
| const uint32_t thread_num = 16; | const uint32_t thread_num = 16; | ||||
| ThreadPool executor(thread_num); | ThreadPool executor(thread_num); | ||||
| auto sub_graph_map = graph_partitioner_.GetSubGraphMap(); | |||||
| std::vector<std::future<Status>> vector_future; | std::vector<std::future<Status>> vector_future; | ||||
| const auto &root_subgraph_list = sub_graph_map[compute_graph]; | const auto &root_subgraph_list = sub_graph_map[compute_graph]; | ||||
| std::string op_compile_strategy; | |||||
| (void)AttrUtils::GetStr(compute_graph, ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | |||||
| GELOGI("OptimizeSubGraphWithMultiThreads Process op_compile_strategy:%s", op_compile_strategy.c_str()); | |||||
| for (const auto &subgraph : root_subgraph_list) { | for (const auto &subgraph : root_subgraph_list) { | ||||
| if (!op_compile_strategy.empty()) { | |||||
| (void)AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | |||||
| } | |||||
| std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, subgraph, session_id, | std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, subgraph, session_id, | ||||
| GetThreadLocalContext()); | GetThreadLocalContext()); | ||||
| if (!f.valid()) { | if (!f.valid()) { | ||||
| @@ -341,6 +392,9 @@ Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_gr | |||||
| for (auto &function_graph : compute_graph->GetAllSubgraphs()) { | for (auto &function_graph : compute_graph->GetAllSubgraphs()) { | ||||
| auto subgraph_list = sub_graph_map[function_graph]; | auto subgraph_list = sub_graph_map[function_graph]; | ||||
| for (const auto &subgraph : subgraph_list) { | for (const auto &subgraph : subgraph_list) { | ||||
| if (!op_compile_strategy.empty()) { | |||||
| (void)AttrUtils::SetStr(subgraph->GetSubGraph(), ATTR_NAME_OP_COMPILE_STRATEGY, op_compile_strategy); | |||||
| } | |||||
| std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, subgraph, session_id, | std::future<Status> f = executor.commit(GraphManager::ProcessSubGraphWithMultiThreads, this, subgraph, session_id, | ||||
| GetThreadLocalContext()); | GetThreadLocalContext()); | ||||
| if (!f.valid()) { | if (!f.valid()) { | ||||
| @@ -361,6 +415,130 @@ Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_gr | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| bool GraphManager::CheckAllFusionOptimizeSuccess(const ComputeGraphPtr &compute_graph, | |||||
| Graph2SubGraphInfoList &sub_graph_map) { | |||||
| if (compute_graph == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "Input param compute_graph is nullptr."); | |||||
| return false; | |||||
| } | |||||
| /// 1. FE will set attr optimize_group with true(false) while lx fusion is success(fail); | |||||
| /// 2. FE will not set attr optimize_group while fe.ini set l2fusion enable false; | |||||
| /// 3. Other engine will not set attr optimize_group. | |||||
| const auto &root_subgraph_list = sub_graph_map[compute_graph]; | |||||
| for (const auto &subgraph : root_subgraph_list) { | |||||
| bool optimize_group = true; | |||||
| (void)AttrUtils::GetBool(subgraph->GetSubGraph(), ATTR_NAME_OPTIMIZE_GROUP, optimize_group); | |||||
| if (!optimize_group) { | |||||
| GELOGW("Run lx optimize for subgraph:%s failed.", subgraph->GetSubGraph()->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| for (auto &function_graph : compute_graph->GetAllSubgraphs()) { | |||||
| const auto &subgraph_list = sub_graph_map[function_graph]; | |||||
| for (const auto &subgraph : subgraph_list) { | |||||
| bool optimize_group = true; | |||||
| (void)AttrUtils::GetBool(subgraph->GetSubGraph(), ATTR_NAME_OPTIMIZE_GROUP, optimize_group); | |||||
| if (!optimize_group) { | |||||
| GELOGW("Run lx optimize for subgraph:%s failed.", subgraph->GetSubGraph()->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| GELOGI("All subgraph are optimized successfully, no need to reuse buffer optimize."); | |||||
| return true; | |||||
| } | |||||
| Status GraphManager::ReplaceSubgraphWithOriGraph(const ComputeGraphPtr &compute_graph, | |||||
| Graph2SubGraphInfoList &sub_graph_map, | |||||
| std::unordered_map<std::string, ComputeGraphPtr> ©_graphs) { | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| const auto &root_subgraph_list = sub_graph_map[compute_graph]; | |||||
| for (const auto &subgraph : root_subgraph_list) { | |||||
| auto iter = copy_graphs.find(subgraph->GetSubGraph()->GetName()); | |||||
| if (iter == copy_graphs.end()) { | |||||
| GELOGE(FAILED, "Can not find subgraph:%s in copy graphs.", subgraph->GetSubGraph()->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| subgraph->SetSubGraph(iter->second); | |||||
| } | |||||
| for (auto &function_graph : compute_graph->GetAllSubgraphs()) { | |||||
| const auto &subgraph_list = sub_graph_map[function_graph]; | |||||
| for (const auto &subgraph : subgraph_list) { | |||||
| auto iter = copy_graphs.find(subgraph->GetSubGraph()->GetName()); | |||||
| if (iter == copy_graphs.end()) { | |||||
| GELOGE(FAILED, "Can not find subgraph:%s in copy graphs.", subgraph->GetSubGraph()->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| subgraph->SetSubGraph(iter->second); | |||||
| } | |||||
| } | |||||
| GELOGI("All subgraphs are successfully replaced."); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph) { | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| auto sub_graph_map = graph_partitioner_.GetSubGraphMap(); | |||||
| std::string buffer_optimize; | |||||
| graphStatus graph_status = ge::GetContext().GetOption(BUFFER_OPTIMIZE, buffer_optimize); | |||||
| bool need_lx_fusion = (graph_status == GRAPH_SUCCESS) && (buffer_optimize != kOffOptimize); | |||||
| if (options_.build_mode.empty() && need_lx_fusion) { | |||||
| GELOGI("Enter normal mode with buffer_optimize:%s.", buffer_optimize.c_str()); | |||||
| /// 1. Copy subgraph for buffer optimize while lx fusion failed. | |||||
| /// 2. Set graph with attr "lx_fusion" for fusion optimize. | |||||
| std::unordered_map<std::string, ComputeGraphPtr> copy_graphs; | |||||
| GE_TIMESTAMP_START(CopySubGraphAndMarkFusion); | |||||
| Status ret = CopySubGraphAndMarkFusion(compute_graph, sub_graph_map, copy_graphs); | |||||
| GE_TIMESTAMP_EVENT_END(CopySubGraphAndMarkFusion, "SetSubgraph:CopySubGraphAndMarkFusion"); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "CopySubGraphAndMarkFusion failed."); | |||||
| return ret; | |||||
| } | |||||
| // Multiply optimize subgraph with lx fusion | |||||
| ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Multiply optimize subgraph with lx fusion failed."); | |||||
| return ret; | |||||
| } | |||||
| // Check whether all subgraph lx fusion success | |||||
| GE_TIMESTAMP_START(CheckAllFusionOptimizeSuccess); | |||||
| if (CheckAllFusionOptimizeSuccess(compute_graph, sub_graph_map)) { | |||||
| GE_TIMESTAMP_EVENT_END(CheckAllFusionOptimizeSuccess, "SetSubgraph:CheckAllFusionOptimizeSuccess"); | |||||
| return SUCCESS; | |||||
| } | |||||
| // Replace subgraph with original graph for lx buffer | |||||
| ret = ReplaceSubgraphWithOriGraph(compute_graph, sub_graph_map, copy_graphs); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Replace subgraph with original graph failed."); | |||||
| return ret; | |||||
| } | |||||
| // Multiply optimize subgraph with lx buffer | |||||
| ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Multiply optimize subgraph with lx buffer failed."); | |||||
| return ret; | |||||
| } | |||||
| } else { | |||||
| /// Multiply optimize subgraph: | |||||
| /// 1. run lx buffer while build_mode is normal and buffer_optimize is empty or "off_optimize"; | |||||
| /// 2. run lx fusion or buffer according build_mode and build_step in fe. | |||||
| GELOGI("Directly optimize subgraph with build mode:%s, and step:%s, buffer_optimize:%s.", | |||||
| options_.build_mode.c_str(), options_.build_step.c_str(), buffer_optimize.c_str()); | |||||
| Status ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Multiply optimize subgraph with lx buffer"); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| #define GM_RUN_AND_DUMP_PERF(name, func, ...) \ | #define GM_RUN_AND_DUMP_PERF(name, func, ...) \ | ||||
| do { \ | do { \ | ||||
| GE_RUN_PERF(GraphManager, func, __VA_ARGS__); \ | GE_RUN_PERF(GraphManager, func, __VA_ARGS__); \ | ||||
| @@ -368,18 +546,10 @@ Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_gr | |||||
| GELOGI("Run %s on graph %s(%u) success.", name, compute_graph->GetName().c_str(), graph_node->GetGraphId()); \ | GELOGI("Run %s on graph %s(%u) success.", name, compute_graph->GetName().c_str(), graph_node->GetGraphId()); \ | ||||
| } while (0) | } while (0) | ||||
| Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, | |||||
| GeRootModelPtr &ge_root_model, uint64_t session_id) { | |||||
| Status GraphManager::PreRunOptimizeOriginalGraph(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, | |||||
| ge::ComputeGraphPtr &compute_graph, uint64_t session_id) { | |||||
| GE_CHECK_NOTNULL(graph_node); | GE_CHECK_NOTNULL(graph_node); | ||||
| GE_CHECK_NOTNULL(graph_node->GetGraph()); | |||||
| auto compute_graph = GraphUtils::GetComputeGraph(*graph_node->GetGraph()); | |||||
| GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
| GEEVENT("PreRun start, graph node size %zu, session id %lu, graph id %u, graph name %s", | |||||
| compute_graph->GetDirectNodesSize(), session_id, compute_graph->GetGraphID(), | |||||
| compute_graph->GetName().c_str()); | |||||
| GE_DUMP(compute_graph, "PreRunBegin"); | |||||
| GM_RUN_AND_DUMP_PERF("OptimizeGraphPrepare", graph_optimize_.OptimizeOriginalGraphForQuantize, compute_graph); | GM_RUN_AND_DUMP_PERF("OptimizeGraphPrepare", graph_optimize_.OptimizeOriginalGraphForQuantize, compute_graph); | ||||
| GM_RUN_AND_DUMP_PERF("HandleSummaryOp", graph_optimize_.HandleSummaryOp, compute_graph); | GM_RUN_AND_DUMP_PERF("HandleSummaryOp", graph_optimize_.HandleSummaryOp, compute_graph); | ||||
| GM_RUN_AND_DUMP_PERF("Prepare", graph_preparer_.PrepareDynShape, graph_node->GetGraph(), inputs, compute_graph, | GM_RUN_AND_DUMP_PERF("Prepare", graph_preparer_.PrepareDynShape, graph_node->GetGraph(), inputs, compute_graph, | ||||
| @@ -388,10 +558,6 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<Ge | |||||
| GM_RUN_AND_DUMP_PERF("PrepareRunningFormatRefiner", graph_preparer_.PrepareRunningFormatRefiner); | GM_RUN_AND_DUMP_PERF("PrepareRunningFormatRefiner", graph_preparer_.PrepareRunningFormatRefiner); | ||||
| GM_RUN_AND_DUMP_PERF("RefineRunningFormat", graph_optimize_.OptimizeOriginalGraphJudgeInsert, compute_graph); | GM_RUN_AND_DUMP_PERF("RefineRunningFormat", graph_optimize_.OptimizeOriginalGraphJudgeInsert, compute_graph); | ||||
| if (std::getenv("AnalyzeMode")) { | |||||
| GELOGI("Do return failed after refine_running_format when in analyze mode!"); | |||||
| return FAILED; | |||||
| } | |||||
| GM_RUN_AND_DUMP_PERF("SubexpressionMigration", SubexpressionMigration, compute_graph); | GM_RUN_AND_DUMP_PERF("SubexpressionMigration", SubexpressionMigration, compute_graph); | ||||
| GE_RUN(GraphManager, graph_preparer_.RecordAIPPInfo, compute_graph); | GE_RUN(GraphManager, graph_preparer_.RecordAIPPInfo, compute_graph); | ||||
| if (IsTailingOptimization()) { | if (IsTailingOptimization()) { | ||||
| @@ -399,18 +565,124 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<Ge | |||||
| } | } | ||||
| GM_RUN_AND_DUMP_PERF("Optimize1", OptimizeStage1, compute_graph); | GM_RUN_AND_DUMP_PERF("Optimize1", OptimizeStage1, compute_graph); | ||||
| GM_RUN_AND_DUMP_PERF("InferShape2", compute_graph->InferShapeInNeed); | GM_RUN_AND_DUMP_PERF("InferShape2", compute_graph->InferShapeInNeed); | ||||
| const char *unknown_shape_skip = std::getenv("EXPERIMENTAL_DYNAMIC_PARTITION"); | |||||
| if (unknown_shape_skip != nullptr) { | |||||
| PassManager graph_pass; | |||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::CtrlEdgeTransferPass", new (std::nothrow) CtrlEdgeTransferPass)) | |||||
| GE_CHK_STATUS_RET(graph_pass.Run(compute_graph)); | |||||
| } | |||||
| PassManager graph_pass; | |||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::CtrlEdgeTransferPass", new (std::nothrow) CtrlEdgeTransferPass)) | |||||
| GE_CHK_STATUS_RET(graph_pass.Run(compute_graph)); | |||||
| GE_CHK_STATUS_RET(graph_optimize_.IdentifyReference(compute_graph), "Identify reference failed."); | GE_CHK_STATUS_RET(graph_optimize_.IdentifyReference(compute_graph), "Identify reference failed."); | ||||
| GELOGI("PreRun:PreRunOptimizeOriginalGraph success."); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphManager::PreRunOptimizeSubGraph(const GraphNodePtr &graph_node, ge::ComputeGraphPtr &compute_graph, | |||||
| uint64_t session_id) { | |||||
| GE_CHECK_NOTNULL(graph_node); | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| GM_RUN_AND_DUMP_PERF("OptimizeSubgraph", OptimizeSubgraph, graph_node, compute_graph, session_id); | GM_RUN_AND_DUMP_PERF("OptimizeSubgraph", OptimizeSubgraph, graph_node, compute_graph, session_id); | ||||
| // Dump graph to tuning path | |||||
| if (options_.build_mode == BUILD_MODE_TUNING && options_.build_step == BUILD_STEP_AFTER_UB_MATCH) { | |||||
| std::string tuning_path; | |||||
| (void)GetContext().GetOption(TUNING_PATH, tuning_path); | |||||
| GELOGI("Dump path:%s.", tuning_path.c_str()); | |||||
| GraphUtils::DumpGEGraph(compute_graph, "", true, tuning_path); | |||||
| } | |||||
| GELOGI("PreRun:PreRunOptimizeSubGraph success."); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphManager::PreRunAfterOptimizeSubGraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, | |||||
| GeRootModelPtr &ge_root_model, uint64_t session_id) { | |||||
| GE_CHECK_NOTNULL(graph_node); | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| GM_RUN_AND_DUMP_PERF("Optimize2", OptimizeStage2, compute_graph); | GM_RUN_AND_DUMP_PERF("Optimize2", OptimizeStage2, compute_graph); | ||||
| GM_RUN_AND_DUMP_PERF("OptimizeGraphBeforeBuildForRts", graph_optimize_.OptimizeGraphBeforeBuildForRts, compute_graph); | GM_RUN_AND_DUMP_PERF("OptimizeGraphBeforeBuildForRts", graph_optimize_.OptimizeGraphBeforeBuildForRts, compute_graph); | ||||
| GM_RUN_AND_DUMP_PERF("Build", Build, graph_node, compute_graph, ge_root_model, session_id); | GM_RUN_AND_DUMP_PERF("Build", Build, graph_node, compute_graph, ge_root_model, session_id); | ||||
| GELOGI("PreRun:PreRunAfterOptimizeSubGraph success."); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphManager::SetRtContext(rtContext_t rt_context, rtCtxMode_t mode, uint64_t session_id, uint32_t graph_id) { | |||||
| GELOGI("set rt_context, session id: %lu, graph id: %u, mode %d, device id:%u.", session_id, graph_id, | |||||
| static_cast<int>(mode), ge::GetContext().DeviceId()); | |||||
| rtError_t rt_ret = rtCtxCreate(&rt_context, mode, ge::GetContext().DeviceId()); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return FAILED; | |||||
| } | |||||
| rt_ret = rtCtxSetCurrent(rt_context); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return FAILED; | |||||
| } | |||||
| RtContextUtil::GetInstance().AddRtContext(session_id, graph_id, rt_context); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, | |||||
| GeRootModelPtr &ge_root_model, uint64_t session_id) { | |||||
| GE_CHECK_NOTNULL(graph_node); | |||||
| GE_CHECK_NOTNULL(graph_node->GetGraph()); | |||||
| auto compute_graph = GraphUtils::GetComputeGraph(*graph_node->GetGraph()); | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| compute_graph->SetSessionID(session_id); | |||||
| auto analyzer_instance = Analyzer::GetInstance(); | |||||
| GE_CHK_STATUS_RET(analyzer_instance->BuildJsonObject(session_id, compute_graph->GetGraphID()), | |||||
| "BuildJsonObject Failed") | |||||
| GEEVENT("PreRun start, graph node size %zu, session id %lu, graph id %u, graph name %s", | |||||
| compute_graph->GetDirectNodesSize(), session_id, compute_graph->GetGraphID(), | |||||
| compute_graph->GetName().c_str()); | |||||
| GE_DUMP(compute_graph, "PreRunBegin"); | |||||
| // rtContext_t | |||||
| Status ret = SetRtContext(rtContext_t(), RT_CTX_GEN_MODE, session_id, compute_graph->GetGraphID()); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Set rt context failed."); | |||||
| return ret; | |||||
| } | |||||
| /// 1. BUILD_MODE_TUNING with BUILD_STEP_AFTER_UB_MATCH no need PreRunOptimizeOriginalGraph; | |||||
| /// 2. BUILD_MODE_TUNING with BUILD_STEP_AFTER_MERGE no need PreRunOptimizeOriginalGraph. | |||||
| /// 3. BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need PreRunOptimizeOriginalGraph. | |||||
| bool run_optimize_original_graph = | |||||
| !((options_.build_mode == BUILD_MODE_TUNING) && | |||||
| (options_.build_step == BUILD_STEP_AFTER_UB_MATCH || options_.build_step == BUILD_STEP_AFTER_MERGE || | |||||
| options_.build_step == BUILD_STEP_AFTER_BUILDER_SUB)); | |||||
| if (run_optimize_original_graph) { | |||||
| Status ret = PreRunOptimizeOriginalGraph(graph_node, inputs, compute_graph, session_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Run PreRunOptimizeOriginalGraph failed for graph:%s.", compute_graph->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| // BUILD_MODE_TUNING with BUILD_STEP_AFTER_MERGE no need PreRunOptimizeSubGraph. | |||||
| bool run_optimize_subgraph = | |||||
| !((options_.build_mode == BUILD_MODE_TUNING) && (options_.build_step == BUILD_STEP_AFTER_MERGE)); | |||||
| if (run_optimize_subgraph) { | |||||
| Status ret = PreRunOptimizeSubGraph(graph_node, compute_graph, session_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Run PreRunOptimizeSubGraph failed for graph:%s.", compute_graph->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| /// 1. BUILD_MODE_TUNING with BUILD_STEP_BEFORE_UB_MATCH no need PreRunAfterOptimizeSubGraph; | |||||
| /// 2. BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER no need PreRunAfterOptimizeSubGraph. | |||||
| /// 3. BUILD_MODE_TUNING with BUILD_STEP_AFTER_BUILDER_SUB no need PreRunAfterOptimizeSubGraph. | |||||
| bool run_after_optimize_subgraph = | |||||
| !((options_.build_mode == BUILD_MODE_TUNING) && | |||||
| (options_.build_step == BUILD_STEP_BEFORE_UB_MATCH || options_.build_step == BUILD_STEP_AFTER_BUILDER || | |||||
| options_.build_step == BUILD_STEP_AFTER_BUILDER_SUB)); | |||||
| if (run_after_optimize_subgraph) { | |||||
| Status ret = PreRunAfterOptimizeSubGraph(graph_node, compute_graph, ge_root_model, session_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Run PreRunAfterOptimizeSubGraph failed for graph:%s.", compute_graph->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| // when set incre build, save om model and var manager | // when set incre build, save om model and var manager | ||||
| GeModelPtr ge_model = nullptr; | GeModelPtr ge_model = nullptr; | ||||
| @@ -456,7 +728,7 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| ret = PreRun(graph_node, inputs, ge_root_model, session_id); | ret = PreRun(graph_node, inputs, ge_root_model, session_id); | ||||
| // release rts generate context | // release rts generate context | ||||
| RtContextUtil::GetInstance().DestroyRtContexts(session_id); | |||||
| RtContextUtil::GetInstance().DestroyRtContexts(session_id, graph_node->GetGraphId()); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "PreRun Failed."); | GELOGE(ret, "PreRun Failed."); | ||||
| return ret; | return ret; | ||||
| @@ -1065,7 +1337,7 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||||
| // net output node dataType | // net output node dataType | ||||
| ParseOption(options, OUTPUT_DATATYPE, options_.output_datatype); | ParseOption(options, OUTPUT_DATATYPE, options_.output_datatype); | ||||
| if (!options_.output_datatype.empty()) { | if (!options_.output_datatype.empty()) { | ||||
| domi::GetContext().output_type = options_.output_datatype; | |||||
| omg_context_.output_type = options_.output_datatype; | |||||
| } | } | ||||
| // Set save_original_model flag (ge.save_original_model) | // Set save_original_model flag (ge.save_original_model) | ||||
| @@ -1074,6 +1346,10 @@ Status GraphManager::ParseOptions(const std::map<std::string, std::string> &opti | |||||
| // Original model file name | // Original model file name | ||||
| ParseOption(options, ORIGINAL_MODEL_FILE, options_.original_model_file); | ParseOption(options, ORIGINAL_MODEL_FILE, options_.original_model_file); | ||||
| // Set Build model and step | |||||
| ParseOption(options, BUILD_MODE, options_.build_mode); | |||||
| ParseOption(options, BUILD_STEP, options_.build_step); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1659,6 +1935,7 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
| ReshapeRemovePass reshape_remove_pass; | ReshapeRemovePass reshape_remove_pass; | ||||
| ConstantFoldingPass constant_folding_pass; | ConstantFoldingPass constant_folding_pass; | ||||
| DimensionAdjustPass dimension_adjust_pass; | DimensionAdjustPass dimension_adjust_pass; | ||||
| EnterPass enter_pass; | |||||
| AddNPass addn_pass; | AddNPass addn_pass; | ||||
| SwitchDeadBranchElimination switch_dead_branch_elimination; | SwitchDeadBranchElimination switch_dead_branch_elimination; | ||||
| SwitchLogicRemovePass switch_logic_remove_pass; | SwitchLogicRemovePass switch_logic_remove_pass; | ||||
| @@ -1667,15 +1944,16 @@ Status GraphManager::OptimizeStage1(ge::ComputeGraphPtr &compute_graph) { | |||||
| TransposeTransDataPass transpose_transdata_pass; | TransposeTransDataPass transpose_transdata_pass; | ||||
| TransOpSymmetryEliminationPass symmetry_elimination_pass; | TransOpSymmetryEliminationPass symmetry_elimination_pass; | ||||
| DimensionComputePass dimension_compute_pass; | DimensionComputePass dimension_compute_pass; | ||||
| names_to_passes.emplace_back("EnterPass", &enter_pass); | |||||
| names_to_passes.emplace_back("AddNPass", &addn_pass); | names_to_passes.emplace_back("AddNPass", &addn_pass); | ||||
| names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); | names_to_passes.emplace_back("SwitchDeadBranchElimination", &switch_dead_branch_elimination); | ||||
| names_to_passes.emplace_back("SwitchLogicRemovePass", &switch_logic_remove_pass); | names_to_passes.emplace_back("SwitchLogicRemovePass", &switch_logic_remove_pass); | ||||
| names_to_passes.emplace_back("MergePass", &merge_pass); | names_to_passes.emplace_back("MergePass", &merge_pass); | ||||
| names_to_passes.emplace_back("CastRemovePass", &cast_remove_pass); | names_to_passes.emplace_back("CastRemovePass", &cast_remove_pass); | ||||
| names_to_passes.emplace_back("TransposeTransDataPass", &transpose_transdata_pass); | names_to_passes.emplace_back("TransposeTransDataPass", &transpose_transdata_pass); | ||||
| names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); | |||||
| names_to_passes.emplace_back("TransOpSymmetryEliminationPass", &symmetry_elimination_pass); | names_to_passes.emplace_back("TransOpSymmetryEliminationPass", &symmetry_elimination_pass); | ||||
| names_to_passes.emplace_back("TransOpNearbyAllreduceFusionPass", &trans_op_nearby_allreduce_fusion_pass); | names_to_passes.emplace_back("TransOpNearbyAllreduceFusionPass", &trans_op_nearby_allreduce_fusion_pass); | ||||
| names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); | |||||
| names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); | names_to_passes.emplace_back("DimensionComputePass", &dimension_compute_pass); | ||||
| names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | ||||
| names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); | names_to_passes.emplace_back("DimensionAdjustPass", &dimension_adjust_pass); | ||||
| @@ -1975,6 +2253,7 @@ Status GraphManager::ProcessSubGraphWithMultiThreads(GraphManager *graph_manager | |||||
| Status ret = SUCCESS; | Status ret = SUCCESS; | ||||
| GetThreadLocalContext() = ge_context; | GetThreadLocalContext() = ge_context; | ||||
| if (sub_graph_info_ptr != nullptr && graph_manager != nullptr) { | if (sub_graph_info_ptr != nullptr && graph_manager != nullptr) { | ||||
| SetLocalOmgContext(graph_manager->omg_context_); | |||||
| ComputeGraphPtr compute_graph_tmp = sub_graph_info_ptr->GetSubGraph(); | ComputeGraphPtr compute_graph_tmp = sub_graph_info_ptr->GetSubGraph(); | ||||
| const std::string &engine_name = sub_graph_info_ptr->GetEngineName(); | const std::string &engine_name = sub_graph_info_ptr->GetEngineName(); | ||||
| GELOGI("ProcessSubGraphWithMultiThreads start, graph name is %s, engine_name is %s, thread id is %lu", | GELOGI("ProcessSubGraphWithMultiThreads start, graph name is %s, engine_name is %s, thread id is %lu", | ||||
| @@ -2079,6 +2358,8 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
| if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { | if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { | ||||
| GELOGW("Set thread name failed."); | GELOGW("Set thread name failed."); | ||||
| } | } | ||||
| SetLocalOmgContext(graph_manager->omg_context_); | |||||
| PreRunArgs args; | PreRunArgs args; | ||||
| while (graph_manager->thread_run_flag_) { | while (graph_manager->thread_run_flag_) { | ||||
| bool pop_status = graph_manager->prerun_args_q_.Pop(args); | bool pop_status = graph_manager->prerun_args_q_.Pop(args); | ||||
| @@ -2146,10 +2427,10 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
| if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { | if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { | ||||
| ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); | ret = graph_manager->PreRun(graph_node, ge_inputs, ge_root_model, args.session_id); | ||||
| // release rts generate context | // release rts generate context | ||||
| RtContextUtil::GetInstance().DestroyRtContexts(args.session_id); | |||||
| RtContextUtil::GetInstance().DestroyRtContexts(args.session_id, graph_node->GetGraphId()); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| graph_node->SetRunFlag(false); | graph_node->SetRunFlag(false); | ||||
| if (!std::getenv("AnalyzeMode")) { | |||||
| if (!ge::Analyzer::GetInstance()->IsEnableNetAnalyzeDebug()) { | |||||
| ReturnError(graph_manager, args.callback, ret, "PreRun Failed, thread exit.."); | ReturnError(graph_manager, args.callback, ret, "PreRun Failed, thread exit.."); | ||||
| graph_node->Unlock(); | graph_node->Unlock(); | ||||
| return; | return; | ||||
| @@ -2176,6 +2457,8 @@ void GraphManager::RunThread(GraphManager *graph_manager) { | |||||
| if (prctl(PR_SET_NAME, ("GE_Run")) != 0) { | if (prctl(PR_SET_NAME, ("GE_Run")) != 0) { | ||||
| GELOGW("Set thread name failed."); | GELOGW("Set thread name failed."); | ||||
| } | } | ||||
| SetLocalOmgContext(graph_manager->omg_context_); | |||||
| RunArgs args; | RunArgs args; | ||||
| while (graph_manager->thread_run_flag_) { | while (graph_manager->thread_run_flag_) { | ||||
| bool pop_status = graph_manager->run_args_q_.Pop(args); | bool pop_status = graph_manager->run_args_q_.Pop(args); | ||||
| @@ -2287,17 +2570,11 @@ void GraphManager::ReturnError(GraphManager *graph_manager, GraphNodePtr &graph_ | |||||
| return; | return; | ||||
| } | } | ||||
| tensor.length = len * size; | tensor.length = len * size; | ||||
| auto pbuff = new (std::nothrow) uint8_t[tensor.length]; | |||||
| if (!pbuff) { | |||||
| GELOGE(MEMALLOC_FAILED, "new buff failed!"); | |||||
| callback(GRAPH_FAILED, outputs); | |||||
| return; | |||||
| } | |||||
| tensor.data.reset(new (std::nothrow) uint8_t[tensor.length]); | |||||
| // To avoid global step too small and can not stop, totally set a bigger value | // To avoid global step too small and can not stop, totally set a bigger value | ||||
| for (int64_t i = 0; i < tensor.length; i++) { | for (int64_t i = 0; i < tensor.length; i++) { | ||||
| *(pbuff + i) = 0x7F; // here stands for a positive max value | |||||
| tensor.data[i] = 0x7F; // here stands for a positive max value | |||||
| } | } | ||||
| tensor.data.reset(pbuff); | |||||
| outputs.emplace_back(std::move(tensor)); | outputs.emplace_back(std::move(tensor)); | ||||
| } | } | ||||
| } | } | ||||
| @@ -2373,6 +2650,20 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| GE_TIMESTAMP_EVENT_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph"); | GE_TIMESTAMP_EVENT_END(SetSubgraph, "OptimizeSubgraph::SetSubGraph"); | ||||
| if ((options_.build_mode == BUILD_MODE_TUNING) && | |||||
| (options_.build_step == BUILD_STEP_BEFORE_UB_MATCH || options_.build_step == BUILD_STEP_AFTER_BUILDER || | |||||
| options_.build_step == BUILD_STEP_AFTER_BUILDER_SUB)) { | |||||
| GE_TIMESTAMP_START(ConvertGraphToFile); | |||||
| std::string tuning_path; | |||||
| (void)GetContext().GetOption(TUNING_PATH, tuning_path); | |||||
| Status ret = ConvertGraphToFile(compute_graph, tuning_path, (options_.build_step == BUILD_STEP_AFTER_BUILDER)); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Convert graph[%s] to file failed", compute_graph->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| GE_TIMESTAMP_EVENT_END(ConvertGraphToFile, "OptimizeSubgraph::ConvertGraphToFile"); | |||||
| return SUCCESS; | |||||
| } | |||||
| ComputeGraphPtr merged_compute_graph = nullptr; | ComputeGraphPtr merged_compute_graph = nullptr; | ||||
| std::vector<ComputeGraphPtr> merged_sub_graph_list; | std::vector<ComputeGraphPtr> merged_sub_graph_list; | ||||
| @@ -2400,6 +2691,32 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphManager::ConvertGraphToFile(ComputeGraphPtr &compute_graph, std::string path, bool exe_flag) { | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| GELOGI("compute_graph [%s] path [%s] Enter ConvertGraphToFile.", compute_graph->GetName().c_str(), path.c_str()); | |||||
| std::vector<ComputeGraphPtr> non_tuning_subgraphs; | |||||
| auto input_node_sub_graph_map = graph_partitioner_.graph_2_input_subgraph_; | |||||
| const auto &input_subgraph_info = input_node_sub_graph_map[compute_graph]; | |||||
| GE_CHECK_NOTNULL(input_subgraph_info); | |||||
| ComputeGraphPtr input_graph_tmp = input_subgraph_info->GetSubGraph(); | |||||
| non_tuning_subgraphs.push_back(input_graph_tmp); | |||||
| auto sub_graph_map = graph_partitioner_.GetSubGraphMap(); | |||||
| const auto &subgraph_infos = sub_graph_map[compute_graph]; | |||||
| std::vector<ComputeGraphPtr> tuning_subgraphs; | |||||
| for (const auto &sub_graph_info_ptr : subgraph_infos) { | |||||
| GE_CHECK_NOTNULL(sub_graph_info_ptr); | |||||
| ComputeGraphPtr sub_graph_tmp = sub_graph_info_ptr->GetSubGraph(); | |||||
| // need to tuning | |||||
| if (sub_graph_info_ptr->GetEngineName() == kVectorEngine || sub_graph_info_ptr->GetEngineName() == kAIcoreEngine) { | |||||
| tuning_subgraphs.push_back(sub_graph_tmp); | |||||
| } else { | |||||
| non_tuning_subgraphs.push_back(sub_graph_tmp); | |||||
| } | |||||
| } | |||||
| return TuningUtils::ConvertGraphToFile(tuning_subgraphs, non_tuning_subgraphs, exe_flag, path); | |||||
| } | |||||
| Status GraphManager::Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, | Status GraphManager::Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, | ||||
| GeRootModelPtr &ge_root_model, uint64_t session_id) { | GeRootModelPtr &ge_root_model, uint64_t session_id) { | ||||
| // build | // build | ||||
| @@ -39,12 +39,13 @@ | |||||
| #include "graph/optimize/graph_optimize.h" | #include "graph/optimize/graph_optimize.h" | ||||
| #include "graph/partition/graph_partition.h" | #include "graph/partition/graph_partition.h" | ||||
| #include "graph/preprocess/graph_preprocess.h" | #include "graph/preprocess/graph_preprocess.h" | ||||
| #include "graph/tuning_utils.h" | |||||
| #include "model/ge_model.h" | #include "model/ge_model.h" | ||||
| namespace ge { | namespace ge { | ||||
| class GraphManager { | class GraphManager { | ||||
| public: | public: | ||||
| GraphManager(); | |||||
| GraphManager(OmgContext &omg_context); | |||||
| ~GraphManager() = default; | ~GraphManager() = default; | ||||
| @@ -248,6 +249,8 @@ class GraphManager { | |||||
| Status MergeSubGraph(ComputeGraphPtr &compute_graph, const ge::ComputeGraphPtr &original_compute_graph); | Status MergeSubGraph(ComputeGraphPtr &compute_graph, const ge::ComputeGraphPtr &original_compute_graph); | ||||
| Status ConvertGraphToFile(ComputeGraphPtr &compute_graph, std::string file_path, bool exe_flag = false); | |||||
| Status SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph); | Status SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph); | ||||
| void SetAttrForHcomBroadCastOp(ge::ComputeGraphPtr &compute_graph); | void SetAttrForHcomBroadCastOp(ge::ComputeGraphPtr &compute_graph); | ||||
| @@ -304,6 +307,25 @@ class GraphManager { | |||||
| void ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph); | void ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph); | ||||
| Status PreRunOptimizeOriginalGraph(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, | |||||
| ge::ComputeGraphPtr &compute_graph, uint64_t session_id); | |||||
| Status PreRunOptimizeSubGraph(const GraphNodePtr &graph_node, ge::ComputeGraphPtr &compute_graph, | |||||
| uint64_t session_id); | |||||
| Status PreRunAfterOptimizeSubGraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, | |||||
| GeRootModelPtr &ge_root_model, uint64_t session_id); | |||||
| Status CopySubGraphAndMarkFusion(const ComputeGraphPtr &compute_graph, Graph2SubGraphInfoList &sub_graph_map, | |||||
| std::unordered_map<std::string, ComputeGraphPtr> ©_graphs); | |||||
| Status OptimizeSubGraphWithMultiThreads(ComputeGraphPtr compute_graph, Graph2SubGraphInfoList &sub_graph_map, | |||||
| uint64_t session_id); | |||||
| bool CheckAllFusionOptimizeSuccess(const ComputeGraphPtr &compute_graph, Graph2SubGraphInfoList &sub_graph_map); | |||||
| Status ReplaceSubgraphWithOriGraph(const ComputeGraphPtr &compute_graph, Graph2SubGraphInfoList &sub_graph_map, | |||||
| std::unordered_map<std::string, ComputeGraphPtr> ©_graphs); | |||||
| Status SetRtContext(rtContext_t rt_context, rtCtxMode_t mode, uint64_t session_id, uint32_t graph_id); | |||||
| std::atomic_bool thread_run_flag_; | std::atomic_bool thread_run_flag_; | ||||
| BlockingQueue<PreRunArgs> prerun_args_q_{}; | BlockingQueue<PreRunArgs> prerun_args_q_{}; | ||||
| BlockingQueue<RunArgs> run_args_q_{}; | BlockingQueue<RunArgs> run_args_q_{}; | ||||
| @@ -326,6 +348,7 @@ class GraphManager { | |||||
| bool init_flag_; | bool init_flag_; | ||||
| GraphManagerOptions options_; | GraphManagerOptions options_; | ||||
| OmgContext &omg_context_; | |||||
| GraphPrepare graph_preparer_; | GraphPrepare graph_preparer_; | ||||
| GraphOptimize graph_optimize_; | GraphOptimize graph_optimize_; | ||||
| @@ -163,42 +163,4 @@ bool HasCalcOp(const ComputeGraphPtr &graph) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| Status ParseOutNodes(const string &out_nodes) { | |||||
| try { | |||||
| if (!out_nodes.empty()) { | |||||
| domi::GetContext().out_nodes_map.clear(); | |||||
| domi::GetContext().user_out_nodes.clear(); | |||||
| vector<string> nodes_v = StringUtils::Split(out_nodes, ';'); | |||||
| for (const string &node : nodes_v) { | |||||
| vector<string> key_value_v = StringUtils::Split(node, ':'); | |||||
| if (key_value_v.size() != 2) { // must contain 2 items | |||||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "Invalid outNodes: %s", node.c_str()); | |||||
| return GE_GRAPH_PARAM_NULLPTR; | |||||
| } | |||||
| auto iter = domi::GetContext().out_nodes_map.find(key_value_v[0]); | |||||
| int32_t index = std::stoi(StringUtils::Trim(key_value_v[1])); | |||||
| if (iter != domi::GetContext().out_nodes_map.end()) { | |||||
| iter->second.emplace_back(index); | |||||
| } else { | |||||
| std::vector<int32_t> index_v; | |||||
| index_v.emplace_back(index); | |||||
| domi::GetContext().out_nodes_map.emplace(key_value_v[0], index_v); | |||||
| } | |||||
| domi::GetContext().user_out_nodes.emplace_back(key_value_v[0], index); | |||||
| } | |||||
| } | |||||
| } catch (std::invalid_argument &) { | |||||
| GELOGE(PARAM_INVALID, "out nodes: %s, key value[1] is invalid argument", out_nodes.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } catch (std::out_of_range &) { | |||||
| GELOGE(PARAM_INVALID, "out nodes: %s, key value[1] is out of range", out_nodes.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } catch (...) { | |||||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "Invalid outNodes: %s", out_nodes.c_str()); | |||||
| return GE_GRAPH_PARAM_NULLPTR; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -116,6 +116,7 @@ class SubGraphInfo { | |||||
| using SubGraphInfoPtr = std::shared_ptr<ge::SubGraphInfo>; | using SubGraphInfoPtr = std::shared_ptr<ge::SubGraphInfo>; | ||||
| using Graph2SubGraphInfoList = std::unordered_map<ComputeGraphPtr, std::vector<SubGraphInfoPtr>>; | using Graph2SubGraphInfoList = std::unordered_map<ComputeGraphPtr, std::vector<SubGraphInfoPtr>>; | ||||
| using Graph2InputNodesSubGraphInfo = std::unordered_map<ComputeGraphPtr, SubGraphInfoPtr>; | |||||
| // for run graph async listener | // for run graph async listener | ||||
| class RunAsyncListener : public ge::ModelListener { | class RunAsyncListener : public ge::ModelListener { | ||||
| @@ -220,8 +221,6 @@ class GraphModelListener : public ge::ModelListener { | |||||
| std::condition_variable &condition_; | std::condition_variable &condition_; | ||||
| }; | }; | ||||
| Status ParseOutNodes(const string &out_nodes); | |||||
| struct GraphManagerOptions { | struct GraphManagerOptions { | ||||
| int32_t stream_num; | int32_t stream_num; | ||||
| int32_t perf_level; | int32_t perf_level; | ||||
| @@ -248,6 +247,8 @@ struct GraphManagerOptions { | |||||
| std::string output_datatype; | std::string output_datatype; | ||||
| std::string original_model_file; | std::string original_model_file; | ||||
| std::string save_original_model; | std::string save_original_model; | ||||
| std::string build_mode; | |||||
| std::string build_step; | |||||
| GraphManagerOptions() | GraphManagerOptions() | ||||
| : stream_num(1), | : stream_num(1), | ||||
| perf_level(domi::GEN_TASK_WITHOUT_FUSION), | perf_level(domi::GEN_TASK_WITHOUT_FUSION), | ||||
| @@ -269,7 +270,9 @@ struct GraphManagerOptions { | |||||
| hcom_parallel(false), | hcom_parallel(false), | ||||
| enable_print_op_pass(true), | enable_print_op_pass(true), | ||||
| is_single_op(false), | is_single_op(false), | ||||
| save_original_model("false") {} | |||||
| save_original_model("false"), | |||||
| build_mode(""), | |||||
| build_step("") {} | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -15,13 +15,13 @@ | |||||
| */ | */ | ||||
| #include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
| #include "graph/manager/graph_caching_allocator.h" | |||||
| #include <set> | #include <set> | ||||
| #include <string> | #include <string> | ||||
| #include <utility> | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/manager/graph_caching_allocator.h" | |||||
| #include "graph/manager/rdma_pool_allocator.h" | |||||
| namespace ge { | namespace ge { | ||||
| void MemoryAllocator::Initialize(uint32_t device_id) { | void MemoryAllocator::Initialize(uint32_t device_id) { | ||||
| @@ -185,30 +185,36 @@ Status MemManager::Initialize(const std::vector<rtMemType_t> &memory_type) { | |||||
| } | } | ||||
| } | } | ||||
| return InitCachingAllocator(memory_type); | |||||
| if (InitAllocator(memory_type, caching_allocator_map_) != SUCCESS) { | |||||
| GELOGE(ge::INTERNAL_ERROR, "Create CachingAllocator failed."); | |||||
| return ge::INTERNAL_ERROR; | |||||
| } | |||||
| if (InitAllocator(memory_type, rdma_allocator_map_) != SUCCESS) { | |||||
| GELOGE(ge::INTERNAL_ERROR, "Create RdmaAllocator failed."); | |||||
| return ge::INTERNAL_ERROR; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | } | ||||
| void MemManager::Finalize() noexcept { | |||||
| GELOGI("Finalize."); | |||||
| std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
| // caching allocator use memory allocator, so finalize it first | |||||
| for (auto &caching_allocator : caching_allocator_map_) { | |||||
| if (caching_allocator.second != nullptr) { | |||||
| caching_allocator.second->Finalize(); | |||||
| delete caching_allocator.second; | |||||
| caching_allocator.second = nullptr; | |||||
| template <typename T> | |||||
| void FinalizeAllocatorMap(std::map<rtMemType_t, T *> &allocate_map) { | |||||
| for (auto &allocator : allocate_map) { | |||||
| if (allocator.second != nullptr) { | |||||
| allocator.second->Finalize(); | |||||
| delete allocator.second; | |||||
| allocator.second = nullptr; | |||||
| } | } | ||||
| } | } | ||||
| caching_allocator_map_.clear(); | |||||
| allocate_map.clear(); | |||||
| } | |||||
| for (auto &memory_allocator : memory_allocator_map_) { | |||||
| if (memory_allocator.second != nullptr) { | |||||
| memory_allocator.second->Finalize(); | |||||
| delete memory_allocator.second; | |||||
| memory_allocator.second = nullptr; | |||||
| } | |||||
| } | |||||
| memory_allocator_map_.clear(); | |||||
| void MemManager::Finalize() noexcept { | |||||
| GELOGI("Finalize."); | |||||
| std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
| // caching and rdma allocator use memory allocator, so finalize them first | |||||
| FinalizeAllocatorMap(caching_allocator_map_); | |||||
| FinalizeAllocatorMap(rdma_allocator_map_); | |||||
| FinalizeAllocatorMap(memory_allocator_map_); | |||||
| } | } | ||||
| MemoryAllocator *MemManager::GetMemoryAllocator(rtMemType_t memory_type) { | MemoryAllocator *MemManager::GetMemoryAllocator(rtMemType_t memory_type) { | ||||
| @@ -229,53 +235,11 @@ MemoryAllocator *MemManager::GetMemoryAllocator(rtMemType_t memory_type) { | |||||
| return memory_allocator; | return memory_allocator; | ||||
| } | } | ||||
| Status MemManager::InitCachingAllocator(const std::vector<rtMemType_t> &memory_type) { | |||||
| CachingAllocator *caching_allocator = nullptr; | |||||
| for (unsigned int index : memory_type) { | |||||
| auto it = caching_allocator_map_.find(index); | |||||
| if (it == caching_allocator_map_.end()) { | |||||
| caching_allocator = new (std::nothrow) CachingAllocator(index); | |||||
| if (caching_allocator != nullptr) { | |||||
| caching_allocator_map_[index] = caching_allocator; | |||||
| GELOGI("Create CachingAllocator memory type[%u] success.", index); | |||||
| } else { | |||||
| GELOGE(ge::INTERNAL_ERROR, "Alloc CachingAllocator failed."); | |||||
| } | |||||
| } else { | |||||
| caching_allocator = it->second; | |||||
| } | |||||
| if (caching_allocator == nullptr) { | |||||
| GELOGE(ge::INTERNAL_ERROR, "Create CachingAllocator failed."); | |||||
| return ge::INTERNAL_ERROR; | |||||
| } else { | |||||
| if (caching_allocator->Initialize() != ge::SUCCESS) { | |||||
| return ge::INTERNAL_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| CachingAllocator &MemManager::GetCachingAllocator(rtMemType_t memory_type) { | |||||
| std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
| CachingAllocator *caching_allocator = nullptr; | |||||
| auto it = caching_allocator_map_.find(memory_type); | |||||
| if (it != caching_allocator_map_.end()) { | |||||
| caching_allocator = it->second; | |||||
| } | |||||
| // Usually impossible | |||||
| if (caching_allocator == nullptr) { | |||||
| GELOGE(ge::INTERNAL_ERROR, "GetCachingAllocator failed, memory type is %u.", memory_type); | |||||
| static CachingAllocator default_caching_allocator(RT_MEMORY_RESERVED); | |||||
| return default_caching_allocator; | |||||
| ; | |||||
| } | |||||
| return *caching_allocator; | |||||
| CachingAllocator &MemManager::CachingInstance(rtMemType_t memory_type) { | |||||
| return Instance().GetAllocator(memory_type, caching_allocator_map_); | |||||
| } | } | ||||
| CachingAllocator &MemManager::CachingInstance(rtMemType_t memory_type) { | |||||
| return Instance().GetCachingAllocator(memory_type); | |||||
| RdmaPoolAllocator &MemManager::RdmaPoolInstance(rtMemType_t memory_type) { | |||||
| return Instance().GetAllocator(memory_type, rdma_allocator_map_); | |||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| #include "runtime/mem.h" | #include "runtime/mem.h" | ||||
| @@ -136,6 +137,7 @@ class MemoryAllocator { | |||||
| using MemoryAllocatorPtr = std::shared_ptr<MemoryAllocator>; | using MemoryAllocatorPtr = std::shared_ptr<MemoryAllocator>; | ||||
| class CachingAllocator; | class CachingAllocator; | ||||
| class RdmaPoolAllocator; | |||||
| class MemManager { | class MemManager { | ||||
| public: | public: | ||||
| @@ -143,7 +145,8 @@ class MemManager { | |||||
| virtual ~MemManager(); | virtual ~MemManager(); | ||||
| static MemManager &Instance(); | static MemManager &Instance(); | ||||
| static MemoryAllocator *Instance(rtMemType_t memory_type); | static MemoryAllocator *Instance(rtMemType_t memory_type); | ||||
| static CachingAllocator &CachingInstance(rtMemType_t memory_type); | |||||
| CachingAllocator &CachingInstance(rtMemType_t memory_type); | |||||
| RdmaPoolAllocator &RdmaPoolInstance(rtMemType_t memory_type); | |||||
| MemManager(const MemManager &) = delete; | MemManager(const MemManager &) = delete; | ||||
| MemManager &operator=(const MemManager &) = delete; | MemManager &operator=(const MemManager &) = delete; | ||||
| /// | /// | ||||
| @@ -172,22 +175,65 @@ class MemManager { | |||||
| /// | /// | ||||
| /// @ingroup ge_graph | /// @ingroup ge_graph | ||||
| /// @brief ge caching allocator | |||||
| /// @param [in] memory_type memory type | /// @param [in] memory_type memory type | ||||
| /// @return CachingAllocator ptr | |||||
| /// @param [in] allocate_map memory allocator map | |||||
| /// @return Status result of function | |||||
| /// | /// | ||||
| CachingAllocator &GetCachingAllocator(rtMemType_t memory_type); | |||||
| template <typename T> | |||||
| Status InitAllocator(const std::vector<rtMemType_t> &memory_type, std::map<rtMemType_t, T *> &allocate_map) { | |||||
| T *allocator = nullptr; | |||||
| for (unsigned int index : memory_type) { | |||||
| auto it = allocate_map.find(index); | |||||
| if (it == allocate_map.end()) { | |||||
| allocator = new (std::nothrow) T(index); | |||||
| if (allocator != nullptr) { | |||||
| allocate_map[index] = allocator; | |||||
| GELOGI("Create Allocator memory type[%u] success.", index); | |||||
| } else { | |||||
| GELOGE(INTERNAL_ERROR, "Alloc Allocator failed."); | |||||
| } | |||||
| } else { | |||||
| allocator = it->second; | |||||
| } | |||||
| if (allocator == nullptr) { | |||||
| GELOGE(INTERNAL_ERROR, "Create Allocator failed."); | |||||
| return INTERNAL_ERROR; | |||||
| } else { | |||||
| if (allocator->Initialize() != SUCCESS) { | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| /// | /// | ||||
| /// @ingroup ge_graph | /// @ingroup ge_graph | ||||
| /// @brief ge create caching allocator | |||||
| /// @param [in] memory_type memory type | /// @param [in] memory_type memory type | ||||
| /// @return Status result of function | |||||
| /// | |||||
| Status InitCachingAllocator(const std::vector<rtMemType_t> &memory_type); | |||||
| /// @param [in] allocate_map memory allocator map | |||||
| /// @return Allocator ptr | |||||
| /// | |||||
| template <typename T> | |||||
| T &GetAllocator(rtMemType_t memory_type, std::map<rtMemType_t, T *> allocate_map) { | |||||
| std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
| T *allocator = nullptr; | |||||
| auto it = allocate_map.find(memory_type); | |||||
| if (it != allocate_map.end()) { | |||||
| allocator = it->second; | |||||
| } | |||||
| // Usually impossible | |||||
| if (allocator == nullptr) { | |||||
| GELOGE(ge::INTERNAL_ERROR, "Get allocator failed, memory type is %u.", memory_type); | |||||
| static T default_allocator(RT_MEMORY_RESERVED); | |||||
| return default_allocator; | |||||
| } | |||||
| return *allocator; | |||||
| } | |||||
| std::map<rtMemType_t, MemoryAllocator *> memory_allocator_map_; | std::map<rtMemType_t, MemoryAllocator *> memory_allocator_map_; | ||||
| std::map<rtMemType_t, CachingAllocator *> caching_allocator_map_; | std::map<rtMemType_t, CachingAllocator *> caching_allocator_map_; | ||||
| std::map<rtMemType_t, RdmaPoolAllocator *> rdma_allocator_map_; | |||||
| std::recursive_mutex allocator_mutex_; | std::recursive_mutex allocator_mutex_; | ||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -15,7 +15,11 @@ | |||||
| */ | */ | ||||
| #include "graph/manager/rdma_pool_allocator.h" | #include "graph/manager/rdma_pool_allocator.h" | ||||
| #include <framework/common/debug/log.h> | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/ge_context.h" | |||||
| #include "runtime/dev.h" | |||||
| namespace { | namespace { | ||||
| const size_t kAlignedSize = 512; | const size_t kAlignedSize = 512; | ||||
| @@ -52,31 +56,41 @@ Status RdmaPoolAllocator::Initialize() { | |||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| void RdmaPoolAllocator::Finalize() { | void RdmaPoolAllocator::Finalize() { | ||||
| GELOGD("Rdma pool finalize start."); | |||||
| for (auto it = allocated_blocks_.begin(); it != allocated_blocks_.end();) { | for (auto it = allocated_blocks_.begin(); it != allocated_blocks_.end();) { | ||||
| auto block = it->second; | auto block = it->second; | ||||
| allocated_blocks_.erase(it); | |||||
| it = allocated_blocks_.erase(it); | |||||
| delete block; | delete block; | ||||
| } | } | ||||
| for (auto it = block_bin_.begin(); it != block_bin_.end();) { | for (auto it = block_bin_.begin(); it != block_bin_.end();) { | ||||
| auto block = *it; | auto block = *it; | ||||
| block_bin_.erase(it); | |||||
| it = block_bin_.erase(it); | |||||
| delete block; | delete block; | ||||
| } | } | ||||
| if (rdma_base_addr_ != nullptr) { | if (rdma_base_addr_ != nullptr) { | ||||
| GELOGD("Start to free rdma pool memory."); | |||||
| if (memory_allocator_->FreeMemory(rdma_base_addr_) != SUCCESS) { | if (memory_allocator_->FreeMemory(rdma_base_addr_) != SUCCESS) { | ||||
| GELOGW("Free rdma pool memory failed"); | GELOGW("Free rdma pool memory failed"); | ||||
| } | } | ||||
| rdma_base_addr_ = nullptr; | |||||
| } | } | ||||
| } | } | ||||
| Status RdmaPoolAllocator::InitMemory(size_t mem_size, uint32_t device_id) { | |||||
| Status RdmaPoolAllocator::InitMemory(size_t mem_size) { | |||||
| auto device_id = GetContext().DeviceId(); | |||||
| GELOGD("Init Rdma Memory with size [%zu] for devid:[%u]", mem_size, device_id); | |||||
| if (rdma_base_addr_ != nullptr) { | if (rdma_base_addr_ != nullptr) { | ||||
| GELOGE(GE_MULTI_INIT, "Rdma pool has been malloced"); | GELOGE(GE_MULTI_INIT, "Rdma pool has been malloced"); | ||||
| return GE_MULTI_INIT; | return GE_MULTI_INIT; | ||||
| } | } | ||||
| const std::string purpose = "Memory for rdma pool."; | const std::string purpose = "Memory for rdma pool."; | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| auto dev_id = static_cast<int32_t>(device_id); | |||||
| GE_CHK_RT_RET(rtSetDevice(dev_id)); | |||||
| // DeviceReset before memory finished! | |||||
| GE_MAKE_GUARD(not_used_var, [&] { GE_CHK_RT(rtDeviceReset(dev_id)); }); | |||||
| rdma_base_addr_ = memory_allocator_->MallocMemory(purpose, mem_size, device_id); | rdma_base_addr_ = memory_allocator_->MallocMemory(purpose, mem_size, device_id); | ||||
| if (rdma_base_addr_ == nullptr) { | if (rdma_base_addr_ == nullptr) { | ||||
| GELOGE(GE_GRAPH_MALLOC_FAILED, "Rdma pool memory malloc failed"); | GELOGE(GE_GRAPH_MALLOC_FAILED, "Rdma pool memory malloc failed"); | ||||
| @@ -94,6 +108,7 @@ Status RdmaPoolAllocator::InitMemory(size_t mem_size, uint32_t device_id) { | |||||
| } | } | ||||
| uint8_t *RdmaPoolAllocator::Malloc(size_t size, uint32_t device_id) { | uint8_t *RdmaPoolAllocator::Malloc(size_t size, uint32_t device_id) { | ||||
| GELOGI("start to malloc rdma memory size:%zu, device id = %u", size, device_id); | |||||
| auto aligned_size = GetAlignedBlockSize(size); | auto aligned_size = GetAlignedBlockSize(size); | ||||
| Block key(device_id, aligned_size, nullptr); | Block key(device_id, aligned_size, nullptr); | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| @@ -107,9 +122,9 @@ uint8_t *RdmaPoolAllocator::Malloc(size_t size, uint32_t device_id) { | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| allocated_blocks_.emplace(block->ptr, block); | allocated_blocks_.emplace(block->ptr, block); | ||||
| GELOGI("Find block size = %zu", block->size); | |||||
| if (ShouldSplit(block, aligned_size)) { | if (ShouldSplit(block, aligned_size)) { | ||||
| GELOGD("Block will be splited block size = %zu, aligned_size:%zu", block->size, aligned_size); | |||||
| auto *new_block = | auto *new_block = | ||||
| new (std::nothrow) Block(device_id, block->size - aligned_size, nullptr, block->ptr + aligned_size); | new (std::nothrow) Block(device_id, block->size - aligned_size, nullptr, block->ptr + aligned_size); | ||||
| if (new_block == nullptr) { | if (new_block == nullptr) { | ||||
| @@ -126,12 +141,14 @@ uint8_t *RdmaPoolAllocator::Malloc(size_t size, uint32_t device_id) { | |||||
| block_bin_.insert(new_block); | block_bin_.insert(new_block); | ||||
| } | } | ||||
| return block->ptr; | return block->ptr; | ||||
| GELOGD("Find block size = %zu", block->size); | |||||
| } | } | ||||
| GELOGW("Memory block not founded."); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| Status RdmaPoolAllocator::Free(uint8_t *memory_addr, uint32_t device_id) { | Status RdmaPoolAllocator::Free(uint8_t *memory_addr, uint32_t device_id) { | ||||
| GELOGI("Free device id = %u", device_id); | |||||
| GELOGI("Free rdma memory, device id = %u", device_id); | |||||
| if (memory_addr == nullptr) { | if (memory_addr == nullptr) { | ||||
| GELOGE(GE_GRAPH_FREE_FAILED, "Invalid memory pointer"); | GELOGE(GE_GRAPH_FREE_FAILED, "Invalid memory pointer"); | ||||
| return GE_GRAPH_FREE_FAILED; | return GE_GRAPH_FREE_FAILED; | ||||
| @@ -143,27 +160,41 @@ Status RdmaPoolAllocator::Free(uint8_t *memory_addr, uint32_t device_id) { | |||||
| GELOGE(PARAM_INVALID, "Invalid memory pointer"); | GELOGE(PARAM_INVALID, "Invalid memory pointer"); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| Block *block = it->second; | Block *block = it->second; | ||||
| block->allocated = false; | block->allocated = false; | ||||
| allocated_blocks_.erase(it); | allocated_blocks_.erase(it); | ||||
| Block *merge_blocks[] = {block->prev, block->next}; | |||||
| for (Block *merge_block : merge_blocks) { | |||||
| MergeBlocks(block, merge_block); | |||||
| } | |||||
| block_bin_.insert(block); | block_bin_.insert(block); | ||||
| // Each time merge with its pre and next. | |||||
| MergeBlockNearby(block, block->next); | |||||
| MergeBlockNearby(block->prev, block); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void RdmaPoolAllocator::MergeBlockNearby(Block *pre_block, Block *block) { | |||||
| if (!(CanMerge(pre_block) && CanMerge(block))) { | |||||
| void RdmaPoolAllocator::MergeBlocks(Block *dst, Block *src) { | |||||
| if (!CanMerge(dst) || !CanMerge(src)) { | |||||
| return; | return; | ||||
| } | } | ||||
| pre_block->size += block->size; | |||||
| pre_block->next = block->next; | |||||
| if (block->next != nullptr) { | |||||
| block->next->prev = pre_block; | |||||
| if (dst->prev == src) { | |||||
| dst->ptr = src->ptr; | |||||
| dst->prev = src->prev; | |||||
| if (dst->prev != nullptr) { | |||||
| dst->prev->next = dst; | |||||
| } | |||||
| } else { | |||||
| dst->next = src->next; | |||||
| if (dst->next != nullptr) { | |||||
| dst->next->prev = dst; | |||||
| } | |||||
| } | } | ||||
| block_bin_.erase(block); | |||||
| delete block; | |||||
| dst->size += src->size; | |||||
| block_bin_.erase(src); | |||||
| delete src; | |||||
| } | } | ||||
| Status RdmaPoolAllocator::GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size) { | Status RdmaPoolAllocator::GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size) { | ||||
| @@ -40,12 +40,12 @@ class RdmaPoolAllocator { | |||||
| RdmaPoolAllocator &operator=(const RdmaPoolAllocator &) = delete; | RdmaPoolAllocator &operator=(const RdmaPoolAllocator &) = delete; | ||||
| ~RdmaPoolAllocator() { Finalize(); } | |||||
| ~RdmaPoolAllocator() = default; | |||||
| Status Initialize(); | Status Initialize(); | ||||
| void Finalize(); | void Finalize(); | ||||
| Status InitMemory(size_t mem_size, uint32_t device_id = 0); | |||||
| Status InitMemory(size_t mem_size); | |||||
| uint8_t *Malloc(size_t size, uint32_t device_id = 0); | uint8_t *Malloc(size_t size, uint32_t device_id = 0); | ||||
| @@ -54,7 +54,7 @@ class RdmaPoolAllocator { | |||||
| Status GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size); | Status GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size); | ||||
| private: | private: | ||||
| void MergeBlockNearby(Block *pre_block, Block *block); | |||||
| void MergeBlocks(Block *dst, Block *src); | |||||
| rtMemType_t memory_type_; | rtMemType_t memory_type_; | ||||
| size_t rdma_mem_size_ = 0; // Total rdma memory size to be allocated. | size_t rdma_mem_size_ = 0; // Total rdma memory size to be allocated. | ||||