| @@ -0,0 +1,51 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_COMMON_OPSKERNELUTILS_OPS_KERNEL_INFO_UTILS_H_ | |||||
| #define INC_COMMON_OPSKERNELUTILS_OPS_KERNEL_INFO_UTILS_H_ | |||||
| #include "external/ge/ge_api_error_codes.h" | |||||
| #include "cce/aicpu_engine_struct.h" | |||||
| #include "common/opskernel/ops_kernel_info_types.h" | |||||
| #include "graph/node.h" | |||||
| #include "proto/task.pb.h" | |||||
| namespace ge { | |||||
| class OpsKernelBuilder { | |||||
| public: | |||||
| OpsKernelBuilder() = default; | |||||
| virtual ~OpsKernelBuilder() = default; | |||||
| // initialize OpsKernelBuilder | |||||
| virtual Status Initialize(const std::map<std::string, std::string> &options) = 0; | |||||
| // finalize OpsKernelBuilder | |||||
| virtual Status Finalize() = 0; | |||||
| // memory allocation requirement | |||||
| virtual Status CalcOpRunningParam(Node &node) = 0; | |||||
| // generate task for op | |||||
| virtual Status GenerateTask(const Node &node, RunContext &context, std::vector<domi::TaskDef> &tasks) = 0; | |||||
| // only call aicpu interface to generate task struct | |||||
| virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, string &task_info) { return FAILED; } | |||||
| // only call aicpu interface to generate task struct | |||||
| virtual Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info) { return FAILED; } | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_COMMON_OPSKERNELUTILS_OPS_KERNEL_INFO_UTILS_H_ | |||||
| @@ -43,10 +43,10 @@ class OpsKernelInfoStore { | |||||
| virtual ~OpsKernelInfoStore() {} | virtual ~OpsKernelInfoStore() {} | ||||
| // initialize opsKernelInfoStore | // initialize opsKernelInfoStore | ||||
| virtual Status Initialize(const map<string, string> &options) = 0; /*lint -e148*/ | |||||
| virtual Status Initialize(const map<string, string> &options) = 0; | |||||
| // close opsKernelInfoStore | // close opsKernelInfoStore | ||||
| virtual Status Finalize() = 0; /*lint -e148*/ | |||||
| virtual Status Finalize() = 0; | |||||
| virtual Status CreateSession(const std::map<std::string, std::string> &session_options) { return SUCCESS; } | virtual Status CreateSession(const std::map<std::string, std::string> &session_options) { return SUCCESS; } | ||||
| @@ -65,24 +65,11 @@ class OpsKernelInfoStore { | |||||
| // opsFlag opsFlag[0] indicates constant folding is supported or not | // opsFlag opsFlag[0] indicates constant folding is supported or not | ||||
| virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag){}; | virtual void opsFlagCheck(const ge::Node &node, std::string &opsFlag){}; | ||||
| // memory allocation requirement | |||||
| virtual Status CalcOpRunningParam(Node &node) = 0; /*lint -e148*/ | |||||
| // generate task for op。 | |||||
| virtual Status GenerateTask(const Node &node, RunContext &context, | |||||
| std::vector<domi::TaskDef> &tasks) = 0; /*lint -e148*/ | |||||
| // only call fe engine interface to compile single op | // only call fe engine interface to compile single op | ||||
| virtual Status CompileOp(vector<ge::NodePtr> &node_vec) { return SUCCESS; } | virtual Status CompileOp(vector<ge::NodePtr> &node_vec) { return SUCCESS; } | ||||
| virtual Status CompileOpRun(vector<ge::NodePtr> &node_vec) { return SUCCESS; } | virtual Status CompileOpRun(vector<ge::NodePtr> &node_vec) { return SUCCESS; } | ||||
| // load task for op | // load task for op | ||||
| virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; } | virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; } | ||||
| // only call aicpu interface to generate task struct | |||||
| virtual Status GenSingleOpRunTask(const NodePtr &node, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } | |||||
| // only call aicpu interface to generate task struct | |||||
| virtual Status GenMemCopyTask(uint64_t count, STR_FWK_OP_KERNEL &task, string &task_info) { return SUCCESS; } | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ | #endif // INC_COMMON_OPSKERNEL_OPS_KERNEL_INFO_STORE_H_ | ||||
| @@ -26,13 +26,14 @@ | |||||
| using std::string; | using std::string; | ||||
| namespace ge { | namespace ge { | ||||
| /*lint -e148*/ | |||||
| struct RunContext { | struct RunContext { | ||||
| rtModel_t model; | rtModel_t model; | ||||
| rtStream_t stream; | rtStream_t stream; | ||||
| uint64_t sessionId; | uint64_t sessionId; | ||||
| uint64_t dataMemSize; | uint64_t dataMemSize; | ||||
| uint8_t *dataMemBase; | uint8_t *dataMemBase; | ||||
| std::map<int64_t, uint64_t> mem_type_data_mem_size; | |||||
| std::map<int64_t, uint8_t *> mem_type_data_mem_base; | |||||
| uint64_t weightMemSize; | uint64_t weightMemSize; | ||||
| uint8_t *weightMemBase; | uint8_t *weightMemBase; | ||||
| ge::Buffer weightsBuffer; | ge::Buffer weightsBuffer; | ||||
| @@ -41,8 +42,6 @@ struct RunContext { | |||||
| std::vector<rtLabel_t> graphLabelList; // all labels of graph, order by ge label id(0,1,...) | std::vector<rtLabel_t> graphLabelList; // all labels of graph, order by ge label id(0,1,...) | ||||
| }; | }; | ||||
| /*lint +e148*/ | |||||
| struct Task { | struct Task { | ||||
| uint32_t id; | uint32_t id; | ||||
| uint16_t type; | uint16_t type; | ||||
| @@ -51,8 +50,7 @@ struct Task { | |||||
| }; | }; | ||||
| struct OpInfo { | struct OpInfo { | ||||
| string engine; // which engin | |||||
| /*lint -e148*/ | |||||
| string engine; // which engin | |||||
| string opKernelLib; // which opsKernelStore | string opKernelLib; // which opsKernelStore | ||||
| int computeCost; // compute cost | int computeCost; // compute cost | ||||
| bool flagPartial; // whether to support is related to shape | bool flagPartial; // whether to support is related to shape | ||||
| @@ -27,7 +27,6 @@ | |||||
| using std::map; | using std::map; | ||||
| using std::string; | using std::string; | ||||
| /*lint -e148*/ | |||||
| namespace ge { | namespace ge { | ||||
| class GraphOptimizer { | class GraphOptimizer { | ||||
| public: | public: | ||||
| @@ -67,5 +66,4 @@ class GraphOptimizer { | |||||
| virtual Status OptimizeFusedGraphAfterGraphSlice(ComputeGraph &graph) { return SUCCESS; } | virtual Status OptimizeFusedGraphAfterGraphSlice(ComputeGraph &graph) { return SUCCESS; } | ||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| /*lint +e148*/ | |||||
| #endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ | #endif // INC_COMMON_OPTIMIZER_GRAPH_OPTIMIZER_H_ | ||||
| @@ -0,0 +1,48 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef AICORE_UTIL_MANAGER_H_ | |||||
| #define AICORE_UTIL_MANAGER_H_ | |||||
| #include <string> | |||||
| #include "register/graph_optimizer/graph_optimize_register_error_codes.h" | |||||
| namespace fe { | |||||
| class AICoreUtilManager { | |||||
| public: | |||||
| static AICoreUtilManager &Instance(); | |||||
| /* | |||||
| * to initialize the aicore configuration | |||||
| * param[in] the options of init | |||||
| * param[in] engine Name | |||||
| * param[in] socVersion soc version from ge | |||||
| * return Status(SUCCESS/FAILED) | |||||
| */ | |||||
| Status Initialize(const std::map<std::string, std::string> &options, std::string &soc_version); | |||||
| /* | |||||
| * to release the source of fusion manager | |||||
| * return Status(SUCCESS/FAILED) | |||||
| */ | |||||
| Status Finalize(); | |||||
| private: | |||||
| AICoreUtilManager(); | |||||
| ~AICoreUtilManager(); | |||||
| bool is_init_; | |||||
| }; | |||||
| } // namespace fe | |||||
| #endif // AICORE_UTIL_MANAGER_H | |||||
| @@ -36,6 +36,14 @@ static const std::string L1_OPTIMIZED = "l1_optimized"; | |||||
| static const std::string L2_OPTIMIZED = "l2_optimized"; | static const std::string L2_OPTIMIZED = "l2_optimized"; | ||||
| static const std::string OP_SLICE_INFO = "_op_slice_info"; | |||||
| static const std::string ATTR_NAME_UNKNOWN_SHAPE = "_unknown_shape"; | |||||
| static const std::string ATTR_NAME_IS_UNKNOWN_GRAPH = "_fe_is_unknown_graph"; | |||||
| static const std::string ATTR_NAME_IS_UNKNOWN_SHAPE_OP = "_fe_is_unknown_shape_op"; | |||||
| static const std::string ATTR_NAME_TVM_CACHE_READ_MODE = "tvm_cache_read_mode"; | |||||
| static const std::string ATTR_NAME_TBE_KERNEL_SIZE = "_tbeKernelSize"; | |||||
| } // namespace fe | } // namespace fe | ||||
| #endif | #endif | ||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_COMMON_UTILS_AI_CORE_COMMON_CONSTANTS_H_ | |||||
| #define INC_COMMON_UTILS_AI_CORE_COMMON_CONSTANTS_H_ | |||||
| #include <string> | |||||
| namespace fe { | |||||
| static const std::string CORE_TYPE = "_coretype"; | |||||
| /* engine name of AI core and vector core */ | |||||
| static const std::string AI_CORE_NAME = "AIcoreEngine"; | |||||
| static const std::string VECTOR_CORE_NAME = "VectorEngine"; | |||||
| static const int64_t IS_UNKNOWN_SHAPE_VALUE = 1; | |||||
| static const int64_t SHAPE_UNKNOWN_DIM = -1; | |||||
| static const int64_t SHAPE_UNKNOWN_DIM_NUM = -2; | |||||
| static const std::string SOC_VERSION_ASCEND310 = "Ascend310"; | |||||
| static const std::string SOC_VERSION_ASCEND610 = "Ascend610"; | |||||
| static const std::string SOC_VERSION_ASCEND615 = "Ascend615"; | |||||
| static const std::string SOC_VERSION_ASCEND710 = "Ascend710"; | |||||
| static const std::string SOC_VERSION_ASCEND710P = "Ascend710Pro"; | |||||
| static const std::string SOC_VERSION_ASCEND910A = "Ascend910A"; | |||||
| static const std::string SOC_VERSION_ASCEND910B = "Ascend910B"; | |||||
| static const std::string SOC_VERSION_ASCEND910PROA = "Ascend910ProA"; | |||||
| static const std::string SOC_VERSION_ASCEND910PROB = "Ascend910ProB"; | |||||
| static const std::string SOC_VERSION_ASCEND910PREMIUMA = "Ascend910PremiumA"; | |||||
| static const std::string SOC_VERSION_HI3796CV300ES = "Hi3796CV300ES"; | |||||
| static const std::string SOC_VERSION_HI3796CV300CS = "Hi3796CV300CS"; | |||||
| static const std::vector<std::string> SOC_VERSION_CLOUD_LIST = {SOC_VERSION_ASCEND910A, SOC_VERSION_ASCEND910B, | |||||
| SOC_VERSION_ASCEND910PROA, SOC_VERSION_ASCEND910PROB, | |||||
| SOC_VERSION_ASCEND910PREMIUMA}; | |||||
| static const std::vector<std::string> SOC_VERSION_DC_LIST = {SOC_VERSION_ASCEND610, SOC_VERSION_ASCEND615, | |||||
| SOC_VERSION_ASCEND710, SOC_VERSION_ASCEND710P}; | |||||
| } // namespace fe | |||||
| #endif | |||||
| @@ -42,47 +42,61 @@ struct FusionDataFlow { | |||||
| std::pair<std::string, ge::AnchorPtr> node_dataindex_pair; | std::pair<std::string, ge::AnchorPtr> node_dataindex_pair; | ||||
| }; | }; | ||||
| typedef struct tagL2FusionData { | |||||
| typedef struct tag_l2_fusion_data { | |||||
| uint32_t l2Index; | uint32_t l2Index; | ||||
| uint64_t l2Addr; | uint64_t l2Addr; | ||||
| uint64_t l2PageNum; | uint64_t l2PageNum; | ||||
| } L2FusionData_t; | } L2FusionData_t; | ||||
| typedef std::map<uint64_t, L2FusionData_t> L2FusionDataMap_t; | typedef std::map<uint64_t, L2FusionData_t> L2FusionDataMap_t; | ||||
| typedef struct tagFeSmDesc { | |||||
| typedef struct tag_fe_sm_desc { | |||||
| rtL2Ctrl_t l2ctrl; | rtL2Ctrl_t l2ctrl; | ||||
| std::string nodeName[8]; | |||||
| uint8_t outputIndex[8]; | |||||
| } feSmDesc_t; | |||||
| std::string node_name[8]; | |||||
| uint8_t output_index[8]; | |||||
| } fe_sm_desc_t; | |||||
| typedef struct TagTaskL2FusionInfo { | typedef struct TagTaskL2FusionInfo { | ||||
| std::string nodeName; | |||||
| feSmDesc_t l2Info; | |||||
| std::string node_name; | |||||
| fe_sm_desc_t l2_info; | |||||
| L2FusionDataMap_t input; | L2FusionDataMap_t input; | ||||
| L2FusionDataMap_t output; | L2FusionDataMap_t output; | ||||
| uint32_t isUsed; | |||||
| uint32_t is_used; | |||||
| } TaskL2FusionInfo_t; | } TaskL2FusionInfo_t; | ||||
| using L2FusionInfoPtr = std::shared_ptr<TaskL2FusionInfo_t>; | using L2FusionInfoPtr = std::shared_ptr<TaskL2FusionInfo_t>; | ||||
| typedef struct ToOpStruct { | typedef struct ToOpStruct { | ||||
| int64_t opL1Space = 0; | |||||
| std::vector<int64_t> opL1FusionType; | |||||
| int64_t opL1WorkspaceFlag = 0; // for workspace flag | |||||
| int64_t opL1WorkspaceSize = 0; | |||||
| std::vector<std::vector<int64_t>> validInputShape; | |||||
| std::vector<std::vector<int64_t>> validOutputShape; | |||||
| std::vector<std::vector<int64_t>> sliceInputOffset; // conv & pooling & ReadSelect | |||||
| std::vector<std::vector<int64_t>> sliceOutputOffset; // WriteSelect | |||||
| std::vector<uint32_t> totalShape; | |||||
| uint32_t splitIndex = 0; | |||||
| int64_t op_l1_space = 0; | |||||
| std::vector<int64_t> op_l1_fusion_type; | |||||
| int64_t op_l1_workspace_flag = 0; // for workspace flag | |||||
| int64_t op_l1_workspace_size = 0; | |||||
| std::vector<std::vector<int64_t>> valid_input_shape; | |||||
| std::vector<std::vector<int64_t>> valid_output_shape; | |||||
| std::vector<std::vector<int64_t>> slice_input_offset; // conv & pooling & ReadSelect | |||||
| std::vector<std::vector<int64_t>> slice_output_offset; // WriteSelect | |||||
| std::vector<uint32_t> total_shape; | |||||
| uint32_t split_index = 0; | |||||
| ToOpStruct() { | ToOpStruct() { | ||||
| // set invalid value for essential variable | // set invalid value for essential variable | ||||
| opL1Space = -1; | |||||
| opL1WorkspaceSize = -1; | |||||
| op_l1_space = -1; | |||||
| op_l1_workspace_size = -1; | |||||
| } | } | ||||
| } ToOpStruct_t; | } ToOpStruct_t; | ||||
| enum SlicePattern { | |||||
| ELEMENT_WISE = 0, | |||||
| ELEMENT_WISE_BROADCAST, | |||||
| BROADCAST, | |||||
| SLIDING_WINDOW, | |||||
| SLIDING_WINDOW_DECONV, | |||||
| CUBE_MATMUL, | |||||
| SLICE_PATTERN_REDUCE, | |||||
| SLICE_PATTERN_RESIZE, | |||||
| SLICE_PATTERN_SCATTER, | |||||
| SLICE_PATTERN_SEGMENT, | |||||
| PATTERN_RESERVED | |||||
| }; | |||||
| enum OpImplType { | enum OpImplType { | ||||
| EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op | EN_IMPL_CUSTOM_CONSTANT_CCE = 0, // custom constant op | ||||
| EN_IMPL_CUSTOM_TIK, // custom tik op | EN_IMPL_CUSTOM_TIK, // custom tik op | ||||
| @@ -99,6 +113,10 @@ enum OpImplType { | |||||
| EN_RESERVED // reserved value | EN_RESERVED // reserved value | ||||
| }; | }; | ||||
| // Dont change the order, only add new mode in the end | |||||
| enum L2Mode { EN_L2_CLOSE = 0, EN_L2_BUFFER_OPTIMIZE, EN_L2_CACHE_NORMAL, EN_L2_CACHE_RC }; | |||||
| enum BufferFusionMode { EN_OPTIMIZE_DISABLE = 0, EN_L2_BUFFER, EN_L2_FUSION }; | |||||
| static const std::map<ge::DataType, uint32_t> DATATYPE_SIZE_MAP{{ge::DT_FLOAT, sizeof(float)}, | static const std::map<ge::DataType, uint32_t> DATATYPE_SIZE_MAP{{ge::DT_FLOAT, sizeof(float)}, | ||||
| {ge::DT_FLOAT16, sizeof(int16_t)}, | {ge::DT_FLOAT16, sizeof(int16_t)}, | ||||
| {ge::DT_INT8, sizeof(int8_t)}, | {ge::DT_INT8, sizeof(int8_t)}, | ||||
| @@ -114,5 +132,13 @@ static const std::map<ge::DataType, uint32_t> DATATYPE_SIZE_MAP{{ge::DT_FLOAT, s | |||||
| {ge::DT_DUAL, sizeof(float) + sizeof(int8_t)}, | {ge::DT_DUAL, sizeof(float) + sizeof(int8_t)}, | ||||
| {ge::DT_DUAL_SUB_UINT8, sizeof(int8_t)}, | {ge::DT_DUAL_SUB_UINT8, sizeof(int8_t)}, | ||||
| {ge::DT_DUAL_SUB_INT8, sizeof(int8_t)}}; | {ge::DT_DUAL_SUB_INT8, sizeof(int8_t)}}; | ||||
| enum OpReduceType { | |||||
| REDUCE_MEAN = 0, | |||||
| REDUCE_ADD, | |||||
| REDUCE_MAX, | |||||
| REDUCE_MIN, | |||||
| }; | |||||
| } // namespace fe | } // namespace fe | ||||
| #endif | #endif | ||||
| @@ -28,33 +28,34 @@ | |||||
| namespace fe { | namespace fe { | ||||
| using kScopeNodeMap_t = std::map<int64_t, std::vector<ge::NodePtr>>; | |||||
| using kScopeNodePair_t = std::pair<int64_t, std::vector<ge::NodePtr>>; | |||||
| using k_scope_node_map_t = std::map<int64_t, std::vector<ge::NodePtr>>; | |||||
| using k_scope_node_pair_t = std::pair<int64_t, std::vector<ge::NodePtr>>; | |||||
| class GraphCommImpl; | class GraphCommImpl; | ||||
| using GraphCommImplPtr = std::unique_ptr<GraphCommImpl>; | using GraphCommImplPtr = std::unique_ptr<GraphCommImpl>; | ||||
| class GraphComm { | class GraphComm { | ||||
| public: | public: | ||||
| GraphComm(const string &engineName); | |||||
| GraphComm(const string &engine_name); | |||||
| virtual ~GraphComm(); | virtual ~GraphComm(); | ||||
| GraphComm(const GraphComm &in) = delete; | GraphComm(const GraphComm &in) = delete; | ||||
| GraphComm &operator=(const GraphComm &in) = delete; | GraphComm &operator=(const GraphComm &in) = delete; | ||||
| Status GetscopeNodeMap(ge::ComputeGraph &graph, kScopeNodeMap_t &fusionMap); | |||||
| Status GetscopeNodeMap(ge::ComputeGraph &graph, k_scope_node_map_t &fusion_map); | |||||
| Status CopyFusionOpNodes(vector<FusionDataFlow> &fusInputEdgeList, vector<FusionDataFlow> &fusOutputEdgeList, | |||||
| vector<ge::NodePtr> &fusNodelist, ge::OpDescPtr fusionOpDesc, | |||||
| ge::ComputeGraphPtr fusionGraph); | |||||
| Status CopyFusionOpNodes(vector<FusionDataFlow> &fus_input_edge_list, vector<FusionDataFlow> &fus_output_edge_list, | |||||
| vector<ge::NodePtr> &fus_nodelist, ge::OpDescPtr fusion_op_desc, | |||||
| ge::ComputeGraphPtr fusion_graph); | |||||
| Status CopyFusionOpEdges(ge::OpDescPtr fusionOpDesc, ge::ComputeGraph &origGraph, ge::ComputeGraphPtr fusionGraph); | |||||
| Status CopyFusionOpEdges(ge::OpDescPtr fusion_op_desc, ge::ComputeGraph &orig_graph, | |||||
| ge::ComputeGraphPtr fusion_graph); | |||||
| Status GetNodeDataFlowMap(const ge::NodePtr &fusNode, | |||||
| std::map<ge::NodePtr, std::map<ge::AnchorPtr, ge::AnchorPtr>> &fusionOpAnchorsMap, | |||||
| ge::kFusionDataFlowVec_t &fusDataflowList, const int &mapType); | |||||
| Status GetNodeDataFlowMap(const ge::NodePtr &fus_node, | |||||
| std::map<ge::NodePtr, std::map<ge::AnchorPtr, ge::AnchorPtr>> &fusion_op_anchors_map, | |||||
| ge::kFusionDataFlowVec_t &fus_dataflow_list, const int &map_type); | |||||
| Status GetFusionNodeEdgeList(std::vector<ge::NodePtr> &fusNodelist, std::vector<FusionDataFlow> &fusInputEdgeList, | |||||
| std::vector<FusionDataFlow> &fusOutputEdgeList); | |||||
| Status GetFusionNodeEdgeList(std::vector<ge::NodePtr> &fus_nodelist, std::vector<FusionDataFlow> &fus_input_edge_list, | |||||
| std::vector<FusionDataFlow> &fus_output_edge_list); | |||||
| void ClearFusionSrc(); | void ClearFusionSrc(); | ||||
| void ClearFusionDst(); | void ClearFusionDst(); | ||||
| @@ -72,25 +73,26 @@ class GraphComm { | |||||
| bool GetFusionSrc(const uint32_t &src_op_id, const ge::AnchorPtr &src_anchor, int32_t &fusion_src_index, | bool GetFusionSrc(const uint32_t &src_op_id, const ge::AnchorPtr &src_anchor, int32_t &fusion_src_index, | ||||
| int32_t &fusion_dst_index); | int32_t &fusion_dst_index); | ||||
| Status GetFusionNodeCtrlEdgeList(vector<ge::NodePtr> &fusNodelist, vector<FusionDataFlow> &fusInputCtrlEdgeList, | |||||
| vector<FusionDataFlow> &fusOutputCtrlEdgeList); | |||||
| Status GetFusionNodeCtrlEdgeList(vector<ge::NodePtr> &fus_nodelist, vector<FusionDataFlow> &fus_input_ctrl_edge_list, | |||||
| vector<FusionDataFlow> &fus_output_ctrl_edge_list); | |||||
| Status MergeFusionNodeEdgeList(ge::NodePtr &fusNode, vector<ge::NodePtr> &fusNodelist, | |||||
| vector<FusionDataFlow> &fusInputEdgeList, vector<FusionDataFlow> &fusOutputEdgeList); | |||||
| Status MergeFusionNodeEdgeList(ge::NodePtr &fus_node, vector<ge::NodePtr> &fus_nodelist, | |||||
| vector<FusionDataFlow> &fus_input_edge_list, | |||||
| vector<FusionDataFlow> &fus_output_edge_list); | |||||
| Status MergeFusionNodeCtrlEdgeList(ge::NodePtr &fusNode, vector<ge::NodePtr> &fusNodelist, | |||||
| vector<FusionDataFlow> &fusInputEdgeList, | |||||
| vector<FusionDataFlow> &fusOutputEdgeList); | |||||
| Status MergeFusionNodeCtrlEdgeList(ge::NodePtr &fus_node, vector<ge::NodePtr> &fus_nodelist, | |||||
| vector<FusionDataFlow> &fus_input_edge_list, | |||||
| vector<FusionDataFlow> &fus_output_edge_list); | |||||
| string GetEngineName(); | string GetEngineName(); | ||||
| private: | private: | ||||
| Status MergeFusionNodeInputEdgeList(ge::NodePtr fusNode, std::vector<ge::NodePtr> &fusNodelist, | |||||
| std::vector<FusionDataFlow> &fusInputEdgeList); | |||||
| Status MergeFusionNodeOutputEdgeList(ge::NodePtr fusNode, std::vector<ge::NodePtr> &fusNodelist, | |||||
| std::vector<FusionDataFlow> &fusOutputEdgeList); | |||||
| Status MergeFusionNodeInputEdgeList(ge::NodePtr fus_node, std::vector<ge::NodePtr> &fus_nodelist, | |||||
| std::vector<FusionDataFlow> &fus_input_edge_list); | |||||
| Status MergeFusionNodeOutputEdgeList(ge::NodePtr fus_node, std::vector<ge::NodePtr> &fus_nodelist, | |||||
| std::vector<FusionDataFlow> &fus_output_edge_list); | |||||
| string engineName_; | |||||
| string engine_name_; | |||||
| std::vector<FusionOpSrc> exist_fusion_src_list_; | std::vector<FusionOpSrc> exist_fusion_src_list_; | ||||
| std::vector<FusionOpDst> exist_fusion_dst_list_; | std::vector<FusionOpDst> exist_fusion_dst_list_; | ||||
| @@ -101,7 +103,7 @@ class GraphComm { | |||||
| // std::vector<std::multimap<std::string, ge::AnchorPtr>> | // std::vector<std::multimap<std::string, ge::AnchorPtr>> | ||||
| ge::kFusionDataFlowVec_t fusion_output_dataflow_list_; | ge::kFusionDataFlowVec_t fusion_output_dataflow_list_; | ||||
| GraphCommImplPtr graphCommImplPtr_; | |||||
| GraphCommImplPtr graph_comm_impl_ptr_; | |||||
| }; | }; | ||||
| } // namespace fe | } // namespace fe | ||||
| #endif | #endif | ||||
| @@ -0,0 +1,54 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PROJECT_JSON_UTIL_H | |||||
| #define PROJECT_JSON_UTIL_H | |||||
| #include "graph/compute_graph.h" | |||||
| #include "common/aicore_util_types.h" | |||||
| #include "fusion_engine/graph_tuner/graph_tuner_errorcode.h" | |||||
| const std::string L1_FUSION_EXTEND_CONTENT = "_l1_fusion_extend_content"; | |||||
| const std::string L2_FUSION_EXTEND_CONTENT = "l2_fusion_extend_content"; | |||||
| const std::string TASK_L2_FUSION_INFO_EXTEND_CONTENT = "task_l2_fusion_info_extend_content"; | |||||
| const std::string L1_FUSION_TO_OP_STRUCT = "_l1fusion_ToOpStruct"; | |||||
| const std::string L2_FUSION_TO_OP_STRUCT = "_l2fusion_ToOpStruct"; | |||||
| const std::string TASK_L2_FUSION_INFO = "_task_L2FusionInfo"; | |||||
| namespace tune { | |||||
| using ToOpStructPtr = std::shared_ptr<fe::ToOpStruct_t>; | |||||
| using L2FusionInfoPtr = std::shared_ptr<fe::TaskL2FusionInfo_t>; | |||||
| Status GetL1InfoFromJson(ge::OpDescPtr opDescPtr); | |||||
| Status GetL2InfoFromJson(ge::OpDescPtr opDescPtr); | |||||
| Status GetTaskL2FusionInfoFromJson(ge::OpDescPtr opDescPtr); | |||||
| Status ReadGraphInfoFromJson(ge::ComputeGraph &graph); | |||||
| Status WriteGraphInfoToJson(ge::ComputeGraph &graph); | |||||
| void GetL2ToOpStructFromJson(ge::OpDescPtr &opDescPtr, ToOpStructPtr &l2InfoPtr); | |||||
| void GetL1ToOpStructFromJson(ge::OpDescPtr &opDescPtr, ToOpStructPtr &l1InfoPtr); | |||||
| L2FusionInfoPtr GetL2FusionInfoFromJson(ge::OpDescPtr &opDescPtr); | |||||
| void SetL2FusionInfoToNode(ge::OpDescPtr &opDescPtr, L2FusionInfoPtr &l2FusionInfoPtr); | |||||
| } // namespace tune | |||||
| #endif // PROJECT_JSON_UTIL_H | |||||
| @@ -0,0 +1,44 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef L2_STREAM_INFO_H_ | |||||
| #define L2_STREAM_INFO_H_ | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <mutex> | |||||
| #include "register/graph_optimizer/graph_optimize_register_error_codes.h" | |||||
| #include "runtime/base.h" | |||||
| #include "cce/l2fusion_struct.hpp" | |||||
| namespace fe { | |||||
| class StreamL2Info { | |||||
| public: | |||||
| StreamL2Info(const StreamL2Info &) = delete; | |||||
| StreamL2Info &operator=(const StreamL2Info &) = delete; | |||||
| static StreamL2Info &Instance(); | |||||
| Status GetStreamL2Info(rtStream_t stream_id, string node_name, fusion::TaskL2Info_t *&l2_data); | |||||
| Status SetStreamL2Info(const rtStream_t &stream_id, fusion::TaskL2InfoFEMap_t &l2_alloc_res); | |||||
| private: | |||||
| StreamL2Info(); | |||||
| ~StreamL2Info(); | |||||
| mutable std::mutex stream_l2_mutex_; | |||||
| std::map<rtStream_t, fusion::TaskL2InfoFEMap_t> stream_l2_map_; | |||||
| }; | |||||
| } // namespace fe | |||||
| #endif // L2_STREAM_INFO_H_ | |||||
| @@ -32,12 +32,12 @@ class ScopeAllocator { | |||||
| int64_t GetCurrentScopeId(); | int64_t GetCurrentScopeId(); | ||||
| int64_t AllocateScopeId(void); | int64_t AllocateScopeId(void); | ||||
| bool HasScopeAttr(ge::ConstOpDescPtr opdef); | bool HasScopeAttr(ge::ConstOpDescPtr opdef); | ||||
| bool GetScopeAttr(ge::ConstOpDescPtr opdef, int64_t& scopeId); | |||||
| bool SetScopeAttr(ge::OpDescPtr opdef, int64_t scopeId); | |||||
| bool ResetScopeId(int64_t scopeId); | |||||
| bool GetScopeAttr(ge::ConstOpDescPtr opdef, int64_t& scope_id); | |||||
| bool SetScopeAttr(ge::OpDescPtr opdef, int64_t scope_id); | |||||
| bool ResetScopeId(int64_t scope_id); | |||||
| private: | private: | ||||
| int64_t scopeId; | |||||
| int64_t scope_id; | |||||
| }; | }; | ||||
| } // namespace fe | } // namespace fe | ||||
| #endif | #endif | ||||
| @@ -29,16 +29,16 @@ class TensorSizeCalculator { | |||||
| public: | public: | ||||
| /** | /** | ||||
| * Calculate the tensor size of input and output of each opdesc | * Calculate the tensor size of input and output of each opdesc | ||||
| * @param opDesc opdesc object | |||||
| * @param opImplType op impl type | |||||
| * @param op_desc opdesc object | |||||
| * @param op_impl_type op impl type | |||||
| * @return status SUCCESS or FAILED | * @return status SUCCESS or FAILED | ||||
| */ | */ | ||||
| static Status CalculateOpTensorSize(ge::OpDesc &opDesc); | |||||
| static Status CalculateOpTensorSize(ge::OpDesc &op_desc); | |||||
| private: | private: | ||||
| static Status CalcInputOpTensorSize(ge::OpDesc &opDesc, int32_t &outputRealCalcFlag); | |||||
| static Status CalcInputOpTensorSize(ge::OpDesc &op_desc, int32_t &output_real_calc_flag); | |||||
| static Status CalcOutputOpTensorSize(ge::OpDesc &opDesc, int32_t &outputRealCalcFlag); | |||||
| static Status CalcOutputOpTensorSize(ge::OpDesc &op_desc, int32_t &output_real_calc_flag); | |||||
| }; | }; | ||||
| } // namespace fe | } // namespace fe | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <mutex> | |||||
| class ErrorManager { | class ErrorManager { | ||||
| public: | public: | ||||
| @@ -86,6 +87,7 @@ class ErrorManager { | |||||
| int ReadJsonFile(const std::string &file_path, void *handle); | int ReadJsonFile(const std::string &file_path, void *handle); | ||||
| bool is_init_ = false; | bool is_init_ = false; | ||||
| std::mutex mutex_; | |||||
| std::map<std::string, ErrorInfo> error_map_; | std::map<std::string, ErrorInfo> error_map_; | ||||
| std::vector<std::string> error_messages_; | std::vector<std::string> error_messages_; | ||||
| std::vector<std::string> warning_messages_; | std::vector<std::string> warning_messages_; | ||||
| @@ -36,66 +36,66 @@ class PlatformInfoManager { | |||||
| uint32_t InitializePlatformInfo(); | uint32_t InitializePlatformInfo(); | ||||
| uint32_t Finalize(); | uint32_t Finalize(); | ||||
| uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); | |||||
| uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platform_info, OptionalInfo &opti_compilation_info); | |||||
| uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); | |||||
| uint32_t GetPlatformInfoWithOutSocVersion(PlatformInfo &platform_info, OptionalInfo &opti_compilation_info); | |||||
| void SetOptionalCompilationInfo(OptionalInfo &optiCompilationInfo); | |||||
| void SetOptionalCompilationInfo(OptionalInfo &opti_compilation_info); | |||||
| private: | private: | ||||
| PlatformInfoManager(); | PlatformInfoManager(); | ||||
| ~PlatformInfoManager(); | ~PlatformInfoManager(); | ||||
| uint32_t LoadIniFile(string iniFileRealPath); | |||||
| uint32_t LoadIniFile(string ini_file_real_path); | |||||
| void Trim(string &str); | void Trim(string &str); | ||||
| uint32_t LoadConfigFile(string realPath); | |||||
| uint32_t LoadConfigFile(string real_path); | |||||
| string RealPath(const std::string &path); | string RealPath(const std::string &path); | ||||
| string GetSoFilePath(); | string GetSoFilePath(); | ||||
| void ParseVersion(map<string, string> &versionMap, string &socVersion, PlatformInfo &platformInfoTemp); | |||||
| void ParseVersion(map<string, string> &version_map, string &soc_version, PlatformInfo &platform_info_temp); | |||||
| void ParseSocInfo(map<string, string> &socInfoMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseSocInfo(map<string, string> &soc_info_map, PlatformInfo &platform_info_temp); | |||||
| void ParseCubeOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseCubeOfAICoreSpec(map<string, string> &ai_core_spec_map, PlatformInfo &platform_info_temp); | |||||
| void ParseBufferOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseBufferOfAICoreSpec(map<string, string> &ai_core_spec_map, PlatformInfo &platform_info_temp); | |||||
| void ParseUBOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseUBOfAICoreSpec(map<string, string> &ai_core_spec_map, PlatformInfo &platform_info_temp); | |||||
| void ParseUnzipOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseUnzipOfAICoreSpec(map<string, string> &ai_core_spec_map, PlatformInfo &platform_info_temp); | |||||
| void ParseAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseAICoreSpec(map<string, string> &ai_core_spec_map, PlatformInfo &platform_info_temp); | |||||
| void ParseBufferOfAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseBufferOfAICoreMemoryRates(map<string, string> &ai_core_memory_rates_map, PlatformInfo &platform_info_temp); | |||||
| void ParseAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseAICoreMemoryRates(map<string, string> &ai_core_memory_rates_map, PlatformInfo &platform_info_temp); | |||||
| void ParseUBOfAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseUBOfAICoreMemoryRates(map<string, string> &ai_core_memory_rates_map, PlatformInfo &platform_info_temp); | |||||
| void ParseAICoreintrinsicDtypeMap(map<string, string> &aiCoreintrinsicDtypeMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseAICoreintrinsicDtypeMap(map<string, string> &ai_coreintrinsic_dtype_map, PlatformInfo &platform_info_temp); | |||||
| void ParseVectorCoreSpec(map<string, string> &vectorCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseVectorCoreSpec(map<string, string> &vector_core_spec_map, PlatformInfo &platform_info_temp); | |||||
| void ParseVectorCoreMemoryRates(map<string, string> &vectorCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseVectorCoreMemoryRates(map<string, string> &vector_core_memory_rates_map, PlatformInfo &platform_info_temp); | |||||
| void ParseCPUCache(map<string, string> &CPUCacheMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseCPUCache(map<string, string> &CPUCacheMap, PlatformInfo &platform_info_temp); | |||||
| void ParseVectorCoreintrinsicDtypeMap(map<string, string> &vectorCoreintrinsicDtypeMap, | |||||
| PlatformInfo &platformInfoTemp); | |||||
| void ParseVectorCoreintrinsicDtypeMap(map<string, string> &vector_coreintrinsic_dtype_map, | |||||
| PlatformInfo &platform_info_temp); | |||||
| uint32_t ParsePlatformInfoFromStrToStruct(map<string, map<string, string>> &contentInfoMap, string &socVersion, | |||||
| PlatformInfo &platformInfoTemp); | |||||
| uint32_t ParsePlatformInfoFromStrToStruct(map<string, map<string, string>> &content_info_map, string &soc_version, | |||||
| PlatformInfo &platform_info_temp); | |||||
| uint32_t AssemblePlatformInfoVector(map<string, map<string, string>> &contentInfoMap); | |||||
| uint32_t AssemblePlatformInfoVector(map<string, map<string, string>> &content_info_map); | |||||
| private: | private: | ||||
| bool initFlag_; | |||||
| map<string, PlatformInfo> platformInfoMap_; | |||||
| OptionalInfo optiCompilationInfo_; | |||||
| bool init_flag_; | |||||
| map<string, PlatformInfo> platform_info_map_; | |||||
| OptionalInfo opti_compilation_info_; | |||||
| }; | }; | ||||
| } // namespace fe | } // namespace fe | ||||
| #endif | #endif | ||||
| @@ -30,111 +30,113 @@ enum MemoryType { DDR = 0, HBM }; | |||||
| enum L2Type { Cache = 0, Buff }; | enum L2Type { Cache = 0, Buff }; | ||||
| typedef struct tagStrInfo { | |||||
| string aicVersion; | |||||
| string ccecAICVersion; | |||||
| string ccecAIVVersion; | |||||
| string isSupportAIcpuCompiler; | |||||
| typedef struct tag_str_info { | |||||
| string aic_version; | |||||
| string ccec_aic_version; | |||||
| string ccec_aiv_version; | |||||
| string is_support_ai_cpu_compiler; | |||||
| } StrInfo; | } StrInfo; | ||||
| typedef struct tagSoCInfo { | |||||
| uint32_t aiCoreCnt; | |||||
| uint32_t vectorCoreCnt; | |||||
| uint32_t aiCpuCnt; | |||||
| MemoryType memoryType; | |||||
| uint64_t memorySize; | |||||
| L2Type l2Type; | |||||
| uint64_t l2Size; | |||||
| typedef struct tag_so_c_info { | |||||
| uint32_t ai_core_cnt; | |||||
| uint32_t vector_core_cnt; | |||||
| uint32_t ai_cpu_cnt; | |||||
| MemoryType memory_type; | |||||
| uint64_t memory_size; | |||||
| L2Type l2_type; | |||||
| uint64_t l2_size; | |||||
| uint32_t l2PageNum; | uint32_t l2PageNum; | ||||
| } SoCInfo; | } SoCInfo; | ||||
| typedef struct tagAiCoreSpec { | |||||
| double cubeFreq; | |||||
| uint64_t cubeMSize; | |||||
| uint64_t cubeNSize; | |||||
| uint64_t cubeKSize; | |||||
| uint64_t vecCalcSize; | |||||
| uint64_t l0ASize; | |||||
| uint64_t l0BSize; | |||||
| uint64_t l0CSize; | |||||
| uint64_t l1Size; | |||||
| uint64_t smaskBuffer; | |||||
| uint64_t ubSize; | |||||
| uint64_t ubblockSize; | |||||
| uint64_t ubbankSize; | |||||
| uint64_t ubbankNum; | |||||
| uint64_t ubburstInOneBlock; | |||||
| uint64_t ubbankGroupNum; | |||||
| uint32_t unzipEngines; | |||||
| uint32_t unzipMaxRatios; | |||||
| uint32_t unzipChannels; | |||||
| uint8_t unzipIsTight; | |||||
| typedef struct tag_ai_core_spec { | |||||
| double cube_freq; | |||||
| uint64_t cube_m_size; | |||||
| uint64_t cube_n_size; | |||||
| uint64_t cube_k_size; | |||||
| uint64_t vec_calc_size; | |||||
| uint64_t l0_a_size; | |||||
| uint64_t l0_b_size; | |||||
| uint64_t l0_c_size; | |||||
| uint64_t l1_size; | |||||
| uint64_t smask_buffer; | |||||
| uint64_t ub_size; | |||||
| uint64_t ubblock_size; | |||||
| uint64_t ubbank_size; | |||||
| uint64_t ubbank_num; | |||||
| uint64_t ubburst_in_one_block; | |||||
| uint64_t ubbank_group_num; | |||||
| uint32_t unzip_engines; | |||||
| uint32_t unzip_max_ratios; | |||||
| uint32_t unzip_channels; | |||||
| uint8_t unzip_is_tight; | |||||
| uint8_t cube_vector_split; | |||||
| } AiCoreSpec; | } AiCoreSpec; | ||||
| typedef struct tagAiCoreMemoryRates { | |||||
| double ddrRate; | |||||
| double ddrReadRate; | |||||
| double ddrWriteRate; | |||||
| double l2Rate; | |||||
| double l2ReadRate; | |||||
| double l2WriteRate; | |||||
| double l1ToL0ARate; | |||||
| double l1ToL0BRate; | |||||
| double l1ToUBRate; | |||||
| double l0CToUBRate; | |||||
| double ubToL2Rate; | |||||
| double ubToDdrRate; | |||||
| double ubToL1Rate; | |||||
| typedef struct tag_ai_core_memory_rates { | |||||
| double ddr_rate; | |||||
| double ddr_read_rate; | |||||
| double ddr_write_rate; | |||||
| double l2_rate; | |||||
| double l2_read_rate; | |||||
| double l2_write_rate; | |||||
| double l1_to_l0_a_rate; | |||||
| double l1_to_l0_b_rate; | |||||
| double l1_to_ub_rate; | |||||
| double l0_c_to_ub_rate; | |||||
| double ub_to_l2_rate; | |||||
| double ub_to_ddr_rate; | |||||
| double ub_to_l1_rate; | |||||
| } AiCoreMemoryRates; | } AiCoreMemoryRates; | ||||
| typedef struct tagVectorCoreSpec { | |||||
| double vecFreq; | |||||
| uint64_t vecCalcSize; | |||||
| uint64_t smaskBuffer; | |||||
| uint64_t ubSize; | |||||
| uint64_t ubblockSize; | |||||
| uint64_t ubbankSize; | |||||
| uint64_t ubbankNum; | |||||
| uint64_t ubburstInOneBlock; | |||||
| uint64_t ubbankGroupNum; | |||||
| uint64_t vectorRegSize; | |||||
| uint64_t predicateRegSize; | |||||
| uint64_t addressRegSize; | |||||
| typedef struct tag_vector_core_spec { | |||||
| double vec_freq; | |||||
| uint64_t vec_calc_size; | |||||
| uint64_t smask_buffer; | |||||
| uint64_t ub_size; | |||||
| uint64_t ubblock_size; | |||||
| uint64_t ubbank_size; | |||||
| uint64_t ubbank_num; | |||||
| uint64_t ubburst_in_one_block; | |||||
| uint64_t ubbank_group_num; | |||||
| uint64_t vector_reg_size; | |||||
| uint64_t predicate_reg_size; | |||||
| uint64_t address_reg_size; | |||||
| uint64_t alignment_reg_size; | |||||
| } VectorCoreSpec; | } VectorCoreSpec; | ||||
| typedef struct tagVectorCoreMemoryRates { | |||||
| double ddrRate; | |||||
| double ddrReadRate; | |||||
| double ddrWriteRate; | |||||
| double l2Rate; | |||||
| double l2ReadRate; | |||||
| double l2WriteRate; | |||||
| double ubToL2Rate; | |||||
| double ubToDdrRate; | |||||
| typedef struct tag_vector_core_memory_rates { | |||||
| double ddr_rate; | |||||
| double ddr_read_rate; | |||||
| double ddr_write_rate; | |||||
| double l2_rate; | |||||
| double l2_read_rate; | |||||
| double l2_write_rate; | |||||
| double ub_to_l2_rate; | |||||
| double ub_to_ddr_rate; | |||||
| } VectorCoreMemoryRates; | } VectorCoreMemoryRates; | ||||
| typedef struct tagCPUCache { | |||||
| typedef struct tag_cpu_cache { | |||||
| uint32_t AICPUSyncBySW; | uint32_t AICPUSyncBySW; | ||||
| uint32_t TSCPUSyncBySW; | uint32_t TSCPUSyncBySW; | ||||
| } CPUCache; | } CPUCache; | ||||
| typedef struct tagPlatformInfo { | |||||
| StrInfo strInfo; | |||||
| SoCInfo socInfo; | |||||
| AiCoreSpec aiCoreSpec; | |||||
| AiCoreMemoryRates aiCoreMemoryRates; | |||||
| map<string, vector<string>> aiCoreIntrinsicDtypeMap; | |||||
| VectorCoreSpec vectorCoreSpec; | |||||
| VectorCoreMemoryRates vectorCoreMemoryRates; | |||||
| typedef struct tag_platform_info { | |||||
| StrInfo str_info; | |||||
| SoCInfo soc_info; | |||||
| AiCoreSpec ai_core_spec; | |||||
| AiCoreMemoryRates ai_core_memory_rates; | |||||
| map<string, vector<string>> ai_core_intrinsic_dtype_map; | |||||
| VectorCoreSpec vector_core_spec; | |||||
| VectorCoreMemoryRates vector_core_memory_rates; | |||||
| CPUCache cpucache; | CPUCache cpucache; | ||||
| map<string, vector<string>> vectorCoreIntrinsicDtypeMap; | |||||
| map<string, vector<string>> vector_core_intrinsic_dtype_map; | |||||
| } PlatformInfo; | } PlatformInfo; | ||||
| typedef struct tagOptionalInfo { | |||||
| string socVersion; | |||||
| string coreType; | |||||
| uint32_t aiCoreNum; | |||||
| string l1FusionFlag; | |||||
| typedef struct tag_optional_info { | |||||
| string soc_version; | |||||
| string core_type; | |||||
| uint32_t ai_core_num; | |||||
| string l1_fusion_flag; | |||||
| } OptionalInfo; | } OptionalInfo; | ||||
| } // namespace fe | } // namespace fe | ||||
| #endif | #endif | ||||
| @@ -70,7 +70,7 @@ using Status = uint32_t; | |||||
| // General error code | // General error code | ||||
| GE_ERRORNO(0, 0, 0, 0, 0, SUCCESS, 0, "success"); | GE_ERRORNO(0, 0, 0, 0, 0, SUCCESS, 0, "success"); | ||||
| GE_ERRORNO(0b11, 0b11, 0b111, 0xFF, 0b11111, FAILED, 0xFFF, "failed"); /*lint !e401*/ | |||||
| GE_ERRORNO(0b11, 0b11, 0b111, 0xFF, 0b11111, FAILED, 0xFFF, "failed"); | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_EXTERNAL_GE_GE_API_ERROR_CODES_H_ | #endif // INC_EXTERNAL_GE_GE_API_ERROR_CODES_H_ | ||||
| @@ -89,5 +89,26 @@ graphStatus aclgrphSaveModel(const string &output_file, const ModelBufferData &m | |||||
| */ | */ | ||||
| graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version); | graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *patch_version); | ||||
| /** | |||||
| * @ingroup AscendCL | |||||
| * @brief infer shape and data type | |||||
| * | |||||
| * @param graph[IN] the graph ready to build | |||||
| * @retval GRAPH_SUCCESS The function is successfully executed. | |||||
| * @retval OtherValues Failure | |||||
| */ | |||||
| graphStatus aclgrphInferShapeAndType(ge::Graph &graph); | |||||
| /** | |||||
| * @ingroup AscendCL | |||||
| * @brief dump graph | |||||
| * | |||||
| * @param graph[IN] the graph ready to build | |||||
| * @param file[IN] file path | |||||
| * @param file[IN] file path string len | |||||
| * @retval GRAPH_SUCCESS The function is successfully executed. | |||||
| * @retval OtherValues Failure | |||||
| */ | |||||
| graphStatus aclgrphDumpGraph(const ge::Graph &graph, const char *file, const size_t len); | |||||
| }; // namespace ge | }; // namespace ge | ||||
| #endif | #endif | ||||
| @@ -0,0 +1,38 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_GRAPH_ASCEND_STRING_H_ | |||||
| #define INC_EXTERNAL_GRAPH_ASCEND_STRING_H_ | |||||
| #include <string> | |||||
| #include <memory> | |||||
| namespace ge { | |||||
| class AscendString { | |||||
| public: | |||||
| AscendString() = default; | |||||
| ~AscendString() = default; | |||||
| explicit AscendString(const char* name); | |||||
| const char* GetString() const; | |||||
| private: | |||||
| std::shared_ptr<std::string> name_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_EXTERNAL_GRAPH_ASCEND_STRING_H_ | |||||
| @@ -34,7 +34,6 @@ using std::vector; | |||||
| namespace ge { | namespace ge { | ||||
| class AttrValueImpl; | class AttrValueImpl; | ||||
| /*lint -e148*/ | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue { | class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue { | ||||
| public: | public: | ||||
| using INT = int64_t; | using INT = int64_t; | ||||
| @@ -70,6 +69,5 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrValue { | |||||
| VALUE_SET_GET_DEC(AttrValue::FLOAT) | VALUE_SET_GET_DEC(AttrValue::FLOAT) | ||||
| #undef VALUE_SET_GET_DEC | #undef VALUE_SET_GET_DEC | ||||
| }; | }; | ||||
| /*lint +e148*/ | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ | #endif // INC_EXTERNAL_GRAPH_ATTR_VALUE_H_ | ||||
| @@ -33,6 +33,7 @@ using graphStatus = uint32_t; | |||||
| const graphStatus GRAPH_FAILED = 0xFFFFFFFF; | const graphStatus GRAPH_FAILED = 0xFFFFFFFF; | ||||
| const graphStatus GRAPH_SUCCESS = 0; | const graphStatus GRAPH_SUCCESS = 0; | ||||
| const graphStatus GRAPH_PARAM_INVALID = 50331649; | const graphStatus GRAPH_PARAM_INVALID = 50331649; | ||||
| const graphStatus GRAPH_NODE_WITHOUT_CONST_INPUT = 50331648; | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ | #endif // INC_EXTERNAL_GRAPH_GE_ERROR_CODES_H_ | ||||
| @@ -0,0 +1,129 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_EXTERNAL_GRAPH_NODE_H_ | |||||
| #define INC_EXTERNAL_GRAPH_NODE_H_ | |||||
| #include <vector> | |||||
| #include <cstdint> | |||||
| #include "./ge_error_codes.h" | |||||
| #include "./types.h" | |||||
| #include "./tensor.h" | |||||
| #include "./ascend_string.h" | |||||
| namespace ge { | |||||
| class AttrValue; | |||||
| class GNode; | |||||
| class OpDesc; | |||||
| class Graph; | |||||
| class ComputeGraph; | |||||
| using GNodePtr = std::shared_ptr<GNode>; | |||||
| using GraphPtr = std::shared_ptr<Graph>; | |||||
| using OpBytes = std::vector<uint8_t>; | |||||
| using OpDescPtr = std::shared_ptr<OpDesc>; | |||||
| using ComputeGraphPtr = std::shared_ptr<ComputeGraph>; | |||||
| class NodeImpl; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GNode { | |||||
| public: | |||||
| GNode(); | |||||
| ~GNode() = default; | |||||
| graphStatus GetType(ge::AscendString &type) const; | |||||
| graphStatus GetName(ge::AscendString &name) const; | |||||
| std::pair<GNodePtr, int32_t> GetInDataNodesAndPortIndexs(const int32_t index) const; | |||||
| std::vector<GNodePtr> GetInControlNodes() const; | |||||
| std::vector<std::pair<GNodePtr, int32_t>> GetOutDataNodesAndPortIndexs(const int32_t index) const; | |||||
| std::vector<GNodePtr> GetOutControlNodes() const; | |||||
| graphStatus GetInputConstData(const int32_t index, Tensor &data) const; | |||||
| graphStatus GetInputIndexByName(const ge::AscendString &name, int32_t &index); | |||||
| graphStatus GetOutputIndexByName(const ge::AscendString &name, int32_t &index); | |||||
| size_t GetInputsSize() const; | |||||
| size_t GetOutputsSize() const; | |||||
| graphStatus GetInputDesc(const int32_t index, TensorDesc &tensor_desc) const; | |||||
| graphStatus UpdateInputDesc(const int32_t index, const TensorDesc &tensor_desc); | |||||
| graphStatus GetOutputDesc(const int32_t index, TensorDesc &tensor_desc) const; | |||||
| graphStatus UpdateOutputDesc(const int32_t index, const TensorDesc &tensor_desc); | |||||
| graphStatus GetAttr(const ge::AscendString &name, int64_t &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, int32_t &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, uint32_t &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, float &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, ge::AscendString &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, bool &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, Tensor &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, std::vector<int64_t> &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, std::vector<int32_t> &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, std::vector<uint32_t> &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, std::vector<float> &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, std::vector<ge::AscendString> &attr_values) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, std::vector<bool> &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, std::vector<Tensor> &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, OpBytes &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, std::vector<std::vector<int64_t>> &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, std::vector<ge::DataType> &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, ge::DataType &attr_value) const; | |||||
| graphStatus GetAttr(const ge::AscendString &name, AttrValue &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, int64_t &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, int32_t &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, uint32_t &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, float &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, ge::AscendString &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, bool &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, Tensor &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, std::vector<int64_t> &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, std::vector<int32_t> &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, std::vector<uint32_t> &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, std::vector<float> &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, std::vector<ge::AscendString> &attr_values) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, std::vector<bool> &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, std::vector<Tensor> &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, OpBytes &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, std::vector<std::vector<int64_t>> &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, std::vector<ge::DataType> &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, ge::DataType &attr_value) const; | |||||
| graphStatus SetAttr(const ge::AscendString &name, AttrValue &attr_value) const; | |||||
| bool HasAttr(const ge::AscendString &name); | |||||
| graphStatus GetSubgraph(uint32_t index, GraphPtr graph) const; | |||||
| graphStatus GetALLSubgraphs(std::vector<GraphPtr> graph_list) const; | |||||
| private: | |||||
| std::shared_ptr<NodeImpl> impl_; | |||||
| friend class NodeAdapter; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_EXTERNAL_GRAPH_NODE_H_ | |||||
| @@ -23,11 +23,14 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "./operator.h" | #include "./operator.h" | ||||
| #include "./gnode.h" | |||||
| namespace ge { | namespace ge { | ||||
| class Graph; | |||||
| class GraphImpl; | class GraphImpl; | ||||
| using GraphImplPtr = std::shared_ptr<GraphImpl>; | using GraphImplPtr = std::shared_ptr<GraphImpl>; | ||||
| using GraphPtr = std::shared_ptr<Graph>; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { | class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { | ||||
| friend class GraphUtils; | friend class GraphUtils; | ||||
| @@ -53,15 +56,15 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { | |||||
| graphStatus AddOp(const ge::Operator &op); | graphStatus AddOp(const ge::Operator &op); | ||||
| graphStatus FindOpByName(const string &name, ge::Operator &op) const; | |||||
| graphStatus FindOpByName(const std::string &name, ge::Operator &op) const; | |||||
| graphStatus FindOpByType(const string &type, std::vector<ge::Operator> &ops) const; | |||||
| graphStatus FindOpByType(const std::string &type, std::vector<ge::Operator> &ops) const; | |||||
| graphStatus GetAllOpName(std::vector<string> &op_name) const; | |||||
| graphStatus GetAllOpName(std::vector<std::string> &op_name) const; | |||||
| graphStatus SaveToFile(const string &file_name) const; | |||||
| graphStatus SaveToFile(const std::string &file_name) const; | |||||
| graphStatus LoadFromFile(const string &file_name); | |||||
| graphStatus LoadFromFile(const std::string &file_name); | |||||
| const std::string &GetName() const; | const std::string &GetName() const; | ||||
| @@ -73,6 +76,22 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Graph { | |||||
| /// | /// | ||||
| void SetNeedIteration(bool need_iteration); | void SetNeedIteration(bool need_iteration); | ||||
| std::vector<GNode> GetAllNodes() const; | |||||
| std::vector<GNode> GetDirectNode() const; | |||||
| graphStatus RemoveNode(GNode &node); | |||||
| graphStatus RemoveEdge(GNode &src_node, const int32_t src_port_index, GNode &dst_node, const int32_t dst_port_index); | |||||
| GNode AddNodeByOp(const Operator &op); | |||||
| graphStatus AddDataEdge(GNode &src_node, const int32_t src_port_index, GNode &dst_node, const int32_t dst_port_index); | |||||
| graphStatus AddControlEdge(GNode &src_node, GNode &dst_node); | |||||
| static GraphPtr ConstructFromInputs(const std::vector<Operator> &inputs, const ge::AscendString &name); | |||||
| private: | private: | ||||
| GraphImplPtr impl_{nullptr}; | GraphImplPtr impl_{nullptr}; | ||||
| }; | }; | ||||
| @@ -63,7 +63,6 @@ using std::function; | |||||
| using std::shared_ptr; | using std::shared_ptr; | ||||
| using std::string; | using std::string; | ||||
| /*lint -e148*/ | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | ||||
| public: | public: | ||||
| friend class OperatorImpl; | friend class OperatorImpl; | ||||
| @@ -91,7 +90,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
| explicit Operator(const string &type); | explicit Operator(const string &type); | ||||
| Operator(const string &name, const string &type); // lint !e148 | |||||
| Operator(const string &name, const string &type); | |||||
| virtual ~Operator() = default; | virtual ~Operator() = default; | ||||
| @@ -104,7 +103,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
| // Only has one output index = 0 | // Only has one output index = 0 | ||||
| Operator &SetInput(const string &dst_name, const Operator &src_oprt); | Operator &SetInput(const string &dst_name, const Operator &src_oprt); | ||||
| Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); // lint !e148 | |||||
| Operator &SetInput(const string &dst_name, const Operator &src_oprt, const string &name); | |||||
| Operator &SetInput(const string &dst_name, const Operator &src_oprt, uint32_t index); | Operator &SetInput(const string &dst_name, const Operator &src_oprt, uint32_t index); | ||||
| @@ -128,22 +127,22 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
| TensorDesc GetOutputDesc(uint32_t index) const; | TensorDesc GetOutputDesc(uint32_t index) const; | ||||
| graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); // lint !e148 | |||||
| graphStatus UpdateOutputDesc(const string &name, const TensorDesc &tensor_desc); | |||||
| TensorDesc GetDynamicInputDesc(const string &name, uint32_t index) const; | TensorDesc GetDynamicInputDesc(const string &name, uint32_t index) const; | ||||
| graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 | |||||
| graphStatus UpdateDynamicInputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); | |||||
| TensorDesc GetDynamicOutputDesc(const string &name, uint32_t index) const; | TensorDesc GetDynamicOutputDesc(const string &name, uint32_t index) const; | ||||
| graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); // lint !e148 | |||||
| graphStatus UpdateDynamicOutputDesc(const string &name, uint32_t index, const TensorDesc &tensor_desc); | |||||
| graphStatus InferShapeAndType(); // lint !e148 | |||||
| graphStatus InferShapeAndType(); | |||||
| void SetInferenceContext(const InferenceContextPtr &inference_context); | void SetInferenceContext(const InferenceContextPtr &inference_context); | ||||
| InferenceContextPtr GetInferenceContext() const; | InferenceContextPtr GetInferenceContext() const; | ||||
| graphStatus VerifyAllAttr(bool disable_common_verifier = false); // lint !e148 | |||||
| graphStatus VerifyAllAttr(bool disable_common_verifier = false); | |||||
| size_t GetInputsSize() const; | size_t GetInputsSize() const; | ||||
| @@ -256,20 +255,19 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
| void RequiredAttrRegister(const string &name); | void RequiredAttrRegister(const string &name); | ||||
| graphStatus VerifyAll(); // lint !e148 | |||||
| graphStatus VerifyAll(); | |||||
| // Only has one output index = 0 | // Only has one output index = 0 | ||||
| Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt); | Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt); | ||||
| Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, | |||||
| const string &name); // lint !e148 | |||||
| Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name); | |||||
| void SubgraphRegister(const string &ir_name, bool dynamic); | void SubgraphRegister(const string &ir_name, bool dynamic); | ||||
| void SubgraphCountRegister(const string &ir_name, uint32_t count); | void SubgraphCountRegister(const string &ir_name, uint32_t count); | ||||
| void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder); | void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder); | ||||
| private: | private: | ||||
| Operator &SetInput(const string &dst_name, const OutHandler &out_handler); // lint !e148 | |||||
| Operator &SetInput(const string &dst_name, const OutHandler &out_handler); | |||||
| OutHandler GetOutput(const string &name) const; | OutHandler GetOutput(const string &name) const; | ||||
| @@ -283,7 +281,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
| std::shared_ptr<const Node> GetNode() const; | std::shared_ptr<const Node> GetNode() const; | ||||
| }; | }; | ||||
| /*lint +e148*/ | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_EXTERNAL_GRAPH_OPERATOR_H_ | #endif // INC_EXTERNAL_GRAPH_OPERATOR_H_ | ||||
| @@ -126,6 +126,5 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Tensor { | |||||
| friend class TensorAdapter; | friend class TensorAdapter; | ||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| /*lint +e148*/ | |||||
| #endif // INC_EXTERNAL_GRAPH_TENSOR_H_ | #endif // INC_EXTERNAL_GRAPH_TENSOR_H_ | ||||
| @@ -0,0 +1,134 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /** | |||||
| * @file hccl.h | |||||
| * @brief HCCL API | |||||
| */ | |||||
| #ifndef HCCL_H_ | |||||
| #define HCCL_H_ | |||||
| #include <hccl/hccl_types.h> | |||||
| #include <acl/acl.h> | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif // __cplusplus | |||||
| /** | |||||
| * @brief Initialize HCCL. | |||||
| * | |||||
| * @param clusterInfo A string identifying the cluster info file path, include file name. | |||||
| * @param rank A integer identifying the identify for the rank. | |||||
| * @param comm A pointer identifying the initialized communication resource. | |||||
| * @return HcclResult | |||||
| * @see HcclCommDestroy() | |||||
| */ | |||||
| extern HcclResult HcclCommInitClusterInfo(const char *clusterInfo, uint32_t rank, HcclComm *comm); | |||||
| /** | |||||
| * @brief Get hccl root info. | |||||
| * | |||||
| * @param rootInfo A pointer identifying the hccl root info. | |||||
| * @return HcclResult | |||||
| */ | |||||
| extern HcclResult HcclGetRootInfo(HcclRootInfo *rootInfo); | |||||
| /** | |||||
| * @brief Initialize HCCL with root info. | |||||
| * | |||||
| * @param nRanks A integer identifying the rank size of the cluster. | |||||
| * @param rootInfo A struct identifying the hccl root info. | |||||
| * @param rank A integer identifying the identify for the rank. | |||||
| * @param comm A pointer identifying the initialized communication resource. | |||||
| * @return HcclResult | |||||
| * @see HcclCommDestroy() | |||||
| */ | |||||
| extern HcclResult HcclCommInitRootInfo(uint32_t nRanks, const HcclRootInfo *rootInfo, uint32_t rank, HcclComm *comm); | |||||
| /** | |||||
| * @brief AllReduce operator. | |||||
| * | |||||
| * @param sendBuf A pointer identifying the input data address of the operator. | |||||
| * @param recvBuf A pointer identifying the output data address of the operator. | |||||
| * @param count An integer(u64) identifying the number of the output data. | |||||
| * @param dataType The data type of the operator, must be one of the following types: int8, int16, int32, float16, | |||||
| * float32. | |||||
| * @param op The reduction type of the operator, must be one of the following types: sum, min, max, prod. | |||||
| * @param comm A pointer identifying the communication resource based on. | |||||
| * @param stream A pointer identifying the stream information. | |||||
| * @return HcclResult | |||||
| */ | |||||
| extern HcclResult HcclAllReduce(void *sendBuf, void *recvBuf, uint64_t count, HcclDataType dataType, HcclReduceOp op, | |||||
| HcclComm comm, aclrtStream stream); | |||||
| /** | |||||
| * @brief Broadcast operator. | |||||
| * | |||||
| * @param buf A pointer identifying the data address of the operator. | |||||
| * @param count An integer(u64) identifying the number of the data. | |||||
| * @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32. | |||||
| * @param root An integer(u32) identifying the the root rank in the operator. | |||||
| * @param comm A pointer identifying the communication resource based on | |||||
| * @param stream A pointer identifying the stream information. | |||||
| * @return HcclResult | |||||
| */ | |||||
| extern HcclResult HcclBroadcast(void *buf, uint64_t count, HcclDataType dataType, uint32_t root, HcclComm comm, | |||||
| aclrtStream stream); | |||||
| /** | |||||
| * @brief ReduceScatter operator. | |||||
| * | |||||
| * @param sendBuf A pointer identifying the input data address of the operator. | |||||
| * @param recvBuf A pointer identifying the output data address of the operator. | |||||
| * @param recvCount An integer(u64) identifying the number of the output data. | |||||
| * @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32. | |||||
| * @param op The reduction type of the operator, must be one of the following types: sum, min, max, prod. | |||||
| * @param comm A pointer identifying the communication resource based on. | |||||
| * @param stream A pointer identifying the stream information. | |||||
| * @return HcclResult | |||||
| */ | |||||
| extern HcclResult HcclReduceScatter(void *sendBuf, void *recvBuf, uint64_t recvCount, HcclDataType dataType, | |||||
| HcclReduceOp op, HcclComm comm, aclrtStream stream); | |||||
| /** | |||||
| * @brief AllGather operator. | |||||
| * | |||||
| * @param sendBuf A pointer identifying the input data address of the operator. | |||||
| * @param recvBuf A pointer identifying the output data address of the operator. | |||||
| * @param sendCount An integer(u64) identifying the number of the input data. | |||||
| * @param dataType The data type of the operator, must be one of the following types: int8, int32, float16, float32. | |||||
| * @param comm A pointer identifying the communication resource based on. | |||||
| * @param stream A pointer identifying the stream information. | |||||
| * @return HcclResult | |||||
| */ | |||||
| extern HcclResult HcclAllGather(void *sendBuf, void *recvBuf, uint64_t sendCount, HcclDataType dataType, HcclComm comm, | |||||
| aclrtStream stream); | |||||
| /** | |||||
| * @brief Destroy HCCL comm | |||||
| * | |||||
| * @param comm A pointer identifying the communication resource targetting | |||||
| * @return HcclResult | |||||
| * @see HcclCommInitClusterInfo() | |||||
| */ | |||||
| extern HcclResult HcclCommDestroy(HcclComm comm); | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif // __cplusplus | |||||
| #endif // HCCL_H_ | |||||
| @@ -0,0 +1,101 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| /** | |||||
| * @file hccl_types.h | |||||
| * @brief HCCL data type definition | |||||
| * | |||||
| */ | |||||
| #ifndef HCCL_TYPES_H_ | |||||
| #define HCCL_TYPES_H_ | |||||
| #include <stdint.h> | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif // __cplusplus | |||||
| /** | |||||
| * @brief HCCL functions return value definition | |||||
| */ | |||||
| typedef enum { | |||||
| HCCL_SUCCESS = 0, /**< success */ | |||||
| HCCL_E_PARA = 1, /**< parameter error */ | |||||
| HCCL_E_PTR = 2, /**< empty pointer */ | |||||
| HCCL_E_MEMORY = 3, /**< memory error */ | |||||
| HCCL_E_INTERNAL = 4, /**< internal error */ | |||||
| HCCL_E_NOT_SUPPORT = 5, /**< not support feature */ | |||||
| HCCL_E_NOT_FOUND = 6, /**< not found specific resource */ | |||||
| HCCL_E_UNAVAIL = 7, /**< resource unavailable */ | |||||
| HCCL_E_SYSCALL = 8, /**< call system interface error */ | |||||
| HCCL_E_TIMEOUT = 9, /**< timeout */ | |||||
| HCCL_E_OPEN_FILE_FAILURE = 10, /**< open file fail */ | |||||
| HCCL_E_TCP_CONNECT = 11, /**< tcp connect fail */ | |||||
| HCCL_E_ROCE_CONNECT = 12, /**< roce connect fail */ | |||||
| HCCL_E_TCP_TRANSFER = 13, /**< tcp transfer fail */ | |||||
| HCCL_E_ROCE_TRANSFER = 14, /**< roce transfer fail */ | |||||
| HCCL_E_RUNTIME = 15, /**< call runtime api fail */ | |||||
| HCCL_E_DRV = 16, /**< call driver api fail */ | |||||
| HCCL_E_PROFILING = 17, /**< call profiling api fail */ | |||||
| HCCL_E_CCE = 18, /**< call cce api fail */ | |||||
| HCCL_E_NETWORK = 19, /**< call network api fail */ | |||||
| HCCL_E_RESERVED /**< reserved */ | |||||
| } HcclResult; | |||||
| /** | |||||
| * @brief handle to HCCL communicator | |||||
| */ | |||||
| typedef void *HcclComm; | |||||
| /** | |||||
| * @brief HCCL Reduction opperation | |||||
| */ | |||||
| typedef enum { | |||||
| HCCL_REDUCE_SUM = 0, /**< sum */ | |||||
| HCCL_REDUCE_PROD = 1, /**< prod */ | |||||
| HCCL_REDUCE_MAX = 2, /**< max */ | |||||
| HCCL_REDUCE_MIN = 3, /**< min */ | |||||
| HCCL_REDUCE_RESERVED /**< reserved */ | |||||
| } HcclReduceOp; | |||||
| /** | |||||
| * @brief HCCL data type | |||||
| */ | |||||
| typedef enum { | |||||
| HCCL_DATA_TYPE_INT8 = 0, /**< int8 */ | |||||
| HCCL_DATA_TYPE_INT16 = 1, /**< int16 */ | |||||
| HCCL_DATA_TYPE_INT32 = 2, /**< int32 */ | |||||
| HCCL_DATA_TYPE_FP16 = 3, /**< fp16 */ | |||||
| HCCL_DATA_TYPE_FP32 = 4, /**< fp32 */ | |||||
| HCCL_DATA_TYPE_INT64 = 5, /**< int64 */ | |||||
| HCCL_DATA_TYPE_UINT64 = 6, /**< uint64 */ | |||||
| HCCL_DATA_TYPE_RESERVED /**< reserved */ | |||||
| } HcclDataType; | |||||
| const uint32_t HCCL_ROOT_INFO_BYTES = 4108; // 4108: root info length | |||||
| /** | |||||
| * @brief HCCL root info | |||||
| */ | |||||
| typedef struct HcclRootInfoDef { | |||||
| char internal[HCCL_ROOT_INFO_BYTES]; | |||||
| } HcclRootInfo; | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif // __cplusplus | |||||
| #endif // HCCL_TYPES_H_ | |||||
| @@ -40,7 +40,6 @@ using std::to_string; | |||||
| using std::unique_ptr; | using std::unique_ptr; | ||||
| using std::vector; | using std::vector; | ||||
| /*lint -e148*/ | |||||
| namespace ge { | namespace ge { | ||||
| class Operator; | class Operator; | ||||
| class TensorDesc; | class TensorDesc; | ||||
| @@ -159,5 +158,4 @@ namespace ge { | |||||
| using OpRegistrationData = domi::OpRegistrationData; | using OpRegistrationData = domi::OpRegistrationData; | ||||
| using OpReceiver = domi::OpReceiver; | using OpReceiver = domi::OpReceiver; | ||||
| } // namespace ge | } // namespace ge | ||||
| /*lint +e148*/ | |||||
| #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | ||||
| @@ -301,7 +301,6 @@ class GE_FUNC_HOST_VISIBILITY GE_FUNC_DEV_VISIBILITY ScopeFusionPassRegistry { | |||||
| private: | private: | ||||
| ScopeFusionPassRegistry(); | ScopeFusionPassRegistry(); | ||||
| class ScopeFusionPassRegistryImpl; | class ScopeFusionPassRegistryImpl; | ||||
| /*lint -e148*/ | |||||
| std::unique_ptr<ScopeFusionPassRegistryImpl> impl_; | std::unique_ptr<ScopeFusionPassRegistryImpl> impl_; | ||||
| friend class TensorFlowModelParser; | friend class TensorFlowModelParser; | ||||
| }; | }; | ||||
| @@ -14,7 +14,6 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| /*lint -e* */ | |||||
| #ifndef INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ | #ifndef INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ | ||||
| #define INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ | #define INC_FRAMEWORK_COMMON_GE_INNER_ERROR_CODES_H_ | ||||
| @@ -304,6 +303,7 @@ GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_WEIGHT_MEM_FAILED, 16, "Failed to allocate wei | |||||
| GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_VAR_MEM_FAILED, 17, "Failed to allocate variable memory."); | GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_VAR_MEM_FAILED, 17, "Failed to allocate variable memory."); | ||||
| GE_ERRORNO_EXECUTOR(GE_AIPP_NOT_EXIST, 18, "GE AIPP is not exist."); | GE_ERRORNO_EXECUTOR(GE_AIPP_NOT_EXIST, 18, "GE AIPP is not exist."); | ||||
| GE_ERRORNO_EXECUTOR(GE_DYNAMIC_AIPP_NOT_SUPPORT_QUERY, 19, "GE Dynamic AIPP is not support to query temporarily."); | GE_ERRORNO_EXECUTOR(GE_DYNAMIC_AIPP_NOT_SUPPORT_QUERY, 19, "GE Dynamic AIPP is not support to query temporarily."); | ||||
| GE_ERRORNO_EXECUTOR(GE_EXEC_ALLOC_P2P_MEM_FAILED, 20, "Failed to allocate P2P memory"); | |||||
| // Generator module error code definition | // Generator module error code definition | ||||
| GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, 1, "Graph manager initialize failed."); | GE_ERRORNO_GENERATOR(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, 1, "Graph manager initialize failed."); | ||||
| @@ -21,7 +21,6 @@ | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <string> | #include <string> | ||||
| #include "common/types.h" | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "proto/om.pb.h" | #include "proto/om.pb.h" | ||||
| @@ -22,7 +22,8 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "common/op/attr_value_util.h" | #include "common/op/attr_value_util.h" | ||||
| #include "common/types.h" | |||||
| #include "register/register_types.h" | |||||
| #include "register/register_error_codes.h" | |||||
| #include "common/util.h" | #include "common/util.h" | ||||
| #include "graph/attr_value.h" | #include "graph/attr_value.h" | ||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| @@ -36,8 +36,8 @@ class StringUtils { | |||||
| #endif | #endif | ||||
| return s; | return s; | ||||
| } | } | ||||
| // lint -esym(551,*) | |||||
| static std::string &Rtrim(std::string &s) { /*lint !e618*/ | |||||
| static std::string &Rtrim(std::string &s) { | |||||
| #if __cplusplus >= 201103L | #if __cplusplus >= 201103L | ||||
| (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); | (void)s.erase(s.begin(), std::find_if(s.begin(), s.end(), [](int c) { return !std::isspace(c); })); | ||||
| #else | #else | ||||
| @@ -45,7 +45,7 @@ class StringUtils { | |||||
| #endif | #endif | ||||
| return s; | return s; | ||||
| } | } | ||||
| // lint -esym(551,*) | |||||
| /// | /// | ||||
| /// @ingroup domi_common | /// @ingroup domi_common | ||||
| /// @brief delete spaces at the beginning and end of a string | /// @brief delete spaces at the beginning and end of a string | ||||
| @@ -61,10 +61,8 @@ class StringUtils { | |||||
| /// @param [in] delim separator | /// @param [in] delim separator | ||||
| /// @return string array after segmentation | /// @return string array after segmentation | ||||
| /// | /// | ||||
| /*lint -e1077*/ | |||||
| static std::vector<std::string> Split(const std::string &str, char delim) { | static std::vector<std::string> Split(const std::string &str, char delim) { | ||||
| std::vector<std::string> elems; | std::vector<std::string> elems; | ||||
| /*lint +e1077*/ | |||||
| if (str.empty()) { | if (str.empty()) { | ||||
| elems.emplace_back(""); | elems.emplace_back(""); | ||||
| @@ -434,6 +434,7 @@ REGISTER_OPTYPE_DECLARE(HCOMREDUCESCATTER, "HcomReduceScatter"); | |||||
| REGISTER_OPTYPE_DECLARE(HCOMSEND, "HcomSend"); | REGISTER_OPTYPE_DECLARE(HCOMSEND, "HcomSend"); | ||||
| REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive"); | REGISTER_OPTYPE_DECLARE(HCOMRECEIVE, "HcomReceive"); | ||||
| REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); | REGISTER_OPTYPE_DECLARE(HCOMREMOTEREAD, "HcomRemoteRead"); | ||||
| REGISTER_OPTYPE_DECLARE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | |||||
| REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | REGISTER_OPTYPE_DECLARE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | ||||
| REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); | REGISTER_OPTYPE_DECLARE(VARASSIGN, "VarAssign"); | ||||
| @@ -345,7 +345,7 @@ std::string ToString(const google::protobuf::RepeatedField<T> &rpd_field) { | |||||
| /// @return Timestamp, in microseconds (US) | /// @return Timestamp, in microseconds (US) | ||||
| /// | /// | ||||
| /// | /// | ||||
| uint64_t GetCurrentTimestap(); | |||||
| uint64_t GetCurrentTimestamp(); | |||||
| /// | /// | ||||
| /// @ingroup domi_common | /// @ingroup domi_common | ||||
| @@ -30,6 +30,7 @@ enum PriorityEnum { | |||||
| COST_0 = 0, | COST_0 = 0, | ||||
| COST_1, | COST_1, | ||||
| COST_2, | COST_2, | ||||
| COST_3, | |||||
| COST_9 = 9, | COST_9 = 9, | ||||
| COST_10 = 10, | COST_10 = 10, | ||||
| }; | }; | ||||
| @@ -86,6 +86,7 @@ class GeGenerator { | |||||
| Status BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs, | Status BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs, | ||||
| const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, | const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, | ||||
| bool is_offline = true); | bool is_offline = true); | ||||
| Status CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs); | |||||
| class Impl; | class Impl; | ||||
| @@ -21,6 +21,7 @@ | |||||
| #include <vector> | #include <vector> | ||||
| #include "ge/ge_api_error_codes.h" | #include "ge/ge_api_error_codes.h" | ||||
| #include "graph//types.h" | |||||
| #include "runtime/mem.h" | #include "runtime/mem.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -35,6 +36,12 @@ struct HostVarInfo { | |||||
| uint64_t var_size; | uint64_t var_size; | ||||
| }; | }; | ||||
| struct TensorInfo { | |||||
| std::string var_name; | |||||
| std::vector<int64_t> dims; | |||||
| DataType data_type; | |||||
| }; | |||||
| /// | /// | ||||
| /// \param size [in] rdma pool memory size to be allocated. | /// \param size [in] rdma pool memory size to be allocated. | ||||
| /// \param mem_type [in] memory type for rdma pool. | /// \param mem_type [in] memory type for rdma pool. | ||||
| @@ -47,6 +54,13 @@ Status InitRdmaPool(size_t size, rtMemType_t mem_type = RT_MEMORY_HBM); | |||||
| /// \return Status result of function | /// \return Status result of function | ||||
| Status RdmaRemoteRegister(const std::vector<HostVarInfo> &var_info, rtMemType_t mem_type = RT_MEMORY_HBM); | Status RdmaRemoteRegister(const std::vector<HostVarInfo> &var_info, rtMemType_t mem_type = RT_MEMORY_HBM); | ||||
| /// | |||||
| /// \param tensor_info [in] description for tensor stored shared memory. | |||||
| /// \param dev_addr [out] malloced shared memory addr. | |||||
| /// \param memory_size [out] malloced shared memory size. | |||||
| /// \return Status result of function | |||||
| Status MallocSharedMemory(const TensorInfo &tensor_info, uint64_t &dev_addr, uint64_t &memory_size); | |||||
| /// | /// | ||||
| /// \param var_name [in] var_name name of host variable. | /// \param var_name [in] var_name name of host variable. | ||||
| /// \param base_addr [out] base_addr vase addr of host variable. | /// \param base_addr [out] base_addr vase addr of host variable. | ||||
| @@ -33,7 +33,7 @@ class MemoryAssigner { | |||||
| MemoryAssigner &operator=(const MemoryAssigner &) = delete; | MemoryAssigner &operator=(const MemoryAssigner &) = delete; | ||||
| Status AssignMemory(bool is_loop_graph, size_t &mem_offset, size_t &zero_copy_mem_size); | |||||
| Status AssignMemory(bool is_loop_graph, map<int64_t, size_t> &mem_offset, size_t &zero_copy_mem_size); | |||||
| private: | private: | ||||
| ge::ComputeGraphPtr compute_graph_; | ge::ComputeGraphPtr compute_graph_; | ||||
| @@ -21,7 +21,6 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | #include <vector> | ||||
| #include "framework/common/types.h" | |||||
| #include "framework/omg/omg_inner_types.h" | #include "framework/omg/omg_inner_types.h" | ||||
| #include "framework/omg/parser/parser_inner_ctx.h" | #include "framework/omg/parser/parser_inner_ctx.h" | ||||
| #include "proto/ge_ir.pb.h" | #include "proto/ge_ir.pb.h" | ||||
| @@ -92,8 +91,6 @@ void GetGroupName(ge::proto::ModelDef &model); | |||||
| void FindParserSo(const string &path, vector<string> &fileList, string &caffe_parser_path); | void FindParserSo(const string &path, vector<string> &fileList, string &caffe_parser_path); | ||||
| Status CheckCustomAiCpuOpLib(); | |||||
| Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); | Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); | ||||
| Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); | Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); | ||||
| @@ -25,7 +25,6 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include <vector> | #include <vector> | ||||
| #include "framework/common/fmk_error_codes.h" | #include "framework/common/fmk_error_codes.h" | ||||
| #include "framework/common/types.h" | |||||
| #include "register/register_fmk_types.h" | #include "register/register_fmk_types.h" | ||||
| using domi::DOMI_TENSOR_ND; | using domi::DOMI_TENSOR_ND; | ||||
| @@ -92,6 +91,8 @@ struct OmgContext { | |||||
| std::map<std::string, std::vector<int32_t>> out_nodes_map; | std::map<std::string, std::vector<int32_t>> out_nodes_map; | ||||
| // user-designate out nodes (this is used for determing the orders) | // user-designate out nodes (this is used for determing the orders) | ||||
| std::vector<std::pair<std::string, int32_t>> user_out_nodes; | std::vector<std::pair<std::string, int32_t>> user_out_nodes; | ||||
| // default out nodes (this is used for determing the orders) | |||||
| std::vector<std::pair<std::string, int32_t>> default_out_nodes; | |||||
| // save the output node of the network, value = topName, | // save the output node of the network, value = topName, | ||||
| // topName indicates the output name of the operator. | // topName indicates the output name of the operator. | ||||
| std::vector<std::string> user_out_nodes_top_vec; | std::vector<std::string> user_out_nodes_top_vec; | ||||
| @@ -99,8 +100,6 @@ struct OmgContext { | |||||
| std::vector<std::string> net_out_nodes; | std::vector<std::string> net_out_nodes; | ||||
| // net out nodes top names(only caffe has top) | // net out nodes top names(only caffe has top) | ||||
| std::vector<std::string> out_top_names; | std::vector<std::string> out_top_names; | ||||
| // path for the aicpu custom operator so_file | |||||
| std::vector<std::string> aicpu_op_run_paths; | |||||
| // preferential format used by the entire network | // preferential format used by the entire network | ||||
| domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | domiTensorFormat_t net_format = DOMI_TENSOR_RESERVED; | ||||
| domi::FrameworkType type = domi::FRAMEWORK_RESERVED; | domi::FrameworkType type = domi::FRAMEWORK_RESERVED; | ||||
| @@ -57,11 +57,11 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Buffer { | |||||
| // For compatibility | // For compatibility | ||||
| inline const std::uint8_t *data() const { return GetData(); } | inline const std::uint8_t *data() const { return GetData(); } | ||||
| inline std::uint8_t *data() { return GetData(); } // lint !e659 | |||||
| inline std::uint8_t *data() { return GetData(); } | |||||
| inline std::size_t size() const { return GetSize(); } | inline std::size_t size() const { return GetSize(); } | ||||
| inline void clear() { return ClearBuffer(); } | inline void clear() { return ClearBuffer(); } | ||||
| uint8_t operator[](size_t index) const { // lint !e1022 !e1042 | |||||
| if (buffer_ != nullptr && index < buffer_->size()) { // lint !e574 | |||||
| uint8_t operator[](size_t index) const { | |||||
| if (buffer_ != nullptr && index < buffer_->size()) { | |||||
| return (uint8_t)(*buffer_)[index]; | return (uint8_t)(*buffer_)[index]; | ||||
| } | } | ||||
| return 0xff; | return 0xff; | ||||
| @@ -84,7 +84,6 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
| NodePtr FindNode(const std::string &name) const; | NodePtr FindNode(const std::string &name) const; | ||||
| NodePtr FindFirstNodeMatchType(const std::string &name) const; | NodePtr FindFirstNodeMatchType(const std::string &name) const; | ||||
| /*lint -e504*/ | |||||
| // AddNode with NodePtr | // AddNode with NodePtr | ||||
| NodePtr AddNode(NodePtr node); | NodePtr AddNode(NodePtr node); | ||||
| NodePtr AddNode(OpDescPtr op); | NodePtr AddNode(OpDescPtr op); | ||||
| @@ -152,7 +151,6 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
| graphStatus InsertEventNodes(); | graphStatus InsertEventNodes(); | ||||
| bool operator==(const ComputeGraph &r_compute_graph) const; | bool operator==(const ComputeGraph &r_compute_graph) const; | ||||
| /*lint +e504*/ | |||||
| const std::map<std::vector<std::string>, std::vector<std::string>> &GetShareParamLayer() const { | const std::map<std::vector<std::string>, std::vector<std::string>> &GetShareParamLayer() const { | ||||
| return params_share_map_; | return params_share_map_; | ||||
| } | } | ||||
| @@ -14,7 +14,6 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| /*lint -e618*/ | |||||
| #ifndef INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | #ifndef INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | ||||
| #define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | #define INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | ||||
| @@ -33,6 +32,8 @@ namespace ge { | |||||
| #define GE_FUNC_DEV_VISIBILITY | #define GE_FUNC_DEV_VISIBILITY | ||||
| #endif | #endif | ||||
| // Public attribute | // Public attribute | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORCE_UNKNOWN_SHAPE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_UNKNOWN_SHAPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IS_UNKNOWN_SHAPE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED; | ||||
| @@ -1021,8 +1022,6 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_KEY; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_VIRTUAL_OP; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FUSION_GROUP_TYPE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION; | ||||
| @@ -1044,6 +1043,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_NAME; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_BUFFER; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TBE_KERNEL_BUFFER; | ||||
| // used for memory allocate | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WORKSPACE_TYPE_LIST; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_TENSOR_MEM_TYPE; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_P2P_MEMORY_SIZE; | |||||
| // for unregistered op | // for unregistered op | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_OPPATH; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_OPPATH; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_ATTRLIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_UNREGST_ATTRLIST; | ||||
| @@ -1121,10 +1127,12 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_VAR | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INPUT_MEMORY_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INPUT_MEMORY_TYPE; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OUTPUT_MEMORY_TYPE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_OUTPUT_MEMORY_TYPE; | ||||
| // stage | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_STAGE_LEVEL; | |||||
| // input_output_offset | // input_output_offset | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_BASIC_OFFSET; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_BASIC_OFFSET; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | ||||
| /*lint +e618*/ | |||||
| @@ -38,7 +38,7 @@ class TypeID { | |||||
| bool operator==(const TypeID &__arg) const { return type_ == __arg.type_; } | bool operator==(const TypeID &__arg) const { return type_ == __arg.type_; } | ||||
| private: | private: | ||||
| explicit TypeID(string type) : type_(std::move(type)) {} // lint !e30 !e32 | |||||
| explicit TypeID(string type) : type_(std::move(type)) {} | |||||
| string type_; | string type_; | ||||
| }; | }; | ||||
| @@ -50,7 +50,7 @@ class OpDef; | |||||
| class GraphDef; | class GraphDef; | ||||
| } // namespace proto | } // namespace proto | ||||
| using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>; // lint !e1073 | |||||
| using ProtoAttrMap = ::google::protobuf::Map<::std::string, ::ge::proto::AttrDef>; | |||||
| using ProtoMsgOwner = std::shared_ptr<::google::protobuf::Message>; | using ProtoMsgOwner = std::shared_ptr<::google::protobuf::Message>; | ||||
| template <class ProtoType> | template <class ProtoType> | ||||
| @@ -147,7 +147,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY AttrHolder { | |||||
| protected: | protected: | ||||
| graphStatus AddRequiredAttr(const std::string &name); | graphStatus AddRequiredAttr(const std::string &name); | ||||
| const std::unordered_set<string> GetAllAttrNames() const; | const std::unordered_set<string> GetAllAttrNames() const; | ||||
| const std::map<string, GeAttrValue> GetAllAttrs() const; // lint !e1073 | |||||
| const std::map<string, GeAttrValue> GetAllAttrs() const; | |||||
| virtual ProtoAttrMapHelper MutableAttrMap() = 0; | virtual ProtoAttrMapHelper MutableAttrMap() = 0; | ||||
| virtual ConstProtoAttrMapHelper GetAttrMap() const = 0; | virtual ConstProtoAttrMapHelper GetAttrMap() const = 0; | ||||
| @@ -310,7 +310,7 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||||
| VALUE_SET_GET_DEC(GeAttrValue::GRAPH) | VALUE_SET_GET_DEC(GeAttrValue::GRAPH) | ||||
| VALUE_SET_GET_DEC(BYTES) | VALUE_SET_GET_DEC(BYTES) | ||||
| VALUE_SET_GET_DEC(NamedAttrs) | VALUE_SET_GET_DEC(NamedAttrs) | ||||
| VALUE_SET_GET_DEC(ge::DataType) // lint !e665 | |||||
| VALUE_SET_GET_DEC(ge::DataType) | |||||
| VALUE_SET_GET_DEC(vector<GeAttrValue::STR>) | VALUE_SET_GET_DEC(vector<GeAttrValue::STR>) | ||||
| VALUE_SET_GET_DEC(vector<GeAttrValue::INT>) | VALUE_SET_GET_DEC(vector<GeAttrValue::INT>) | ||||
| VALUE_SET_GET_DEC(vector<GeAttrValue::FLOAT>) | VALUE_SET_GET_DEC(vector<GeAttrValue::FLOAT>) | ||||
| @@ -320,8 +320,8 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY GeAttrValue { | |||||
| VALUE_SET_GET_DEC(vector<GeAttrValue::GRAPH>) | VALUE_SET_GET_DEC(vector<GeAttrValue::GRAPH>) | ||||
| VALUE_SET_GET_DEC(vector<GeAttrValue::BYTES>) | VALUE_SET_GET_DEC(vector<GeAttrValue::BYTES>) | ||||
| VALUE_SET_GET_DEC(vector<NamedAttrs>) | VALUE_SET_GET_DEC(vector<NamedAttrs>) | ||||
| VALUE_SET_GET_DEC(vector<vector<int64_t>>) // lint !e665 | |||||
| VALUE_SET_GET_DEC(vector<ge::DataType>) // lint !e665 | |||||
| VALUE_SET_GET_DEC(vector<vector<int64_t>>) | |||||
| VALUE_SET_GET_DEC(vector<ge::DataType>) | |||||
| #undef VALUE_SET_GET_DEC | #undef VALUE_SET_GET_DEC | ||||
| GeIrProtoHelper<proto::AttrDef> value_; | GeIrProtoHelper<proto::AttrDef> value_; | ||||
| @@ -33,7 +33,7 @@ class GEContext { | |||||
| void SetCtxDeviceId(uint32_t device_id); | void SetCtxDeviceId(uint32_t device_id); | ||||
| private: | private: | ||||
| uint64_t session_id_ = 0; | |||||
| thread_local static uint64_t session_id_; | |||||
| uint32_t device_id_ = 0; | uint32_t device_id_ = 0; | ||||
| uint64_t trace_id_ = 0; | uint64_t trace_id_ = 0; | ||||
| }; // class GEContext | }; // class GEContext | ||||
| @@ -33,6 +33,11 @@ class GEThreadLocalContext { | |||||
| void SetSessionOption(map<std::string, string> options_map); | void SetSessionOption(map<std::string, string> options_map); | ||||
| void SetGlobalOption(map<std::string, string> options_map); | void SetGlobalOption(map<std::string, string> options_map); | ||||
| map<string, string> GetAllGraphOptions() const; | |||||
| map<string, string> GetAllSessionOptions() const; | |||||
| map<string, string> GetAllGlobalOptions() const; | |||||
| map<string, string> GetAllOptions() const; | |||||
| private: | private: | ||||
| map<string, string> graph_options_; | map<string, string> graph_options_; | ||||
| map<string, string> session_options_; | map<string, string> session_options_; | ||||
| @@ -193,7 +193,7 @@ class Node : public std::enable_shared_from_this<Node> { | |||||
| vector<OutDataAnchorPtr> out_data_anchors_; | vector<OutDataAnchorPtr> out_data_anchors_; | ||||
| InControlAnchorPtr in_control_anchor_; | InControlAnchorPtr in_control_anchor_; | ||||
| OutControlAnchorPtr out_control_anchor_; | OutControlAnchorPtr out_control_anchor_; | ||||
| map<string, GeAttrValue> attrs_; // lint !e1073 | |||||
| map<string, GeAttrValue> attrs_; | |||||
| bool has_init_{false}; | bool has_init_{false}; | ||||
| bool host_node_{false}; | bool host_node_{false}; | ||||
| bool anchor_status_updated_{false}; | bool anchor_status_updated_{false}; | ||||
| @@ -22,10 +22,8 @@ | |||||
| template <class E, class O> | template <class E, class O> | ||||
| class RangeVistor { | class RangeVistor { | ||||
| public: | public: | ||||
| /*lint -e151*/ | |||||
| using Iterator = typename std::vector<E>::iterator; | using Iterator = typename std::vector<E>::iterator; | ||||
| using ConstIterator = typename std::vector<E>::const_iterator; | using ConstIterator = typename std::vector<E>::const_iterator; | ||||
| /*lint +e151*/ | |||||
| RangeVistor(O owner, const std::vector<E> &vs) : owner_(owner), elements_(vs) {} | RangeVistor(O owner, const std::vector<E> &vs) : owner_(owner), elements_(vs) {} | ||||
| @@ -43,9 +41,7 @@ class RangeVistor { | |||||
| bool empty() const { return elements_.empty(); } | bool empty() const { return elements_.empty(); } | ||||
| /*lint -e659*/ | |||||
| E &at(std::size_t index) { return elements_.at(index); } | E &at(std::size_t index) { return elements_.at(index); } | ||||
| /*lint +e659*/ | |||||
| const E &at(std::size_t index) const { return elements_.at(index); } | const E &at(std::size_t index) const { return elements_.at(index); } | ||||
| @@ -19,18 +19,18 @@ | |||||
| #include <fstream> | #include <fstream> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <list> | |||||
| #include <map> | #include <map> | ||||
| #include <string> | #include <string> | ||||
| #include <vector> | |||||
| #include <list> | |||||
| #include <unordered_map> | #include <unordered_map> | ||||
| #include <vector> | |||||
| #include "graph/anchor.h" | #include "graph/anchor.h" | ||||
| #include "graph/node.h" | |||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "graph/utils/anchor_utils.h" | |||||
| #include "graph/graph.h" | #include "graph/graph.h" | ||||
| #include "graph/model.h" | #include "graph/model.h" | ||||
| #include "graph/node.h" | |||||
| #include "graph/utils/anchor_utils.h" | |||||
| #define GE_DUMP(compute_graph, name) \ | #define GE_DUMP(compute_graph, name) \ | ||||
| do { \ | do { \ | ||||
| @@ -206,6 +206,8 @@ class GraphUtils { | |||||
| static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false, | static void DumpGEGraph(const ge::ComputeGraphPtr &graph, const std::string &suffix, bool is_always_dump = false, | ||||
| const std::string &user_graph_name = ""); | const std::string &user_graph_name = ""); | ||||
| static void DumpGEGrph(const ge::ComputeGraphPtr &graph, const std::string &path, const std::string &suffix); | |||||
| static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); | static bool LoadGEGraph(const char *file, ge::ComputeGraph &compute_graph); | ||||
| static bool LoadGEGraph(const char *file, ge::ComputeGraphPtr &compute_graph); | static bool LoadGEGraph(const char *file, ge::ComputeGraphPtr &compute_graph); | ||||
| @@ -214,6 +216,8 @@ class GraphUtils { | |||||
| static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); | static void DumpGEGraphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &suffix); | ||||
| static void DumpGrphToOnnx(const ge::ComputeGraph &compute_graph, const std::string &path, const std::string &suffix); | |||||
| static bool LoadGEGraphFromOnnx(const char *file, ge::ComputeGraph &compute_graph); | static bool LoadGEGraphFromOnnx(const char *file, ge::ComputeGraph &compute_graph); | ||||
| static bool ReadProtoFromTextFile(const char *file, google::protobuf::Message *message); | static bool ReadProtoFromTextFile(const char *file, google::protobuf::Message *message); | ||||
| @@ -559,7 +563,8 @@ class ComputeGraphBuilder { | |||||
| class CompleteGraphBuilder : public ComputeGraphBuilder { | class CompleteGraphBuilder : public ComputeGraphBuilder { | ||||
| public: | public: | ||||
| explicit CompleteGraphBuilder(std::string name) : name_(std::move(name)), parent_node_(nullptr) {} | |||||
| explicit CompleteGraphBuilder(std::string name, bool retval_flag = true) | |||||
| : name_(std::move(name)), parent_node_(nullptr), retval_flag_(retval_flag) {} | |||||
| CompleteGraphBuilder(const CompleteGraphBuilder &) = delete; | CompleteGraphBuilder(const CompleteGraphBuilder &) = delete; | ||||
| CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete; | CompleteGraphBuilder &operator=(const CompleteGraphBuilder &) = delete; | ||||
| CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete; | CompleteGraphBuilder(const CompleteGraphBuilder &&) = delete; | ||||
| @@ -687,8 +692,37 @@ class CompleteGraphBuilder : public ComputeGraphBuilder { | |||||
| /// | /// | ||||
| void BuildGraphTargets(graphStatus &error_code, std::string &error_msg); | void BuildGraphTargets(graphStatus &error_code, std::string &error_msg); | ||||
| /// | |||||
| /// @brief Add NetOutput node | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void AddNetOutputNode(graphStatus &error_code, std::string &error_msg); | |||||
| /// | |||||
| /// @brief Build NetOutput nodes with data & ctrl edges | |||||
| /// @param [in] net_output_desc | |||||
| /// @param [in] peer_out_anchors | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void BuildNetOutputNodeWithLink(const OpDescPtr &net_output_desc, | |||||
| const std::vector<OutDataAnchorPtr> &peer_out_anchors, graphStatus &error_code, | |||||
| std::string &error_msg); | |||||
| /// | |||||
| /// @brief process after build | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void PostProcess(graphStatus &error_code, std::string &error_msg); | |||||
| std::string name_; | std::string name_; | ||||
| NodePtr parent_node_; | NodePtr parent_node_; | ||||
| bool retval_flag_; | |||||
| std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_; | std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_; | ||||
| std::vector<std::pair<std::string, uint32_t>> graph_outputs_; | std::vector<std::pair<std::string, uint32_t>> graph_outputs_; | ||||
| std::vector<std::string> graph_targets_; | std::vector<std::string> graph_targets_; | ||||
| @@ -0,0 +1,32 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef INC_GRAPH_UTILS_NODE_ADAPTER_H_ | |||||
| #define INC_GRAPH_UTILS_NODE_ADAPTER_H_ | |||||
| #include "graph/gnode.h" | |||||
| #include "graph/node.h" | |||||
| namespace ge { | |||||
| using NodePtr = std::shared_ptr<Node>; | |||||
| class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodeAdapter { | |||||
| public: | |||||
| static GNode Node2GNode(const NodePtr &node); | |||||
| static NodePtr GNode2Node(const GNode &node); | |||||
| static GNodePtr Node2GNodePtr(const NodePtr &node); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // INC_GRAPH_UTILS_NODE_ADAPTER_H_ | |||||
| @@ -83,6 +83,7 @@ class NodeUtils { | |||||
| static std::string GetNodeType(const Node &node); | static std::string GetNodeType(const Node &node); | ||||
| static std::string GetNodeType(const NodePtr &node); | static std::string GetNodeType(const NodePtr &node); | ||||
| static std::vector<ComputeGraphPtr> GetAllSubgraphs(const Node &node); | |||||
| static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); | static ComputeGraphPtr GetSubgraph(const Node &node, uint32_t index); | ||||
| static graphStatus SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph); | static graphStatus SetSubgraph(Node &node, uint32_t index, const ComputeGraphPtr &subgraph); | ||||
| @@ -162,6 +163,13 @@ class NodeUtils { | |||||
| static graphStatus GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor); | static graphStatus GetInputConstData(const Node &node, const string &dst_name, GeTensorPtr &ge_tensor); | ||||
| /// | |||||
| /// @brief Get node type in cross subgragh. | |||||
| /// @param [in] node | |||||
| /// @return type | |||||
| /// | |||||
| static std::string GetInConstNodeTypeCrossSubgraph(const ge::NodePtr &node); | |||||
| private: | private: | ||||
| static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | ||||
| static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_; | static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_; | ||||
| @@ -14,20 +14,20 @@ | |||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef AICORE_PARAM_CALCULATOR | |||||
| #define AICORE_PARAM_CALCULATOR | |||||
| #include "external/graph/ascend_string.h" | |||||
| #include "graph/node.h" | |||||
| #include "graph_optimizer/graph_optimize_register_error_codes.h" | |||||
| namespace ge { | |||||
| AscendString::AscendString(const char* name) { | |||||
| if (name != nullptr) { | |||||
| name_ = std::shared_ptr<std::string>(new (std::nothrow) std::string(name)); | |||||
| } | |||||
| } | |||||
| namespace fe { | |||||
| class AICoreParamCalculator { | |||||
| public: | |||||
| AICoreParamCalculator(); | |||||
| const char* AscendString::GetString() const { | |||||
| if (name_ == nullptr) { | |||||
| return nullptr; | |||||
| } | |||||
| ~AICoreParamCalculator(); | |||||
| Status CalcOpRunningParam(ge::Node &node); | |||||
| }; | |||||
| } // namespace fe | |||||
| #endif // AICORE_PARAM_CALCULATOR | |||||
| return (*name_).c_str(); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -41,6 +41,7 @@ using namespace ge; | |||||
| using namespace std; | using namespace std; | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const size_t kDimSize4d = 4; | |||||
| const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; | const std::unordered_set<string> kChangeDimNodes = {PERMUTE, EXPANDDIMS, SQUEEZE}; | ||||
| const string kIsGraphInferred = "_is_graph_inferred"; | const string kIsGraphInferred = "_is_graph_inferred"; | ||||
| thread_local RefRelations reflection_builder; | thread_local RefRelations reflection_builder; | ||||
| @@ -410,28 +411,26 @@ graphStatus FormatRefiner::DataNodeFormatProcess(const ComputeGraphPtr &graph, s | |||||
| GE_CHECK_NOTNULL(data_node); | GE_CHECK_NOTNULL(data_node); | ||||
| auto op_desc = data_node->GetOpDesc(); | auto op_desc = data_node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(0)); | |||||
| auto curr_format = op_desc->GetOutputDescPtr(0)->GetOriginFormat(); | |||||
| auto input_desc = op_desc->MutableInputDesc(0); | |||||
| auto output_desc = op_desc->MutableOutputDesc(0); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| auto curr_format = output_desc->GetOriginFormat(); | |||||
| if (curr_format != FORMAT_ND) { | if (curr_format != FORMAT_ND) { | ||||
| // Data format has been infered , continue | // Data format has been infered , continue | ||||
| continue; | continue; | ||||
| } | } | ||||
| // Set format for un-infered data node | |||||
| auto input_descs = op_desc->GetAllInputsDescPtr(); | |||||
| auto output_descs = op_desc->GetAllOutputsDescPtr(); | |||||
| for (const auto &input_desc : input_descs) { | |||||
| if (input_desc != nullptr) { | |||||
| input_desc->SetOriginFormat(data_format); | |||||
| input_desc->SetFormat(data_format); | |||||
| } | |||||
| } | |||||
| for (const auto &output_desc : output_descs) { | |||||
| if (output_desc != nullptr) { | |||||
| output_desc->SetOriginFormat(data_format); | |||||
| output_desc->SetFormat(data_format); | |||||
| } | |||||
| // keep data format be ND because lacking of defination when input shape num is smaller than 4 | |||||
| if (input_desc->MutableShape().GetDimNum() < kDimSize4d) { | |||||
| continue; | |||||
| } | } | ||||
| // Set format for un-infered data node | |||||
| input_desc->SetOriginFormat(data_format); | |||||
| input_desc->SetFormat(data_format); | |||||
| output_desc->SetOriginFormat(data_format); | |||||
| output_desc->SetFormat(data_format); | |||||
| uninfered_data_nodes.push_back(data_node); | uninfered_data_nodes.push_back(data_node); | ||||
| } | } | ||||
| // Reinfer format from uninfered data nodes | // Reinfer format from uninfered data nodes | ||||
| @@ -18,6 +18,8 @@ | |||||
| namespace ge { | namespace ge { | ||||
| // Public attribute | // Public attribute | ||||
| const std::string ATTR_NAME_FORCE_UNKNOWN_SHAPE = "_force_unknown_shape"; | |||||
| const std::string ATTR_NAME_IS_UNKNOWN_SHAPE = "_is_unknown_shape"; | const std::string ATTR_NAME_IS_UNKNOWN_SHAPE = "_is_unknown_shape"; | ||||
| const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED = "_dynamic_shape_partitioned"; | const std::string ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED = "_dynamic_shape_partitioned"; | ||||
| @@ -718,6 +720,8 @@ const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; | |||||
| const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size"; | const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size"; | ||||
| const std::string ATTR_MODEL_P2P_MEMORY_SIZE = "p2p_memory_size"; | |||||
| const std::string ATTR_MODEL_OUT_NODES_NAME = "attr_model_out_nodes_name"; | const std::string ATTR_MODEL_OUT_NODES_NAME = "attr_model_out_nodes_name"; | ||||
| const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; | const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; | ||||
| @@ -957,8 +961,6 @@ const std::string ATTR_NAME_FUSION_GROUP_KEY = "_fusion_group_key"; | |||||
| const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; | const std::string ATTR_NAME_L1_FUSION_GROUP_KEY = "_l1_fusion_group_key"; | ||||
| const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; | const std::string ATTR_NAME_FUSION_VIRTUAL_OP = "_fusion_virtual_op"; | ||||
| const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; | const std::string ATTR_NAME_FUSION_GROUP_TYPE = "_fusion_group_type"; | ||||
| const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST = "_input_memory_type"; | |||||
| const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST = "_output_memory_type"; | |||||
| const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR = "_l1_fusion_extend_content"; | const std::string ATTR_NAME_L1_FUSION_EXTEND_PTR = "_l1_fusion_extend_content"; | ||||
| const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE = "_tensor_actual_size"; | const std::string ATTR_NAME_GET_TENSOR_ACTUAL_SIZE = "_tensor_actual_size"; | ||||
| const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1_fuison"; | const std::string ATTR_NAME_OUTPUT_OFFSET_FOR_L1_FUSION = "_output_offset_for_l1_fuison"; | ||||
| @@ -980,6 +982,12 @@ const std::string ATTR_NAME_OP_COMPILE_STRATEGY = "_op_compile_strategy"; | |||||
| const std::string ATTR_NAME_TBE_KERNEL_NAME = "_tbe_kernel_name"; | const std::string ATTR_NAME_TBE_KERNEL_NAME = "_tbe_kernel_name"; | ||||
| const std::string ATTR_NAME_TBE_KERNEL_BUFFER = "_tbe_kernel_buffer"; | const std::string ATTR_NAME_TBE_KERNEL_BUFFER = "_tbe_kernel_buffer"; | ||||
| // used for memory allocate | |||||
| const std::string ATTR_NAME_INPUT_MEM_TYPE_LIST = "_input_memory_type"; | |||||
| const std::string ATTR_NAME_OUTPUT_MEM_TYPE_LIST = "_output_memory_type"; | |||||
| const std::string ATTR_NAME_WORKSPACE_TYPE_LIST = "_workspace_type"; | |||||
| const std::string ATTR_NAME_TENSOR_MEM_TYPE = "_tensor_memory_type"; | |||||
| // Op debug attrs | // Op debug attrs | ||||
| const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; | const std::string ATTR_OP_DEBUG_FLAG = "_op_debug_flag"; | ||||
| const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode"; | const std::string ATTR_OP_DEBUG_MODE = "_op_debug_mode"; | ||||
| @@ -1080,6 +1088,9 @@ const std::string ATTR_VARIABLE_PLACEMENT = "_variable_placement"; | |||||
| const std::string ATTR_INPUT_MEMORY_TYPE = "_input_memory_type"; | const std::string ATTR_INPUT_MEMORY_TYPE = "_input_memory_type"; | ||||
| const std::string ATTR_OUTPUT_MEMORY_TYPE = "_output_memory_type"; | const std::string ATTR_OUTPUT_MEMORY_TYPE = "_output_memory_type"; | ||||
| // stage | |||||
| const std::string ATTR_STAGE_LEVEL = "_stage_level"; | |||||
| // input_output_offset | // input_output_offset | ||||
| const std::string ATTR_ZERO_COPY_BASIC_OFFSET = "_zero_copy_basic_offset"; | const std::string ATTR_ZERO_COPY_BASIC_OFFSET = "_zero_copy_basic_offset"; | ||||
| const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET = "_zero_copy_relative_offset"; | const std::string ATTR_ZERO_COPY_RELATIVE_OFFSET = "_zero_copy_relative_offset"; | ||||
| @@ -33,8 +33,7 @@ using std::vector; | |||||
| namespace ge { | namespace ge { | ||||
| NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } | NamedAttrs::NamedAttrs() { named_attrs_.InitDefault(); } | ||||
| NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) | |||||
| : named_attrs_(owner, proto_msg) {} // lint !e1744 | |||||
| NamedAttrs::NamedAttrs(const ProtoMsgOwner &owner, proto::NamedAttrs *proto_msg) : named_attrs_(owner, proto_msg) {} | |||||
| void NamedAttrs::SetName(const std::string &name) { | void NamedAttrs::SetName(const std::string &name) { | ||||
| auto proto_msg = named_attrs_.GetProtoMsg(); | auto proto_msg = named_attrs_.GetProtoMsg(); | ||||
| @@ -239,7 +238,7 @@ ATTR_VALUE_SET_GET_IMP(GeAttrValue::STR) | |||||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::STR>) | ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::STR>) | ||||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT) | ATTR_VALUE_SET_GET_IMP(GeAttrValue::INT) | ||||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::INT>) | ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::INT>) | ||||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) // lint !e524 | |||||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::FLOAT) | |||||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::FLOAT>) | ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::FLOAT>) | ||||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL) | ATTR_VALUE_SET_GET_IMP(GeAttrValue::BOOL) | ||||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BOOL>) | ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BOOL>) | ||||
| @@ -253,11 +252,9 @@ ATTR_VALUE_SET_GET_IMP(GeAttrValue::BYTES) | |||||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BYTES>) | ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::BYTES>) | ||||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS) | ATTR_VALUE_SET_GET_IMP(GeAttrValue::NAMED_ATTRS) | ||||
| ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::NAMED_ATTRS>) | ATTR_VALUE_SET_GET_IMP(vector<GeAttrValue::NAMED_ATTRS>) | ||||
| /*lint -e665*/ | |||||
| ATTR_VALUE_SET_GET_IMP(vector<vector<int64_t>>) | ATTR_VALUE_SET_GET_IMP(vector<vector<int64_t>>) | ||||
| /*lint +e665*/ | |||||
| ATTR_VALUE_SET_GET_IMP(vector<DataType>) // lint !e665 | |||||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) // lint !e665 | |||||
| ATTR_VALUE_SET_GET_IMP(vector<DataType>) | |||||
| ATTR_VALUE_SET_GET_IMP(GeAttrValue::DATA_TYPE) | |||||
| #undef ATTR_VALUE_SET_GET_IMP | #undef ATTR_VALUE_SET_GET_IMP | ||||
| @@ -785,14 +782,14 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||||
| if (graph_def == nullptr) { | if (graph_def == nullptr) { | ||||
| GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); | GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); | ||||
| graph_def = nullptr; | graph_def = nullptr; | ||||
| return false; // lint !e665 | |||||
| return false; | |||||
| } else { | } else { | ||||
| ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
| imp.SetProtobufOwner(graph_def); | imp.SetProtobufOwner(graph_def); | ||||
| if (!imp.UnserializeGraph(graph, *graph_def)) { | if (!imp.UnserializeGraph(graph, *graph_def)) { | ||||
| GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); | GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); | ||||
| return false; | return false; | ||||
| } // lint !e514 | |||||
| } | |||||
| value = graph; | value = graph; | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -812,7 +809,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||||
| if (graph_def == nullptr) { | if (graph_def == nullptr) { | ||||
| GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); | GELOGE(GRAPH_FAILED, "proto::GraphDef make shared failed"); | ||||
| graph_def = nullptr; | graph_def = nullptr; | ||||
| return false; // lint !e665 | |||||
| return false; | |||||
| } else { | } else { | ||||
| ComputeGraphPtr graph = nullptr; | ComputeGraphPtr graph = nullptr; | ||||
| ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
| @@ -820,7 +817,7 @@ bool GeAttrValueImp::GetValue(const proto::AttrDef &proto_attr_val, const ProtoM | |||||
| if (!imp.UnserializeGraph(graph, *graph_def)) { | if (!imp.UnserializeGraph(graph, *graph_def)) { | ||||
| GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); | GELOGE(GRAPH_FAILED, "UnserializeGraph Failed"); | ||||
| return false; | return false; | ||||
| } // lint !e514 | |||||
| } | |||||
| value.push_back(graph); | value.push_back(graph); | ||||
| } | } | ||||
| } | } | ||||
| @@ -972,9 +969,7 @@ ATTR_UTILS_SET_IMP(Tensor, GeTensor) | |||||
| ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS) | ATTR_UTILS_SET_GET_IMP(NamedAttrs, GeAttrValue::NAMED_ATTRS) | ||||
| ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) | ATTR_UTILS_SET_GET_IMP(Bytes, Buffer) | ||||
| ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) | ATTR_UTILS_SET_GET_IMP(Graph, ComputeGraphPtr) | ||||
| /*lint -e665*/ | |||||
| ATTR_UTILS_SET_GET_IMP(ListListInt, vector<vector<int64_t>>) | ATTR_UTILS_SET_GET_IMP(ListListInt, vector<vector<int64_t>>) | ||||
| /*lint +e665*/ | |||||
| ATTR_UTILS_SET_GET_IMP(ListInt, vector<int64_t>) | ATTR_UTILS_SET_GET_IMP(ListInt, vector<int64_t>) | ||||
| ATTR_UTILS_SET_IMP(ListInt, vector<int32_t>) | ATTR_UTILS_SET_IMP(ListInt, vector<int32_t>) | ||||
| @@ -989,8 +984,8 @@ ATTR_UTILS_SET_IMP(ListTensor, vector<GeTensor>) | |||||
| ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NAMED_ATTRS>) | ATTR_UTILS_SET_GET_IMP(ListNamedAttrs, vector<GeAttrValue::NAMED_ATTRS>) | ||||
| ATTR_UTILS_SET_GET_IMP(ListBytes, vector<Buffer>) | ATTR_UTILS_SET_GET_IMP(ListBytes, vector<Buffer>) | ||||
| ATTR_UTILS_SET_GET_IMP(ListGraph, vector<ComputeGraphPtr>) | ATTR_UTILS_SET_GET_IMP(ListGraph, vector<ComputeGraphPtr>) | ||||
| ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) // lint !e665 | |||||
| ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) // lint !e665 | |||||
| ATTR_UTILS_SET_GET_IMP(ListDataType, vector<ge::DataType>) | |||||
| ATTR_UTILS_SET_GET_IMP(DataType, ge::DataType) | |||||
| bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name, | bool AttrUtils::SetListTensor(AttrHolderAdapter &&obj, const string &name, | ||||
| std::initializer_list<ConstGeTensorPtr> &&value) { | std::initializer_list<ConstGeTensorPtr> &&value) { | ||||
| @@ -1159,7 +1154,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool AttrUtils::GetListOpDesc(Con | |||||
| } | } | ||||
| for (const auto &item : bytes_vals) { | for (const auto &item : bytes_vals) { | ||||
| ModelSerialize serialize; | ModelSerialize serialize; | ||||
| auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); // lint !e732 | |||||
| auto op_desc = serialize.UnserializeOpDesc(item.GetData(), item.GetSize()); | |||||
| value.push_back(op_desc); | value.push_back(op_desc); | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -1211,7 +1206,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY OpDescPtr AttrUtils::CloneOpDesc( | |||||
| op_def = ComGraphMakeShared<proto::OpDef>(); | op_def = ComGraphMakeShared<proto::OpDef>(); | ||||
| if (op_def == nullptr) { | if (op_def == nullptr) { | ||||
| GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); | GELOGE(GRAPH_FAILED, "proto::OpDef make shared failed"); | ||||
| return nullptr; // lint !e665 | |||||
| return nullptr; | |||||
| } | } | ||||
| ModelSerializeImp imp; | ModelSerializeImp imp; | ||||
| (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); | (void)imp.SerializeOpDesc(org_op_desc, op_def.get()); | ||||
| @@ -0,0 +1,857 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/gnode.h" | |||||
| #include <utility> | |||||
| #include "debug/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/anchor.h" | |||||
| #include "graph/node.h" | |||||
| #include "graph/utils/node_adapter.h" | |||||
| #include "graph/utils/tensor_adapter.h" | |||||
| #include <graph/utils/graph_utils.h> | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "utils/node_utils.h" | |||||
| #include "utils/op_desc_utils.h" | |||||
| namespace ge { | |||||
| class NodeImpl { | |||||
| public: | |||||
| NodeImpl() = default; | |||||
| ~NodeImpl() = default; | |||||
| NodeImpl(NodeImpl &) = delete; | |||||
| NodeImpl &operator=(const NodeImpl &) = delete; | |||||
| std::weak_ptr<Node> node_ptr_; | |||||
| }; | |||||
| NodePtr NodeAdapter::GNode2Node(const ge::GNode &graph_node) { | |||||
| if (graph_node.impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GNode2Node: gnode impl is nullptr."); | |||||
| return nullptr; | |||||
| } | |||||
| return graph_node.impl_->node_ptr_.lock(); | |||||
| } | |||||
| GNode NodeAdapter::Node2GNode(const ge::NodePtr &node) { | |||||
| if (node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Node2GNode: node is nullptr"); | |||||
| return GNode(); | |||||
| } | |||||
| GNode graph_node; | |||||
| if (graph_node.impl_ == nullptr) { | |||||
| GELOGW("Node2GNode: gnode impl is nullptr, node[%s].", node->GetName().c_str()); | |||||
| return graph_node; | |||||
| } | |||||
| graph_node.impl_->node_ptr_ = node; | |||||
| return graph_node; | |||||
| } | |||||
| GNodePtr NodeAdapter::Node2GNodePtr(const ge::NodePtr &node) { | |||||
| if (node == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Node2GNodePtr: node is nullptr"); | |||||
| return nullptr; | |||||
| } | |||||
| GNodePtr gnode = std::shared_ptr<GNode>(new (std::nothrow) GNode()); | |||||
| if (gnode == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Node2GNodePtr: gnode is nullptr, node[%s].", node->GetName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| if (gnode->impl_ == nullptr) { | |||||
| GELOGW("Node2GNode: gnode impl is nullptr, node[%s].", node->GetName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| gnode->impl_->node_ptr_ = node; | |||||
| return gnode; | |||||
| } | |||||
| GNode::GNode() { impl_ = ComGraphMakeShared<NodeImpl>(); } | |||||
| graphStatus GNode::GetType(ge::AscendString &type) const { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetType: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetType: the shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::string node_type = node_ptr->GetType(); | |||||
| AscendString ascend_type(node_type.c_str()); | |||||
| type = ascend_type; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus GNode::GetName(ge::AscendString &name) const { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetName: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetName: the shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::string node_name = node_ptr->GetName(); | |||||
| AscendString ascend_name(node_name.c_str()); | |||||
| name = ascend_name; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| std::pair<GNodePtr, int32_t> GNode::GetInDataNodesAndPortIndexs(const int32_t index) const { | |||||
| pair<GNodePtr, int32_t> gnode_idx = {nullptr, 0xFF}; | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr."); | |||||
| return gnode_idx; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid."); | |||||
| return gnode_idx; | |||||
| } | |||||
| auto in_anchor = node_ptr->GetInDataAnchor(index); | |||||
| if (in_anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Failed to get in data node of index[%d] from node[%s], the anchor does not exist", index, | |||||
| node_ptr->GetName().c_str()); | |||||
| return gnode_idx; | |||||
| } | |||||
| auto out_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (out_anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Failed to get in data node of index[%d] from node [%s], the data input does not exist", index, | |||||
| node_ptr->GetName().c_str()); | |||||
| return gnode_idx; | |||||
| } | |||||
| NodePtr peer_node_ptr = out_anchor->GetOwnerNode(); | |||||
| GNodePtr gnode = NodeAdapter::Node2GNodePtr(peer_node_ptr); | |||||
| if (gnode == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Peer node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); | |||||
| return gnode_idx; | |||||
| } | |||||
| return {gnode, out_anchor->GetIdx()}; | |||||
| } | |||||
| std::vector<GNodePtr> GNode::GetInControlNodes() const { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr."); | |||||
| return {}; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid."); | |||||
| return {}; | |||||
| } | |||||
| std::vector<GNodePtr> gnodes; | |||||
| auto in_control_nodes = node_ptr->GetInControlNodes(); | |||||
| for (auto &in_control_node : in_control_nodes) { | |||||
| GNodePtr gnode = NodeAdapter::Node2GNodePtr(in_control_node); | |||||
| if (gnode == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "In control_node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); | |||||
| return {}; | |||||
| } | |||||
| gnodes.emplace_back(gnode); | |||||
| } | |||||
| return gnodes; | |||||
| } | |||||
| std::vector<std::pair<GNodePtr, int32_t>> GNode::GetOutDataNodesAndPortIndexs(const int32_t index) const { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Gnode: node impl is nullptr."); | |||||
| return {}; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Gnode: the shared ptr is not valid."); | |||||
| return {}; | |||||
| } | |||||
| auto out_anchor = node_ptr->GetOutDataAnchor(index); | |||||
| if (out_anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Failed to get out data node of index %d from node %s, the anchor does not exists", index, | |||||
| node_ptr->GetName().c_str()); | |||||
| return {}; | |||||
| } | |||||
| vector<std::pair<GNodePtr, int32_t>> gnode_index; | |||||
| auto in_data_anchors = out_anchor->GetPeerInDataAnchors(); | |||||
| for (auto &in_data_anchor : in_data_anchors) { | |||||
| if (in_data_anchor == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "In data anchor of node[%s] is nullptr.", node_ptr->GetName().c_str()); | |||||
| return {}; | |||||
| } | |||||
| NodePtr peer_node_ptr = in_data_anchor->GetOwnerNode(); | |||||
| GNodePtr gnode = NodeAdapter::Node2GNodePtr(peer_node_ptr); | |||||
| if (gnode == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Peer node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); | |||||
| return {}; | |||||
| } | |||||
| gnode_index.emplace_back(std::pair<GNodePtr, int32_t>(gnode, in_data_anchor->GetIdx())); | |||||
| } | |||||
| return gnode_index; | |||||
| } | |||||
| std::vector<GNodePtr> GNode::GetOutControlNodes() const { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetOutControlNodes: node impl is nullptr."); | |||||
| return {}; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetOutControlNodes: the node shared ptr is not valid."); | |||||
| return {}; | |||||
| } | |||||
| std::vector<GNodePtr> gnodes; | |||||
| auto out_control_nodes = node_ptr->GetOutControlNodes(); | |||||
| for (auto &out_control_node : out_control_nodes) { | |||||
| GNodePtr gnode = NodeAdapter::Node2GNodePtr(out_control_node); | |||||
| if (gnode == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "In control_node of node[%s] to gnode faild.", node_ptr->GetName().c_str()); | |||||
| return {}; | |||||
| } | |||||
| gnodes.emplace_back(gnode); | |||||
| } | |||||
| return gnodes; | |||||
| } | |||||
| graphStatus GNode::GetInputConstData(const int32_t index, Tensor &data) const { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetInputConstData: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetInputConstData: the node shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| NodePtr input_data_node = NodeUtils::GetInDataNodeByIndex(*node_ptr, index); | |||||
| bool is_const = NodeUtils::IsConst(*input_data_node); | |||||
| if (!is_const) { | |||||
| GELOGE(GRAPH_NODE_WITHOUT_CONST_INPUT, "Node[%s] has no const input.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_NODE_WITHOUT_CONST_INPUT; | |||||
| } | |||||
| Operator const_op = OpDescUtils::CreateOperatorFromNode(input_data_node); | |||||
| if (const_op.GetAttr(ATTR_NAME_WEIGHTS, data) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Input data node[%s] of node[%s] get data failed.", input_data_node->GetName().c_str(), | |||||
| node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus GNode::GetInputIndexByName(const ge::AscendString &name, int32_t &index) { | |||||
| const char *ascend_name = name.GetString(); | |||||
| if (ascend_name == nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "GetInputIndexByName: ascend string error."); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetInputIndexByName: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetInputIndexByName: the node shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| OpDescPtr op_desc = node_ptr->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::string node_name = ascend_name; | |||||
| index = op_desc->GetInputIndexByName(node_name); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus GNode::GetOutputIndexByName(const ge::AscendString &name, int32_t &index) { | |||||
| const char *ascend_name = name.GetString(); | |||||
| if (ascend_name == nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "GetOutputIndexByName: ascend string error."); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetOutputIndexByName: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetOutputIndexByName: the node shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| OpDescPtr op_desc = node_ptr->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::string node_name = ascend_name; | |||||
| index = op_desc->GetOutputIndexByName(node_name); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| size_t GNode::GetInputsSize() const { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetInputsSize: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetInputsSize: the node shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| OpDescPtr op_desc = node_ptr->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return op_desc->GetInputsSize(); | |||||
| } | |||||
| size_t GNode::GetOutputsSize() const { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetOutputsSize: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetOutputsSize: the shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| OpDescPtr op_desc = node_ptr->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return op_desc->GetOutputsSize(); | |||||
| } | |||||
| graphStatus GNode::GetInputDesc(const int32_t index, TensorDesc &tensor_desc) const { | |||||
| if (index < 0) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "GetInputDesc: index[%d] cannot be less than zero.", index); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetInputDesc: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetInputDesc: the node shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| OpDescPtr op_desc = node_ptr->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| ConstGeTensorDescPtr ge_tensor_desc = op_desc->GetInputDescPtr(static_cast<uint32_t>(index)); | |||||
| if (ge_tensor_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Get tensor desc of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus GNode::UpdateInputDesc(const int32_t index, const TensorDesc &tensor_desc) { | |||||
| if (index < 0) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "UpdateInputDesc: index[%d] cannot be less than zero.", index); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "UpdateInputDesc: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "UpdateInputDesc: the node shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| OpDescPtr op_desc = node_ptr->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| GeTensorDesc ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc); | |||||
| if (op_desc->UpdateInputDesc(static_cast<uint32_t>(index), ge_tensor_desc) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Update input desc of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus GNode::GetOutputDesc(const int32_t index, TensorDesc &tensor_desc) const { | |||||
| if (index < 0) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "GetOutputDesc: index[%d] cannot be less than zero.", index); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetOutputDesc: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetOutputDesc: the node shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| OpDescPtr op_desc = node_ptr->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| ConstGeTensorDescPtr ge_tensor_desc = op_desc->GetOutputDescPtr(static_cast<uint32_t>(index)); | |||||
| if (ge_tensor_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Get tensor desc of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| tensor_desc = TensorAdapter::GeTensorDesc2TensorDesc(*ge_tensor_desc); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus GNode::UpdateOutputDesc(const int32_t index, const TensorDesc &tensor_desc) { | |||||
| if (index < 0) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "Gnode: index[%d] cannot be less than zero.", index); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "UpdateOutputDesc: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "UpdateOutputDesc: the shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| OpDescPtr op_desc = node_ptr->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| GeTensorDesc ge_tensor_desc = TensorAdapter::TensorDesc2GeTensorDesc(tensor_desc); | |||||
| if (op_desc->UpdateOutputDesc(static_cast<uint32_t>(index), ge_tensor_desc) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Update input desc of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| #define NODE_ATTR_GET_IMP(ArgType) \ | |||||
| graphStatus GNode::GetAttr(const ge::AscendString &name, ArgType &attr_value) const { \ | |||||
| const char *ascend_name = name.GetString(); \ | |||||
| if (ascend_name == nullptr) { \ | |||||
| GELOGE(GRAPH_PARAM_INVALID, "GetAttr: ascend string error."); \ | |||||
| return GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| \ | |||||
| if (impl_ == nullptr) { \ | |||||
| GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr."); \ | |||||
| return GRAPH_FAILED; \ | |||||
| } \ | |||||
| \ | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); \ | |||||
| if (node_ptr == nullptr) { \ | |||||
| GELOGE(GRAPH_FAILED, "GetAttr: the shared ptr is not valid."); \ | |||||
| return GRAPH_FAILED; \ | |||||
| } \ | |||||
| \ | |||||
| std::string node_name = ascend_name; \ | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); \ | |||||
| if (op.GetAttr(node_name, attr_value) != GRAPH_SUCCESS) { \ | |||||
| GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str()); \ | |||||
| return GRAPH_FAILED; \ | |||||
| } \ | |||||
| \ | |||||
| return GRAPH_SUCCESS; \ | |||||
| } | |||||
| #define NODE_ATTR_SET_IMP(ArgType) \ | |||||
| graphStatus GNode::SetAttr(const ge::AscendString &name, ArgType &attr_value) const { \ | |||||
| const char *ascend_name = name.GetString(); \ | |||||
| if (ascend_name == nullptr) { \ | |||||
| GELOGE(GRAPH_PARAM_INVALID, "SetAttr: ascend string error."); \ | |||||
| return GRAPH_PARAM_INVALID; \ | |||||
| } \ | |||||
| \ | |||||
| if (impl_ == nullptr) { \ | |||||
| GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); \ | |||||
| return GRAPH_FAILED; \ | |||||
| } \ | |||||
| \ | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); \ | |||||
| if (node_ptr == nullptr) { \ | |||||
| GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); \ | |||||
| return GRAPH_FAILED; \ | |||||
| } \ | |||||
| \ | |||||
| std::string node_name = ascend_name; \ | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); \ | |||||
| (void)op.SetAttr(node_name, attr_value); \ | |||||
| return GRAPH_SUCCESS; \ | |||||
| } | |||||
| NODE_ATTR_GET_IMP(int64_t) | |||||
| NODE_ATTR_GET_IMP(int32_t) | |||||
| NODE_ATTR_GET_IMP(uint32_t) | |||||
| NODE_ATTR_GET_IMP(float) | |||||
| NODE_ATTR_GET_IMP(bool) | |||||
| NODE_ATTR_GET_IMP(Tensor) | |||||
| NODE_ATTR_GET_IMP(std::vector<int64_t>) | |||||
| NODE_ATTR_GET_IMP(std::vector<int32_t>) | |||||
| NODE_ATTR_GET_IMP(std::vector<uint32_t>) | |||||
| NODE_ATTR_GET_IMP(std::vector<float>) | |||||
| NODE_ATTR_GET_IMP(std::vector<bool>) | |||||
| NODE_ATTR_GET_IMP(std::vector<Tensor>) | |||||
| NODE_ATTR_GET_IMP(OpBytes) | |||||
| NODE_ATTR_GET_IMP(std::vector<std::vector<int64_t>>) | |||||
| NODE_ATTR_GET_IMP(std::vector<ge::DataType>) | |||||
| NODE_ATTR_GET_IMP(ge::DataType) | |||||
| NODE_ATTR_GET_IMP(AttrValue) | |||||
| NODE_ATTR_SET_IMP(int64_t) | |||||
| NODE_ATTR_SET_IMP(int32_t) | |||||
| NODE_ATTR_SET_IMP(uint32_t) | |||||
| NODE_ATTR_SET_IMP(float) | |||||
| NODE_ATTR_SET_IMP(bool) | |||||
| NODE_ATTR_SET_IMP(Tensor) | |||||
| NODE_ATTR_SET_IMP(std::vector<int64_t>) | |||||
| NODE_ATTR_SET_IMP(std::vector<int32_t>) | |||||
| NODE_ATTR_SET_IMP(std::vector<uint32_t>) | |||||
| NODE_ATTR_SET_IMP(std::vector<float>) | |||||
| NODE_ATTR_SET_IMP(std::vector<bool>) | |||||
| NODE_ATTR_SET_IMP(std::vector<Tensor>) | |||||
| NODE_ATTR_SET_IMP(OpBytes) | |||||
| NODE_ATTR_SET_IMP(std::vector<std::vector<int64_t>>) | |||||
| NODE_ATTR_SET_IMP(std::vector<ge::DataType>) | |||||
| NODE_ATTR_SET_IMP(ge::DataType) | |||||
| graphStatus GNode::SetAttr(const ge::AscendString &name, AttrValue &attr_value) const { | |||||
| const char *ascend_name = name.GetString(); | |||||
| if (ascend_name == nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "SetAttr: ascend string error."); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::string node_name = ascend_name; | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); | |||||
| (void)op.SetAttr(node_name, std::move(attr_value)); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus GNode::SetAttr(const ge::AscendString &name, ge::AscendString &attr_value) const { | |||||
| const char *ascend_name = name.GetString(); | |||||
| if (ascend_name == nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "SetAttr: name ascend string error."); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| const char *ascend_attr_value = attr_value.GetString(); | |||||
| if (ascend_attr_value == nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "SetAttr: attr value ascend string error."); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::string node_name = ascend_name; | |||||
| std::string node_attr_value = ascend_attr_value; | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); | |||||
| (void)op.SetAttr(node_name, node_attr_value); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus GNode::SetAttr(const ge::AscendString &name, std::vector<ge::AscendString> &attr_values) const { | |||||
| const char *ascend_name = name.GetString(); | |||||
| if (ascend_name == nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "SetAttr: name ascend string error."); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| for (auto &attr_val : attr_values) { | |||||
| const char *ascend_attr_value = attr_val.GetString(); | |||||
| if (ascend_attr_value == nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "SetAttr: attr val error."); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| } | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "SetAttr: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "SetAttr: the shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| vector<std::string> node_attr_vals; | |||||
| for (auto attr_val : attr_values) { | |||||
| if (attr_val.GetString() != nullptr) { | |||||
| std::string node_attr_val = attr_val.GetString(); | |||||
| node_attr_vals.emplace_back(node_attr_val); | |||||
| } | |||||
| } | |||||
| std::string node_name = ascend_name; | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); | |||||
| (void)op.SetAttr(node_name, node_attr_vals); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus GNode::GetAttr(const ge::AscendString &name, ge::AscendString &attr_value) const { | |||||
| const char *ascend_name = name.GetString(); | |||||
| if (ascend_name == nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "GetAttr: name ascend string error."); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetAttr: the node shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::string node_name = ascend_name; | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); | |||||
| std::string op_name; | |||||
| if (op.GetAttr(node_name, op_name) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| ge::AscendString attr_value_get(op_name.c_str()); | |||||
| attr_value = attr_value_get; | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus GNode::GetAttr(const ge::AscendString &name, std::vector<ge::AscendString> &attr_values) const { | |||||
| const char *ascend_name = name.GetString(); | |||||
| if (ascend_name == nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "GetAttr: name ascend string error."); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetAttr: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetAttr: the node shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::string node_name = ascend_name; | |||||
| Operator op = OpDescUtils::CreateOperatorFromNode(node_ptr); | |||||
| vector<std::string> attr_names; | |||||
| if (op.GetAttr(node_name, attr_names) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Get attr of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| for (auto &attr_name : attr_names) { | |||||
| AscendString ascend_attr_name(attr_name.c_str()); | |||||
| attr_values.push_back(ascend_attr_name); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| bool GNode::HasAttr(const ge::AscendString &name) { | |||||
| const char *ascend_name = name.GetString(); | |||||
| if (ascend_name == nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "HasAttr: ascend string error."); | |||||
| return false; | |||||
| } | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "HasAttr: node impl is nullptr."); | |||||
| return false; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "HasAttr: the node shared ptr is not valid."); | |||||
| return false; | |||||
| } | |||||
| OpDescPtr op_desc = node_ptr->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Get op desc of node[%s] failed.", node_ptr->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| std::string attr_name = ascend_name; | |||||
| if (!op_desc->HasAttr(attr_name)) { | |||||
| GELOGE(GRAPH_FAILED, "Node[%s] has no attr name[%s]", node_ptr->GetName().c_str(), attr_name.c_str()); | |||||
| return false; | |||||
| } | |||||
| return true; | |||||
| } | |||||
| graphStatus GNode::GetSubgraph(uint32_t index, GraphPtr graph) const { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetSubgraph: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetSubgraph: the node shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| ComputeGraphPtr compute_graph_ptr = NodeUtils::GetSubgraph(*node_ptr, index); | |||||
| if (compute_graph_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetSubgraph: get subgraph[%u] failed form node[%s].", index, node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| Graph create_graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph_ptr); | |||||
| graph = std::make_shared<Graph>(create_graph); | |||||
| if (graph == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetSubgraph: graph make shared failed form node[%s].", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus GNode::GetALLSubgraphs(std::vector<GraphPtr> graph_list) const { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetALLSubgraphs: node impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::shared_ptr<Node> node_ptr = impl_->node_ptr_.lock(); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetALLSubgraphs: the node shared ptr is not valid."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::vector<ComputeGraphPtr> sub_graphs = NodeUtils::GetAllSubgraphs(*node_ptr); | |||||
| if (sub_graphs.empty()) { | |||||
| GELOGE(GRAPH_FAILED, "GetALLSubgraphs: get all subgraphs failed form node[%s].", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| for (auto &sub_graph : sub_graphs) { | |||||
| if (sub_graph == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Get subgraph failed form node[%s].", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| Graph create_graph = GraphUtils::CreateGraphFromComputeGraph(sub_graph); | |||||
| GraphPtr graph = std::make_shared<Graph>(create_graph); | |||||
| if (graph == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Subgraph make shared failed form node[%s].", node_ptr->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| graph_list.emplace_back(graph); | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "external/graph/graph.h" | #include "external/graph/graph.h" | ||||
| #include <cstring> | |||||
| #include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| @@ -22,6 +23,7 @@ | |||||
| #include "graph/model.h" | #include "graph/model.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
| #include "graph/utils/node_adapter.h" | |||||
| using std::map; | using std::map; | ||||
| using std::pair; | using std::pair; | ||||
| @@ -242,6 +244,8 @@ class GraphImpl { | |||||
| const std::string &GetName() const { return name_; } | const std::string &GetName() const { return name_; } | ||||
| ComputeGraphPtr GetComputeGraph() const { return compute_graph_; } | |||||
| private: | private: | ||||
| std::string name_; | std::string name_; | ||||
| std::string output_name_; | std::string output_name_; | ||||
| @@ -261,7 +265,7 @@ graphStatus Graph::AddOp(const ge::Operator &op) { | |||||
| return impl_->AddOp(op); | return impl_->AddOp(op); | ||||
| } | } | ||||
| graphStatus Graph::GetAllOpName(std::vector<string> &op_name) const { | |||||
| graphStatus Graph::GetAllOpName(std::vector<std::string> &op_name) const { | |||||
| GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, | GE_CHK_BOOL_EXEC(impl_ != nullptr, return GRAPH_FAILED, | ||||
| "GetAllOpName failed: graph can not be used, impl is nullptr."); | "GetAllOpName failed: graph can not be used, impl is nullptr."); | ||||
| return impl_->GetAllOpName(op_name); | return impl_->GetAllOpName(op_name); | ||||
| @@ -335,6 +339,235 @@ void Graph::SetNeedIteration(bool need_iteration) { | |||||
| impl_->SetNeedIteration(need_iteration); | impl_->SetNeedIteration(need_iteration); | ||||
| } | } | ||||
| std::vector<GNode> Graph::GetAllNodes() const { | |||||
| std::vector<GNode> graph_nodes; | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetAllNodes: graph can not be used, impl is nullptr."); | |||||
| return graph_nodes; | |||||
| } | |||||
| ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph(); | |||||
| if (compute_graph_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetAllNodes: compute graph ptr is nullptr."); | |||||
| return graph_nodes; | |||||
| } | |||||
| for (auto &node : compute_graph_ptr->GetAllNodes()) { | |||||
| GNode gnode = NodeAdapter::Node2GNode(node); | |||||
| graph_nodes.emplace_back(gnode); | |||||
| } | |||||
| return graph_nodes; | |||||
| } | |||||
| std::vector<GNode> Graph::GetDirectNode() const { | |||||
| std::vector<GNode> graph_nodes; | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetDirectNode: graph can not be used, impl is nullptr."); | |||||
| return graph_nodes; | |||||
| } | |||||
| ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph(); | |||||
| if (compute_graph_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "GetDirectNode: compute graph ptr is nullptr."); | |||||
| return graph_nodes; | |||||
| } | |||||
| for (auto &node : compute_graph_ptr->GetDirectNode()) { | |||||
| GNode gnode = NodeAdapter::Node2GNode(node); | |||||
| graph_nodes.emplace_back(gnode); | |||||
| } | |||||
| return graph_nodes; | |||||
| } | |||||
| graphStatus Graph::RemoveNode(GNode &node) { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "RemoveNode: graph can not be used, impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| NodePtr node_ptr = NodeAdapter::GNode2Node(node); | |||||
| if (node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "RemoveNode: gnode to node failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph(); | |||||
| if (compute_graph_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "RemoveNde: compute graph ptr is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (compute_graph_ptr->RemoveNode(node_ptr) != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "RemoveNde: remove node failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus Graph::RemoveEdge(GNode &src_node, const int32_t src_port_index, GNode &dst_node, | |||||
| const int32_t dst_port_index) { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "RemoveEdge: graph can not be used, impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if ((src_port_index == -1) && (dst_port_index != -1)) { | |||||
| GELOGE(GRAPH_FAILED, "RemoveEdge:src control anchor link to dst data anchor not exists."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node); | |||||
| if (src_node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "RemoveEdge: src gnode to node failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node); | |||||
| if (dst_node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "RemoveEdge: dst gnode to node failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| graphStatus res = GRAPH_FAILED; | |||||
| if ((src_port_index == -1) && (dst_port_index == -1)) { | |||||
| res = GraphUtils::RemoveEdge(src_node_ptr->GetOutControlAnchor(), dst_node_ptr->GetInControlAnchor()); | |||||
| if (res != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "RemoveEdge: remove control edge failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| if (src_port_index != -1 && dst_port_index == -1) { | |||||
| res = GraphUtils::RemoveEdge(src_node_ptr->GetOutDataAnchor(src_port_index), dst_node_ptr->GetInControlAnchor()); | |||||
| if (res != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "RemoveEdge: remove data-control edge failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| res = GraphUtils::RemoveEdge(src_node_ptr->GetOutDataAnchor(src_port_index), | |||||
| dst_node_ptr->GetInDataAnchor(dst_port_index)); | |||||
| if (res != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "RemoveEdge: remove data edge failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| GNode Graph::AddNodeByOp(const Operator &op) { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "AddNodeByOp: graph can not be used, impl is nullptr."); | |||||
| return GNode(); | |||||
| } | |||||
| std::shared_ptr<ge::OpDesc> op_desc = ge::OpDescUtils::GetOpDescFromOperator(op); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "AddNodeByOp: get op desc from op[%s] failed.", op.GetName().c_str()); | |||||
| return GNode(); | |||||
| } | |||||
| ComputeGraphPtr compute_graph_ptr = impl_->GetComputeGraph(); | |||||
| if (compute_graph_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "AddNodeByOp: compute graph ptr is nullptr."); | |||||
| return GNode(); | |||||
| } | |||||
| NodePtr node_ptr = compute_graph_ptr->AddNode(op_desc); | |||||
| GNode gnode = NodeAdapter::Node2GNode(node_ptr); | |||||
| return gnode; | |||||
| } | |||||
| graphStatus Graph::AddDataEdge(GNode &src_node, const int32_t src_port_index, GNode &dst_node, | |||||
| const int32_t dst_port_index) { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "AddDataEdge: graph can not be used, impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node); | |||||
| if (src_node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "AddDataEdge: src gnode to node failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node); | |||||
| if (dst_node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "AddDataEdge: dst gnode to node failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| graphStatus res = | |||||
| GraphUtils::AddEdge(src_node_ptr->GetOutDataAnchor(src_port_index), dst_node_ptr->GetInDataAnchor(dst_port_index)); | |||||
| if (res != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "AddDataEdge: Add data edge failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus Graph::AddControlEdge(GNode &src_node, GNode &dst_node) { | |||||
| if (impl_ == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "AddControlEdge: graph can not be used, impl is nullptr."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| NodePtr src_node_ptr = NodeAdapter::GNode2Node(src_node); | |||||
| if (src_node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "AddControlEdge: src gnode to node failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| NodePtr dst_node_ptr = NodeAdapter::GNode2Node(dst_node); | |||||
| if (dst_node_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "AddControlEdge: dst gnode to node failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| graphStatus res = GraphUtils::AddEdge(src_node_ptr->GetOutControlAnchor(), dst_node_ptr->GetInControlAnchor()); | |||||
| if (res != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "AddControlEdge: Add control edge failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| GraphPtr Graph::ConstructFromInputs(const std::vector<Operator> &inputs, const ge::AscendString &name) { | |||||
| const char *ascend_name = name.GetString(); | |||||
| if (ascend_name == nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "ConstructFromInputs: ascend string error."); | |||||
| return nullptr; | |||||
| } | |||||
| if (inputs.empty()) { | |||||
| GELOGE(GRAPH_FAILED, "ConstructFromInputs: inputs size can not be 0."); | |||||
| return nullptr; | |||||
| } | |||||
| std::string graph_name = ascend_name; | |||||
| ComputeGraphPtr compute_graph = GraphUtils::CreateGraphFromOperator(graph_name, inputs); | |||||
| if (compute_graph == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "ConstructFromInputs: create compute graph failed."); | |||||
| return nullptr; | |||||
| } | |||||
| compute_graph->SetInputSize(static_cast<uint32_t>(inputs.size())); | |||||
| Graph graph = GraphUtils::CreateGraphFromComputeGraph(compute_graph); | |||||
| GraphPtr graph_ptr = std::make_shared<Graph>(graph); | |||||
| if (graph_ptr == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "ConstructFromInputs: graph make shared failed."); | |||||
| return nullptr; | |||||
| } | |||||
| return graph_ptr; | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::GetComputeGraph(const ge::Graph &graph) { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraphPtr GraphUtils::GetComputeGraph(const ge::Graph &graph) { | ||||
| GE_CHK_BOOL_EXEC_NOLOG(graph.IsValid(), return nullptr); | GE_CHK_BOOL_EXEC_NOLOG(graph.IsValid(), return nullptr); | ||||
| return graph.impl_->compute_graph_; | return graph.impl_->compute_graph_; | ||||
| @@ -14,6 +14,8 @@ COMMON_LOCAL_SRC_FILES := \ | |||||
| ./attr_value.cc \ | ./attr_value.cc \ | ||||
| ./buffer.cc \ | ./buffer.cc \ | ||||
| ./compute_graph.cc \ | ./compute_graph.cc \ | ||||
| ./ascend_string.cc \ | |||||
| ./gnode.cc \ | |||||
| ./graph.cc \ | ./graph.cc \ | ||||
| ./inference_context.cc \ | ./inference_context.cc \ | ||||
| ./shape_refiner.cc \ | ./shape_refiner.cc \ | ||||
| @@ -98,11 +100,13 @@ LOCAL_CPPFLAGS += -fexceptions | |||||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | ||||
| LOCAL_SRC_FILES := \ | LOCAL_SRC_FILES := \ | ||||
| ../../out/graph/lib64/stub/attr_value.cc \ | |||||
| ../../out/graph/lib64/stub/graph.cc \ | ../../out/graph/lib64/stub/graph.cc \ | ||||
| ../../out/graph/lib64/stub/operator.cc \ | ../../out/graph/lib64/stub/operator.cc \ | ||||
| ../../out/graph/lib64/stub/tensor.cc \ | ../../out/graph/lib64/stub/tensor.cc \ | ||||
| ../../out/graph/lib64/stub/operator_factory.cc \ | ../../out/graph/lib64/stub/operator_factory.cc \ | ||||
| ../../out/graph/lib64/stub/ascend_string.cc \ | |||||
| ../../out/graph/lib64/stub/gnode.cc \ | |||||
| LOCAL_SHARED_LIBRARIES := | LOCAL_SHARED_LIBRARIES := | ||||
| @@ -128,7 +132,8 @@ LOCAL_SRC_FILES := \ | |||||
| ../../out/graph/lib64/stub/operator_factory.cc \ | ../../out/graph/lib64/stub/operator_factory.cc \ | ||||
| ../../out/graph/lib64/stub/tensor.cc \ | ../../out/graph/lib64/stub/tensor.cc \ | ||||
| ../../out/graph/lib64/stub/inference_context.cc \ | ../../out/graph/lib64/stub/inference_context.cc \ | ||||
| ../../out/graph/lib64/stub/ascend_string.cc \ | |||||
| ../../out/graph/lib64/stub/gnode.cc \ | |||||
| LOCAL_SHARED_LIBRARIES := | LOCAL_SHARED_LIBRARIES := | ||||
| @@ -173,11 +178,13 @@ LOCAL_CFLAGS += -O2 | |||||
| LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | LOCAL_C_INCLUDES := $(COMMON_LOCAL_C_INCLUDES) | ||||
| LOCAL_SRC_FILES := \ | LOCAL_SRC_FILES := \ | ||||
| ../../out/graph/lib64/stub/attr_value.cc \ | |||||
| ../../out/graph/lib64/stub/graph.cc \ | ../../out/graph/lib64/stub/graph.cc \ | ||||
| ../../out/graph/lib64/stub/operator.cc \ | ../../out/graph/lib64/stub/operator.cc \ | ||||
| ../../out/graph/lib64/stub/tensor.cc \ | ../../out/graph/lib64/stub/tensor.cc \ | ||||
| ../../out/graph/lib64/stub/operator_factory.cc \ | ../../out/graph/lib64/stub/operator_factory.cc \ | ||||
| ../../out/graph/lib64/stub/ascend_string.cc \ | |||||
| ../../out/graph/lib64/stub/gnode.cc \ | |||||
| LOCAL_SHARED_LIBRARIES := | LOCAL_SHARED_LIBRARIES := | ||||
| @@ -206,6 +213,8 @@ LOCAL_SRC_FILES := \ | |||||
| ../../out/graph/lib64/stub/operator_factory.cc \ | ../../out/graph/lib64/stub/operator_factory.cc \ | ||||
| ../../out/graph/lib64/stub/tensor.cc \ | ../../out/graph/lib64/stub/tensor.cc \ | ||||
| ../../out/graph/lib64/stub/inference_context.cc \ | ../../out/graph/lib64/stub/inference_context.cc \ | ||||
| ../../out/graph/lib64/stub/ascend_string.cc \ | |||||
| ../../out/graph/lib64/stub/gnode.cc \ | |||||
| LOCAL_SHARED_LIBRARIES := | LOCAL_SHARED_LIBRARIES := | ||||
| @@ -47,6 +47,7 @@ const int ACCESS_PERMISSION_BITS = 0400; | |||||
| namespace ge { | namespace ge { | ||||
| void Model::Init() { | void Model::Init() { | ||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0); | (void)AttrUtils::SetInt(this, ATTR_MODEL_MEMORY_SIZE, 0); | ||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_P2P_MEMORY_SIZE, 0); | |||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0); | (void)AttrUtils::SetInt(this, ATTR_MODEL_STREAM_NUM, 0); | ||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0); | (void)AttrUtils::SetInt(this, ATTR_MODEL_EVENT_NUM, 0); | ||||
| (void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0); | (void)AttrUtils::SetInt(this, ATTR_MODEL_LABEL_NUM, 0); | ||||
| @@ -409,13 +409,13 @@ bool ModelSerializeImp::HandleNodeNameRef() { | |||||
| item.dst_node_name.c_str(), item.dst_in_index); | item.dst_node_name.c_str(), item.dst_in_index); | ||||
| return false; | return false; | ||||
| } | } | ||||
| GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 | |||||
| GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); | |||||
| } else { | } else { | ||||
| // Control edge | // Control edge | ||||
| auto src_anchor = src_node_it->second->GetOutControlAnchor(); | auto src_anchor = src_node_it->second->GetOutControlAnchor(); | ||||
| auto dst_anchor = item.dst_node->GetInControlAnchor(); | auto dst_anchor = item.dst_node->GetInControlAnchor(); | ||||
| if (src_anchor != nullptr && dst_anchor != nullptr) { | if (src_anchor != nullptr && dst_anchor != nullptr) { | ||||
| GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); // lint !e737 | |||||
| GE_CHK_BOOL_ONLY_LOG((src_anchor->LinkTo(dst_anchor) == GRAPH_SUCCESS), " linkTo failed."); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -33,7 +33,6 @@ using std::shared_ptr; | |||||
| using std::string; | using std::string; | ||||
| using std::vector; | using std::vector; | ||||
| /*lint -save -e521 -e681 -e732 -e737*/ | |||||
| namespace ge { | namespace ge { | ||||
| const std::string ATTR_NAME_ID = "id"; | const std::string ATTR_NAME_ID = "id"; | ||||
| @@ -56,9 +56,6 @@ using std::string; | |||||
| using std::to_string; | using std::to_string; | ||||
| using std::vector; | using std::vector; | ||||
| /*lint -save -e529 -e728*/ | |||||
| /*lint -e446 -e732*/ | |||||
| /*lint -e665*/ | |||||
| namespace ge { | namespace ge { | ||||
| class OpIO { | class OpIO { | ||||
| public: | public: | ||||
| @@ -768,6 +765,8 @@ const std::map<GeAttrValue::ValueType, std::string> kAttrTypesMap = { | |||||
| {GeAttrValue::VT_BYTES, "VT_BYTES"}, | {GeAttrValue::VT_BYTES, "VT_BYTES"}, | ||||
| {GeAttrValue::VT_GRAPH, "VT_GRAPH"}, | {GeAttrValue::VT_GRAPH, "VT_GRAPH"}, | ||||
| {GeAttrValue::VT_NAMED_ATTRS, "VT_NAMED_ATTRS"}, | {GeAttrValue::VT_NAMED_ATTRS, "VT_NAMED_ATTRS"}, | ||||
| {GeAttrValue::VT_LIST_LIST_INT, "VT_LIST_LIST_INT"}, | |||||
| {GeAttrValue::VT_DATA_TYPE, "VT_DATA_TYPE"}, | |||||
| {GeAttrValue::VT_LIST_BASE, "VT_LIST_BASE"}, | {GeAttrValue::VT_LIST_BASE, "VT_LIST_BASE"}, | ||||
| {GeAttrValue::VT_LIST_STRING, "VT_LIST_STRING"}, | {GeAttrValue::VT_LIST_STRING, "VT_LIST_STRING"}, | ||||
| {GeAttrValue::VT_LIST_FLOAT, "VT_LIST_FLOAT"}, | {GeAttrValue::VT_LIST_FLOAT, "VT_LIST_FLOAT"}, | ||||
| @@ -778,6 +777,7 @@ const std::map<GeAttrValue::ValueType, std::string> kAttrTypesMap = { | |||||
| {GeAttrValue::VT_LIST_BYTES, "VT_LIST_BYTES"}, | {GeAttrValue::VT_LIST_BYTES, "VT_LIST_BYTES"}, | ||||
| {GeAttrValue::VT_GRAPH, "VT_GRAPH"}, | {GeAttrValue::VT_GRAPH, "VT_GRAPH"}, | ||||
| {GeAttrValue::VT_LIST_NAMED_ATTRS, "VT_LIST_NAMED_ATTRS"}, | {GeAttrValue::VT_LIST_NAMED_ATTRS, "VT_LIST_NAMED_ATTRS"}, | ||||
| {GeAttrValue::VT_LIST_DATA_TYPE, "VT_LIST_DATA_TYPE"}, | |||||
| }; | }; | ||||
| } // namespace | } // namespace | ||||
| const std::map<std::string, std::string> Operator::GetAllAttrNamesAndTypes() const { | const std::map<std::string, std::string> Operator::GetAllAttrNamesAndTypes() const { | ||||
| @@ -943,7 +943,7 @@ OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; } | |||||
| GELOGW("set attr name %s failed.", name.c_str()); \ | GELOGW("set attr name %s failed.", name.c_str()); \ | ||||
| } \ | } \ | ||||
| return *this; \ | return *this; \ | ||||
| } // lint !e665 | |||||
| } | |||||
| #define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \ | #define OP_ATTR_GET_IMP(ArgType, AttrUtilsFun) \ | ||||
| graphStatus Operator::GetAttr(const string &name, ArgType attr_value) const { \ | graphStatus Operator::GetAttr(const string &name, ArgType attr_value) const { \ | ||||
| @@ -956,7 +956,7 @@ OperatorImplPtr Operator::GetOperatorImplPtr() const { return operator_impl_; } | |||||
| return GRAPH_FAILED; \ | return GRAPH_FAILED; \ | ||||
| } \ | } \ | ||||
| return GRAPH_SUCCESS; \ | return GRAPH_SUCCESS; \ | ||||
| } // lint !e665 | |||||
| } | |||||
| void Operator::BreakConnect() const { | void Operator::BreakConnect() const { | ||||
| if (operator_impl_ == nullptr) { | if (operator_impl_ == nullptr) { | ||||
| @@ -977,7 +977,7 @@ void Operator::BreakConnect() const { | |||||
| if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ | if (!AttrUtils::Set##AttrUtilsFun(operator_impl_->GetOpDescImpl(), name, attr_value)) { \ | ||||
| GELOGW("reg attr name %s failed.", name.c_str()); \ | GELOGW("reg attr name %s failed.", name.c_str()); \ | ||||
| } \ | } \ | ||||
| } // lint !e665 | |||||
| } | |||||
| OP_ATTR_SET_IMP(int64_t, Int) | OP_ATTR_SET_IMP(int64_t, Int) | ||||
| OP_ATTR_SET_IMP(int32_t, Int) | OP_ATTR_SET_IMP(int32_t, Int) | ||||
| @@ -998,22 +998,22 @@ OP_ATTR_SET_IMP(const vector<vector<int64_t>> &, ListListInt) | |||||
| OP_ATTR_SET_IMP(float, Float) | OP_ATTR_SET_IMP(float, Float) | ||||
| OP_ATTR_GET_IMP(float &, Float) | OP_ATTR_GET_IMP(float &, Float) | ||||
| OP_ATTR_SET_IMP(const vector<float> &, ListFloat) | OP_ATTR_SET_IMP(const vector<float> &, ListFloat) | ||||
| OP_ATTR_GET_IMP(vector<float> &, ListFloat) // lint !e665 | |||||
| OP_ATTR_GET_IMP(vector<float> &, ListFloat) | |||||
| OP_ATTR_SET_IMP(bool, Bool) | OP_ATTR_SET_IMP(bool, Bool) | ||||
| OP_ATTR_GET_IMP(bool &, Bool) | OP_ATTR_GET_IMP(bool &, Bool) | ||||
| OP_ATTR_SET_IMP(const vector<bool> &, ListBool) | OP_ATTR_SET_IMP(const vector<bool> &, ListBool) | ||||
| OP_ATTR_GET_IMP(vector<bool> &, ListBool) // lint !e665 | |||||
| OP_ATTR_GET_IMP(vector<bool> &, ListBool) | |||||
| OP_ATTR_SET_IMP(const string &, Str) | OP_ATTR_SET_IMP(const string &, Str) | ||||
| OP_ATTR_GET_IMP(string &, Str) | OP_ATTR_GET_IMP(string &, Str) | ||||
| OP_ATTR_SET_IMP(const vector<string> &, ListStr) | OP_ATTR_SET_IMP(const vector<string> &, ListStr) | ||||
| OP_ATTR_GET_IMP(vector<string> &, ListStr) // lint !e665 | |||||
| OP_ATTR_GET_IMP(vector<string> &, ListStr) | |||||
| OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) | OP_ATTR_SET_IMP(const GeAttrValue::NAMED_ATTRS &, NamedAttrs) | ||||
| OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs) | OP_ATTR_GET_IMP(GeAttrValue::NAMED_ATTRS &, NamedAttrs) | ||||
| OP_ATTR_SET_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) | OP_ATTR_SET_IMP(const vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) | ||||
| OP_ATTR_GET_IMP(vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) // lint !e665 | |||||
| OP_ATTR_GET_IMP(vector<GeAttrValue::NAMED_ATTRS> &, ListNamedAttrs) | |||||
| OP_ATTR_REG_IMP(int64_t, Int) | OP_ATTR_REG_IMP(int64_t, Int) | ||||
| OP_ATTR_REG_IMP(const vector<int64_t> &, ListInt) | OP_ATTR_REG_IMP(const vector<int64_t> &, ListInt) | ||||
| @@ -1583,5 +1583,3 @@ void GraphUtils::BreakConnect(const std::map<OperatorImplPtr, NodePtr> &all_node | |||||
| } | } | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| /*lint +e446 +e732*/ | |||||
| /*lint +e665*/ | |||||
| @@ -38,9 +38,7 @@ bool OpsProtoManager::Initialize(const std::map<std::string, std::string> &optio | |||||
| return true; | return true; | ||||
| } | } | ||||
| /*lint -e1561*/ | |||||
| auto proto_iter = options.find("ge.opsProtoLibPath"); | auto proto_iter = options.find("ge.opsProtoLibPath"); | ||||
| /*lint +e1561*/ | |||||
| if (proto_iter == options.end()) { | if (proto_iter == options.end()) { | ||||
| GELOGW("ge.opsProtoLibPath option not set, return."); | GELOGW("ge.opsProtoLibPath option not set, return."); | ||||
| return false; | return false; | ||||
| @@ -31,6 +31,8 @@ GEContext &GetContext() { | |||||
| return ge_context; | return ge_context; | ||||
| } | } | ||||
| thread_local uint64_t GEContext::session_id_; | |||||
| graphStatus GEContext::GetOption(const std::string &key, std::string &option) { | graphStatus GEContext::GetOption(const std::string &key, std::string &option) { | ||||
| return GetThreadLocalContext().GetOption(key, option); | return GetThreadLocalContext().GetOption(key, option); | ||||
| } | } | ||||
| @@ -57,4 +57,18 @@ void GEThreadLocalContext::SetGraphOption(map<std::string, string> options_map) | |||||
| graph_options_.clear(); | graph_options_.clear(); | ||||
| graph_options_ = std::move(options_map); | graph_options_ = std::move(options_map); | ||||
| } | } | ||||
| map<string, string> GEThreadLocalContext::GetAllGraphOptions() const { return graph_options_; } | |||||
| map<string, string> GEThreadLocalContext::GetAllSessionOptions() const { return session_options_; } | |||||
| map<string, string> GEThreadLocalContext::GetAllGlobalOptions() const { return global_options_; } | |||||
| map<string, string> GEThreadLocalContext::GetAllOptions() const { | |||||
| map<string, string> options_all; | |||||
| options_all.insert(graph_options_.begin(), graph_options_.end()); | |||||
| options_all.insert(session_options_.begin(), session_options_.end()); | |||||
| options_all.insert(global_options_.begin(), global_options_.end()); | |||||
| return options_all; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -365,6 +365,37 @@ string Serial(const vector<int64_t> &dims) { | |||||
| return serial_string; | return serial_string; | ||||
| } | } | ||||
| void SerialShapeRange(const GeTensorDescPtr &desc, std::string &desc_str) { | |||||
| desc_str += "["; | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
| (void)desc->GetShapeRange(shape_range); | |||||
| for (const auto &pair : shape_range) { | |||||
| desc_str += "{"; | |||||
| desc_str += std::to_string(pair.first) + "," + std::to_string(pair.second); | |||||
| desc_str += "},"; | |||||
| } | |||||
| desc_str += "] "; | |||||
| } | |||||
| void SerialShapeAndDtype(const GeTensorDescPtr &desc, bool is_origin_info, std::string &desc_str) { | |||||
| desc_str += "["; | |||||
| if (!is_origin_info) { | |||||
| for (int64_t dim : desc->GetShape().GetDims()) { | |||||
| desc_str += std::to_string(dim) + " "; | |||||
| } | |||||
| desc_str += "]"; | |||||
| desc_str += ":" + TypeUtils::DataTypeToSerialString(desc->GetDataType()) + ":" + | |||||
| TypeUtils::FormatToSerialString(desc->GetFormat()) + " "; | |||||
| } else { | |||||
| for (int64_t dim : desc->GetOriginShape().GetDims()) { | |||||
| desc_str += std::to_string(dim) + " "; | |||||
| } | |||||
| desc_str += "]"; | |||||
| desc_str += ":" + TypeUtils::DataTypeToSerialString(desc->GetOriginDataType()) + ":" + | |||||
| TypeUtils::FormatToSerialString(desc->GetOriginFormat()) + " "; | |||||
| } | |||||
| } | |||||
| graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) { | graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) { | ||||
| GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | GE_IF_BOOL_EXEC(node_ptr == nullptr, GELOGE(GRAPH_FAILED, "node is null."); return GRAPH_FAILED); | ||||
| GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); | GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, GELOGE(GRAPH_FAILED, "op_desc is null."); return GRAPH_FAILED); | ||||
| @@ -386,9 +417,9 @@ graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) { | |||||
| if (in_desc == nullptr) { | if (in_desc == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| auto in_shape = in_desc->GetShape().GetDims(); | |||||
| auto in_shape = in_desc->MutableShape().GetDims(); | |||||
| auto in_dtype = in_desc->GetDataType(); | auto in_dtype = in_desc->GetDataType(); | ||||
| auto peer_out_shape = peer_out_desc->GetShape().GetDims(); | |||||
| auto peer_out_shape = peer_out_desc->MutableShape().GetDims(); | |||||
| auto peer_out_dtype = peer_out_desc->GetDataType(); | auto peer_out_dtype = peer_out_desc->GetDataType(); | ||||
| if (peer_out_dtype != in_dtype) { | if (peer_out_dtype != in_dtype) { | ||||
| GELOGW( | GELOGW( | ||||
| @@ -407,13 +438,15 @@ graphStatus UpdateOpInputDesc(const ConstNodePtr &node_ptr) { | |||||
| } | } | ||||
| // refresh current node input desc | // refresh current node input desc | ||||
| in_desc->SetOriginShape(peer_out_desc->GetOriginShape()); | in_desc->SetOriginShape(peer_out_desc->GetOriginShape()); | ||||
| in_desc->SetShape(peer_out_desc->GetShape()); | |||||
| in_desc->SetShape(peer_out_desc->MutableShape()); | |||||
| in_desc->SetDataType(peer_out_desc->GetDataType()); | in_desc->SetDataType(peer_out_desc->GetDataType()); | ||||
| in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType()); | in_desc->SetOriginDataType(peer_out_desc->GetOriginDataType()); | ||||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
| (void)peer_out_desc->GetShapeRange(shape_range); | |||||
| in_desc->SetShapeRange(shape_range); | |||||
| ge::TensorUtils::SetRealDimCnt(*in_desc, static_cast<uint32_t>(peer_out_desc->GetShape().GetDims().size())); | |||||
| if (peer_out_desc->MutableShape().GetDims() != UNKNOWN_RANK) { | |||||
| std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
| (void)peer_out_desc->GetShapeRange(shape_range); | |||||
| in_desc->SetShapeRange(shape_range); | |||||
| } | |||||
| ge::TensorUtils::SetRealDimCnt(*in_desc, static_cast<uint32_t>(peer_out_desc->MutableShape().GetDims().size())); | |||||
| } | } | ||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| @@ -432,25 +465,19 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||||
| if (op_desc->GetInputsSize() != 0) { | if (op_desc->GetInputsSize() != 0) { | ||||
| std::string input_desc_str = "input shape: "; | std::string input_desc_str = "input shape: "; | ||||
| for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | ||||
| input_desc_str += "["; | |||||
| for (int64_t dim : input_desc->GetShape().GetDims()) { | |||||
| input_desc_str += std::to_string(dim) + " "; | |||||
| } | |||||
| input_desc_str += "]"; | |||||
| input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetDataType()) + ":" + | |||||
| TypeUtils::FormatToSerialString(input_desc->GetFormat()) + " "; | |||||
| SerialShapeAndDtype(input_desc, false, input_desc_str); | |||||
| } | } | ||||
| str += input_desc_str; | str += input_desc_str; | ||||
| input_desc_str = "input origin shape: "; | input_desc_str = "input origin shape: "; | ||||
| for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | ||||
| input_desc_str += "["; | |||||
| for (int64_t dim : input_desc->GetOriginShape().GetDims()) { | |||||
| input_desc_str += std::to_string(dim) + " "; | |||||
| } | |||||
| input_desc_str += "]"; | |||||
| input_desc_str += ":" + TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()) + ":" + | |||||
| TypeUtils::FormatToSerialString(input_desc->GetOriginFormat()) + " "; | |||||
| SerialShapeAndDtype(input_desc, true, input_desc_str); | |||||
| } | |||||
| str += input_desc_str; | |||||
| input_desc_str = "input shape range: "; | |||||
| for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | |||||
| SerialShapeRange(input_desc, input_desc_str); | |||||
| } | } | ||||
| str += input_desc_str; | str += input_desc_str; | ||||
| } | } | ||||
| @@ -461,13 +488,7 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||||
| if (output_desc == nullptr) { | if (output_desc == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| output_desc_str += "["; | |||||
| for (int64_t dim : output_desc->GetShape().GetDims()) { | |||||
| output_desc_str += std::to_string(dim) + " "; | |||||
| } | |||||
| output_desc_str += "]"; | |||||
| output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetDataType()) + ":" + | |||||
| TypeUtils::FormatToSerialString(output_desc->GetFormat()) + " "; | |||||
| SerialShapeAndDtype(output_desc, false, output_desc_str); | |||||
| } | } | ||||
| str += output_desc_str; | str += output_desc_str; | ||||
| @@ -476,13 +497,13 @@ void ShapeRefiner::PrintInOutTensorShape(const ge::NodePtr &node, const std::str | |||||
| if (output_desc == nullptr) { | if (output_desc == nullptr) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| output_desc_str += "["; | |||||
| for (int64_t dim : output_desc->GetOriginShape().GetDims()) { | |||||
| output_desc_str += std::to_string(dim) + " "; | |||||
| } | |||||
| output_desc_str += "]"; | |||||
| output_desc_str += ":" + TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()) + ":" + | |||||
| TypeUtils::FormatToSerialString(output_desc->GetOriginFormat()) + " "; | |||||
| SerialShapeAndDtype(output_desc, true, output_desc_str); | |||||
| } | |||||
| str += output_desc_str; | |||||
| output_desc_str = "output shape range: "; | |||||
| for (const auto &output_desc : op_desc->GetAllOutputsDescPtr()) { | |||||
| SerialShapeRange(output_desc, output_desc_str); | |||||
| } | } | ||||
| str += output_desc_str; | str += output_desc_str; | ||||
| } | } | ||||
| @@ -1,6 +0,0 @@ | |||||
| inc_path := $(shell pwd)/metadef/inc/external/ | |||||
| out_path := $(shell pwd)/out/graph/lib64/stub/ | |||||
| stub_path := $(shell pwd)/metadef/graph/stub/ | |||||
| mkdir_stub := $(shell mkdir -p $(out_path)) | |||||
| graph_local_stub := $(shell $(HI_PYTHON) $(stub_path)/gen_stubapi.py $(inc_path) $(out_path)) | |||||
| @@ -1,578 +0,0 @@ | |||||
| import os | |||||
| import re | |||||
| import sys | |||||
| import logging | |||||
| logging.basicConfig(stream=sys.stdout, format='[%(asctime)s] [%(lineno)s] %(levelname)s: %(message)s', | |||||
| level=logging.INFO) | |||||
| """ | |||||
| this attr is used for symbol table visible | |||||
| """ | |||||
| GE_ATTR = 'GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY' | |||||
| """ | |||||
| generate stub func body by return type | |||||
| """ | |||||
| RETURN_STATEMENTS = { | |||||
| 'graphStatus': ' std::cout << "[ERROR]: stub library libgraph or libge_compiler cannot be used for execution, please check your "\n ' | |||||
| ' << "environment variables and compilation options to make sure you use the correct library."\n' | |||||
| ' << std::endl;\n' | |||||
| ' return ACL_ERROR_COMPILING_STUB_MODE;', | |||||
| 'Status': ' return SUCCESS;', | |||||
| 'Graph': ' return Graph();', | |||||
| 'Graph&': ' return *this;', | |||||
| 'Format': ' return Format();', | |||||
| 'Format&': ' return *this;', | |||||
| 'Shape': ' return Shape();', | |||||
| 'Shape&': ' return *this;', | |||||
| 'TensorDesc': ' return TensorDesc();', | |||||
| 'TensorDesc&': ' return *this;', | |||||
| 'Tensor': ' return Tensor();', | |||||
| 'Tensor&': ' return *this;', | |||||
| 'Operator': ' return Operator();', | |||||
| 'Operator&': ' return *this;', | |||||
| 'Ptr': ' return nullptr;', | |||||
| 'std::string': ' return "";', | |||||
| 'std::string&': ' return "";', | |||||
| 'string': ' return "";', | |||||
| 'int': ' return 0;', | |||||
| 'DataType': ' return DT_FLOAT;', | |||||
| 'InferenceContextPtr': ' return nullptr;', | |||||
| 'SubgraphBuilder': ' return nullptr;', | |||||
| 'OperatorImplPtr': ' return nullptr;', | |||||
| 'OutHandler': ' return nullptr;', | |||||
| 'std::vector<std::string>': ' return {};', | |||||
| 'std::vector<int64_t>': ' return {};', | |||||
| 'std::map': ' return {};', | |||||
| 'uint32_t': ' return 0;', | |||||
| 'int64_t': ' return 0;', | |||||
| 'uint64_t': ' return 0;', | |||||
| 'size_t': ' return 0;', | |||||
| 'float': ' return 0.0f;', | |||||
| 'bool': ' return false;', | |||||
| } | |||||
| """ | |||||
| max code len per line in hua_wei software programming specifications | |||||
| """ | |||||
| max_code_len_per_line = 100 | |||||
| """ | |||||
| white_list_for_debug, include_dir_key_words is to | |||||
| determines which header files to generate cc files from | |||||
| when DEBUG on | |||||
| """ | |||||
| white_list_for_debug = ["attr_value.h", "operator.h", "tensor.h", "graph.h", "operator_factory.h", "inference_context.h", | |||||
| "ge_ir_build.h", "ge_api.h", "ascend_string.h", "gnode.h"] | |||||
| include_dir_key_words = ["ge", "graph"] | |||||
| DEBUG = True | |||||
| def need_generate_func(func_line): | |||||
| """ | |||||
| :param func_line: | |||||
| :return: | |||||
| """ | |||||
| if func_line.strip().endswith("default") or func_line.strip().endswith("delete") \ | |||||
| or func_line.strip().startswith("typedef") or func_line.strip().startswith("using"): | |||||
| return False | |||||
| return True | |||||
| def file_endswith_white_list_suffix(file): | |||||
| """ | |||||
| :param file: | |||||
| :return: | |||||
| """ | |||||
| if DEBUG: | |||||
| for suffix in white_list_for_debug: | |||||
| if file.endswith(suffix): | |||||
| return True | |||||
| return False | |||||
| else: | |||||
| return True | |||||
| """ | |||||
| belows are patterns used for analyse .h file | |||||
| """ | |||||
| # pattern function | |||||
| pattern_func = re.compile(r"""(^[\s]*) #leading with space,we will find and delete after | |||||
| ([a-zA-Z~_] # void int likely | |||||
| .* | |||||
| [)] #we find ) | |||||
| (?!.*{) # we do not want the case int abc() const | |||||
| .*) | |||||
| (;.*) #we want to find ; and after for we will replace these later | |||||
| \n$ | |||||
| """, re.VERBOSE | re.MULTILINE | re.DOTALL) | |||||
| # pattern comment | |||||
| pattern_comment = re.compile(r'^\s*//') | |||||
| pattern_comment_2_start = re.compile(r'^\s*/[*]') | |||||
| pattern_comment_2_end = re.compile(r'[*]/\s*$') | |||||
| # pattern define | |||||
| pattern_define = re.compile(r'^\s*#define') | |||||
| pattern_define_return = re.compile(r'\\\s*$') | |||||
| # blank line | |||||
| pattern_blank_line = re.compile(r'^\s*$') | |||||
| # virtual,explicit,friend,static | |||||
| pattern_keyword = re.compile(r'(virtual\s+|explicit\s+|friend\s+|static\s+)') | |||||
| # lead space | |||||
| pattern_leading_space = re.compile(r'(^[\s]*)[a-zA-Z~_]') | |||||
| # functions will have patterns such as func ( or func( | |||||
| # but operator is an exception; the class name is preceded by an operator, and the above mode does not exist | |||||
| # format like :"operator = ()" | |||||
| pattern_func_name = re.compile(r'([a-zA-Z0-9~_\-]+\s*|operator?.*)[(]') | |||||
| # template | |||||
| pattern_template = re.compile(r'^\s*template') | |||||
| pattern_template_end = re.compile(r'>\s*$') | |||||
| # namespace | |||||
| pattern_namespace = re.compile(r'namespace.*{') | |||||
| # class : which can handle classA a and {not on the same line, but if found ';' after class,then don't deal with | |||||
| pattern_class = re.compile(r'^[\s]*(class|struct)\s+(%s\s+)?([a-zA-Z0-9_\-]+<?)(?!.*;)' % GE_ATTR) | |||||
| # {} | |||||
| pattern_start = re.compile('{') | |||||
| pattern_end = re.compile('}') | |||||
| line_index = 0 | |||||
| class H2CC(object): | |||||
| def __init__(self, input_file, output_file, shared_includes_content): | |||||
| """ | |||||
| :param input_file: | |||||
| :param output_file: | |||||
| :param shared_includes_content: | |||||
| """ | |||||
| self.input_file = input_file | |||||
| self.output_file = output_file | |||||
| self.shared_includes_content = shared_includes_content | |||||
| self.line_index = 0 | |||||
| self.input_fd = open(self.input_file, 'r') | |||||
| self.input_content = self.input_fd.readlines() | |||||
| self.output_fd = open(self.output_file, 'w') | |||||
| # The state may be normal_now(in the middle of {}),class_now,namespace_now | |||||
| self.stack = [] | |||||
| self.stack_class = [] | |||||
| self.stack_template = [] | |||||
| # record funcs generated by h2cc func | |||||
| self.func_list_exist = [] | |||||
| def __del__(self): | |||||
| self.input_fd.close() | |||||
| self.output_fd.close() | |||||
| del self.stack | |||||
| del self.stack_class | |||||
| del self.stack_template | |||||
| del self.func_list_exist | |||||
| def just_skip(self): | |||||
| # skip blank line or comment | |||||
| if pattern_blank_line.search(self.input_content[self.line_index]) or pattern_comment.search( | |||||
| self.input_content[self.line_index]): # /n or comment using // | |||||
| self.line_index += 1 | |||||
| if pattern_comment_2_start.search(self.input_content[self.line_index]): # comment using /* | |||||
| while not pattern_comment_2_end.search(self.input_content[self.line_index]): # */ | |||||
| self.line_index += 1 | |||||
| self.line_index += 1 | |||||
| # skip define | |||||
| if pattern_define.search(self.input_content[self.line_index]): | |||||
| while pattern_blank_line.search(self.input_content[self.line_index]) or pattern_define_return.search( | |||||
| self.input_content[self.line_index]): | |||||
| self.line_index += 1 | |||||
| self.line_index += 1 | |||||
| def write_inc_content(self): | |||||
| for shared_include_content in self.shared_includes_content: | |||||
| self.output_fd.write(shared_include_content) | |||||
| def h2cc(self): | |||||
| """ | |||||
| :return: | |||||
| """ | |||||
| logging.info("start generate cc_file[%s] from h_file[%s]", self.output_file, self.input_file) | |||||
| global pattern_comment | |||||
| global pattern_comment_2_start | |||||
| global pattern_comment_2_end | |||||
| global pattern_blank_line | |||||
| global pattern_func | |||||
| global pattern_keyword | |||||
| global pattern_leading_space | |||||
| global pattern_func_name | |||||
| global pattern_template | |||||
| global pattern_template_end | |||||
| global pattern_namespace | |||||
| global pattern_class | |||||
| global pattern_start | |||||
| global pattern_end | |||||
| global line_index | |||||
| # write inc content | |||||
| self.write_inc_content() | |||||
| # core processing cycle, process the input .h file by line | |||||
| while self.line_index < len(self.input_content): | |||||
| # handle comment and blank line | |||||
| self.just_skip() | |||||
| # match namespace | |||||
| self.handle_namespace() | |||||
| # match template | |||||
| template_string = self.handle_template() | |||||
| # match class | |||||
| line = self.input_content[self.line_index] | |||||
| match_class = pattern_class.search(line) | |||||
| match_start = pattern_start.search(line) | |||||
| handle_class_result = self.handle_class(template_string, line, match_start, match_class) | |||||
| if handle_class_result == "continue": | |||||
| continue | |||||
| # match "}" | |||||
| handle_stack_result = self.handle_stack(match_start) | |||||
| if handle_stack_result == "continue": | |||||
| continue | |||||
| # handle func | |||||
| handle_func1_result, line, start_i = self.handle_func1(line) | |||||
| if handle_func1_result == "continue": | |||||
| continue | |||||
| # here means func is found | |||||
| # delete key word | |||||
| line = pattern_keyword.sub('', line) | |||||
| logging.info("line[%s]", line) | |||||
| # Class member function | |||||
| # if friend we will not add class name | |||||
| friend_match = re.search('friend ', line) | |||||
| if len(self.stack_class) > 0 and not friend_match: | |||||
| line, func_name = self.handle_class_member_func(line, template_string) | |||||
| # Normal functions | |||||
| else: | |||||
| line, func_name = self.handle_normal_func(line, template_string) | |||||
| need_generate = need_generate_func(line) | |||||
| # func body | |||||
| line += self.implement_function(line) | |||||
| # comment | |||||
| line = self.gen_comment(start_i) + line | |||||
| # write to out file | |||||
| self.write_func_content(line, func_name, need_generate) | |||||
| # next loop | |||||
| self.line_index += 1 | |||||
| logging.info('Added %s functions', len(self.func_list_exist)) | |||||
| logging.info('Successfully converted,please see ' + self.output_file) | |||||
| def handle_func1(self, line): | |||||
| """ | |||||
| :param line: | |||||
| :return: | |||||
| """ | |||||
| find1 = re.search('[(]', line) | |||||
| if not find1: | |||||
| self.line_index += 1 | |||||
| return "continue", line, None | |||||
| find2 = re.search('[)]', line) | |||||
| start_i = self.line_index | |||||
| space_match = pattern_leading_space.search(line) | |||||
| # deal with | |||||
| # int abc(int a, | |||||
| # int b) | |||||
| if find1 and (not find2): | |||||
| self.line_index += 1 | |||||
| line2 = self.input_content[self.line_index] | |||||
| if space_match: | |||||
| line2 = re.sub('^' + space_match.group(1), '', line2) | |||||
| line += line2 | |||||
| while self.line_index < len(self.input_content) and (not re.search('[)]', line2)): | |||||
| self.line_index += 1 | |||||
| line2 = self.input_content[self.line_index] | |||||
| line2 = re.sub('^' + space_match.group(1), '', line2) | |||||
| line += line2 | |||||
| match_start = pattern_start.search(self.input_content[self.line_index]) | |||||
| match_end = pattern_end.search(self.input_content[self.line_index]) | |||||
| if match_start: # like ) { or ) {} int the last line | |||||
| if not match_end: | |||||
| self.stack.append('normal_now') | |||||
| ii = start_i | |||||
| while ii <= self.line_index: | |||||
| ii += 1 | |||||
| self.line_index += 1 | |||||
| return "continue", line, start_i | |||||
| logging.info("line[%s]", line) | |||||
| # ' int abc();'->'int abc()' | |||||
| (line, match) = pattern_func.subn(r'\2\n', line) | |||||
| logging.info("line[%s]", line) | |||||
| # deal with case: | |||||
| # 'int \n abc(int a, int b)' | |||||
| if re.search(r'^\s*(inline)?\s*[a-zA-Z0-9_]+\s*$', self.input_content[start_i - 1]): | |||||
| line = self.input_content[start_i - 1] + line | |||||
| line = line.lstrip() | |||||
| if not match: | |||||
| self.line_index += 1 | |||||
| return "continue", line, start_i | |||||
| return "pass", line, start_i | |||||
| def handle_stack(self, match_start): | |||||
| """ | |||||
| :param match_start: | |||||
| :return: | |||||
| """ | |||||
| line = self.input_content[self.line_index] | |||||
| match_end = pattern_end.search(line) | |||||
| if match_start: | |||||
| self.stack.append('normal_now') | |||||
| if match_end: | |||||
| top_status = self.stack.pop() | |||||
| if top_status == 'namespace_now': | |||||
| self.output_fd.write(line + '\n') | |||||
| elif top_status == 'class_now': | |||||
| self.stack_class.pop() | |||||
| self.stack_template.pop() | |||||
| if match_start or match_end: | |||||
| self.line_index += 1 | |||||
| return "continue" | |||||
| if len(self.stack) > 0 and self.stack[-1] == 'normal_now': | |||||
| self.line_index += 1 | |||||
| return "continue" | |||||
| return "pass" | |||||
| def handle_class(self, template_string, line, match_start, match_class): | |||||
| """ | |||||
| :param template_string: | |||||
| :param line: | |||||
| :param match_start: | |||||
| :param match_class: | |||||
| :return: | |||||
| """ | |||||
| if match_class: # we face a class | |||||
| self.stack_template.append(template_string) | |||||
| self.stack.append('class_now') | |||||
| class_name = match_class.group(3) | |||||
| # class template specializations: class A<u,Node<u> > | |||||
| if '<' in class_name: | |||||
| k = line.index('<') | |||||
| fit = 1 | |||||
| for ii in range(k + 1, len(line)): | |||||
| if line[ii] == '<': | |||||
| fit += 1 | |||||
| if line[ii] == '>': | |||||
| fit -= 1 | |||||
| if fit == 0: | |||||
| break | |||||
| class_name += line[k + 1:ii + 1] | |||||
| logging.info('class_name[%s]', class_name) | |||||
| self.stack_class.append(class_name) | |||||
| while not match_start: | |||||
| self.line_index += 1 | |||||
| line = self.input_content[self.line_index] | |||||
| match_start = pattern_start.search(line) | |||||
| self.line_index += 1 | |||||
| return "continue" | |||||
| return "pass" | |||||
| def handle_template(self): | |||||
| line = self.input_content[self.line_index] | |||||
| match_template = pattern_template.search(line) | |||||
| template_string = '' | |||||
| if match_template: | |||||
| match_template_end = pattern_template_end.search(line) | |||||
| template_string = line | |||||
| while not match_template_end: | |||||
| self.line_index += 1 | |||||
| line = self.input_content[self.line_index] | |||||
| template_string += line | |||||
| match_template_end = pattern_template_end.search(line) | |||||
| self.line_index += 1 | |||||
| return template_string | |||||
| def handle_namespace(self): | |||||
| line = self.input_content[self.line_index] | |||||
| match_namespace = pattern_namespace.search(line) | |||||
| if match_namespace: # we face namespace | |||||
| self.output_fd.write(line + '\n') | |||||
| self.stack.append('namespace_now') | |||||
| self.line_index += 1 | |||||
| def handle_normal_func(self, line, template_string): | |||||
| template_line = '' | |||||
| self.stack_template.append(template_string) | |||||
| if self.stack_template[-1] != '': | |||||
| template_line = re.sub(r'\s*template', 'template', self.stack_template[-1]) | |||||
| # change '< class T = a, class U = A(3)>' to '<class T, class U>' | |||||
| template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) | |||||
| template_line = re.sub(r'\s*=.*,', ',', template_line) | |||||
| template_line = re.sub(r'\s*=.*', '', template_line) | |||||
| line = re.sub(r'\s*=.*,', ',', line) | |||||
| line = re.sub(r'\s*=.*\)', ')', line) | |||||
| line = template_line + line | |||||
| self.stack_template.pop() | |||||
| func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() | |||||
| logging.info("line[%s]", line) | |||||
| logging.info("func_name[%s]", func_name) | |||||
| return line, func_name | |||||
| def handle_class_member_func(self, line, template_string): | |||||
| template_line = '' | |||||
| x = '' | |||||
| if template_string != '': | |||||
| template_string = re.sub(r'\s*template', 'template', template_string) | |||||
| template_string = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_string) | |||||
| template_string = re.sub(r'\s*=.*,', ',', template_string) | |||||
| template_string = re.sub(r'\s*=.*', '', template_string) | |||||
| if self.stack_template[-1] != '': | |||||
| if not (re.search(r'<\s*>', stack_template[-1])): | |||||
| template_line = re.sub(r'^\s*template', 'template', stack_template[-1]) | |||||
| if not (re.search(r'<.*>', self.stack_class[-1])): | |||||
| # for x we get like template<class T, typename U> -> <T,U> | |||||
| x = re.sub(r'template\s*<', '<', template_line) # remove template -> <class T, typename U> | |||||
| x = re.sub(r'\n', '', x) | |||||
| x = re.sub(r'\s*=.*,', ',', x) | |||||
| x = re.sub(r'\s*=.*\>', '>', x) | |||||
| x = x.rstrip() # remove \n | |||||
| x = re.sub(r'(class|typename)\s+|(<class>|<typename>\s*class)', '', | |||||
| x) # remove class,typename -> <T, U> | |||||
| x = re.sub(r'<\s+', '<', x) | |||||
| x = re.sub(r'\s+>', '>', x) | |||||
| x = re.sub(r'\s+,', ',', x) | |||||
| x = re.sub(r',\s+', ', ', x) | |||||
| line = re.sub(r'\s*=\s+0', '', line) | |||||
| line = re.sub(r'\s*=\s+.*,', ',', line) | |||||
| line = re.sub(r'\s*=\s+.*\)', ')', line) | |||||
| logging.info("x[%s]\nline[%s]", x, line) | |||||
| # if the function is long, void ABC::foo() | |||||
| # breaks into two lines void ABC::\n foo() | |||||
| temp_line = pattern_func_name.sub(self.stack_class[-1] + x + '::' + r'\1(', line, count=1) | |||||
| if len(temp_line) > max_code_len_per_line: | |||||
| line = pattern_func_name.sub(self.stack_class[-1] + x + '::\n' + r'\1(', line, count=1) | |||||
| else: | |||||
| line = temp_line | |||||
| logging.info("line[%s]", line) | |||||
| # add template as the above if there is one | |||||
| template_line = re.sub(r'\s*=.*>(\s*)$', r'>\1', template_line) | |||||
| template_line = re.sub(r'\s*=.*,', ',', template_line) | |||||
| template_line = re.sub(r'\s*=.*', '', template_line) | |||||
| line = template_line + template_string + line | |||||
| func_name = re.search(r'^.*\)', line, re.MULTILINE | re.DOTALL).group() | |||||
| logging.info("line[%s]", line) | |||||
| logging.info("func_name[%s]", func_name) | |||||
| return line, func_name | |||||
| def write_func_content(self, content, func_name, need_generate): | |||||
| if not (func_name in self.func_list_exist) and need_generate: | |||||
| self.output_fd.write(content) | |||||
| self.func_list_exist.append(func_name) | |||||
| logging.info('add func:[%s]', func_name) | |||||
| def gen_comment(self, start_i): | |||||
| comment_line = '' | |||||
| # Function comments are on top of function declarations, copy them over | |||||
| k = start_i - 1 # one line before this func start | |||||
| if pattern_template.search(self.input_content[k]): | |||||
| k -= 1 | |||||
| if pattern_comment_2_end.search(self.input_content[k]): | |||||
| comment_line = self.input_content[k].lstrip() | |||||
| while not pattern_comment_2_start.search(self.input_content[k]): | |||||
| k -= 1 | |||||
| comment_line = self.input_content[k].lstrip() + comment_line | |||||
| else: | |||||
| for j in range(k, 0, -1): | |||||
| c_line = self.input_content[j] | |||||
| if pattern_comment.search(c_line): | |||||
| c_line = re.sub(r'\s*//', '//', c_line) | |||||
| comment_line = c_line + comment_line | |||||
| else: | |||||
| break | |||||
| return comment_line | |||||
| @staticmethod | |||||
| def implement_function(func): | |||||
| function_def = '' | |||||
| function_def += '{\n' | |||||
| all_items = func.split() | |||||
| start = 0 | |||||
| return_type = all_items[start] | |||||
| if return_type == "const": | |||||
| start += 1 | |||||
| return_type = all_items[start] | |||||
| if return_type.startswith(('std::map', 'std::set', 'std::vector')): | |||||
| return_type = "std::map" | |||||
| if return_type.endswith('*') or (len(all_items) > start + 1 and all_items[start + 1].startswith('*')): | |||||
| return_type = "Ptr" | |||||
| if len(all_items) > start + 1 and all_items[start + 1].startswith('&'): | |||||
| return_type += "&" | |||||
| if RETURN_STATEMENTS.__contains__(return_type): | |||||
| function_def += RETURN_STATEMENTS[return_type] | |||||
| else: | |||||
| logging.warning("Unhandled return type[%s]", return_type) | |||||
| function_def += '\n' | |||||
| function_def += '}\n' | |||||
| function_def += '\n' | |||||
| return function_def | |||||
| def collect_header_files(path): | |||||
| """ | |||||
| :param path: | |||||
| :return: | |||||
| """ | |||||
| header_files = [] | |||||
| shared_includes_content = [] | |||||
| for root, dirs, files in os.walk(path): | |||||
| files.sort() | |||||
| for file in files: | |||||
| if file.find("git") >= 0: | |||||
| continue | |||||
| if not file.endswith('.h'): | |||||
| continue | |||||
| file_path = os.path.join(root, file) | |||||
| file_path = file_path.replace('\\', '/') | |||||
| header_files.append(file_path) | |||||
| include_str = '#include "{}"\n'.format(file_path[path.rindex('/') + 1:]) | |||||
| shared_includes_content.append(include_str) | |||||
| # for acl error code | |||||
| shared_includes_content.append('#include <iostream>\n') | |||||
| shared_includes_content.append('const int ACL_ERROR_COMPILING_STUB_MODE = 100039;\n') | |||||
| return header_files, shared_includes_content | |||||
| def generate_stub_file(inc_dir, out_cc_dir): | |||||
| """ | |||||
| :param inc_dir: | |||||
| :param out_cc_dir: | |||||
| :return: | |||||
| """ | |||||
| target_header_files, shared_includes_content = collect_header_files(inc_dir) | |||||
| for header_file in target_header_files: | |||||
| if not file_endswith_white_list_suffix(header_file): | |||||
| continue | |||||
| cc_file = re.sub('.h*$', '.cc', header_file) | |||||
| h_2_cc = H2CC(header_file, out_cc_dir + cc_file[cc_file.rindex('/') + 1:], shared_includes_content) | |||||
| h_2_cc.h2cc() | |||||
| def gen_code(inc_dir, out_cc_dir): | |||||
| """ | |||||
| :param inc_dir: | |||||
| :param out_cc_dir: | |||||
| :return: | |||||
| """ | |||||
| if not inc_dir.endswith('/'): | |||||
| inc_dir += '/' | |||||
| if not out_cc_dir.endswith('/'): | |||||
| out_cc_dir += '/' | |||||
| for include_dir_key_word in include_dir_key_words: | |||||
| generate_stub_file(inc_dir + include_dir_key_word, out_cc_dir) | |||||
| if __name__ == '__main__': | |||||
| inc_dir = sys.argv[1] | |||||
| out_cc_dir = sys.argv[2] | |||||
| gen_code(inc_dir, out_cc_dir) | |||||
| @@ -178,18 +178,16 @@ int64_t Shape::GetShapeSize() const { | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| TensorDesc::TensorDesc() { | |||||
| impl = ComGraphMakeShared<TensorDescImpl>(); // lint !e665 | |||||
| } | |||||
| TensorDesc::TensorDesc() { impl = ComGraphMakeShared<TensorDescImpl>(); } | |||||
| TensorDesc::TensorDesc(Shape shape, Format format, DataType dt) { | TensorDesc::TensorDesc(Shape shape, Format format, DataType dt) { | ||||
| impl = ComGraphMakeShared<TensorDescImpl>(shape, format, dt); // lint !e665 | |||||
| impl = ComGraphMakeShared<TensorDescImpl>(shape, format, dt); | |||||
| SetRealDimCnt(shape.GetDimNum()); | SetRealDimCnt(shape.GetDimNum()); | ||||
| } | } | ||||
| TensorDesc::TensorDesc(const TensorDesc &desc) { | TensorDesc::TensorDesc(const TensorDesc &desc) { | ||||
| // Copy | // Copy | ||||
| impl = ComGraphMakeShared<TensorDescImpl>(); // lint !e665 | |||||
| impl = ComGraphMakeShared<TensorDescImpl>(); | |||||
| if (desc.impl != nullptr && impl != nullptr) { | if (desc.impl != nullptr && impl != nullptr) { | ||||
| *impl = *desc.impl; | *impl = *desc.impl; | ||||
| } | } | ||||
| @@ -360,9 +358,7 @@ void TensorDesc::SetName(const std::string &name) { | |||||
| Tensor::Tensor() { impl = ComGraphMakeShared<TensorImpl>(); } | Tensor::Tensor() { impl = ComGraphMakeShared<TensorImpl>(); } | ||||
| Tensor::Tensor(const TensorDesc &tensor_desc) { | |||||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc); // lint !e665 | |||||
| } | |||||
| Tensor::Tensor(const TensorDesc &tensor_desc) { impl = ComGraphMakeShared<TensorImpl>(tensor_desc); } | |||||
| Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data) { | Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data) { | ||||
| uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); | uint64_t shape_size = tensor_desc.GetShape().GetShapeSize(); | ||||
| @@ -384,7 +380,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const std::vector<uint8_t> &data) | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data); // lint !e665 | |||||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data); | |||||
| } | } | ||||
| Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) { | Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) { | ||||
| @@ -406,7 +402,7 @@ Tensor::Tensor(const TensorDesc &tensor_desc, const uint8_t *data, size_t size) | |||||
| } | } | ||||
| } | } | ||||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size); // lint !e665 | |||||
| impl = ComGraphMakeShared<TensorImpl>(tensor_desc, data, size); | |||||
| } | } | ||||
| Tensor::Tensor(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data) { | Tensor::Tensor(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data) { | ||||
| @@ -429,7 +425,7 @@ Tensor::Tensor(TensorDesc &&tensor_desc, std::vector<uint8_t> &&data) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| impl = ComGraphMakeShared<TensorImpl>(std::move(tensor_desc), std::move(data)); // lint !e665 | |||||
| impl = ComGraphMakeShared<TensorImpl>(std::move(tensor_desc), std::move(data)); | |||||
| } | } | ||||
| TensorDesc Tensor::GetTensorDesc() const { | TensorDesc Tensor::GetTensorDesc() const { | ||||
| @@ -643,7 +639,7 @@ TensorDesc TensorAdapter::GeTensorDesc2TensorDesc(const GeTensorDesc &ge_tensor_ | |||||
| GeTensorPtr TensorAdapter::Tensor2GeTensor(const Tensor &tensor) { | GeTensorPtr TensorAdapter::Tensor2GeTensor(const Tensor &tensor) { | ||||
| GeTensorPtr ge_tensor; | GeTensorPtr ge_tensor; | ||||
| if (tensor.impl != nullptr) { | if (tensor.impl != nullptr) { | ||||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor.Clone()); // lint !e665 | |||||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor.Clone()); | |||||
| } | } | ||||
| return ge_tensor; | return ge_tensor; | ||||
| } | } | ||||
| @@ -659,7 +655,7 @@ Tensor TensorAdapter::GeTensor2Tensor(const ConstGeTensorPtr &ge_tensor) { | |||||
| ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) { | ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) { | ||||
| GeTensorPtr ge_tensor; | GeTensorPtr ge_tensor; | ||||
| if (tensor.impl != nullptr) { | if (tensor.impl != nullptr) { | ||||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); // lint !e665 | |||||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); | |||||
| } | } | ||||
| return ge_tensor; | return ge_tensor; | ||||
| } | } | ||||
| @@ -667,7 +663,7 @@ ConstGeTensorPtr TensorAdapter::AsGeTensorPtr(const Tensor &tensor) { | |||||
| GeTensorPtr TensorAdapter::AsGeTensorPtr(Tensor &tensor) { | GeTensorPtr TensorAdapter::AsGeTensorPtr(Tensor &tensor) { | ||||
| GeTensorPtr ge_tensor; | GeTensorPtr ge_tensor; | ||||
| if (tensor.impl != nullptr) { | if (tensor.impl != nullptr) { | ||||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); // lint !e665 | |||||
| ge_tensor = ComGraphMakeShared<GeTensor>(tensor.impl->ge_tensor); | |||||
| } | } | ||||
| return ge_tensor; | return ge_tensor; | ||||
| } | } | ||||
| @@ -58,8 +58,10 @@ namespace { | |||||
| const int32_t kBaseOfIntegerValue = 10; | const int32_t kBaseOfIntegerValue = 10; | ||||
| #ifdef FMK_SUPPORT_DUMP | #ifdef FMK_SUPPORT_DUMP | ||||
| const char *const kDumpGeGraph = "DUMP_GE_GRAPH"; | const char *const kDumpGeGraph = "DUMP_GE_GRAPH"; | ||||
| const int kDumpGraphIndexWidth = 5; | |||||
| const int kDumpGraphIndexWidth = 8; | |||||
| #endif | #endif | ||||
| const char *const kDumpGraphPath = "DUMP_GRAPH_PATH"; | |||||
| const char *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; | const char *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; | ||||
| const char *const kDumpStrBuild = "Build"; | const char *const kDumpStrBuild = "Build"; | ||||
| const char *const kDumpStrPartition = "partition"; | const char *const kDumpStrPartition = "partition"; | ||||
| @@ -588,6 +590,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons | |||||
| } | } | ||||
| std::stringstream stream_file_name; | std::stringstream stream_file_name; | ||||
| char *dump_graph_path = std::getenv(kDumpGraphPath); | |||||
| if (dump_graph_path != nullptr) { | |||||
| std::string dump_graph_path_str(dump_graph_path); | |||||
| stream_file_name << (dump_graph_path_str.empty() ? "" : dump_graph_path_str + "/"); | |||||
| } | |||||
| stream_file_name << "ge_proto_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; | stream_file_name << "ge_proto_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; | ||||
| stream_file_name << "_" << suffix << ".txt"; | stream_file_name << "_" << suffix << ".txt"; | ||||
| std::string proto_file = user_graph_name.empty() ? stream_file_name.str() : user_graph_name; | std::string proto_file = user_graph_name.empty() ? stream_file_name.str() : user_graph_name; | ||||
| @@ -598,7 +605,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons | |||||
| Buffer buffer; | Buffer buffer; | ||||
| const int64_t kDumpLevel = | const int64_t kDumpLevel = | ||||
| (dump_ge_graph != nullptr) ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) : ge::OnnxUtils::NO_DUMP; | (dump_ge_graph != nullptr) ? std::strtol(dump_ge_graph, nullptr, kBaseOfIntegerValue) : ge::OnnxUtils::NO_DUMP; | ||||
| model.Save(buffer, kDumpLevel != ge::OnnxUtils::DUMP_ALL); | |||||
| model.Save(buffer, kDumpLevel != ge::OnnxUtils::DUMP_ALL && !is_always_dump); | |||||
| // Write file | // Write file | ||||
| ge::proto::ModelDef ge_proto; | ge::proto::ModelDef ge_proto; | ||||
| @@ -620,6 +627,54 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraph(cons | |||||
| #endif | #endif | ||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGrph(const ge::ComputeGraphPtr &graph, | |||||
| const std::string &path, | |||||
| const std::string &suffix) { | |||||
| // file name | |||||
| static std::atomic_long atomic_file_index(0); | |||||
| auto file_index = atomic_file_index.fetch_add(1); | |||||
| GELOGD("Start to dump om txt: %ld", file_index); | |||||
| thread_local long max_dump_file_num = 0; | |||||
| if (max_dump_file_num == 0) { | |||||
| string opt = "0"; | |||||
| (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); | |||||
| max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); | |||||
| } | |||||
| if (max_dump_file_num != 0 && file_index > max_dump_file_num) { | |||||
| GELOGW("Dump graph file cnt > maxDumpFileNum, maxDumpFileCnt=%ld.", max_dump_file_num); | |||||
| return; | |||||
| } | |||||
| std::stringstream stream_file_name; | |||||
| stream_file_name << path.c_str() << "/ge_proto_" << std::setw(5) << std::setfill('0') << file_index; | |||||
| stream_file_name << "_" << suffix << ".txt"; | |||||
| std::string proto_file = stream_file_name.str(); | |||||
| // Create buffer | |||||
| ge::Model model("", ""); | |||||
| model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast<ComputeGraph>(graph))); | |||||
| Buffer buffer; | |||||
| const int64_t kDumpLevel = ge::OnnxUtils::NO_DUMP; | |||||
| model.Save(buffer, kDumpLevel != ge::OnnxUtils::DUMP_ALL); | |||||
| // Write file | |||||
| ge::proto::ModelDef ge_proto; | |||||
| if (buffer.GetData() != nullptr) { | |||||
| std::string str(reinterpret_cast<const char *>(buffer.GetData()), buffer.GetSize()); | |||||
| if (!ge_proto.ParseFromString(str)) { | |||||
| GELOGE(GRAPH_FAILED, "parse from string failed."); | |||||
| return; | |||||
| } | |||||
| char real_path[PATH_MAX] = {0x00}; | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(proto_file.c_str()) >= PATH_MAX, return, "file path is too longer!"); | |||||
| GE_IF_BOOL_EXEC(realpath(proto_file.c_str(), real_path) == nullptr, | |||||
| GELOGI("file %s does not exist, it will be created.", proto_file.c_str())); | |||||
| GraphUtils::WriteProtoToTextFile(ge_proto, real_path); | |||||
| } | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(const char *file, | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraph(const char *file, | ||||
| ge::ComputeGraph &compute_graph) { | ge::ComputeGraph &compute_graph) { | ||||
| ge::proto::ModelDef model_def; | ge::proto::ModelDef model_def; | ||||
| @@ -722,7 +777,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::WriteProtoToText | |||||
| } | } | ||||
| GE_CHK_BOOL_EXEC(fclose(file) == 0, return, "Fclose fileoutputstream failed"); | GE_CHK_BOOL_EXEC(fclose(file) == 0, return, "Fclose fileoutputstream failed"); | ||||
| #else | #else | ||||
| GELOGW("need to define FMK_SUPPORT_DUMP for dump graph."); | |||||
| GELOGW("Need to define FMK_SUPPORT_DUMP for dump graph."); | |||||
| #endif | #endif | ||||
| } | } | ||||
| @@ -789,6 +844,11 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn | |||||
| } | } | ||||
| std::stringstream stream_file_name; | std::stringstream stream_file_name; | ||||
| char *dump_graph_path = std::getenv(kDumpGraphPath); | |||||
| if (dump_graph_path != nullptr) { | |||||
| std::string dump_graph_path_str(dump_graph_path); | |||||
| stream_file_name << (dump_graph_path_str.empty() ? "" : dump_graph_path_str + "/"); | |||||
| } | |||||
| stream_file_name << "ge_onnx_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; | stream_file_name << "ge_onnx_" << std::setw(kDumpGraphIndexWidth) << std::setfill('0') << file_index; | ||||
| stream_file_name << "_graph_" << compute_graph.GetGraphID(); | stream_file_name << "_graph_" << compute_graph.GetGraphID(); | ||||
| stream_file_name << "_" << suffix << ".pbtxt"; | stream_file_name << "_" << suffix << ".pbtxt"; | ||||
| @@ -822,6 +882,66 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGEGraphToOnn | |||||
| #endif | #endif | ||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void GraphUtils::DumpGrphToOnnx(const ge::ComputeGraph &compute_graph, | |||||
| const std::string &path, | |||||
| const std::string &suffix) { | |||||
| // 1.Get ge::onnx::ModelProto from ge::Model | |||||
| ge::Model model("GE", ""); | |||||
| std::shared_ptr<ge::ComputeGraph> compute_graph_ptr = ComGraphMakeShared<ge::ComputeGraph>(compute_graph); | |||||
| model.SetGraph(GraphUtils::CreateGraphFromComputeGraph(std::const_pointer_cast<ComputeGraph>(compute_graph_ptr))); | |||||
| onnx::ModelProto model_proto; | |||||
| if (!OnnxUtils::ConvertGeModelToModelProto(model, model_proto)) { | |||||
| GELOGE(GRAPH_FAILED, "DumpGEGraphToOnnx failed."); | |||||
| return; | |||||
| } | |||||
| // 2.Set file name | |||||
| static std::atomic_long atomic_file_index(0); | |||||
| auto file_index = atomic_file_index.fetch_add(1); | |||||
| GELOGD("Start to dump ge onnx file: %ld", file_index); | |||||
| thread_local long max_dump_file_num = 0; | |||||
| if (max_dump_file_num == 0) { | |||||
| string opt = "0"; | |||||
| (void)GetContext().GetOption(OPTION_GE_MAX_DUMP_FILE_NUM, opt); | |||||
| max_dump_file_num = std::strtol(opt.c_str(), nullptr, kBaseOfIntegerValue); | |||||
| } | |||||
| if (max_dump_file_num != 0 && file_index > max_dump_file_num) { | |||||
| GELOGW("Dump graph file cnt > maxDumpFileNum, maxDumpFileNum=%ld.", max_dump_file_num); | |||||
| return; | |||||
| } | |||||
| std::stringstream stream_file_name; | |||||
| stream_file_name << path.c_str() << "/ge_onnx_" << std::setw(5) << std::setfill('0') << file_index; | |||||
| stream_file_name << "_graph_" << compute_graph.GetGraphID(); | |||||
| stream_file_name << "_" << suffix << ".pbtxt"; | |||||
| std::string proto_file = stream_file_name.str(); | |||||
| if ((proto_file.length()) >= NAME_MAX) { | |||||
| GELOGE(GRAPH_FAILED, "File name is too longer!"); | |||||
| return; | |||||
| } | |||||
| std::unique_ptr<char[]> real_path(new (std::nothrow) char[PATH_MAX]{0}); | |||||
| if (real_path == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "New real_path failed."); | |||||
| return; | |||||
| } | |||||
| /// Returning nullptr means 3 case as follows: | |||||
| /// a.path is PATH_MAX chars or more | |||||
| /// b.the file does not exist | |||||
| /// c.the path has no permissions | |||||
| /// Distinguish between last the two cases in the function WriteProtoToTextFile call open() | |||||
| if (realpath(proto_file.c_str(), real_path.get()) == nullptr) { | |||||
| // For case a | |||||
| if (errno == ENAMETOOLONG) { | |||||
| GELOGE(GRAPH_FAILED, "Call realpath failed: path is PATH_MAX chars or more."); | |||||
| return; | |||||
| } | |||||
| } | |||||
| // 3. Serialize to file in current path | |||||
| GraphUtils::WriteProtoToTextFile(model_proto, real_path.get()); | |||||
| } | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraphFromOnnx(const char *file, | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::LoadGEGraphFromOnnx(const char *file, | ||||
| ge::ComputeGraph &compute_graph) { | ge::ComputeGraph &compute_graph) { | ||||
| if (file == nullptr) { | if (file == nullptr) { | ||||
| @@ -1419,7 +1539,7 @@ GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std::string &prefix, | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| op_desc->SetName(prefix + n->GetName()); | |||||
| op_desc->SetName(n->GetName() + prefix); | |||||
| NodePtr node = new_graph->AddNode(op_desc); | NodePtr node = new_graph->AddNode(op_desc); | ||||
| GE_CHK_BOOL_EXEC(node != nullptr, return nullptr, "Add node[%s] to graph failed", op_desc->GetName().c_str()); | GE_CHK_BOOL_EXEC(node != nullptr, return nullptr, "Add node[%s] to graph failed", op_desc->GetName().c_str()); | ||||
| all_new_nodes[node->GetName()] = node; | all_new_nodes[node->GetName()] = node; | ||||
| @@ -1445,6 +1565,17 @@ GraphUtils::CloneGraph(const ComputeGraphPtr &graph, const std::string &prefix, | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| } | } | ||||
| // copy info of output nodes from old graph to new graph. | |||||
| std::vector<std::pair<NodePtr, int32_t>> out_nodes_info = graph->GetGraphOutNodesInfo(); | |||||
| std::vector<std::pair<NodePtr, int32_t>> new_out_nodes_info; | |||||
| for (const auto &info : out_nodes_info) { | |||||
| auto it = all_new_nodes.find(info.first->GetName()); | |||||
| if (it != all_new_nodes.end()) { | |||||
| new_out_nodes_info.emplace_back(it->second, info.second); | |||||
| } | |||||
| } | |||||
| new_graph->SetGraphOutNodesInfo(new_out_nodes_info); | |||||
| return new_graph; | return new_graph; | ||||
| } | } | ||||
| @@ -1501,7 +1632,7 @@ graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &pref | |||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| auto it = all_nodes.find(prefix + node->GetName()); | |||||
| auto it = all_nodes.find(node->GetName() + prefix); | |||||
| if (it == all_nodes.end()) { | if (it == all_nodes.end()) { | ||||
| GELOGE(GRAPH_FAILED, "node[%s] not found", node->GetName().c_str()); | GELOGE(GRAPH_FAILED, "node[%s] not found", node->GetName().c_str()); | ||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| @@ -1517,7 +1648,7 @@ graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &pref | |||||
| } | } | ||||
| GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); | GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); | ||||
| it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName()); | |||||
| it = all_nodes.find(out_anchor->GetOwnerNode()->GetName() + prefix); | |||||
| if (it == all_nodes.end()) { | if (it == all_nodes.end()) { | ||||
| GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); | GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); | ||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| @@ -1535,7 +1666,7 @@ graphStatus GraphUtils::RelinkGraphEdges(const NodePtr &node, const string &pref | |||||
| GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "Peer out anchor is null: %s", node->GetName().c_str()); | GE_CHK_BOOL_EXEC(out_anchor != nullptr, continue, "Peer out anchor is null: %s", node->GetName().c_str()); | ||||
| GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); | GE_CHK_BOOL_EXEC(out_anchor->GetOwnerNode() != nullptr, return GRAPH_FAILED, "Peer out node is null"); | ||||
| it = all_nodes.find(prefix + out_anchor->GetOwnerNode()->GetName()); | |||||
| it = all_nodes.find(out_anchor->GetOwnerNode()->GetName() + prefix); | |||||
| if (it == all_nodes.end()) { | if (it == all_nodes.end()) { | ||||
| GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); | GELOGE(GRAPH_FAILED, "node[%s] not found", out_anchor->GetOwnerNode()->GetName().c_str()); | ||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| @@ -1736,7 +1867,7 @@ graphStatus GraphUtils::HandleMergeInput(const NodePtr &node, | |||||
| if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next_name) && !next_name.empty()) { | if (AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_NEXT_ITERATION, next_name) && !next_name.empty()) { | ||||
| ComputeGraphPtr graph = node->GetOwnerComputeGraph(); | ComputeGraphPtr graph = node->GetOwnerComputeGraph(); | ||||
| GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
| ge::NodePtr next_node = graph->FindNode(next_name); | |||||
| ge::NodePtr next_node = FindNodeFromAllNodes(graph, next_name); | |||||
| GE_CHECK_NOTNULL(next_node); | GE_CHECK_NOTNULL(next_node); | ||||
| // NextIteration has and only has one output | // NextIteration has and only has one output | ||||
| peer_out_anchor = next_node->GetOutDataAnchor(0); | peer_out_anchor = next_node->GetOutDataAnchor(0); | ||||
| @@ -2332,15 +2463,12 @@ CompleteGraphBuilder &CompleteGraphBuilder::SetOutputMapping(const std::map<uint | |||||
| /// | /// | ||||
| ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { | ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string &error_msg) { | ||||
| owner_graph_ = shared_ptr<ComputeGraph>(new (std::nothrow) ComputeGraph(name_)); | owner_graph_ = shared_ptr<ComputeGraph>(new (std::nothrow) ComputeGraph(name_)); | ||||
| if ((owner_graph_ == nullptr) || (parent_node_ == nullptr)) { | |||||
| if (owner_graph_ == nullptr) { | |||||
| error_code = GRAPH_FAILED; | error_code = GRAPH_FAILED; | ||||
| error_msg = "graph / parent_node is NULL."; | |||||
| error_msg = "graph is NULL."; | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| owner_graph_->SetParentNode(parent_node_); | |||||
| owner_graph_->SetParentGraph(parent_node_->GetOwnerComputeGraph()); | |||||
| BuildNodes(error_code, error_msg); | BuildNodes(error_code, error_msg); | ||||
| if (error_code != GRAPH_SUCCESS) { | if (error_code != GRAPH_SUCCESS) { | ||||
| return nullptr; | return nullptr; | ||||
| @@ -2361,37 +2489,27 @@ ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| AddRetValNodes(error_code, error_msg); | |||||
| if (error_code != GRAPH_SUCCESS) { | |||||
| return nullptr; | |||||
| if (retval_flag_) { | |||||
| AddRetValNodes(error_code, error_msg); | |||||
| if (error_code != GRAPH_SUCCESS) { | |||||
| return nullptr; | |||||
| } | |||||
| BuildGraphTargets(error_code, error_msg); | |||||
| if (error_code != GRAPH_SUCCESS) { | |||||
| return nullptr; | |||||
| } | |||||
| } else { | |||||
| AddNetOutputNode(error_code, error_msg); | |||||
| if (error_code != GRAPH_SUCCESS) { | |||||
| return nullptr; | |||||
| } | |||||
| } | } | ||||
| BuildGraphTargets(error_code, error_msg); | |||||
| PostProcess(error_code, error_msg); | |||||
| if (error_code != GRAPH_SUCCESS) { | if (error_code != GRAPH_SUCCESS) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // ATTR_NAME_SESSION_GRAPH_ID | |||||
| std::string graph_id; | |||||
| if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { | |||||
| error_code = GRAPH_FAILED; | |||||
| error_msg = "Get attr session_graph_id failed."; | |||||
| return nullptr; | |||||
| } | |||||
| if (!AttrUtils::SetStr(owner_graph_, ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { | |||||
| error_code = GRAPH_FAILED; | |||||
| error_msg = "Set attr session_graph_id failed."; | |||||
| return nullptr; | |||||
| } | |||||
| // refresh node name | |||||
| for (const NodePtr &node : owner_graph_->GetDirectNode()) { | |||||
| if ((node->GetOpDesc() == nullptr) || (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2)) { | |||||
| continue; | |||||
| } | |||||
| node->GetOpDesc()->SetName(owner_graph_->GetName() + "/" + node->GetName()); | |||||
| } | |||||
| return owner_graph_; | return owner_graph_; | ||||
| } | } | ||||
| @@ -2586,7 +2704,144 @@ void CompleteGraphBuilder::BuildGraphTargets(graphStatus &error_code, std::strin | |||||
| target_nodes.emplace_back(target_iter->second); | target_nodes.emplace_back(target_iter->second); | ||||
| } | } | ||||
| owner_graph_->SetGraphTargetNodesInfo(target_nodes); | owner_graph_->SetGraphTargetNodesInfo(target_nodes); | ||||
| return; | |||||
| } | |||||
| /// | |||||
| /// @brief Add NetOutput node | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void CompleteGraphBuilder::AddNetOutputNode(graphStatus &error_code, std::string &error_msg) { | |||||
| std::string log_msg = "AddNetOutputNode name:" + std::string(NODE_NAME_NET_OUTPUT) + ", type:" + NETOUTPUT; | |||||
| OpDescPtr net_output_desc = shared_ptr<OpDesc>(new (std::nothrow) OpDesc(NODE_NAME_NET_OUTPUT, NETOUTPUT)); | |||||
| if (net_output_desc == nullptr) { | |||||
| error_code = GRAPH_FAILED; | |||||
| error_msg = log_msg + " failed: op_desc is NULL."; | |||||
| return; | |||||
| } | |||||
| size_t output_num = graph_outputs_.size(); | |||||
| std::vector<OutDataAnchorPtr> peer_out_anchors(output_num); | |||||
| for (size_t i = 0; i < output_num; i++) { | |||||
| int32_t index = graph_outputs_[i].second; | |||||
| auto out_iter = node_names_.find(graph_outputs_[i].first); | |||||
| if (out_iter == node_names_.end()) { | |||||
| error_code = GRAPH_FAILED; | |||||
| error_msg = "AddNetOutputNode failed: node " + graph_outputs_[i].first + " not exist in graph."; | |||||
| return; | |||||
| } | |||||
| NodePtr node = out_iter->second; | |||||
| if ((node == nullptr) || (node->GetOpDesc() == nullptr)) { | |||||
| error_code = GRAPH_FAILED; | |||||
| error_msg = "AddNetOutputNode failed: node is NULL."; | |||||
| return; | |||||
| } | |||||
| ge::GeTensorDesc tensor = node->GetOpDesc()->GetOutputDesc(index); | |||||
| uint32_t update_index = i; | |||||
| auto iter = output_mapping_.find(i); | |||||
| if (iter != output_mapping_.end()) { | |||||
| update_index = iter->second; | |||||
| } | |||||
| if (!ge::AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, update_index)) { | |||||
| error_code = GRAPH_FAILED; | |||||
| error_msg = "AddNetOutputNode failed: set attr PARENT_NODE_INDEX failed."; | |||||
| return; | |||||
| } | |||||
| if (net_output_desc->AddInputDesc(tensor) != GRAPH_SUCCESS) { | |||||
| error_code = GRAPH_FAILED; | |||||
| error_msg = "AddNetOutputNode failed: add input_desc ailed."; | |||||
| return; | |||||
| } | |||||
| peer_out_anchors[i] = node->GetOutDataAnchor(index); | |||||
| } | |||||
| BuildNetOutputNodeWithLink(net_output_desc, peer_out_anchors, error_code, error_msg); | |||||
| if (error_code != GRAPH_SUCCESS) { | |||||
| return; | |||||
| } | |||||
| GELOGD("%s succ.", log_msg.c_str()); | |||||
| } | |||||
| /// | |||||
| /// @brief Build NetOutput nodes with data & ctrl edges | |||||
| /// @param [in] net_output_desc | |||||
| /// @param [in] peer_out_anchors | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void CompleteGraphBuilder::BuildNetOutputNodeWithLink(const OpDescPtr &net_output_desc, | |||||
| const std::vector<OutDataAnchorPtr> &peer_out_anchors, | |||||
| graphStatus &error_code, std::string &error_msg) { | |||||
| std::string log_msg = "AddNetOutputNode name:" + std::string(NODE_NAME_NET_OUTPUT) + ", type:" + NETOUTPUT; | |||||
| NodePtr net_output = owner_graph_->AddNode(net_output_desc); | |||||
| if (net_output == nullptr) { | |||||
| error_code = GRAPH_FAILED; | |||||
| error_msg = log_msg + " failed: add NetOutput node failed."; | |||||
| return; | |||||
| } | |||||
| size_t output_num = graph_outputs_.size(); | |||||
| for (size_t i = 0; i < output_num; i++) { | |||||
| if (GraphUtils::AddEdge(peer_out_anchors[i], net_output->GetInDataAnchor(i)) != GRAPH_SUCCESS) { | |||||
| error_code = GRAPH_FAILED; | |||||
| error_msg = "AddNetOutputNode failed: add data-edge " + peer_out_anchors[i]->GetOwnerNode()->GetName() + ":" + | |||||
| std::to_string(peer_out_anchors[i]->GetIdx()) + "->" + NODE_NAME_NET_OUTPUT + ":" + | |||||
| std::to_string(i) + " failed."; | |||||
| return; | |||||
| } | |||||
| } | |||||
| for (const std::string &target_name : graph_targets_) { | |||||
| auto target_iter = node_names_.find(target_name); | |||||
| if ((target_iter == node_names_.end()) || (target_iter->second == nullptr)) { | |||||
| error_code = GRAPH_FAILED; | |||||
| error_msg = "BuildGraphTargets failed: target_node " + target_name + " not exist in graph."; | |||||
| return; | |||||
| } | |||||
| const auto &target_node = target_iter->second; | |||||
| if (GraphUtils::AddEdge(target_node->GetOutControlAnchor(), net_output->GetInControlAnchor()) != GRAPH_SUCCESS) { | |||||
| error_code = GRAPH_FAILED; | |||||
| error_msg = | |||||
| "AddNetOutputNode failed: add ctrl-edge " + target_node->GetName() + "->" + NODE_NAME_NET_OUTPUT + " failed."; | |||||
| return; | |||||
| } | |||||
| } | |||||
| } | |||||
| /// | |||||
| /// @brief process after build | |||||
| /// @param [out] error_code | |||||
| /// @param [out] error_msg | |||||
| /// @return void | |||||
| /// | |||||
| void CompleteGraphBuilder::PostProcess(graphStatus &error_code, std::string &error_msg) { | |||||
| if (parent_node_ != nullptr) { | |||||
| owner_graph_->SetParentNode(parent_node_); | |||||
| owner_graph_->SetParentGraph(parent_node_->GetOwnerComputeGraph()); | |||||
| // ATTR_NAME_SESSION_GRAPH_ID | |||||
| std::string graph_id; | |||||
| if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { | |||||
| error_code = GRAPH_FAILED; | |||||
| error_msg = "Get attr session_graph_id failed."; | |||||
| return; | |||||
| } | |||||
| if (!AttrUtils::SetStr(owner_graph_, ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { | |||||
| error_code = GRAPH_FAILED; | |||||
| error_msg = "Set attr session_graph_id failed."; | |||||
| return; | |||||
| } | |||||
| } | |||||
| // refresh node name | |||||
| for (const NodePtr &node : owner_graph_->GetDirectNode()) { | |||||
| if ((node->GetOpDesc() == nullptr) || (node->GetType() == VARIABLE) || (node->GetType() == VARIABLEV2)) { | |||||
| continue; | |||||
| } | |||||
| node->GetOpDesc()->SetName(owner_graph_->GetName() + "/" + node->GetName()); | |||||
| } | |||||
| } | } | ||||
| /// | /// | ||||
| @@ -391,7 +391,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendInpu | |||||
| GELOGE(GRAPH_FAILED, "Add input desc failed"); | GELOGE(GRAPH_FAILED, "Add input desc failed"); | ||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| } | |||||
| for (size_t i = node->in_data_anchors_.size(); i < num; ++i) { | |||||
| auto anchor = ComGraphMakeShared<InDataAnchor>(node, i); | auto anchor = ComGraphMakeShared<InDataAnchor>(node, i); | ||||
| if (anchor == nullptr) { | if (anchor == nullptr) { | ||||
| GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed."); | GELOGE(OUT_OF_MEMORY, "Current in data anchor is null, make shared_ptr failed."); | ||||
| @@ -444,7 +446,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::AppendOutp | |||||
| GELOGE(GRAPH_FAILED, "Add output desc failed"); | GELOGE(GRAPH_FAILED, "Add output desc failed"); | ||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| } | |||||
| for (size_t i = node->out_data_anchors_.size(); i < num; ++i) { | |||||
| auto anchor = ComGraphMakeShared<OutDataAnchor>(node, i); | auto anchor = ComGraphMakeShared<OutDataAnchor>(node, i); | ||||
| if (anchor == nullptr) { | if (anchor == nullptr) { | ||||
| GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed."); | GELOGE(OUT_OF_MEMORY, "Current out data anchor is null, make shared_ptr failed."); | ||||
| @@ -644,6 +648,20 @@ std::string NodeUtils::GetNodeType(const Node &node) { | |||||
| std::string NodeUtils::GetNodeType(const NodePtr &node) { return node == nullptr ? "" : GetNodeType(*node); } | std::string NodeUtils::GetNodeType(const NodePtr &node) { return node == nullptr ? "" : GetNodeType(*node); } | ||||
| std::vector<ComputeGraphPtr> NodeUtils::GetAllSubgraphs(const Node &node) { | |||||
| auto op_desc = node.GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Failed to get op desc from node %s ", node.GetName().c_str()); | |||||
| return {}; | |||||
| } | |||||
| auto root_graph = GraphUtils::FindRootGraph(node.GetOwnerComputeGraph()); | |||||
| if (root_graph == nullptr) { | |||||
| GELOGE(GRAPH_FAILED, "Failed to find root graph from node %s ", node.GetName().c_str()); | |||||
| return {}; | |||||
| } | |||||
| return root_graph->GetAllSubgraphs(); | |||||
| } | |||||
| ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { | ComputeGraphPtr NodeUtils::GetSubgraph(const Node &node, uint32_t index) { | ||||
| auto op_desc = node.GetOpDesc(); | auto op_desc = node.GetOpDesc(); | ||||
| if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
| @@ -1002,4 +1020,23 @@ vector<pair<InDataAnchorPtr, NodePtr>> NodeUtils::GetOutDataNodesWithAnchorByInd | |||||
| } | } | ||||
| ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) { return oprt.GetNode(); } | ConstNodePtr NodeUtils::GetNodeFromOperator(const Operator &oprt) { return oprt.GetNode(); } | ||||
| std::string NodeUtils::GetInConstNodeTypeCrossSubgraph(const NodePtr &node) { | |||||
| NodePtr input_node = node; | |||||
| while (input_node != nullptr) { | |||||
| if (input_node->GetType() != DATA) { | |||||
| return input_node->GetType(); | |||||
| } | |||||
| auto owner_graph = input_node->GetOwnerComputeGraph(); | |||||
| auto parent_node = owner_graph->GetParentNode(); | |||||
| if ((parent_node == nullptr) || (kWhileOpTypes.count(parent_node->GetType()) > 0)) { | |||||
| return node->GetType(); // not in subgraph or while subgraph. | |||||
| } | |||||
| input_node = GetParentInput(input_node); | |||||
| } | |||||
| return ""; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -28,7 +28,6 @@ | |||||
| using std::vector; | using std::vector; | ||||
| /*lint -e512 -e737 -e752*/ | |||||
| namespace ge { | namespace ge { | ||||
| const char OP_DESC_QUANT_PARAMS[] = "quantize_factor"; | const char OP_DESC_QUANT_PARAMS[] = "quantize_factor"; | ||||
| static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1; | static const int CONST_OP_NORMAL_WEIGHT_SIZE = 1; | ||||
| @@ -133,11 +132,11 @@ graphStatus OpDescUtils::GetQuantizeFactorParams(const OpDesc &op_desc, Quantize | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | ||||
| OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) { | OpDescUtils::SetQuantizeFactorParams(const OpDescPtr &op_desc, const QuantizeFactorParams &quant) { | ||||
| GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); | GE_CHK_BOOL_EXEC_INFO(op_desc != nullptr, return GRAPH_FAILED, "op_desc is nullptr"); | ||||
| return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732 | |||||
| return op_desc->SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); | |||||
| } | } | ||||
| graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) { | graphStatus OpDescUtils::SetQuantizeFactorParams(OpDesc &op_desc, const QuantizeFactorParams &quant) { | ||||
| return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); // lint !e732 | |||||
| return op_desc.SetAttr(OP_DESC_QUANT_PARAMS, GeAttrValue::CreateFrom<QuantizeFactorParams>(quant)); | |||||
| } | } | ||||
| GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) { | GeTensorPtr OpDescUtils::MutableWeights(OpDesc &op_desc) { | ||||
| @@ -255,7 +254,7 @@ size_t OpDescUtils::GetNonConstInputsSize(const ge::Node &node) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| } | } | ||||
| return input_num; // lint !e712 | |||||
| return input_num; | |||||
| } else { | } else { | ||||
| GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
| node.GetInDataNodes().size() < GetConstInputs(node).size(), | node.GetInDataNodes().size() < GetConstInputs(node).size(), | ||||
| @@ -360,7 +359,7 @@ bool OpDescUtils::IsNonConstInput(const ge::Node &node, const size_t index) { | |||||
| bool ret = false; | bool ret = false; | ||||
| if (index < node.GetAllInDataAnchors().size()) { | if (index < node.GetAllInDataAnchors().size()) { | ||||
| if (NodeUtils::IsAnchorStatusSet(node)) { | if (NodeUtils::IsAnchorStatusSet(node)) { | ||||
| ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA); // lint !e712 | |||||
| ret = (ge::AnchorUtils::GetStatus(node.GetInDataAnchor(static_cast<int>(index))) == ANCHOR_DATA); | |||||
| } else { | } else { | ||||
| for (const auto &anchor : node.GetAllInDataAnchors()) { | for (const auto &anchor : node.GetAllInDataAnchors()) { | ||||
| if (anchor->GetIdx() != static_cast<int>(index)) { | if (anchor->GetIdx() != static_cast<int>(index)) { | ||||
| @@ -822,4 +821,3 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus OpDescUtils::SetSubgr | |||||
| return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name); | return op_desc->SetSubgraphInstanceName(iter->second, subgraph_instance_name); | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| /*lint +e512 +e737 +e752*/ | |||||
| @@ -17,8 +17,10 @@ | |||||
| #include "graph/tuning_utils.h" | #include "graph/tuning_utils.h" | ||||
| #include "../debug/ge_util.h" | #include "../debug/ge_util.h" | ||||
| #include "../debug/ge_op_types.h" | #include "../debug/ge_op_types.h" | ||||
| #include "framework/common/scope_guard.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| const std::string peer_node_name_attr = "_peerNodeName"; | const std::string peer_node_name_attr = "_peerNodeName"; | ||||
| const std::string parent_node_name_attr = "_parentNodeName"; | const std::string parent_node_name_attr = "_parentNodeName"; | ||||
| const std::string alias_name_attr = "_aliasName"; | const std::string alias_name_attr = "_aliasName"; | ||||
| @@ -28,6 +30,7 @@ const std::string tuning_subgraph_prefix = "/aicore_subgraph_"; | |||||
| const std::string non_tuning_subgraph_prefix = "/subgraph_"; | const std::string non_tuning_subgraph_prefix = "/subgraph_"; | ||||
| const std::set<std::string> kPartitionOpTypes = {PLACEHOLDER, END}; | const std::set<std::string> kPartitionOpTypes = {PLACEHOLDER, END}; | ||||
| const std::set<std::string> kExeTypes = {DATA, NETOUTPUT}; | const std::set<std::string> kExeTypes = {DATA, NETOUTPUT}; | ||||
| } // namespace | |||||
| NodeNametoNodeNameMap TuningUtils::data_2_netoutput_; | NodeNametoNodeNameMap TuningUtils::data_2_netoutput_; | ||||
| NodetoNodeNameMap TuningUtils::data_node_2_netoutput_; | NodetoNodeNameMap TuningUtils::data_node_2_netoutput_; | ||||
| NodetoNodeMap TuningUtils::data_node_2_netoutput_node_; | NodetoNodeMap TuningUtils::data_node_2_netoutput_node_; | ||||
| @@ -116,6 +119,10 @@ graphStatus TuningUtils::ConvertGraphToFile(std::vector<ComputeGraphPtr> tuning_ | |||||
| // +---------------+ | // +---------------+ | ||||
| graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info) { | graphStatus TuningUtils::MakeExeGraph(ComputeGraphPtr &exe_graph, const HelpInfo &help_info) { | ||||
| GE_CHECK_NOTNULL(exe_graph); | GE_CHECK_NOTNULL(exe_graph); | ||||
| // clear graph id | |||||
| GELOGI("TUU:clear [%s] session_graph_id %s", exe_graph->GetName().c_str(), | |||||
| (AttrUtils::SetStr(*exe_graph, ATTR_NAME_SESSION_GRAPH_ID, "") ? "success" : "not success")); | |||||
| // if not make exe, just dump and return | // if not make exe, just dump and return | ||||
| if (!help_info.exe_flag) { | if (!help_info.exe_flag) { | ||||
| DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path); | DumpGraphToPath(exe_graph, help_info.index, help_info.is_tuning_graph, help_info.path); | ||||
| @@ -346,7 +353,9 @@ graphStatus TuningUtils::LinkEnd2NetOutput(NodePtr &end_node, NodePtr &out_node) | |||||
| AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr) | AnchorPtr end_in_anchor = (end_node->GetInDataAnchor(0)->GetFirstPeerAnchor() == nullptr) | ||||
| ? Anchor::DynamicAnchorCast<Anchor>(end_node->GetInControlAnchor()) | ? Anchor::DynamicAnchorCast<Anchor>(end_node->GetInControlAnchor()) | ||||
| : Anchor::DynamicAnchorCast<Anchor>(end_node->GetInDataAnchor(0)); | : Anchor::DynamicAnchorCast<Anchor>(end_node->GetInDataAnchor(0)); | ||||
| GE_CHECK_NOTNULL(end_in_anchor); | |||||
| auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1 | auto src_anchor = end_in_anchor->GetFirstPeerAnchor(); // src_anchor should be only 1 | ||||
| GE_CHECK_NOTNULL(src_anchor); | |||||
| if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) { | if (GraphUtils::RemoveEdge(src_anchor, end_in_anchor) != GRAPH_SUCCESS) { | ||||
| GELOGE(FAILED, "TUU:remove end input edge from from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", | GELOGE(FAILED, "TUU:remove end input edge from from %s(%d) to %s(%d) failed. node_name:%s, graph_name:%s", | ||||
| GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), | GetNodeNameByAnchor(src_anchor.get()).c_str(), src_anchor->GetIdx(), | ||||
| @@ -447,6 +456,14 @@ graphStatus TuningUtils::HandleEnd(NodePtr &node) { | |||||
| // part 2 | // part 2 | ||||
| graphStatus TuningUtils::ConvertFileToGraph(const map<int64_t, string> &options, ge::Graph &graph) { | graphStatus TuningUtils::ConvertFileToGraph(const map<int64_t, string> &options, ge::Graph &graph) { | ||||
| std::function<void()> callback = [&]() { | |||||
| data_2_netoutput_.clear(); | |||||
| data_node_2_netoutput_.clear(); | |||||
| data_node_2_netoutput_node_.clear(); | |||||
| netoutput_nodes_.clear(); | |||||
| merged_graph_nodes_.clear(); | |||||
| }; | |||||
| GE_MAKE_GUARD(release, callback); | |||||
| // 1. get all subgraph object | // 1. get all subgraph object | ||||
| std::vector<ComputeGraphPtr> graphs; | std::vector<ComputeGraphPtr> graphs; | ||||
| // options format like {index:"subgraph_path"} | // options format like {index:"subgraph_path"} | ||||
| @@ -666,7 +683,9 @@ graphStatus TuningUtils::GetInAndOutAnchorPair(NodePtr &data_node, NodePtr &out_ | |||||
| GE_CHECK_NOTNULL(src_anchor); | GE_CHECK_NOTNULL(src_anchor); | ||||
| auto src_node = src_anchor->GetOwnerNode(); | auto src_node = src_anchor->GetOwnerNode(); | ||||
| GE_CHECK_NOTNULL(src_node); | GE_CHECK_NOTNULL(src_node); | ||||
| if (src_node->GetName() == netoutput_input_name && src_anchor->GetIdx() == parent_node_anchor_index) { | |||||
| std::string src_node_name = src_node->GetName(); | |||||
| if (src_node_name.find(netoutput_input_name) != src_node_name.npos && | |||||
| src_anchor->GetIdx() == parent_node_anchor_index) { | |||||
| dest_in_anchor = in_anchor; | dest_in_anchor = in_anchor; | ||||
| src_out_anchor = src_anchor; | src_out_anchor = src_anchor; | ||||
| GELOGD("TUU:get out node:%s 's in anchor(%d) src_node:%s 's out anchor(%d) related with data node:%s", | GELOGD("TUU:get out node:%s 's in anchor(%d) src_node:%s 's out anchor(%d) related with data node:%s", | ||||
| @@ -39,7 +39,7 @@ ge_protobuf_generate(ge PROTO_HEADER_SRCS PROTO_HEADER_HDRS ${PROTO_HEADER_LIST} | |||||
| # include directories | # include directories | ||||
| include_directories(${CMAKE_CURRENT_LIST_DIR}) | include_directories(${CMAKE_CURRENT_LIST_DIR}) | ||||
| include_directories(${GE_SOURCE_DIR}) | include_directories(${GE_SOURCE_DIR}) | ||||
| include_directories(${GE_SOURCE_DIR}/src) | |||||
| include_directories(${GE_SOURCE_DIR}/src/ge) | |||||
| include_directories(${GE_SOURCE_DIR}/src/ge/analyzer) | include_directories(${GE_SOURCE_DIR}/src/ge/analyzer) | ||||
| include_directories(${GE_SOURCE_DIR}/inc) | include_directories(${GE_SOURCE_DIR}/inc) | ||||
| include_directories(${GE_SOURCE_DIR}/inc/common/util) | include_directories(${GE_SOURCE_DIR}/inc/common/util) | ||||
| @@ -109,6 +109,8 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/manager/graph_mem_allocator.cc" | "graph/manager/graph_mem_allocator.cc" | ||||
| "graph/manager/graph_caching_allocator.cc" | "graph/manager/graph_caching_allocator.cc" | ||||
| "graph/manager/graph_var_manager.cc" | "graph/manager/graph_var_manager.cc" | ||||
| "graph/manager/host_mem_manager.cc" | |||||
| "graph/manager/memory_api.cc" | |||||
| "graph/manager/model_manager/event_manager.cc" | "graph/manager/model_manager/event_manager.cc" | ||||
| "graph/manager/rdma_pool_allocator.cc" | "graph/manager/rdma_pool_allocator.cc" | ||||
| "graph/manager/trans_var_data_utils.cc" | "graph/manager/trans_var_data_utils.cc" | ||||
| @@ -127,6 +129,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/partition/dynamic_shape_partition.cc" | "graph/partition/dynamic_shape_partition.cc" | ||||
| "graph/partition/engine_place.cc" | "graph/partition/engine_place.cc" | ||||
| "graph/partition/graph_partition.cc" | "graph/partition/graph_partition.cc" | ||||
| "graph/partition/stage_partition.cc" | |||||
| "graph/passes/*.cc" | "graph/passes/*.cc" | ||||
| "graph/preprocess/graph_preprocess.cc" | "graph/preprocess/graph_preprocess.cc" | ||||
| "graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
| @@ -200,6 +203,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "model/ge_root_model.cc" | "model/ge_root_model.cc" | ||||
| "omm/csa_interact.cc" | "omm/csa_interact.cc" | ||||
| "opskernel_manager/ops_kernel_manager.cc" | "opskernel_manager/ops_kernel_manager.cc" | ||||
| "opskernel_manager/ops_kernel_builder_manager.cc" | |||||
| "session/inner_session.cc" | "session/inner_session.cc" | ||||
| "session/session_manager.cc" | "session/session_manager.cc" | ||||
| "single_op/*.cc" | "single_op/*.cc" | ||||
| @@ -283,6 +287,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/manager/graph_manager.cc" | "graph/manager/graph_manager.cc" | ||||
| "graph/manager/graph_manager_utils.cc" | "graph/manager/graph_manager_utils.cc" | ||||
| "graph/manager/graph_mem_allocator.cc" | "graph/manager/graph_mem_allocator.cc" | ||||
| "graph/manager/host_mem_manager.cc" | |||||
| "graph/manager/trans_var_data_utils.cc" | "graph/manager/trans_var_data_utils.cc" | ||||
| "graph/manager/graph_var_manager.cc" | "graph/manager/graph_var_manager.cc" | ||||
| "graph/manager/model_manager/event_manager.cc" | "graph/manager/model_manager/event_manager.cc" | ||||
| @@ -296,6 +301,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/partition/dynamic_shape_partition.cc" | "graph/partition/dynamic_shape_partition.cc" | ||||
| "graph/partition/engine_place.cc" | "graph/partition/engine_place.cc" | ||||
| "graph/partition/graph_partition.cc" | "graph/partition/graph_partition.cc" | ||||
| "graph/partition/stage_partition.cc" | |||||
| "graph/passes/*.cc" | "graph/passes/*.cc" | ||||
| "graph/preprocess/graph_preprocess.cc" | "graph/preprocess/graph_preprocess.cc" | ||||
| "graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
| @@ -349,6 +355,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "model/ge_root_model.cc" | "model/ge_root_model.cc" | ||||
| "omm/csa_interact.cc" | "omm/csa_interact.cc" | ||||
| "opskernel_manager/ops_kernel_manager.cc" | "opskernel_manager/ops_kernel_manager.cc" | ||||
| "opskernel_manager/ops_kernel_builder_manager.cc" | |||||
| "session/inner_session.cc" | "session/inner_session.cc" | ||||
| "session/session_manager.cc" | "session/session_manager.cc" | ||||
| "single_op/*.cc" | "single_op/*.cc" | ||||
| @@ -75,9 +75,8 @@ Status Analyzer::BuildJsonObject(uint64_t session_id, uint64_t graph_id) { | |||||
| std::lock_guard<std::recursive_mutex> lg(mutex_); | std::lock_guard<std::recursive_mutex> lg(mutex_); | ||||
| auto iter = graph_infos_.find(session_id); | auto iter = graph_infos_.find(session_id); | ||||
| if (iter == graph_infos_.end()) { | if (iter == graph_infos_.end()) { | ||||
| auto p = new (std::nothrow) GraphInfo(); | |||||
| GE_CHECK_NOTNULL(p); | |||||
| std::shared_ptr<GraphInfo> graph_info(p); | |||||
| std::shared_ptr<GraphInfo> graph_info(new (std::nothrow) GraphInfo()); | |||||
| GE_CHECK_NOTNULL(graph_info); | |||||
| std::map<uint64_t, std::shared_ptr<GraphInfo>> graph_map; | std::map<uint64_t, std::shared_ptr<GraphInfo>> graph_map; | ||||
| graph_map[graph_id] = graph_info; | graph_map[graph_id] = graph_info; | ||||
| graph_info->session_id = session_id; | graph_info->session_id = session_id; | ||||
| @@ -86,9 +85,8 @@ Status Analyzer::BuildJsonObject(uint64_t session_id, uint64_t graph_id) { | |||||
| } else { | } else { | ||||
| auto iter1 = (iter->second).find(graph_id); | auto iter1 = (iter->second).find(graph_id); | ||||
| if (iter1 == (iter->second).end()) { | if (iter1 == (iter->second).end()) { | ||||
| auto p = new (std::nothrow) GraphInfo(); | |||||
| GE_CHECK_NOTNULL(p); | |||||
| std::shared_ptr<GraphInfo> graph_info(p); | |||||
| std::shared_ptr<GraphInfo> graph_info(new (std::nothrow) GraphInfo()); | |||||
| GE_CHECK_NOTNULL(graph_info); | |||||
| graph_info->session_id = session_id; | graph_info->session_id = session_id; | ||||
| graph_info->graph_id = graph_id; | graph_info->graph_id = graph_id; | ||||
| (iter->second).insert({graph_id, graph_info}); | (iter->second).insert({graph_id, graph_info}); | ||||
| @@ -100,7 +98,14 @@ Status Analyzer::BuildJsonObject(uint64_t session_id, uint64_t graph_id) { | |||||
| } | } | ||||
| ge::Status Analyzer::Initialize() { | ge::Status Analyzer::Initialize() { | ||||
| ClearHistoryFile(); | |||||
| // Initialize file | |||||
| string real_path = RealPath(kFilePath.c_str()); | |||||
| if (real_path.empty()) { | |||||
| GELOGE(FAILED, "File path is invalid."); | |||||
| return FAILED; | |||||
| } | |||||
| json_file_name_ = real_path + "/" + kAnalyzeFile; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -138,6 +143,7 @@ void Analyzer::DestroyGraphJsonObject(uint64_t session_id, uint64_t graph_id) { | |||||
| if (iter1 == (iter->second).end()) { | if (iter1 == (iter->second).end()) { | ||||
| GELOGW("Can not find the graph json object by session_id[%lu] and graph_id[%lu]. Do nothing.", session_id, | GELOGW("Can not find the graph json object by session_id[%lu] and graph_id[%lu]. Do nothing.", session_id, | ||||
| graph_id); | graph_id); | ||||
| return; | |||||
| } | } | ||||
| (iter->second).erase(iter1); | (iter->second).erase(iter1); | ||||
| } | } | ||||
| @@ -174,15 +180,8 @@ ge::Status Analyzer::CreateAnalyzerFile() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| GELOGD("start to create analyzer file!"); | GELOGD("start to create analyzer file!"); | ||||
| // Check whether the manifest exists, if not, create it. | |||||
| string real_path = RealPath(kFilePath.c_str()); | |||||
| if (real_path.empty()) { | |||||
| GELOGE(FAILED, "File path is invalid."); | |||||
| return FAILED; | |||||
| } | |||||
| std::lock_guard<std::mutex> lg(file_mutex_); | std::lock_guard<std::mutex> lg(file_mutex_); | ||||
| json_file_name_ = real_path + "/" + kAnalyzeFile; | |||||
| GELOGD("Created analyzer file:[%s]", json_file_name_.c_str()); | |||||
| int fd = open(json_file_name_.c_str(), O_WRONLY | O_CREAT | O_TRUNC, kFileAuthority); | int fd = open(json_file_name_.c_str(), O_WRONLY | O_CREAT | O_TRUNC, kFileAuthority); | ||||
| if (fd < 0) { | if (fd < 0) { | ||||
| GELOGE(INTERNAL_ERROR, "Fail to open the file: %s.", json_file_name_.c_str()); | GELOGE(INTERNAL_ERROR, "Fail to open the file: %s.", json_file_name_.c_str()); | ||||
| @@ -198,25 +197,27 @@ ge::Status Analyzer::CreateAnalyzerFile() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| ge::Status Analyzer::SaveAnalyzerDataToFile() { | |||||
| ge::Status Analyzer::SaveAnalyzerDataToFile(uint64_t session_id, uint64_t graph_id) { | |||||
| GELOGD("start to save analyze file!"); | GELOGD("start to save analyze file!"); | ||||
| auto graph_info = GetJsonObject(session_id, graph_id); | |||||
| GE_CHECK_NOTNULL(graph_info); | |||||
| if (graph_info->op_info.size() == 0) { | |||||
| GELOGD("session_id:%lu graph_id:%lu does not owner op info, break it!", session_id, graph_id); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::lock_guard<std::mutex> lg(file_mutex_); | std::lock_guard<std::mutex> lg(file_mutex_); | ||||
| json_file_.open(json_file_name_, std::ios::out); | |||||
| json_file_.open(json_file_name_, std::ios::app); | |||||
| if (!json_file_.is_open()) { | if (!json_file_.is_open()) { | ||||
| GELOGE(FAILED, "analyzer file does not exist[%s]", json_file_name_.c_str()); | GELOGE(FAILED, "analyzer file does not exist[%s]", json_file_name_.c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| std::lock_guard<std::recursive_mutex> lk(mutex_); | |||||
| for (auto &ele : graph_infos_) { | |||||
| for (auto &ele2 : ele.second) { | |||||
| json jsn; | |||||
| GraphInfoToJson(jsn, *(ele2.second)); | |||||
| json_file_ << jsn.dump(kJsonDumpLevel) << std::endl; | |||||
| } | |||||
| } | |||||
| json jsn; | |||||
| GraphInfoToJson(jsn, *graph_info); | |||||
| json_file_ << jsn.dump(kJsonDumpLevel) << std::endl; | |||||
| json_file_.close(); | json_file_.close(); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -237,13 +238,7 @@ ge::Status Analyzer::DoAnalyze(DataInfo &data_info) { | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| // create json file | // create json file | ||||
| status = CreateAnalyzerFile(); | |||||
| if (status != SUCCESS) { | |||||
| GELOGE(status, "create analyzer file failed!"); | |||||
| return status; | |||||
| } | |||||
| // save data to file | |||||
| return SaveAnalyzerDataToFile(); | |||||
| return CreateAnalyzerFile(); | |||||
| } | } | ||||
| ge::Status Analyzer::SaveOpInfo(ge::OpDescPtr desc, DataInfo &data_info, | ge::Status Analyzer::SaveOpInfo(ge::OpDescPtr desc, DataInfo &data_info, | ||||
| @@ -156,6 +156,14 @@ class Analyzer { | |||||
| */ | */ | ||||
| ge::Status DoAnalyze(analyzer::DataInfo &data_info); | ge::Status DoAnalyze(analyzer::DataInfo &data_info); | ||||
| /** | |||||
| * @ingroup ge | |||||
| * @brief: Buff analyzed data and output to json file | |||||
| * @param [in]: session id , graph id | |||||
| * @return: 0: SUCCESS other: FAILED | |||||
| */ | |||||
| ge::Status SaveAnalyzerDataToFile(uint64_t session_id, uint64_t graph_id); | |||||
| Analyzer(const Analyzer &) = delete; | Analyzer(const Analyzer &) = delete; | ||||
| Analyzer &operator=(const Analyzer &) = delete; | Analyzer &operator=(const Analyzer &) = delete; | ||||
| Analyzer(Analyzer &&) = delete; | Analyzer(Analyzer &&) = delete; | ||||
| @@ -166,7 +174,6 @@ class Analyzer { | |||||
| void OpInfoToJson(nlohmann::json &j, const analyzer::OpInfo &op_info); | void OpInfoToJson(nlohmann::json &j, const analyzer::OpInfo &op_info); | ||||
| void GraphInfoToJson(nlohmann::json &j, const analyzer::GraphInfo &graph_info); | void GraphInfoToJson(nlohmann::json &j, const analyzer::GraphInfo &graph_info); | ||||
| ge::Status SaveAnalyzerDataToFile(); | |||||
| ge::Status SaveOpInfo(ge::OpDescPtr desc, analyzer::DataInfo &data_info, | ge::Status SaveOpInfo(ge::OpDescPtr desc, analyzer::DataInfo &data_info, | ||||
| std::shared_ptr<analyzer::GraphInfo> graph_info); | std::shared_ptr<analyzer::GraphInfo> graph_info); | ||||
| @@ -324,10 +324,17 @@ Status aclgrphProfStop(aclgrphProfConfig *profiler_config) { | |||||
| return GE_PROF_NOT_INIT; | return GE_PROF_NOT_INIT; | ||||
| } | } | ||||
| Status ret = ProfStopProfiling(&profiler_config->config); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Stop profiling failed, prof result = %d", ret); | |||||
| return ret; | |||||
| for (uint32_t i = 0; i < profiler_config->config.devNums; i++) { | |||||
| uint64_t data_type_config; | |||||
| Status status = ProfGetDataTypeConfig(profiler_config->config.devIdList[i], data_type_config); | |||||
| if (status != SUCCESS) { | |||||
| GELOGE(status, "Prof get data type config failed, prof result = %d", status); | |||||
| return status; | |||||
| } | |||||
| if (data_type_config != profiler_config->config.dataTypeConfig) { | |||||
| GELOGE(FAILED, "data type config verify failed"); | |||||
| return FAILED; | |||||
| } | |||||
| } | } | ||||
| std::vector<string> prof_params; | std::vector<string> prof_params; | ||||
| @@ -344,12 +351,18 @@ Status aclgrphProfStop(aclgrphProfConfig *profiler_config) { | |||||
| command.module_index = profiler_config->config.dataTypeConfig; | command.module_index = profiler_config->config.dataTypeConfig; | ||||
| GELOGI("Profiling will stop, device nums:%s , deviceID:[%s], data type config: 0x%llx", prof_params[0].c_str(), | GELOGI("Profiling will stop, device nums:%s , deviceID:[%s], data type config: 0x%llx", prof_params[0].c_str(), | ||||
| prof_params[kDeviceListIndex].c_str(), command.module_index); | prof_params[kDeviceListIndex].c_str(), command.module_index); | ||||
| ret = graph_loader.CommandHandle(command); | |||||
| Status ret = graph_loader.CommandHandle(command); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Handle profiling command failed"); | GELOGE(ret, "Handle profiling command failed"); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| ret = ProfStopProfiling(&profiler_config->config); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Stop profiling failed, prof result = %d", ret); | |||||
| return ret; | |||||
| } | |||||
| GELOGI("Successfully execute GraphProfStopProfiling."); | GELOGI("Successfully execute GraphProfStopProfiling."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -70,10 +70,9 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
| libregister \ | libregister \ | ||||
| libge_compiler \ | libge_compiler \ | ||||
| libge_common \ | libge_common \ | ||||
| libmsprof \ | |||||
| stub/libascend_hal | |||||
| libmsprof | |||||
| LOCAL_STATIC_LIBRARIES := libmsprofiler | |||||
| LOCAL_LDFLAGS := -lrt -ldl | LOCAL_LDFLAGS := -lrt -ldl | ||||
| @@ -108,7 +107,6 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
| libge_common \ | libge_common \ | ||||
| libmsprof | libmsprof | ||||
| LOCAL_STATIC_LIBRARIES := libmsprofiler | |||||
| LOCAL_LDFLAGS := -lrt -ldl | LOCAL_LDFLAGS := -lrt -ldl | ||||
| LOCAL_CFLAGS += \ | LOCAL_CFLAGS += \ | ||||
| @@ -55,9 +55,26 @@ Status FileSaver::OpenFile(int32_t &fd, const std::string &file_path) { | |||||
| Status FileSaver::WriteData(const void *data, uint32_t size, int32_t fd) { | Status FileSaver::WriteData(const void *data, uint32_t size, int32_t fd) { | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size == 0 || data == nullptr, return PARAM_INVALID); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size == 0 || data == nullptr, return PARAM_INVALID); | ||||
| mmSsize_t write_count; | |||||
| uint32_t size_2g = ((uint32_t)0x1 << 31); | |||||
| uint32_t size_1g = ((uint32_t)0x1 << 30); | |||||
| // Write data | // Write data | ||||
| int32_t write_count = mmWrite(fd, const_cast<void *>(data), size); | |||||
| if (size > size_2g) { | |||||
| auto seek = reinterpret_cast<uint8_t *>(const_cast<void *>(data)); | |||||
| while (size > size_1g) { | |||||
| write_count = mmWrite(fd, reinterpret_cast<void *>(seek), size_1g); | |||||
| if (write_count == EN_INVALID_PARAM || write_count == EN_ERROR) { | |||||
| GELOGE(FAILED, "Write data failed. mmpa_errorno = %d, %s", write_count, strerror(errno)); | |||||
| return FAILED; | |||||
| } | |||||
| size -= size_1g; | |||||
| seek += size_1g; | |||||
| } | |||||
| write_count = mmWrite(fd, reinterpret_cast<void *>(seek), size); | |||||
| } else { | |||||
| write_count = mmWrite(fd, const_cast<void *>(data), size); | |||||
| } | |||||
| // -1: Failed to write to file; - 2: Illegal parameter | // -1: Failed to write to file; - 2: Illegal parameter | ||||
| if (write_count == EN_INVALID_PARAM || write_count == EN_ERROR) { | if (write_count == EN_INVALID_PARAM || write_count == EN_ERROR) { | ||||
| GELOGE(FAILED, "Write data failed. mmpa_errorno = %d, %s", write_count, strerror(errno)); | GELOGE(FAILED, "Write data failed. mmpa_errorno = %d, %s", write_count, strerror(errno)); | ||||
| @@ -117,6 +134,7 @@ Status FileSaver::SaveWithFileHeader(const std::string &file_path, const ModelFi | |||||
| WriteData(static_cast<const void *>(&model_partition_table), table_size, fd) != SUCCESS, ret = FAILED; break); | WriteData(static_cast<const void *>(&model_partition_table), table_size, fd) != SUCCESS, ret = FAILED; break); | ||||
| // Write partition data | // Write partition data | ||||
| for (const auto &partitionData : partition_datas) { | for (const auto &partitionData : partition_datas) { | ||||
| GELOGI("GC:size[%zu]", partitionData.size); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
| WriteData(static_cast<const void *>(partitionData.data), partitionData.size, fd) != SUCCESS, ret = FAILED; | WriteData(static_cast<const void *>(partitionData.data), partitionData.size, fd) != SUCCESS, ret = FAILED; | ||||
| break); | break); | ||||
| @@ -1,248 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // File: pb2json.h | |||||
| // Description: This imply file for protobuf message and json interconversion | |||||
| #include "common/convert/pb2json.h" | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include "securec.h" | |||||
| #include "framework/common/fmk_types.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| using std::set; | |||||
| using std::string; | |||||
| namespace ge { | |||||
| namespace { | |||||
| const int kSignificantDigits = 10; | |||||
| } | |||||
| // JSON parses non utf8 character throwing exceptions, so some fields need to be shielded through black fields | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void Pb2Json::Message2Json(const ProtobufMsg &message, | |||||
| const set<string> &black_fields, Json &json, | |||||
| bool enum2str) { | |||||
| auto descriptor = message.GetDescriptor(); | |||||
| auto reflection = message.GetReflection(); | |||||
| if (descriptor == nullptr || reflection == nullptr) { | |||||
| return; | |||||
| } | |||||
| auto count = descriptor->field_count(); | |||||
| for (auto i = 0; i < count; ++i) { | |||||
| const auto field = descriptor->field(i); | |||||
| if (field == nullptr) { | |||||
| return; | |||||
| } | |||||
| // Do not display weight data | |||||
| if (black_fields.find(field->name()) != black_fields.end()) { | |||||
| continue; | |||||
| } | |||||
| if (field->is_repeated()) { | |||||
| if (reflection->FieldSize(message, field) > 0) { | |||||
| RepeatedMessage2Json(message, field, reflection, black_fields, json[field->name()], enum2str); | |||||
| } | |||||
| continue; | |||||
| } | |||||
| if (!reflection->HasField(message, field)) { | |||||
| continue; | |||||
| } | |||||
| OneField2Json(message, field, reflection, black_fields, json, enum2str); | |||||
| } | |||||
| } | |||||
| void Pb2Json::OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||||
| const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | |||||
| bool enum2str) { | |||||
| switch (field->type()) { | |||||
| case ProtobufFieldDescriptor::TYPE_MESSAGE: { | |||||
| const ProtobufMsg &tmp_message = reflection->GetMessage(message, field); | |||||
| if (0 != tmp_message.ByteSize()) { | |||||
| Message2Json(tmp_message, black_fields, json[field->name()], enum2str); | |||||
| } | |||||
| break; | |||||
| } | |||||
| case ProtobufFieldDescriptor::TYPE_BOOL: | |||||
| json[field->name()] = reflection->GetBool(message, field); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_ENUM: { | |||||
| auto *enum_value_desc = reflection->GetEnum(message, field); | |||||
| Enum2Json(enum_value_desc, field, enum2str, json); | |||||
| break; | |||||
| } | |||||
| case ProtobufFieldDescriptor::TYPE_INT32: | |||||
| case ProtobufFieldDescriptor::TYPE_SINT32: | |||||
| case ProtobufFieldDescriptor::TYPE_SFIXED32: | |||||
| json[field->name()] = reflection->GetInt32(message, field); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_UINT32: | |||||
| case ProtobufFieldDescriptor::TYPE_FIXED32: | |||||
| json[field->name()] = reflection->GetUInt32(message, field); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_INT64: | |||||
| case ProtobufFieldDescriptor::TYPE_SINT64: | |||||
| case ProtobufFieldDescriptor::TYPE_SFIXED64: | |||||
| json[field->name()] = reflection->GetInt64(message, field); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_UINT64: | |||||
| case ProtobufFieldDescriptor::TYPE_FIXED64: | |||||
| json[field->name()] = reflection->GetUInt64(message, field); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_FLOAT: | |||||
| char str[kSignificantDigits]; | |||||
| if (sprintf_s(str, kSignificantDigits, "%g", reflection->GetFloat(message, field)) != -1) { | |||||
| json[field->name()] = str; | |||||
| } else { | |||||
| json[field->name()] = reflection->GetFloat(message, field); | |||||
| } | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_STRING: | |||||
| json[field->name()] = reflection->GetString(message, field); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_BYTES: { | |||||
| string field_name = field->name(); | |||||
| string type_bytes = reflection->GetString(message, field); | |||||
| json[field_name] = TypeBytes2String(field_name, type_bytes); | |||||
| break; | |||||
| } | |||||
| default: | |||||
| break; | |||||
| } | |||||
| } | |||||
| string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { | |||||
| if (field_name != "offset") { | |||||
| return type_bytes; | |||||
| } | |||||
| string result = ""; | |||||
| for (char temp_value : type_bytes) { | |||||
| uint8_t *value = 0; | |||||
| value = reinterpret_cast<uint8_t *>(&temp_value); | |||||
| char str[kSignificantDigits]; | |||||
| if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1) { | |||||
| GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str()); | |||||
| continue; | |||||
| } | |||||
| result += str; | |||||
| } | |||||
| return result; | |||||
| } | |||||
| void Pb2Json::RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||||
| const ProtobufReflection *reflection, const set<string> &black_fields, Json &json, | |||||
| bool enum2str) { | |||||
| if ((field == nullptr) || (reflection == nullptr)) { | |||||
| Message2Json(message, black_fields, json, enum2str); | |||||
| return; | |||||
| } | |||||
| for (auto i = 0; i < reflection->FieldSize(message, field); ++i) { | |||||
| Json tmp_json; | |||||
| switch (field->type()) { | |||||
| case ProtobufFieldDescriptor::TYPE_MESSAGE: { | |||||
| const ProtobufMsg &tmp_message = reflection->GetRepeatedMessage(message, field, i); | |||||
| if (0 != tmp_message.ByteSize()) { | |||||
| Message2Json(tmp_message, black_fields, tmp_json, enum2str); | |||||
| } | |||||
| } break; | |||||
| case ProtobufFieldDescriptor::TYPE_BOOL: | |||||
| tmp_json = reflection->GetRepeatedBool(message, field, i); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_ENUM: { | |||||
| auto *enum_value_desc = reflection->GetRepeatedEnum(message, field, i); | |||||
| RepeatedEnum2Json(enum_value_desc, enum2str, tmp_json); | |||||
| } break; | |||||
| case ProtobufFieldDescriptor::TYPE_INT32: | |||||
| case ProtobufFieldDescriptor::TYPE_SINT32: | |||||
| case ProtobufFieldDescriptor::TYPE_SFIXED32: | |||||
| tmp_json = reflection->GetRepeatedInt32(message, field, i); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_UINT32: | |||||
| case ProtobufFieldDescriptor::TYPE_FIXED32: | |||||
| tmp_json = reflection->GetRepeatedUInt32(message, field, i); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_INT64: | |||||
| case ProtobufFieldDescriptor::TYPE_SINT64: | |||||
| case ProtobufFieldDescriptor::TYPE_SFIXED64: | |||||
| tmp_json = reflection->GetRepeatedInt64(message, field, i); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_UINT64: | |||||
| case ProtobufFieldDescriptor::TYPE_FIXED64: | |||||
| tmp_json = reflection->GetRepeatedUInt64(message, field, i); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_FLOAT: | |||||
| tmp_json = reflection->GetRepeatedFloat(message, field, i); | |||||
| break; | |||||
| case ProtobufFieldDescriptor::TYPE_STRING: | |||||
| case ProtobufFieldDescriptor::TYPE_BYTES: | |||||
| tmp_json = reflection->GetRepeatedString(message, field, i); | |||||
| break; | |||||
| default: | |||||
| break; | |||||
| } | |||||
| json += tmp_json; | |||||
| } | |||||
| } | |||||
| void Pb2Json::Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, | |||||
| bool enum2str, Json &json) { | |||||
| if (enum_value_desc != nullptr) { | |||||
| if (field == nullptr) { | |||||
| return; | |||||
| } | |||||
| if (enum2str) { | |||||
| json[field->name()] = enum_value_desc->name(); | |||||
| } else { | |||||
| json[field->name()] = enum_value_desc->number(); | |||||
| } | |||||
| } | |||||
| } | |||||
| void Pb2Json::RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json) { | |||||
| if (enum_value_desc != nullptr) { | |||||
| if (enum2str) { | |||||
| json = enum_value_desc->name(); | |||||
| } else { | |||||
| json = enum_value_desc->number(); | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,68 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| // File: pb2json.h | |||||
| // Description: This header file for protobuf message and json interconversion | |||||
| #ifndef GE_COMMON_CONVERT_PB2JSON_H_ | |||||
| #define GE_COMMON_CONVERT_PB2JSON_H_ | |||||
| #include <functional> | |||||
| #include <memory> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include "google/protobuf/descriptor.h" | |||||
| #include "google/protobuf/message.h" | |||||
| #include "nlohmann/json.hpp" | |||||
| namespace ge { | |||||
| using Json = nlohmann::json; | |||||
| using ProtobufMsg = ::google::protobuf::Message; | |||||
| using ProtobufReflection = ::google::protobuf::Reflection; | |||||
| using ProtobufFieldDescriptor = ::google::protobuf::FieldDescriptor; | |||||
| using ProtobufDescriptor = ::google::protobuf::Descriptor; | |||||
| using ProtobufEnumValueDescriptor = ::google::protobuf::EnumValueDescriptor; | |||||
| class Pb2Json { | |||||
| public: | |||||
| /** | |||||
| * @ingroup domi_omg | |||||
| * @brief Transfer protobuf object to JSON object | |||||
| * @param [out] json Converted JSON object | |||||
| * @return void success | |||||
| * @author | |||||
| */ | |||||
| static void Message2Json(const ProtobufMsg &message, const std::set<std::string> &black_fields, Json &json, | |||||
| bool enum2str = false); | |||||
| protected: | |||||
| static void RepeatedMessage2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||||
| const ProtobufReflection *reflection, const std::set<std::string> &black_fields, | |||||
| Json &json, bool enum2str); | |||||
| static void Enum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, const ProtobufFieldDescriptor *field, | |||||
| bool enum2str, Json &json); | |||||
| static void RepeatedEnum2Json(const ProtobufEnumValueDescriptor *enum_value_desc, bool enum2str, Json &json); | |||||
| static void OneField2Json(const ProtobufMsg &message, const ProtobufFieldDescriptor *field, | |||||
| const ProtobufReflection *reflection, const std::set<std::string> &black_fields, Json &json, | |||||
| bool enum2str); | |||||
| static std::string TypeBytes2String(std::string &field_name, std::string &type_bytes); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_CONVERT_PB2JSON_H_ | |||||
| @@ -201,7 +201,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY const std::string &DumpProperti | |||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpOpSwitch( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void DumpProperties::SetDumpOpSwitch( | ||||
| const std::string &dump_op_switch) { | |||||
| const std::string dump_op_switch) { | |||||
| dump_op_switch_ = dump_op_switch; | dump_op_switch_ = dump_op_switch; | ||||
| } | } | ||||
| @@ -65,7 +65,7 @@ class DumpProperties { | |||||
| const std::string &GetDumpStatus() const; | const std::string &GetDumpStatus() const; | ||||
| void SetDumpOpSwitch(const std::string &dump_op_switch); | |||||
| void SetDumpOpSwitch(const std::string dump_op_switch); | |||||
| const std::string &GetDumpOpSwitch() const; | const std::string &GetDumpOpSwitch() const; | ||||
| @@ -94,13 +94,6 @@ void TBEPluginManager::ProcessSoFullName(vector<string> &file_list, string &caff | |||||
| full_name.compare(full_name.size() - caffe_parser_so_suff.size(), caffe_parser_so_suff.size(), | full_name.compare(full_name.size() - caffe_parser_so_suff.size(), caffe_parser_so_suff.size(), | ||||
| caffe_parser_so_suff) == 0) { | caffe_parser_so_suff) == 0) { | ||||
| caffe_parser_path = full_name; | caffe_parser_path = full_name; | ||||
| } else if ((full_name.size() >= aicpu_so_suff.size() && | |||||
| full_name.compare(full_name.size() - aicpu_so_suff.size(), aicpu_so_suff.size(), aicpu_so_suff) == 0) || | |||||
| (full_name.size() >= aicpu_host_so_suff.size() && | |||||
| full_name.compare(full_name.size() - aicpu_host_so_suff.size(), aicpu_host_so_suff.size(), | |||||
| aicpu_host_so_suff) == 0)) { | |||||
| // aicpu so, Put the file path into the omgcontext and save into the model in the builder stage. | |||||
| domi::GetContext().aicpu_op_run_paths.push_back(full_name); | |||||
| } else { | } else { | ||||
| // Save parser so path into file_list vector | // Save parser so path into file_list vector | ||||
| file_list.push_back(full_name); | file_list.push_back(full_name); | ||||
| @@ -230,39 +223,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::LoadPlug | |||||
| } | } | ||||
| } | } | ||||
| Status TBEPluginManager::CheckCustomAiCpuOpLib() { | |||||
| std::vector<std::string> vec_op_type; | |||||
| domi::OpRegistry::Instance()->GetOpTypeByImplyType(vec_op_type, domi::ImplyType::CUSTOM); | |||||
| for (size_t i = 0; i < vec_op_type.size(); i++) { | |||||
| bool aicpu_so_exist = false; | |||||
| std::string ai_cpu_so_name = "lib" + vec_op_type[i] + "_aicpu.so"; | |||||
| for (size_t j = 0; j < domi::GetContext().aicpu_op_run_paths.size(); j++) { | |||||
| string bin_file_path = domi::GetContext().aicpu_op_run_paths[j]; | |||||
| if (bin_file_path.size() >= ai_cpu_so_name.size() && | |||||
| bin_file_path.compare(bin_file_path.size() - ai_cpu_so_name.size(), ai_cpu_so_name.size(), ai_cpu_so_name) == | |||||
| 0) { | |||||
| aicpu_so_exist = true; | |||||
| break; | |||||
| } | |||||
| } | |||||
| if (!aicpu_so_exist) { | |||||
| GELOGE(FAILED, "Can't find aicpu run so(%s), please check the plugin path!", ai_cpu_so_name.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::InitPreparation( | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void TBEPluginManager::InitPreparation( | ||||
| const std::map<string, string> &options) { | const std::map<string, string> &options) { | ||||
| options_.insert(options.begin(), options.end()); | options_.insert(options.begin(), options.end()); | ||||
| // Load TBE plugin | // Load TBE plugin | ||||
| TBEPluginManager::Instance().LoadCustomOpLib(); | TBEPluginManager::Instance().LoadCustomOpLib(); | ||||
| Status ret = CheckCustomAiCpuOpLib(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Check custom aicpu run so failed!"); | |||||
| return; | |||||
| } | |||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -62,7 +62,6 @@ class TBEPluginManager { | |||||
| static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path); | static void GetPluginSoFileList(const string &path, vector<string> &file_list, string &caffe_parser_path); | ||||
| static void GetCustomOpPath(std::string &customop_path); | static void GetCustomOpPath(std::string &customop_path); | ||||
| void LoadCustomOpLib(); | void LoadCustomOpLib(); | ||||
| static Status CheckCustomAiCpuOpLib(); | |||||
| SoHandlesVec handles_vec_; | SoHandlesVec handles_vec_; | ||||
| static std::map<string, string> options_; | static std::map<string, string> options_; | ||||
| @@ -71,7 +71,10 @@ GE_COMMON_LOCAL_C_INCLUDES := \ | |||||
| $(TOPDIR)third_party/openssl/include/x86/include \ | $(TOPDIR)third_party/openssl/include/x86/include \ | ||||
| $(TOPDIR)framework/domi \ | $(TOPDIR)framework/domi \ | ||||
| $(TOPDIR)framework/domi/common \ | $(TOPDIR)framework/domi/common \ | ||||
| $(TOPDIR)framework/domi/common/op | |||||
| $(TOPDIR)framework/domi/common/op \ | |||||
| $(TOPDIR)graphengine/ge \ | |||||
| $(TOPDIR)graphengine/ge/common \ | |||||
| $(TOPDIR)graphengine/ge/common/op \ | |||||
| #compile host libge_common | #compile host libge_common | ||||
| include $(CLEAR_VARS) | include $(CLEAR_VARS) | ||||
| @@ -1497,7 +1497,6 @@ Status ModelCacheHelper::ParseMemResourceFromJson(const Json &json, map<rtMemTyp | |||||
| } | } | ||||
| mem_resource.clear(); | mem_resource.clear(); | ||||
| for (const Json &mem_resource_json : json) { | for (const Json &mem_resource_json : json) { | ||||
| MemResource var_addr_mgr; | |||||
| try { | try { | ||||
| rtMemType_t mem_type = mem_resource_json[kMemType].get<rtMemType_t>(); | rtMemType_t mem_type = mem_resource_json[kMemType].get<rtMemType_t>(); | ||||
| uint64_t var_mem_size = mem_resource_json[kVarMemSize].get<int64_t>(); | uint64_t var_mem_size = mem_resource_json[kVarMemSize].get<int64_t>(); | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "framework/common/op/attr_value_util.h" | #include "framework/common/op/attr_value_util.h" | ||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "framework/common/util.h" | #include "framework/common/util.h" | ||||
| #include "register/register_types.h" | |||||
| namespace ge { | namespace ge { | ||||
| #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ | #define DEFINE_SET_ATTR_VALUE_ONE(ARG_TYPE, FIELD) \ | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "framework/common/op/attr_value_util.h" | #include "framework/common/op/attr_value_util.h" | ||||
| #include "framework/common/util.h" | #include "framework/common/util.h" | ||||
| #include "framework/common/types.h" | |||||
| #include "graph/anchor.h" | #include "graph/anchor.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
| @@ -353,20 +353,18 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::StopProf | |||||
| } | } | ||||
| uint64_t module = GetProfilingModule(); | uint64_t module = GetProfilingModule(); | ||||
| int32_t device_num = static_cast<int32_t>(device_id_.size()); | int32_t device_num = static_cast<int32_t>(device_id_.size()); | ||||
| uint32_t *device_id_ptr = new (std::nothrow) uint32_t[device_num]; | |||||
| auto device_id_ptr = std::unique_ptr<uint32_t[]>(new (std::nothrow) uint32_t[device_num]); | |||||
| if (device_id_ptr == nullptr) { | if (device_id_ptr == nullptr) { | ||||
| GELOGE(FAILED, "Stop profiling device id ptr is null."); | |||||
| GELOGE(FAILED, "Stop profiling: device id ptr is null."); | |||||
| return; | return; | ||||
| } | } | ||||
| for (int32_t i = 0; i < device_num; i++) { | for (int32_t i = 0; i < device_num; i++) { | ||||
| device_id_ptr[i] = static_cast<uint32_t>(device_id_[i]); | device_id_ptr[i] = static_cast<uint32_t>(device_id_[i]); | ||||
| } | } | ||||
| rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr); | |||||
| rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get()); | |||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGW("Call rtProfilerStop failed, ret:%d", rt_ret); | GELOGW("Call rtProfilerStop failed, ret:%d", rt_ret); | ||||
| } | } | ||||
| delete[] device_id_ptr; | |||||
| device_id_ptr = nullptr; | |||||
| for (size_t i = 0; i < prof_handle_vec_.size(); ++i) { | for (size_t i = 0; i < prof_handle_vec_.size(); ++i) { | ||||
| int result = ProfMgrStop(prof_handle_vec_[i]); | int result = ProfMgrStop(prof_handle_vec_[i]); | ||||
| @@ -732,23 +730,21 @@ ProfilingManager::ProfStartProfiling(uint64_t module, const std::map<std::string | |||||
| GELOGE(FAILED, "Prof start parse param failed."); | GELOGE(FAILED, "Prof start parse param failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| auto *device_id = new (std::nothrow) uint32_t[device_num]; | |||||
| if (device_id == nullptr) { | |||||
| GELOGE(FAILED, "Prof start parse param failed."); | |||||
| auto device_id_ptr = std::unique_ptr<uint32_t[]>(new (std::nothrow) uint32_t[device_num]); | |||||
| if (device_id_ptr == nullptr) { | |||||
| GELOGE(FAILED, "Prof start: device id ptr is null."); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| for (int32_t i = 0; i < device_num; i++) { | for (int32_t i = 0; i < device_num; i++) { | ||||
| device_id[i] = static_cast<uint32_t>(device_list[i]); | |||||
| device_id_ptr[i] = static_cast<uint32_t>(device_list[i]); | |||||
| } | } | ||||
| GELOGI("Runtime config param: 0x%llx, device num: %d.", module, device_num); | GELOGI("Runtime config param: 0x%llx, device num: %d.", module, device_num); | ||||
| rtError_t rt_ret = rtProfilerStart(module, device_num, device_id); | |||||
| rtError_t rt_ret = rtProfilerStart(module, device_num, device_id_ptr.get()); | |||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| delete[] device_id; | |||||
| GELOGE(FAILED, "Runtime profiler config proc failed."); | GELOGE(FAILED, "Runtime profiler config proc failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| delete[] device_id; | |||||
| device_id = nullptr; | |||||
| if ((module & PROF_MODEL_EXECUTE_MASK) == PROF_MODEL_EXECUTE_MASK) { | if ((module & PROF_MODEL_EXECUTE_MASK) == PROF_MODEL_EXECUTE_MASK) { | ||||
| for (int32_t i = 0; i < device_num; i++) { | for (int32_t i = 0; i < device_num; i++) { | ||||
| if (std::find(device_id_.begin(), device_id_.end(), device_list[i]) == device_id_.end()) { | if (std::find(device_id_.begin(), device_id_.end(), device_list[i]) == device_id_.end()) { | ||||
| @@ -776,23 +772,20 @@ ProfilingManager::ProfStopProfiling(uint64_t module, const std::map<std::string, | |||||
| GELOGE(FAILED, "Prof stop parse param failed."); | GELOGE(FAILED, "Prof stop parse param failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| auto *device_id = new (std::nothrow) uint32_t[device_num]; | |||||
| if (device_id == nullptr) { | |||||
| GELOGE(FAILED, "Prof stop parse param failed."); | |||||
| auto device_id_ptr = std::unique_ptr<uint32_t[]>(new (std::nothrow) uint32_t[device_num]); | |||||
| if (device_id_ptr == nullptr) { | |||||
| GELOGE(FAILED, "Prof stop: device id ptr is null."); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| for (int32_t i = 0; i < device_num; i++) { | for (int32_t i = 0; i < device_num; i++) { | ||||
| device_id[i] = static_cast<uint32_t>(device_list[i]); | |||||
| device_id_ptr[i] = static_cast<uint32_t>(device_list[i]); | |||||
| } | } | ||||
| GELOGI("Prof stop: runtime config param: 0x%llx, device num: %d", module, device_num); | GELOGI("Prof stop: runtime config param: 0x%llx, device num: %d", module, device_num); | ||||
| rtError_t rt_ret = rtProfilerStop(module, device_num, device_id); | |||||
| rtError_t rt_ret = rtProfilerStop(module, device_num, device_id_ptr.get()); | |||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| delete[] device_id; | |||||
| GELOGE(FAILED, "Prof stop: runtime profiler config proc failed."); | GELOGE(FAILED, "Prof stop: runtime profiler config proc failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| delete[] device_id; | |||||
| device_id = nullptr; | |||||
| uint64_t execute_model_mask = module & PROF_MODEL_EXECUTE_MASK; | uint64_t execute_model_mask = module & PROF_MODEL_EXECUTE_MASK; | ||||
| if (execute_model_mask == PROF_MODEL_EXECUTE_MASK) { | if (execute_model_mask == PROF_MODEL_EXECUTE_MASK) { | ||||
| for (int32_t i = 0; i < device_num; i++) { | for (int32_t i = 0; i < device_num; i++) { | ||||
| @@ -384,6 +384,7 @@ REGISTER_OPTYPE_DEFINE(HCOMREDUCESCATTER, "HcomReduceScatter"); | |||||
| REGISTER_OPTYPE_DEFINE(HCOMSEND, "HcomSend"); | REGISTER_OPTYPE_DEFINE(HCOMSEND, "HcomSend"); | ||||
| REGISTER_OPTYPE_DEFINE(HCOMRECEIVE, "HcomReceive"); | REGISTER_OPTYPE_DEFINE(HCOMRECEIVE, "HcomReceive"); | ||||
| REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead"); | REGISTER_OPTYPE_DEFINE(HCOMREMOTEREAD, "HcomRemoteRead"); | ||||
| REGISTER_OPTYPE_DEFINE(HCOMREMOTEREFREAD, "HcomRemoteRefRead"); | |||||
| REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | REGISTER_OPTYPE_DEFINE(HCOMREMOTEWRITE, "HcomRemoteWrite"); | ||||
| REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign"); | REGISTER_OPTYPE_DEFINE(VARASSIGN, "VarAssign"); | ||||
| @@ -54,8 +54,7 @@ const int kProtoReadBytesLimit = INT_MAX; // Max size of 2 GB minus 1 byte. | |||||
| const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M | const int kWarningThreshold = 536870912 * 2; // 536870912 represent 512M | ||||
| /// The maximum length of the file. | /// The maximum length of the file. | ||||
| /// Based on the security coding specification and the current actual (protobuf) model size, it is determined as 2G-1 | |||||
| const int kMaxFileSizeLimit = INT_MAX; | |||||
| const uint32_t kMaxFileSizeLimit = UINT32_MAX; // 4G for now | |||||
| const int kMaxBuffSize = 256; | const int kMaxBuffSize = 256; | ||||
| const char *const kPathValidReason = "The path can only contain 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character"; | const char *const kPathValidReason = "The path can only contain 'a-z' 'A-Z' '0-9' '-' '.' '_' and chinese character"; | ||||
| constexpr uint32_t kMaxConfigFileByte = 10 * 1024 * 1024; | constexpr uint32_t kMaxConfigFileByte = 10 * 1024 * 1024; | ||||
| @@ -186,7 +185,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadBytesFromBinaryFile(co | |||||
| std::streamsize size = file.tellg(); | std::streamsize size = file.tellg(); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((size <= 0), file.close(); return false, "file length <= 0, not valid."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((size <= 0), file.close(); return false, "file length <= 0, not valid."); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size > kMaxFileSizeLimit, file.close(); | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(size > static_cast<int64_t>(kMaxFileSizeLimit), file.close(); | |||||
| return false, "file size %ld is out of limit: %d.", size, kMaxFileSizeLimit); | return false, "file size %ld is out of limit: %d.", size, kMaxFileSizeLimit); | ||||
| file.seekg(0, std::ios::beg); // [no need to check value] | file.seekg(0, std::ios::beg); // [no need to check value] | ||||
| @@ -304,7 +303,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromMem(const cha | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestap() { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY uint64_t GetCurrentTimestamp() { | |||||
| struct timeval tv {}; | struct timeval tv {}; | ||||
| int ret = gettimeofday(&tv, nullptr); | int ret = gettimeofday(&tv, nullptr); | ||||
| GE_LOGE_IF(ret != 0, "Func gettimeofday may failed: ret=%d", ret); | GE_LOGE_IF(ret != 0, "Func gettimeofday may failed: ret=%d", ret); | ||||
| @@ -216,9 +216,9 @@ std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) { | |||||
| if (kernel_info_store != kernel_map.end()) { | if (kernel_info_store != kernel_map.end()) { | ||||
| std::string unsupported_reason; | std::string unsupported_reason; | ||||
| // It will be replaced by engine' checksupport | // It will be replaced by engine' checksupport | ||||
| uint64_t start_time = GetCurrentTimestap(); | |||||
| uint64_t start_time = GetCurrentTimestamp(); | |||||
| if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { | if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { | ||||
| checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time; | |||||
| checksupport_cost_[kernel_name] += GetCurrentTimestamp() - start_time; | |||||
| op_desc->SetOpEngineName(it.engine); | op_desc->SetOpEngineName(it.engine); | ||||
| op_desc->SetOpKernelLibName(kernel_name); | op_desc->SetOpKernelLibName(kernel_name); | ||||
| // set attrs for taking information when load txt to graph object | // set attrs for taking information when load txt to graph object | ||||
| @@ -228,7 +228,7 @@ std::string DNNEngineManager::GetDNNEngineName(const ge::NodePtr &node_ptr) { | |||||
| it.engine.c_str(), op_desc->GetName().c_str()); | it.engine.c_str(), op_desc->GetName().c_str()); | ||||
| return it.engine; | return it.engine; | ||||
| } else { | } else { | ||||
| checksupport_cost_[kernel_name] += GetCurrentTimestap() - start_time; | |||||
| checksupport_cost_[kernel_name] += GetCurrentTimestamp() - start_time; | |||||
| bool is_custom_op = false; | bool is_custom_op = false; | ||||
| if ((ge::AttrUtils::GetBool(op_desc, kCustomOpFlag, is_custom_op)) && is_custom_op) { | if ((ge::AttrUtils::GetBool(op_desc, kCustomOpFlag, is_custom_op)) && is_custom_op) { | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E13001", {"kernelname", "optype", "opname"}, | ErrorManager::GetInstance().ATCReportErrMessage("E13001", {"kernelname", "optype", "opname"}, | ||||
| @@ -41,6 +41,13 @@ | |||||
| "skip_assign_stream": false, | "skip_assign_stream": false, | ||||
| "attach": true | "attach": true | ||||
| }, | }, | ||||
| { | |||||
| "id": "DNN_VM_AICPU_ASCEND", | |||||
| "name": "AICPU_ASCEND", | |||||
| "independent": false, | |||||
| "skip_assign_stream": false, | |||||
| "attach": true | |||||
| }, | |||||
| { | { | ||||
| "id": "DNN_HCCL", | "id": "DNN_HCCL", | ||||
| "name": "HCCL", | "name": "HCCL", | ||||
| @@ -38,6 +38,7 @@ | |||||
| #include "single_op/single_op_manager.h" | #include "single_op/single_op_manager.h" | ||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| #include "graph/load/new_model_manager/davinci_model.h" | #include "graph/load/new_model_manager/davinci_model.h" | ||||
| #include "opskernel_manager/ops_kernel_builder_manager.h" | |||||
| using std::string; | using std::string; | ||||
| using std::vector; | using std::vector; | ||||
| @@ -241,12 +242,16 @@ Status GeExecutor::Initialize() { | |||||
| } | } | ||||
| std::vector<rtMemType_t> mem_type(1, RT_MEMORY_HBM); | std::vector<rtMemType_t> mem_type(1, RT_MEMORY_HBM); | ||||
| mem_type.push_back(RT_MEMORY_P2P_DDR); | |||||
| auto ret = MemManager::Instance().Initialize(mem_type); | auto ret = MemManager::Instance().Initialize(mem_type); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Memory Manager init failed."); | GELOGE(ret, "Memory Manager init failed."); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| GE_CHK_STATUS_RET(OpsKernelBuilderManager::Instance().Initialize({}, false), | |||||
| "Failed to initialize OpsKernelBuilders"); | |||||
| // Start profiling | // Start profiling | ||||
| Options profiling_options; | Options profiling_options; | ||||
| profiling_options.device_id = 0; | profiling_options.device_id = 0; | ||||
| @@ -265,6 +270,8 @@ Status GeExecutor::Finalize() { | |||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| (void)OpsKernelBuilderManager::Instance().Finalize(); | |||||
| // Stop profiling | // Stop profiling | ||||
| if (ProfilingManager::Instance().ProfilingOn()) { | if (ProfilingManager::Instance().ProfilingOn()) { | ||||
| ProfilingManager::Instance().StopProfiling(); | ProfilingManager::Instance().StopProfiling(); | ||||
| @@ -282,11 +289,14 @@ Status GeExecutor::SetDynamicBatchSize(uint32_t model_id, void *dynamic_input_ad | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| uint64_t size = sizeof(uint64_t); | |||||
| uint64_t size = sizeof(uint32_t); | |||||
| if (length < size) { | if (length < size) { | ||||
| GELOGE(PARAM_INVALID, "Dynamic input size [%lu] is less than [%lu]!", length, size); | GELOGE(PARAM_INVALID, "Dynamic input size [%lu] is less than [%lu]!", length, size); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| if (length >= sizeof(uint64_t)) { | |||||
| size = sizeof(uint64_t); | |||||
| } | |||||
| // Verify whether the input dynamic batch matches the model gear | // Verify whether the input dynamic batch matches the model gear | ||||
| std::vector<std::vector<int64_t>> batch_info; | std::vector<std::vector<int64_t>> batch_info; | ||||
| @@ -324,12 +334,15 @@ Status GeExecutor::SetDynamicImageSize(uint32_t model_id, void *dynamic_input_ad | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| uint64_t dynamic_input_size = kDynamicImageSizeInputSize * sizeof(uint64_t); | |||||
| uint64_t dynamic_input_size = kDynamicImageSizeInputSize * sizeof(uint32_t); | |||||
| if (length < dynamic_input_size) { | if (length < dynamic_input_size) { | ||||
| GELOGE(PARAM_INVALID, "Dynamic input size [%lu] is less than [%lu]!", length, dynamic_input_size); | GELOGE(PARAM_INVALID, "Dynamic input size [%lu] is less than [%lu]!", length, dynamic_input_size); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| uint64_t size = sizeof(uint32_t); | |||||
| if (length >= kDynamicImageSizeInputSize * sizeof(uint64_t)) { | |||||
| size = sizeof(uint64_t); | |||||
| } | |||||
| // Verify whether the input dynamic resolution matches the model gear | // Verify whether the input dynamic resolution matches the model gear | ||||
| std::vector<std::vector<int64_t>> batch_info; | std::vector<std::vector<int64_t>> batch_info; | ||||
| std::vector<uint64_t> batch_num{image_height, image_width}; | std::vector<uint64_t> batch_num{image_height, image_width}; | ||||
| @@ -350,18 +363,18 @@ Status GeExecutor::SetDynamicImageSize(uint32_t model_id, void *dynamic_input_ad | |||||
| GELOGE(ret, "Set dynamic size failed"); | GELOGE(ret, "Set dynamic size failed"); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| // Memcpy dynamic resolution height from host to device | // Memcpy dynamic resolution height from host to device | ||||
| rtError_t rt_ret = | |||||
| rtMemcpy(dynamic_input_addr, sizeof(uint64_t), &image_height, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| rtError_t rt_ret = rtMemcpy(dynamic_input_addr, size, &image_height, size, RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "memcpy dynamic resolution input data failed! ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "memcpy dynamic resolution input data failed! ret: 0x%X", rt_ret); | ||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | return RT_ERROR_TO_GE_STATUS(rt_ret); | ||||
| } | } | ||||
| uint64_t remain_size = length - sizeof(uint64_t); | |||||
| uint64_t remain_size = length - size; | |||||
| // Memcpy dynamic resolution width from host to device | // Memcpy dynamic resolution width from host to device | ||||
| if (rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(dynamic_input_addr) + sizeof(uint64_t)), | |||||
| remain_size, &image_width, sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE) != RT_ERROR_NONE) { | |||||
| if (rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(dynamic_input_addr) + size), remain_size, | |||||
| &image_width, size, RT_MEMCPY_HOST_TO_DEVICE) != RT_ERROR_NONE) { | |||||
| GELOGE(FAILED, "memcpy dynamic resolution input data failed!"); | GELOGE(FAILED, "memcpy dynamic resolution input data failed!"); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -401,16 +414,19 @@ Status GeExecutor::SetDynamicDims(uint32_t model_id, void *dynamic_input_addr, u | |||||
| } | } | ||||
| size_t dynamic_dim_num = cur_dynamic_dims.size(); | size_t dynamic_dim_num = cur_dynamic_dims.size(); | ||||
| uint64_t dynamic_input_size = static_cast<uint64_t>(dynamic_dim_num * sizeof(uint64_t)); | |||||
| uint64_t dynamic_input_size = static_cast<uint64_t>(dynamic_dim_num * sizeof(uint32_t)); | |||||
| if (length < dynamic_input_size) { | if (length < dynamic_input_size) { | ||||
| GELOGE(FAILED, "Dynamic input size [%lu] is less than [%lu]!", length, dynamic_input_size); | GELOGE(FAILED, "Dynamic input size [%lu] is less than [%lu]!", length, dynamic_input_size); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| uint64_t size = sizeof(uint32_t); | |||||
| if (length >= dynamic_dim_num * sizeof(uint64_t)) { | |||||
| size = sizeof(uint64_t); | |||||
| } | |||||
| for (uint32_t i = 0; i < dynamic_dim_num; ++i) { | for (uint32_t i = 0; i < dynamic_dim_num; ++i) { | ||||
| // Memcpy dynamic dim[i] from host to device | // Memcpy dynamic dim[i] from host to device | ||||
| if (rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(dynamic_input_addr) + sizeof(uint64_t) * i), | |||||
| length - sizeof(uint64_t) * i, &cur_dynamic_dims[i], sizeof(uint64_t), | |||||
| RT_MEMCPY_HOST_TO_DEVICE) != RT_ERROR_NONE) { | |||||
| if (rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(dynamic_input_addr) + size * i), | |||||
| length - size * i, &cur_dynamic_dims[i], size, RT_MEMCPY_HOST_TO_DEVICE) != RT_ERROR_NONE) { | |||||
| GELOGE(FAILED, "memcpy dynamic resolution input data failed!"); | GELOGE(FAILED, "memcpy dynamic resolution input data failed!"); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -1113,7 +1129,7 @@ Status GeExecutor::SetDump(const DumpConfig &dump_config) { | |||||
| GELOGE(ret, "Set dump conf failed"); | GELOGE(ret, "Set dump conf failed"); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| GELOGI("Set dump config succ."); | |||||
| GELOGI("Set dump config successfully"); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -50,6 +50,7 @@ local_ge_executor_src_files := \ | |||||
| ../graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | ../graph/load/new_model_manager/task_info/end_graph_task_info.cc \ | ||||
| ../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ | ../graph/load/new_model_manager/task_info/super_kernel/super_kernel_factory.cc \ | ||||
| ../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ | ../graph/load/new_model_manager/task_info/super_kernel/super_kernel.cc \ | ||||
| ../opskernel_manager/ops_kernel_builder_manager.cc \ | |||||
| ../single_op/single_op_manager.cc \ | ../single_op/single_op_manager.cc \ | ||||
| ../single_op/single_op_model.cc \ | ../single_op/single_op_model.cc \ | ||||
| ../single_op/single_op.cc \ | ../single_op/single_op.cc \ | ||||
| @@ -74,6 +75,7 @@ local_ge_executor_c_include := \ | |||||
| $(TOPDIR)inc/framework \ | $(TOPDIR)inc/framework \ | ||||
| $(TOPDIR)inc \ | $(TOPDIR)inc \ | ||||
| $(LOCAL_PATH)/../ \ | $(LOCAL_PATH)/../ \ | ||||
| $(TOPDIR)graphengine/ge \ | |||||
| $(TOPDIR)libc_sec/include \ | $(TOPDIR)libc_sec/include \ | ||||
| third_party/protobuf/include \ | third_party/protobuf/include \ | ||||
| third_party/json/include \ | third_party/json/include \ | ||||
| @@ -89,7 +91,6 @@ local_ge_executor_shared_library := \ | |||||
| libregister \ | libregister \ | ||||
| libmsprof \ | libmsprof \ | ||||
| liberror_manager \ | liberror_manager \ | ||||
| libascend_hal | |||||
| local_ge_executor_ldflags := -lrt -ldl \ | local_ge_executor_ldflags := -lrt -ldl \ | ||||
| @@ -105,7 +106,12 @@ LOCAL_SRC_FILES := $(local_ge_executor_src_files) | |||||
| LOCAL_C_INCLUDES := $(local_ge_executor_c_include) | LOCAL_C_INCLUDES := $(local_ge_executor_c_include) | ||||
| LOCAL_SHARED_LIBRARIES := $(local_ge_executor_shared_library) | LOCAL_SHARED_LIBRARIES := $(local_ge_executor_shared_library) | ||||
| LOCAL_STATIC_LIBRARIES := libmsprofiler | |||||
| LOCAL_SHARED_LIBRARIES += libascend_hal | |||||
| LOCAL_STATIC_LIBRARIES := \ | |||||
| libmsprofiler \ | |||||
| ifeq ($(device_os),android) | ifeq ($(device_os),android) | ||||
| LOCAL_LDFLAGS += -ldl | LOCAL_LDFLAGS += -ldl | ||||
| LOCAL_LDLIBS += -L$(PWD)/prebuilts/clang/linux-x86/aarch64/android-ndk-r21/sysroot/usr/lib/aarch64-linux-android/29 -llog | LOCAL_LDLIBS += -L$(PWD)/prebuilts/clang/linux-x86/aarch64/android-ndk-r21/sysroot/usr/lib/aarch64-linux-android/29 -llog | ||||
| @@ -142,9 +148,10 @@ LOCAL_SHARED_LIBRARIES := \ | |||||
| libregister \ | libregister \ | ||||
| libmsprof \ | libmsprof \ | ||||
| liberror_manager \ | liberror_manager \ | ||||
| stub/libascend_hal | |||||
| stub/libascend_hal \ | |||||
| LOCAL_STATIC_LIBRARIES := libmsprofiler | |||||
| LOCAL_STATIC_LIBRARIES := \ | |||||
| libmsprofiler \ | |||||
| LOCAL_LDFLAGS += $(local_ge_executor_ldflags) | LOCAL_LDFLAGS += $(local_ge_executor_ldflags) | ||||
| @@ -42,6 +42,7 @@ GRAPH_MANAGER_LOCAL_SRC_FILES := \ | |||||
| session/session_manager.cc \ | session/session_manager.cc \ | ||||
| engine_manager/dnnengine_manager.cc \ | engine_manager/dnnengine_manager.cc \ | ||||
| opskernel_manager/ops_kernel_manager.cc \ | opskernel_manager/ops_kernel_manager.cc \ | ||||
| opskernel_manager/ops_kernel_builder_manager.cc \ | |||||
| graph/manager/graph_manager.cc \ | graph/manager/graph_manager.cc \ | ||||
| graph/manager/graph_manager_utils.cc \ | graph/manager/graph_manager_utils.cc \ | ||||
| graph/manager/graph_context.cc \ | graph/manager/graph_context.cc \ | ||||
| @@ -57,9 +58,11 @@ GRAPH_MANAGER_LOCAL_SRC_FILES := \ | |||||
| graph/partition/engine_place.cc \ | graph/partition/engine_place.cc \ | ||||
| graph/partition/graph_partition.cc \ | graph/partition/graph_partition.cc \ | ||||
| graph/partition/dynamic_shape_partition.cc \ | graph/partition/dynamic_shape_partition.cc \ | ||||
| graph/partition/stage_partition.cc \ | |||||
| generator/ge_generator.cc \ | generator/ge_generator.cc \ | ||||
| generator/generator_api.cc \ | generator/generator_api.cc \ | ||||
| graph/manager/graph_var_manager.cc \ | graph/manager/graph_var_manager.cc \ | ||||
| graph/manager/host_mem_manager.cc \ | |||||
| graph/manager/rdma_pool_allocator.cc \ | graph/manager/rdma_pool_allocator.cc \ | ||||
| graph/manager/graph_mem_allocator.cc \ | graph/manager/graph_mem_allocator.cc \ | ||||
| graph/manager/graph_caching_allocator.cc \ | graph/manager/graph_caching_allocator.cc \ | ||||
| @@ -178,6 +181,7 @@ OMG_HOST_SRC_FILES := \ | |||||
| graph/passes/multi_batch_pass.cc \ | graph/passes/multi_batch_pass.cc \ | ||||
| graph/passes/multi_batch_clone_pass.cc \ | graph/passes/multi_batch_clone_pass.cc \ | ||||
| graph/passes/subexpression_migration_pass.cc \ | graph/passes/subexpression_migration_pass.cc \ | ||||
| graph/passes/subgraph_const_migration_pass.cc \ | |||||
| graph/passes/unused_args_clean_pass.cc \ | graph/passes/unused_args_clean_pass.cc \ | ||||
| graph/passes/next_iteration_pass.cc \ | graph/passes/next_iteration_pass.cc \ | ||||
| graph/passes/control_trigger_pass.cc \ | graph/passes/control_trigger_pass.cc \ | ||||
| @@ -343,6 +347,7 @@ DEVICE_LOCAL_C_INCLUDES := \ | |||||
| $(TOPDIR)inc/runtime \ | $(TOPDIR)inc/runtime \ | ||||
| $(TOPDIR)ops/built-in/op_proto/inc \ | $(TOPDIR)ops/built-in/op_proto/inc \ | ||||
| $(TOPDIR)framework/domi \ | $(TOPDIR)framework/domi \ | ||||
| $(TOPDIR)graphengine/ge \ | |||||
| $(TOPDIR)toolchain/ide/ide-daemon/external \ | $(TOPDIR)toolchain/ide/ide-daemon/external \ | ||||
| third_party/json/include \ | third_party/json/include \ | ||||
| third_party/protobuf/include \ | third_party/protobuf/include \ | ||||