| @@ -176,7 +176,7 @@ cd ${BASEPATH} | |||||
| mkdir -p output/plugin/nnengine/ge_config/ | mkdir -p output/plugin/nnengine/ge_config/ | ||||
| find output/ -name graphengine_lib.tar -exec rm {} \; | find output/ -name graphengine_lib.tar -exec rm {} \; | ||||
| cp src/ge/engine_manager/engine_conf.json output/plugin/nnengine/ge_config/ | cp src/ge/engine_manager/engine_conf.json output/plugin/nnengine/ge_config/ | ||||
| find output/ -maxdepth 1 -name libengine.so -exec mv {} output/plugin/nnengine/ \; | |||||
| find output/ -maxdepth 1 -name libengine.so -exec mv -f {} output/plugin/nnengine/ \; | |||||
| tar -cf graphengine_lib.tar output/* | tar -cf graphengine_lib.tar output/* | ||||
| mv -f graphengine_lib.tar output | mv -f graphengine_lib.tar output | ||||
| echo "---------------- GraphEngine package archive generated ----------------" | echo "---------------- GraphEngine package archive generated ----------------" | ||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef COMPRESS_H | |||||
| #define COMPRESS_H | |||||
| #include <uchar.h> | |||||
| enum CmpStatus { RET_SUCCESS = 0, RET_ERROR = -1 }; | |||||
| struct CompressConfig { | |||||
| size_t inputSize; // length of data to compress | |||||
| size_t engineNum; // how many decompress engines | |||||
| size_t maxRatio; // how much size of a basic compression block, only 64 supported now (8x: 64 4x: 32) | |||||
| size_t channel; // channels of L2 or DDR. For load balance | |||||
| size_t fractalSize; // size of compressing block | |||||
| bool isTight; // whether compose compressed data tightly | |||||
| }; | |||||
| CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, | |||||
| size_t& compressedLength); | |||||
| #endif // COMPRESS_H | |||||
| @@ -0,0 +1,97 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PLATFORM_INFO_H | |||||
| #define PLATFORM_INFO_H | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "platform_info_def.h" | |||||
| using std::map; | |||||
| using std::string; | |||||
| using std::vector; | |||||
| namespace fe { | |||||
| class PlatformInfoManager { | |||||
| public: | |||||
| PlatformInfoManager(const PlatformInfoManager &) = delete; | |||||
| PlatformInfoManager &operator=(const PlatformInfoManager &) = delete; | |||||
| static PlatformInfoManager &Instance(); | |||||
| uint32_t InitializePlatformInfo(); | |||||
| uint32_t Finalize(); | |||||
| uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); | |||||
| void SetOptionalCompilationInfo(OptionalInfo &optiCompilationInfo); | |||||
| private: | |||||
| PlatformInfoManager(); | |||||
| ~PlatformInfoManager(); | |||||
| uint32_t LoadIniFile(string iniFileRealPath); | |||||
| void Trim(string &str); | |||||
| uint32_t LoadConfigFile(string realPath); | |||||
| string RealPath(const std::string &path); | |||||
| string GetSoFilePath(); | |||||
| void ParseVersion(map<string, string> &versionMap, string &socVersion, PlatformInfo &platformInfoTemp); | |||||
| void ParseSocInfo(map<string, string> &socInfoMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseCubeOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseBufferOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseUBOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseBufferOfAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseUBOfAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseAICoreintrinsicDtypeMap(map<string, string> &aiCoreintrinsicDtypeMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseVectorCoreSpec(map<string, string> &vectorCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseVectorCoreMemoryRates(map<string, string> &vectorCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||||
| void ParseVectorCoreintrinsicDtypeMap(map<string, string> &vectorCoreintrinsicDtypeMap, | |||||
| PlatformInfo &platformInfoTemp); | |||||
| uint32_t ParsePlatformInfoFromStrToStruct(map<string, map<string, string>> &contentInfoMap, string &socVersion, | |||||
| PlatformInfo &platformInfoTemp); | |||||
| uint32_t AssemblePlatformInfoVector(map<string, map<string, string>> &contentInfoMap); | |||||
| private: | |||||
| bool initFlag_; | |||||
| map<string, PlatformInfo> platformInfoMap_; | |||||
| OptionalInfo optiCompilationInfo_; | |||||
| }; | |||||
| } // namespace fe | |||||
| #endif | |||||
| @@ -0,0 +1,122 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef PLATFORM_INFO_DEF_H | |||||
| #define PLATFORM_INFO_DEF_H | |||||
| #include <map> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| using std::map; | |||||
| using std::string; | |||||
| using std::vector; | |||||
| namespace fe { | |||||
| enum MemoryType { DDR = 0, HBM }; | |||||
| enum L2Type { Cache = 0, Buff }; | |||||
| typedef struct tagStrInfo { | |||||
| string aicVersion; | |||||
| string ccecAICVersion; | |||||
| string ccecAIVVersion; | |||||
| string isSupportAIcpuCompiler; | |||||
| } StrInfo; | |||||
| typedef struct tagSoCInfo { | |||||
| uint32_t aiCoreCnt; | |||||
| uint32_t vectorCoreCnt; | |||||
| uint32_t aiCpuCnt; | |||||
| MemoryType memoryType; | |||||
| uint64_t memorySize; | |||||
| L2Type l2Type; | |||||
| uint64_t l2Size; | |||||
| uint32_t l2PageNum; | |||||
| } SoCInfo; | |||||
| typedef struct tagAiCoreSpec { | |||||
| double cubeFreq; | |||||
| uint64_t cubeMSize; | |||||
| uint64_t cubeNSize; | |||||
| uint64_t cubeKSize; | |||||
| uint64_t vecCalcSize; | |||||
| uint64_t l0ASize; | |||||
| uint64_t l0BSize; | |||||
| uint64_t l0CSize; | |||||
| uint64_t l1Size; | |||||
| uint64_t smaskBuffer; | |||||
| uint64_t ubSize; | |||||
| uint64_t ubblockSize; | |||||
| uint64_t ubbankSize; | |||||
| uint64_t ubbankNum; | |||||
| uint64_t ubburstInOneBlock; | |||||
| uint64_t ubbankGroupNum; | |||||
| } AiCoreSpec; | |||||
| typedef struct tagAiCoreMemoryRates { | |||||
| double ddrRate; | |||||
| double l2Rate; | |||||
| double l2ReadRate; | |||||
| double l2WriteRate; | |||||
| double l1ToL0ARate; | |||||
| double l1ToL0BRate; | |||||
| double l1ToUBRate; | |||||
| double l0CToUBRate; | |||||
| double ubToL2Rate; | |||||
| double ubToDdrRate; | |||||
| double ubToL1Rate; | |||||
| } AiCoreMemoryRates; | |||||
| typedef struct tagVectorCoreSpec { | |||||
| uint64_t vecCalcSize; | |||||
| uint64_t smaskBuffer; | |||||
| uint64_t ubSize; | |||||
| uint64_t ubblockSize; | |||||
| uint64_t ubbankSize; | |||||
| uint64_t ubbankNum; | |||||
| uint64_t ubburstInOneBlock; | |||||
| uint64_t ubbankGroupNum; | |||||
| } VectorCoreSpec; | |||||
| typedef struct tagVectorCoreMemoryRates { | |||||
| double ddrRate; | |||||
| double l2Rate; | |||||
| double l2ReadRate; | |||||
| double l2WriteRate; | |||||
| double ubToL2Rate; | |||||
| double ubToDdrRate; | |||||
| } VectorCoreMemoryRates; | |||||
| typedef struct tagPlatformInfo { | |||||
| StrInfo strInfo; | |||||
| SoCInfo socInfo; | |||||
| AiCoreSpec aiCoreSpec; | |||||
| AiCoreMemoryRates aiCoreMemoryRates; | |||||
| map<string, vector<string>> aiCoreIntrinsicDtypeMap; | |||||
| VectorCoreSpec vectorCoreSpec; | |||||
| VectorCoreMemoryRates vectorCoreMemoryRates; | |||||
| map<string, vector<string>> vectorCoreIntrinsicDtypeMap; | |||||
| } PlatformInfo; | |||||
| typedef struct tagOptionalInfo { | |||||
| string socVersion; | |||||
| string coreType; | |||||
| uint32_t aiCoreNum; | |||||
| string l1FusionFlag; | |||||
| } OptionalInfo; | |||||
| } // namespace fe | |||||
| #endif | |||||
| @@ -40,6 +40,8 @@ const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; | |||||
| const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | ||||
| const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | ||||
| const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | ||||
| const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; | |||||
| const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; | |||||
| // Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | // Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | ||||
| const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | ||||
| const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | ||||
| @@ -116,27 +116,5 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { | |||||
| namespace ge { | namespace ge { | ||||
| using OpRegistrationData = domi::OpRegistrationData; | using OpRegistrationData = domi::OpRegistrationData; | ||||
| using OpReceiver = domi::OpReceiver; | using OpReceiver = domi::OpReceiver; | ||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { | |||||
| public: | |||||
| HostCpuOp() = default; | |||||
| virtual ~HostCpuOp() = default; | |||||
| virtual graphStatus Compute(Operator &op, const std::map<std::string, const Tensor> &inputs, | |||||
| std::map<std::string, Tensor> &outputs) = 0; | |||||
| }; | |||||
| class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { | |||||
| public: | |||||
| HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); | |||||
| }; | |||||
| #define REGISTER_HOST_CPU_OP_BUILDER(name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) | |||||
| #define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) | |||||
| #define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ | |||||
| static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr __attribute__((unused)) = \ | |||||
| ::ge::HostCpuOpRegistrar(name, []() -> ::ge::HostCpuOp * { return new (std::nothrow) op(); }) | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | ||||
| @@ -434,6 +434,7 @@ REGISTER_OPTYPE_DECLARE(STREAMSWITCH, "StreamSwitch"); | |||||
| REGISTER_OPTYPE_DECLARE(STREAMSWITCHN, "StreamSwitchN"); | REGISTER_OPTYPE_DECLARE(STREAMSWITCHN, "StreamSwitchN"); | ||||
| REGISTER_OPTYPE_DECLARE(STREAMACTIVE, "StreamActive"); | REGISTER_OPTYPE_DECLARE(STREAMACTIVE, "StreamActive"); | ||||
| REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); | REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); | ||||
| REGISTER_OPTYPE_DECLARE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | |||||
| REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | ||||
| REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | ||||
| REGISTER_OPTYPE_DECLARE(SEND, "Send"); | REGISTER_OPTYPE_DECLARE(SEND, "Send"); | ||||
| @@ -441,6 +442,7 @@ REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | |||||
| REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | ||||
| REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | ||||
| REGISTER_OPTYPE_DECLARE(LABELGOTOEX, "LabelGotoEx"); | |||||
| REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); | REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); | ||||
| REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | ||||
| @@ -979,9 +979,14 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | ||||
| // functional ops attr | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; | |||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; | |||||
| // used for label switch | // used for label switch | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | ||||
| // Varible | // Varible | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | ||||
| @@ -22,7 +22,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "graph/anchor.h" | #include "graph/anchor.h" | ||||
| #include "detail/attributes_holder.h" | |||||
| #include "graph/detail/attributes_holder.h" | |||||
| #include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
| #include "graph/graph.h" | #include "graph/graph.h" | ||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| @@ -262,6 +262,8 @@ class GraphUtils { | |||||
| static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | ||||
| static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | ||||
| static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec); | |||||
| }; | }; | ||||
| class ComputeGraphBuilder { | class ComputeGraphBuilder { | ||||
| @@ -54,17 +54,34 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesS | |||||
| return s; | return s; | ||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const { | ||||
| vector<NodePtr> all_nodes(nodes_.size()); | |||||
| (void)std::copy(nodes_.begin(), nodes_.end(), all_nodes.begin()); | |||||
| for (const auto &sub_graph : sub_graph_) { | |||||
| if (sub_graph == nullptr) { | |||||
| GELOGW("sub graph is nullptr"); | |||||
| if (sub_graph_.empty()) { | |||||
| return Vistor<NodePtr>(shared_from_this(), nodes_); | |||||
| } | |||||
| std::vector<NodePtr> all_nodes; | |||||
| std::deque<NodePtr> candidates; | |||||
| candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end()); | |||||
| while (!candidates.empty()) { | |||||
| NodePtr node = candidates.front(); | |||||
| all_nodes.emplace_back(node); | |||||
| candidates.pop_front(); | |||||
| OpDescPtr op_desc = node->GetOpDesc(); | |||||
| if (op_desc == nullptr) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| for (const auto &node : sub_graph->GetAllNodes()) { | |||||
| all_nodes.push_back(node); | |||||
| const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||||
| for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { | |||||
| auto subgraph = GetSubgraph(*name_iter); | |||||
| if (subgraph != nullptr) { | |||||
| candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return Vistor<NodePtr>(shared_from_this(), all_nodes); | return Vistor<NodePtr>(shared_from_this(), all_nodes); | ||||
| } | } | ||||
| size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } | size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } | ||||
| @@ -602,7 +619,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE | |||||
| graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | ||||
| std::map<NodePtr, uint32_t> &map_in_edge_num, | std::map<NodePtr, uint32_t> &map_in_edge_num, | ||||
| std::vector<NodePtr> &stack) { | std::vector<NodePtr> &stack) { | ||||
| GELOGI("Runing_Dfs_Sort"); | |||||
| GELOGI("Runing_Dfs_Sort: %s", name_.c_str()); | |||||
| // Record the number of non data nodes but no input nodes | // Record the number of non data nodes but no input nodes | ||||
| GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); | GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); | ||||
| @@ -647,7 +664,7 @@ graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | |||||
| graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec, | graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec, | ||||
| std::map<NodePtr, uint32_t> &map_in_edge_num, | std::map<NodePtr, uint32_t> &map_in_edge_num, | ||||
| std::deque<NodePtr> &stack) { | std::deque<NodePtr> &stack) { | ||||
| GELOGI("Runing_Bfs_Sort"); | |||||
| GELOGI("Runing_Bfs_Sort: %s", name_.c_str()); | |||||
| std::vector<NodePtr> stack_input; | std::vector<NodePtr> stack_input; | ||||
| std::map<string, NodePtr> breadth_node_map; | std::map<string, NodePtr> breadth_node_map; | ||||
| // Record the number of non data nodes but no input nodes | // Record the number of non data nodes but no input nodes | ||||
| @@ -735,7 +752,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog | |||||
| use_BFS = true; | use_BFS = true; | ||||
| } | } | ||||
| } else { | } else { | ||||
| GELOGW("Get OPTION_GRAPH_RUN_MODE failed, use BFSTopologicalSorting by default."); | |||||
| GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); | |||||
| } | } | ||||
| if (use_BFS) { | if (use_BFS) { | ||||
| @@ -955,11 +955,8 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | |||||
| const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | ||||
| // functional ops attr | // functional ops attr | ||||
| const std::string ATTR_NAME_TCOND = "Tcond"; | |||||
| const std::string ATTR_NAME_TIN = "Tin"; | |||||
| const std::string ATTR_NAME_TOUT = "Tout"; | |||||
| const std::string ATTR_NAME_THEN_BRANCH = "then_branch"; | |||||
| const std::string ATTR_NAME_ELSE_BRANCH = "else_branch"; | |||||
| const std::string ATTR_NAME_WHILE_COND = "cond"; | |||||
| const std::string ATTR_NAME_WHILE_BODY = "body"; | |||||
| // used for label switch | // used for label switch | ||||
| const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include <cstring> | #include <cstring> | ||||
| #include <fstream> | #include <fstream> | ||||
| #include <iomanip> | #include <iomanip> | ||||
| #include <queue> | |||||
| #include "./ge_context.h" | #include "./ge_context.h" | ||||
| #include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
| @@ -1999,4 +2000,60 @@ void PartialGraphBuilder::BuildExistNodes(graphStatus &error_code, std::string & | |||||
| GELOGD("Build exist nodes succ."); | GELOGD("Build exist nodes succ."); | ||||
| } | } | ||||
| GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
| GraphUtils::TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec) { | |||||
| std::vector<NodePtr> stack_input; | |||||
| std::map<NodePtr, uint32_t> map_in_edge_num; | |||||
| graphStatus ret = compute_graph->SortNodes(stack_input, map_in_edge_num); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| GELOGE(GRAPH_FAILED, "Sort nodes failed."); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| const size_t non_user_input_index = stack_input.size() - compute_graph->inputs_order_.size() - 1; | |||||
| std::sort(stack_input.begin(), stack_input.begin() + non_user_input_index, | |||||
| [](const NodePtr &a, const NodePtr &b) -> bool { return (a->GetName() > b->GetName()); }); | |||||
| std::queue<NodePtr> stack; | |||||
| NodePtr cur_node = nullptr; | |||||
| std::map<string, NodePtr> name_node_map; | |||||
| vector<string> nodes_name; | |||||
| while (!stack_input.empty() || !stack.empty()) { | |||||
| if (!stack.empty()) { | |||||
| cur_node = stack.front(); | |||||
| stack.pop(); | |||||
| } else { | |||||
| cur_node = stack_input.back(); | |||||
| stack_input.pop_back(); | |||||
| } | |||||
| node_vec.emplace_back(cur_node); | |||||
| compute_graph->CollectBreadthOutNode(cur_node, map_in_edge_num, name_node_map); | |||||
| for (const auto &iter : name_node_map) { | |||||
| nodes_name.emplace_back(iter.first); | |||||
| } | |||||
| std::sort(nodes_name.begin(), nodes_name.end()); | |||||
| for (const auto &iter : nodes_name) { | |||||
| stack.push(name_node_map[iter]); | |||||
| } | |||||
| name_node_map.clear(); | |||||
| nodes_name.clear(); | |||||
| } | |||||
| // If they are not equal, there is a closed loop | |||||
| if (node_vec.size() != compute_graph->nodes_.size()) { | |||||
| std::set<Node *> itered_nodes_set; | |||||
| for (auto &node : node_vec) { | |||||
| itered_nodes_set.insert(node.get()); | |||||
| } | |||||
| GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", | |||||
| compute_graph->nodes_.size(), node_vec.size()); | |||||
| for (auto &node : compute_graph->nodes_) { | |||||
| if (itered_nodes_set.count(node.get()) == 0) { | |||||
| GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -41,6 +41,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||||
| include_directories(${GE_SOURCE_DIR}/inc/framework) | include_directories(${GE_SOURCE_DIR}/inc/framework) | ||||
| include_directories(${GE_SOURCE_DIR}/inc/framework/common) | include_directories(${GE_SOURCE_DIR}/inc/framework/common) | ||||
| include_directories(${GE_SOURCE_DIR}/inc/runtime) | include_directories(${GE_SOURCE_DIR}/inc/runtime) | ||||
| include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib) | |||||
| include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
| include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | ||||
| include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | ||||
| @@ -55,6 +56,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "common/formats/utils/formats_trans_utils.cc" | "common/formats/utils/formats_trans_utils.cc" | ||||
| "common/fp16_t.cc" | "common/fp16_t.cc" | ||||
| "common/ge/plugin_manager.cc" | "common/ge/plugin_manager.cc" | ||||
| "common/helper/model_cache_helper.cc" | |||||
| "common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
| "engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
| "ge_local_engine/engine/host_cpu_engine.cc" | "ge_local_engine/engine/host_cpu_engine.cc" | ||||
| @@ -92,6 +94,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/load/new_model_manager/task_info/kernel_task_info.cc" | "graph/load/new_model_manager/task_info/kernel_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/label_goto_task_info.cc" | "graph/load/new_model_manager/task_info/label_goto_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/label_set_task_info.cc" | "graph/load/new_model_manager/task_info/label_set_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||||
| "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
| @@ -269,6 +272,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "common/formats/utils/formats_trans_utils.cc" | "common/formats/utils/formats_trans_utils.cc" | ||||
| "common/fp16_t.cc" | "common/fp16_t.cc" | ||||
| "common/ge/plugin_manager.cc" | "common/ge/plugin_manager.cc" | ||||
| "common/helper/model_cache_helper.cc" | |||||
| "common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
| "engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
| "ge_local_engine/engine/host_cpu_engine.cc" | "ge_local_engine/engine/host_cpu_engine.cc" | ||||
| @@ -305,6 +309,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "graph/load/new_model_manager/task_info/kernel_task_info.cc" | "graph/load/new_model_manager/task_info/kernel_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/label_goto_task_info.cc" | "graph/load/new_model_manager/task_info/label_goto_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/label_set_task_info.cc" | "graph/load/new_model_manager/task_info/label_set_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||||
| "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
| "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
| @@ -470,7 +475,7 @@ target_link_libraries(ge_compiler | |||||
| ${slog} | ${slog} | ||||
| ${mmpa} | ${mmpa} | ||||
| ${msprof} | ${msprof} | ||||
| ${runtime} | |||||
| ${runtime_compiler} | |||||
| ${resouce} | ${resouce} | ||||
| rt | rt | ||||
| dl) | dl) | ||||
| @@ -134,10 +134,6 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
| } | } | ||||
| auto trans_mode = iter->second; | auto trans_mode = iter->second; | ||||
| if (args.src_data_size == 0) { | |||||
| GELOGE(PARAM_INVALID, "Invalid src data size %zu", args.src_data_size); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| int size = GetSizeByDataType(args.dst_data_type); | int size = GetSizeByDataType(args.dst_data_type); | ||||
| if (size <= 0) { | if (size <= 0) { | ||||
| GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", | GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", | ||||
| @@ -149,6 +145,12 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| size_t total_size = static_cast<size_t>(args.src_data_size * size); | size_t total_size = static_cast<size_t>(args.src_data_size * size); | ||||
| result.length = total_size; | |||||
| if (total_size == 0) { | |||||
| GELOGI("In TransDataType, total_size is zero, has no data."); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
| if (dst == nullptr) { | if (dst == nullptr) { | ||||
| GELOGE(OUT_OF_MEMORY, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); | GELOGE(OUT_OF_MEMORY, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); | ||||
| @@ -162,7 +164,6 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| result.data = dst; | result.data = dst; | ||||
| result.length = total_size; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -134,6 +134,11 @@ Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResu | |||||
| int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
| int64_t total_size = GetItemNumByShape(args.dst_shape) * size; | int64_t total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
| if (total_size <= 0) { | if (total_size <= 0) { | ||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -88,6 +88,11 @@ Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { | |||||
| dst_size *= dim; | dst_size *= dim; | ||||
| } | } | ||||
| dst_size *= data_size; | dst_size *= data_size; | ||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
| if (dst == nullptr) { | if (dst == nullptr) { | ||||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
| @@ -89,6 +89,11 @@ Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &resul | |||||
| dst_size *= dim; | dst_size *= dim; | ||||
| } | } | ||||
| dst_size *= data_size; | dst_size *= data_size; | ||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
| if (dst == nullptr) { | if (dst == nullptr) { | ||||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
| @@ -116,6 +116,11 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
| Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | ||||
| int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
| int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
| if (dst == nullptr) { | if (dst == nullptr) { | ||||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
| @@ -184,6 +189,11 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||||
| Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | ||||
| int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
| int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
| if (dst == nullptr) { | if (dst == nullptr) { | ||||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
| @@ -119,6 +119,11 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||||
| int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | ||||
| int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
| int64_t dst_size = total_ele_cnt * size; | int64_t dst_size = total_ele_cnt * size; | ||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
| if (dst == nullptr) { | if (dst == nullptr) { | ||||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
| @@ -194,6 +199,11 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||||
| dst_size *= dim; | dst_size *= dim; | ||||
| } | } | ||||
| dst_size *= data_size; | dst_size *= data_size; | ||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
| if (dst == nullptr) { | if (dst == nullptr) { | ||||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
| @@ -259,6 +269,11 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||||
| dst_size *= dim; | dst_size *= dim; | ||||
| } | } | ||||
| dst_size *= data_size; | dst_size *= data_size; | ||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
| if (dst == nullptr) { | if (dst == nullptr) { | ||||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
| @@ -117,6 +117,11 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
| Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | ||||
| int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
| int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
| if (dst == nullptr) { | if (dst == nullptr) { | ||||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
| @@ -153,8 +158,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||||
| auto src_offset = (src_h_head + w1_idx * w0) * size; | auto src_offset = (src_h_head + w1_idx * w0) * size; | ||||
| auto dst_offset = (h0_head + w1_idx * h0w0) * size; | auto dst_offset = (h0_head + w1_idx * h0w0) * size; | ||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
| static_cast<size_t>(size * w0)); | static_cast<size_t>(size * w0)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| @@ -169,8 +174,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||||
| auto src_offset = (src_h_head + src_w_idx) * size; | auto src_offset = (src_h_head + src_w_idx) * size; | ||||
| auto dst_offset = (w0_head + w0_idx) * size; | auto dst_offset = (w0_head + w0_idx) * size; | ||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
| static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| @@ -189,6 +194,11 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||||
| Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | ||||
| int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
| int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
| if (dst_size == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
| if (dst == nullptr) { | if (dst == nullptr) { | ||||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
| @@ -226,8 +236,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||||
| auto src_offset = (h0_head + w1_idx * h0w0) * size; | auto src_offset = (h0_head + w1_idx * h0w0) * size; | ||||
| auto dst_offset = (dst_h_head + w1_idx * w0) * size; | auto dst_offset = (dst_h_head + w1_idx * w0) * size; | ||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
| static_cast<size_t>(size * w0)); | static_cast<size_t>(size * w0)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| @@ -242,8 +252,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||||
| auto dst_w_idx = w1_head + w0_idx; | auto dst_w_idx = w1_head + w0_idx; | ||||
| auto dst_offset = (dst_h_head + dst_w_idx) * size; | auto dst_offset = (dst_h_head + dst_w_idx) * size; | ||||
| auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ? dst_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
| static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| @@ -133,6 +133,12 @@ Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult & | |||||
| int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
| if (total_size <= 0) { | if (total_size <= 0) { | ||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -133,6 +133,12 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||||
| int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
| if (total_size <= 0) { | if (total_size <= 0) { | ||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -140,6 +146,7 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||||
| GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | ||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
| ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
| if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | ||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
| @@ -132,6 +132,12 @@ Status FormatTransferFracZNhwc::TransFormat(const TransArgs &args, TransResult & | |||||
| int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
| if (total_size <= 0) { | if (total_size <= 0) { | ||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -35,7 +35,7 @@ Status TransShapeHwcnToC1hwncoc0(const DataType &data_type, const std::vector<in | |||||
| std::vector<int64_t> &dst_shape) { | std::vector<int64_t> &dst_shape) { | ||||
| auto cube_size = GetCubeSizeByDataType(data_type); | auto cube_size = GetCubeSizeByDataType(data_type); | ||||
| dst_shape.clear(); | dst_shape.clear(); | ||||
| dst_shape.push_back((src_shape.at(kHwcnC) - 1) / cube_size + 1); | |||||
| dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast<int64_t>(cube_size))); | |||||
| dst_shape.push_back(src_shape.at(kHwcnH)); | dst_shape.push_back(src_shape.at(kHwcnH)); | ||||
| dst_shape.push_back(src_shape.at(kHwcnW)); | dst_shape.push_back(src_shape.at(kHwcnW)); | ||||
| dst_shape.push_back(src_shape.at(kHwcnN)); | dst_shape.push_back(src_shape.at(kHwcnN)); | ||||
| @@ -169,6 +169,12 @@ Status FormatTransferHwcnC1hwncoc0::TransFormat(const TransArgs &args, TransResu | |||||
| int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
| if (total_size <= 0) { | if (total_size <= 0) { | ||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -58,7 +58,7 @@ Status CheckArgsForNc1hwc0ToNchw(const TransArgs &args) { | |||||
| } | } | ||||
| if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNchwH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNchwW) || | if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNchwH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNchwW) || | ||||
| src_shape.at(kNc1hwc0N) != dst_shape.at(kNchwN) || src_shape.at(kNc1hwc0C0) != c0 || | src_shape.at(kNc1hwc0N) != dst_shape.at(kNchwN) || src_shape.at(kNc1hwc0C0) != c0 || | ||||
| src_shape.at(kNc1hwc0C1) != (dst_shape.at(kNchwC) - 1) / c0 + 1) { | |||||
| src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNchwC), c0))) { | |||||
| GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | ||||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -102,8 +102,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
| auto src_offset = src_idx * size; | auto src_offset = src_idx * size; | ||||
| auto dst_offset = dst_idx * size; | auto dst_offset = dst_idx * size; | ||||
| auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
| ? total_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ? total_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
| static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| @@ -130,6 +130,12 @@ Status FormatTransferNc1hwc0Nchw::TransFormat(const TransArgs &args, TransResult | |||||
| int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
| if (total_size <= 0) { | if (total_size <= 0) { | ||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -58,7 +58,7 @@ Status CheckArgsForNc1hwc0ToNhwc(const TransArgs &args) { | |||||
| } | } | ||||
| if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNhwcH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNhwcW) || | if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNhwcH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNhwcW) || | ||||
| src_shape.at(kNc1hwc0N) != dst_shape.at(kNhwcN) || src_shape.at(kNc1hwc0C0) != c0 || | src_shape.at(kNc1hwc0N) != dst_shape.at(kNhwcN) || src_shape.at(kNc1hwc0C0) != c0 || | ||||
| src_shape.at(kNc1hwc0C1) != (dst_shape.at(kNhwcC) - 1) / c0 + 1) { | |||||
| src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNhwcC), c0))) { | |||||
| GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | ||||
| ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -102,8 +102,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
| auto src_offset = src_idx * size; | auto src_offset = src_idx * size; | ||||
| auto dst_offset = dst_idx * size; | auto dst_offset = dst_idx * size; | ||||
| auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
| ? total_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ? total_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
| static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| @@ -130,6 +130,12 @@ Status FormatTransferNc1hwc0Nhwc::TransFormat(const TransArgs &args, TransResult | |||||
| int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
| if (total_size <= 0) { | if (total_size <= 0) { | ||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -134,6 +134,10 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||||
| GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", total_ele_cnt, size); | GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", total_ele_cnt, size); | ||||
| return INTERNAL_ERROR); | return INTERNAL_ERROR); | ||||
| int64_t dst_size = total_ele_cnt * size; | int64_t dst_size = total_ele_cnt * size; | ||||
| if (dst_size == 0) { | |||||
| result.length = 0; | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
| if (dst == nullptr) { | if (dst == nullptr) { | ||||
| @@ -219,6 +223,10 @@ Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr<uin | |||||
| return INTERNAL_ERROR); | return INTERNAL_ERROR); | ||||
| int64_t dst_size = total_ele_cnt * size; | int64_t dst_size = total_ele_cnt * size; | ||||
| if (dst_size == 0) { | |||||
| return SUCCESS; | |||||
| } | |||||
| dst.reset(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | dst.reset(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
| if (dst == nullptr) { | if (dst == nullptr) { | ||||
| GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
| @@ -40,7 +40,7 @@ Status TransShapeNchwToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||||
| } | } | ||||
| dst_shape.clear(); | dst_shape.clear(); | ||||
| dst_shape.push_back(src_shape.at(kNchwN)); | dst_shape.push_back(src_shape.at(kNchwN)); | ||||
| dst_shape.push_back((src_shape.at(kNchwC) - 1) / c0 + 1); | |||||
| dst_shape.push_back(Ceil(src_shape.at(kNchwC), c0)); | |||||
| dst_shape.push_back(src_shape.at(kNchwH)); | dst_shape.push_back(src_shape.at(kNchwH)); | ||||
| dst_shape.push_back(src_shape.at(kNchwW)); | dst_shape.push_back(src_shape.at(kNchwW)); | ||||
| dst_shape.push_back(c0); | dst_shape.push_back(c0); | ||||
| @@ -74,25 +74,8 @@ Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace | |||||
| Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| if (CheckArgsForNchwToNc1hwc0(args) != SUCCESS) { | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| // Guarantee the validity of parameters in check function | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (total_size <= 0) { | |||||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| GELOGD( | |||||
| "Begin to trans format from NCHW to NC1HWC0, src shape %s, data type " | |||||
| "%s, dst shape %s memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
| if (dst == nullptr) { | if (dst == nullptr) { | ||||
| GELOGE(OUT_OF_MEMORY, | GELOGE(OUT_OF_MEMORY, | ||||
| @@ -132,8 +115,8 @@ Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult | |||||
| int64_t dst_index = c0_idx + w_head_addr; | int64_t dst_index = c0_idx + w_head_addr; | ||||
| int64_t dst_offset = dst_index * size; | int64_t dst_offset = dst_index * size; | ||||
| auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
| ? total_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ? total_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| int64_t cIdx = c0_idx + c1_idx * c0; | int64_t cIdx = c0_idx + c1_idx * c0; | ||||
| int64_t srcIdx = n_idx * chw + cIdx * hw + h_idx * w + w_idx; | int64_t srcIdx = n_idx * chw + cIdx * hw + h_idx * w + w_idx; | ||||
| auto src_offset = srcIdx * size; | auto src_offset = srcIdx * size; | ||||
| @@ -150,7 +133,7 @@ Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult | |||||
| } | } | ||||
| } else { | } else { | ||||
| auto ret = | auto ret = | ||||
| memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | |||||
| memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | |||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| GELOGE(INTERNAL_ERROR, | GELOGE(INTERNAL_ERROR, | ||||
| "Failed to set to 0 to " | "Failed to set to 0 to " | ||||
| @@ -169,6 +152,39 @@ Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult | |||||
| result.length = static_cast<size_t>(total_size); | result.length = static_cast<size_t>(total_size); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace | |||||
| Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { | |||||
| if (CheckArgsForNchwToNc1hwc0(args) != SUCCESS) { | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| // Guarantee the validity of parameters in check function | |||||
| int size = GetSizeByDataType(args.src_data_type); | |||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
| if (total_size <= 0) { | |||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| GELOGD( | |||||
| "Begin to trans format from NCHW to NC1HWC0, src shape %s, data type " | |||||
| "%s, dst shape %s memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||||
| GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
| ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
| ShapeToString(args.dst_shape).c_str(), total_size); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status FormatTransferNchwNc1hwc0::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | Status FormatTransferNchwNc1hwc0::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | ||||
| DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | ||||
| @@ -38,7 +38,7 @@ Status TransShapeNhwcToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||||
| } | } | ||||
| dst_shape.clear(); | dst_shape.clear(); | ||||
| dst_shape.push_back(src_shape.at(kNhwcN)); | dst_shape.push_back(src_shape.at(kNhwcN)); | ||||
| dst_shape.push_back((src_shape.at(kNhwcC) - 1) / c0 + 1); | |||||
| dst_shape.push_back(Ceil(src_shape.at(kNhwcC), c0)); | |||||
| dst_shape.push_back(src_shape.at(kNhwcH)); | dst_shape.push_back(src_shape.at(kNhwcH)); | ||||
| dst_shape.push_back(src_shape.at(kNhwcW)); | dst_shape.push_back(src_shape.at(kNhwcW)); | ||||
| dst_shape.push_back(c0); | dst_shape.push_back(c0); | ||||
| @@ -119,8 +119,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
| int64_t dst_idx = c0_idx + w_head_addr; | int64_t dst_idx = c0_idx + w_head_addr; | ||||
| int64_t dst_offset = dst_idx * size; | int64_t dst_offset = dst_idx * size; | ||||
| auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
| ? total_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ? total_size - dst_offset | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| int64_t c_idx = c0_idx + c1_idx * c0; | int64_t c_idx = c0_idx + c1_idx * c0; | ||||
| int64_t src_idx = n_idx * hwc + h_idx * wc + w_idx * c + c_idx; | int64_t src_idx = n_idx * hwc + h_idx * wc + w_idx * c + c_idx; | ||||
| auto src_offset = src_idx * size; | auto src_offset = src_idx * size; | ||||
| @@ -161,6 +161,12 @@ Status FormatTransferNhwcNc1hwc0::TransFormat(const TransArgs &args, TransResult | |||||
| int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
| auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
| if (total_size <= 0) { | if (total_size <= 0) { | ||||
| int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
| if (total_size == 0 && src_size == 0) { | |||||
| result.length = static_cast<size_t>(total_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
| ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| @@ -27,22 +27,22 @@ namespace ge { | |||||
| namespace formats { | namespace formats { | ||||
| namespace { | namespace { | ||||
| std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{ | std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{ | ||||
| {FORMAT_NCHW, | |||||
| {{FORMAT_NHWC, std::vector<int64_t>({0, 2, 3, 1})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({2, 3, 1, 0})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({1, 2, 3, 0})}}}, | |||||
| {FORMAT_NHWC, | |||||
| {{FORMAT_NCHW, std::vector<int64_t>({0, 3, 1, 2})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({3, 1, 2, 0})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({1, 2, 3, 0})}}}, | |||||
| {FORMAT_HWCN, | |||||
| {{FORMAT_NCHW, std::vector<int64_t>({3, 2, 0, 1})}, | |||||
| {FORMAT_NHWC, std::vector<int64_t>({3, 0, 1, 2})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({2, 0, 1, 3})}}}, | |||||
| {FORMAT_CHWN, | |||||
| {{FORMAT_NCHW, std::vector<int64_t>({3, 0, 1, 2})}, | |||||
| {FORMAT_NHWC, std::vector<int64_t>({3, 1, 2, 0})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({1, 2, 0, 3})}}}, | |||||
| {FORMAT_NCHW, | |||||
| {{FORMAT_NHWC, std::vector<int64_t>({0, 2, 3, 1})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({2, 3, 1, 0})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({1, 2, 3, 0})}}}, | |||||
| {FORMAT_NHWC, | |||||
| {{FORMAT_NCHW, std::vector<int64_t>({0, 3, 1, 2})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({3, 1, 2, 0})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({1, 2, 3, 0})}}}, | |||||
| {FORMAT_HWCN, | |||||
| {{FORMAT_NCHW, std::vector<int64_t>({3, 2, 0, 1})}, | |||||
| {FORMAT_NHWC, std::vector<int64_t>({3, 0, 1, 2})}, | |||||
| {FORMAT_CHWN, std::vector<int64_t>({2, 0, 1, 3})}}}, | |||||
| {FORMAT_CHWN, | |||||
| {{FORMAT_NCHW, std::vector<int64_t>({3, 0, 1, 2})}, | |||||
| {FORMAT_NHWC, std::vector<int64_t>({3, 1, 2, 0})}, | |||||
| {FORMAT_HWCN, std::vector<int64_t>({1, 2, 0, 3})}}}, | |||||
| }; | }; | ||||
| bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | ||||
| @@ -51,8 +51,8 @@ bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<in | |||||
| return false; | return false; | ||||
| } | } | ||||
| for (auto dim : src_shape) { | for (auto dim : src_shape) { | ||||
| if (dim <= 0) { | |||||
| GELOGE(PARAM_INVALID, "Failed to transpose, zero dim in src shape %s", ShapeToString(src_shape).c_str()); | |||||
| if (dim < 0) { | |||||
| GELOGE(PARAM_INVALID, "Failed to transpose, negative dim in src shape %s", ShapeToString(src_shape).c_str()); | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -146,20 +146,24 @@ Status Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, Data | |||||
| int64_t dst_ele_num = GetItemNumByShape(dst_shape); | int64_t dst_ele_num = GetItemNumByShape(dst_shape); | ||||
| int64_t data_size = GetSizeByDataType(src_data_type); | int64_t data_size = GetSizeByDataType(src_data_type); | ||||
| int64_t dst_size = data_size * dst_ele_num; | int64_t dst_size = data_size * dst_ele_num; | ||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
| GELOGD("Begin to transpose, src shape %s, perm arg %s, dst shape %s, data type %s", JoinToString(src_shape).c_str(), | GELOGD("Begin to transpose, src shape %s, perm arg %s, dst shape %s, data type %s", JoinToString(src_shape).c_str(), | ||||
| JoinToString(perm_arg).c_str(), JoinToString(dst_shape).c_str(), | JoinToString(perm_arg).c_str(), JoinToString(dst_shape).c_str(), | ||||
| TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | ||||
| if (dst_ele_num == 0) { | |||||
| result.length = static_cast<size_t>(dst_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
| int64_t dst_index = 0; | int64_t dst_index = 0; | ||||
| std::vector<int64_t> dst_indexes(dst_shape.size()); | std::vector<int64_t> dst_indexes(dst_shape.size()); | ||||
| while (dst_index < dst_ele_num) { | while (dst_index < dst_ele_num) { | ||||
| auto src_offset = GenOffset(src_heads, dst_indexes) * data_size; | auto src_offset = GenOffset(src_heads, dst_indexes) * data_size; | ||||
| auto dst_offset_bytes = dst_index * data_size; | auto dst_offset_bytes = dst_index * data_size; | ||||
| auto protected_size = dst_size - dst_offset_bytes < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset_bytes < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
| ? dst_size - dst_offset_bytes | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| ? dst_size - dst_offset_bytes | |||||
| : static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
| auto ret = memcpy_s(dst.get() + dst_offset_bytes, static_cast<size_t>(protected_size), src + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset_bytes, static_cast<size_t>(protected_size), src + src_offset, | ||||
| static_cast<size_t>(data_size)); | static_cast<size_t>(data_size)); | ||||
| if (ret != EOK) { | if (ret != EOK) { | ||||
| @@ -24,6 +24,7 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include "common/formats/utils/formats_trans_utils.h" | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| @@ -38,10 +39,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArg | |||||
| TypeUtils::FormatToSerialString(args.dst_format).c_str()); | TypeUtils::FormatToSerialString(args.dst_format).c_str()); | ||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| if (args.data == nullptr) { | |||||
| auto src_shape_size = GetItemNumByShape(args.src_shape); | |||||
| if (args.data == nullptr && src_shape_size != 0) { | |||||
| GELOGE(PARAM_INVALID, "Invalid input null data"); | GELOGE(PARAM_INVALID, "Invalid input null data"); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| return transfer->TransFormat(args, result); | return transfer->TransFormat(args, result); | ||||
| } | } | ||||
| @@ -71,6 +75,12 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastAr | |||||
| TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | ||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| if (args.data == nullptr && args.src_data_size != 0) { | |||||
| GELOGE(PARAM_INVALID, "Invalid input null data"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| return transfer->TransDataType(args, result); | return transfer->TransDataType(args, result); | ||||
| } | } | ||||
| @@ -69,11 +69,11 @@ bool IsShapeValid(const std::vector<int64_t> &shape) { | |||||
| } | } | ||||
| int64_t num = 1; | int64_t num = 1; | ||||
| for (auto dim : shape) { | for (auto dim : shape) { | ||||
| if (dim < 1) { | |||||
| GELOGE(PARAM_INVALID, "Invalid zero dim in the shape %s", ShapeToString(shape).c_str()); | |||||
| if (dim < 0) { | |||||
| GELOGE(PARAM_INVALID, "Invalid negative dim in the shape %s", ShapeToString(shape).c_str()); | |||||
| return false; | return false; | ||||
| } | } | ||||
| if (kShapeItemNumMAX / dim < num) { | |||||
| if (dim != 0 && kShapeItemNumMAX / dim < num) { | |||||
| GELOGE(PARAM_INVALID, "Shape overflow, the total count should be less than %ld!", kShapeItemNumMAX); | GELOGE(PARAM_INVALID, "Shape overflow, the total count should be less than %ld!", kShapeItemNumMAX); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -64,6 +64,9 @@ bool IsShapeEqual(const GeShape &src, const GeShape &dst); | |||||
| template <typename T> | template <typename T> | ||||
| T Ceil(T n1, T n2) { | T Ceil(T n1, T n2) { | ||||
| if (n1 == 0) { | |||||
| return 0; | |||||
| } | |||||
| return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; | return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; | ||||
| } | } | ||||
| @@ -0,0 +1,121 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ | |||||
| #define GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ | |||||
| #include <nlohmann/json.hpp> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include "ge/ge_api_error_codes.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/manager/graph_var_manager.h" | |||||
| #include "model/ge_model.h" | |||||
| namespace ge { | |||||
| using Json = nlohmann::json; | |||||
| struct CacheInfo { | |||||
| size_t node_num; | |||||
| size_t edge_num; | |||||
| size_t graph_hash; | |||||
| map<std::string, size_t> nodes_hash; | |||||
| CacheInfo() : node_num(0), edge_num(0), graph_hash(0) {} | |||||
| }; | |||||
| class ModelCacheHelper { | |||||
| public: | |||||
| ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph); | |||||
| Status SaveCacheInfoToCache() const; | |||||
| Status SaveVarManagerToCache(bool before_build) const; | |||||
| Status SaveOmModelToCache(const GeModelPtr &ge_model) const; | |||||
| bool IsModelCacheHit() const; | |||||
| Status RecoverVarManagerFromCache() const; | |||||
| Status LoadOmModelFromCache(GeModelPtr &ge_model) const; | |||||
| Status RefreshComputeGraph(const ComputeGraphPtr &compute_graph); | |||||
| Status ClearCache(uint32_t graph_id) const; | |||||
| private: | |||||
| Status GetComputeGraphHash(size_t &hash) const; | |||||
| Status GetNodesHash(map<std::string, size_t> &hash_map) const; | |||||
| Status GetCacheInfo(CacheInfo &cache_info) const; | |||||
| Status RecoverMemResource(const Json &json) const; | |||||
| Status RecoverAllocatedGraphId(const Json &json) const; | |||||
| Status RecoverChangedGraphId(const Json &json) const; | |||||
| Status RecoverVarAddrAndTensorDesc(const Json &json) const; | |||||
| Status RecoverBroadcastInfo(const Json &json) const; | |||||
| Status RecoverTransRoads(const Json &json) const; | |||||
| static Status RecompileNodes(GeModelPtr &ge_model); | |||||
| bool IsNodeHashSameAsCache(const map<std::string, size_t> &hash_map) const; | |||||
| bool IsMemResourceSameAsCache(Json &json) const; | |||||
| bool IsChangedGraphIdSameAsCache(Json &json) const; | |||||
| bool IsAllocatedGraphIdSameAsCache(Json &json) const; | |||||
| bool IsCurVarTensorDescSameAsCache(Json &json) const; | |||||
| bool IsVarAddrMgrMapSameAsCache(Json &json) const; | |||||
| bool IsBroadcastInfoSameAsCache(Json &json) const; | |||||
| bool IsTransRoadsSameAsCache(Json &json) const; | |||||
| bool IsVarManagerSameAsCache(Json &json) const; | |||||
| bool IsVarManagerParamSameAsCache(Json &json) const; | |||||
| Status SaveJsonToFile(const string &file_name, const Json &json) const; | |||||
| Status LoadJsonFromFile(const string &file_name, Json &json) const; | |||||
| Status GetNodesHashMapJson(Json &json) const; | |||||
| Status GetMemResourceMap(Json &json) const; | |||||
| Status GetVarAddrMgrMapJson(Json &json) const; | |||||
| Status GetCurVarTensorDescMapJson(Json &json) const; | |||||
| Status GetTransRoadsJson(Json &json) const; | |||||
| Status GetChangedGraphIdJson(Json &json) const; | |||||
| Status GetAllocatedGraphIdJson(Json &json) const; | |||||
| Status GetBroadcastInfoJson(Json &json) const; | |||||
| Status GetVarResourceJson(Json &json) const; | |||||
| Status GetVarManagerJson(Json &json) const; | |||||
| static Status TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json); | |||||
| static Status JsonToTensorDesc(const Json &json, GeTensorDesc &ge_tensor_desc); | |||||
| static Status ParseMemResourceFromJson(const Json &json, map<rtMemType_t, int64_t> &mem_resource); | |||||
| static Status ParseVarAddrMgrMapFromJson(const Json &json, | |||||
| std::vector<std::pair<std::string, VarAddrMgr>> &var_addr_mgr_vector, | |||||
| std::unordered_set<uint64_t> &var_offset_set); | |||||
| static Status ParseCurVarTensorDescMapFromJson( | |||||
| const Json &json, std::unordered_map<std::string, ge::GeTensorDesc> &cur_var_tensor_desc_map); | |||||
| static Status ParseTransRoadsFromJson(const Json &json, | |||||
| std::unordered_map<std::string, std::vector<TransNodeInfo>> &trans_roads); | |||||
| static Status ParseChangedGraphIdFromJson(const Json &json, | |||||
| std::unordered_map<std::string, uint32_t> &changed_graph_id); | |||||
| static Status ParseAllocatedGraphIdFromJson(const Json &json, | |||||
| std::unordered_map<std::string, uint32_t> &allocated_graph_id); | |||||
| static Status ParseBroadcastInfoFromJson(const Json &json, | |||||
| std::unordered_map<std::string, VarBroadCastInfo> &var_broadcast_info); | |||||
| static Status GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc, string &var_name); | |||||
| uint64_t session_id_; | |||||
| uint32_t graph_id_; | |||||
| string cache_path_; | |||||
| ComputeGraphPtr compute_graph_; | |||||
| std::set<string> var_names_; | |||||
| bool is_cache_path_valid_for_output; | |||||
| static map<uint32_t, uint32_t> graph_id_run_times_; | |||||
| }; | |||||
| using ModelCacheHelperPtr = std::shared_ptr<ModelCacheHelper>; | |||||
| } // namespace ge | |||||
| #endif // GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ | |||||
| @@ -385,6 +385,7 @@ REGISTER_OPTYPE_DEFINE(STREAMSWITCH, "StreamSwitch"); | |||||
| REGISTER_OPTYPE_DEFINE(STREAMSWITCHN, "StreamSwitchN"); | REGISTER_OPTYPE_DEFINE(STREAMSWITCHN, "StreamSwitchN"); | ||||
| REGISTER_OPTYPE_DEFINE(STREAMACTIVE, "StreamActive"); | REGISTER_OPTYPE_DEFINE(STREAMACTIVE, "StreamActive"); | ||||
| REGISTER_OPTYPE_DEFINE(MEMCPYASYNC, "MemcpyAsync"); | REGISTER_OPTYPE_DEFINE(MEMCPYASYNC, "MemcpyAsync"); | ||||
| REGISTER_OPTYPE_DEFINE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | |||||
| REGISTER_OPTYPE_DEFINE(STREAMMERGE, "StreamMerge"); | REGISTER_OPTYPE_DEFINE(STREAMMERGE, "StreamMerge"); | ||||
| REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); | REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); | ||||
| REGISTER_OPTYPE_DEFINE(SEND, "Send"); | REGISTER_OPTYPE_DEFINE(SEND, "Send"); | ||||
| @@ -392,6 +393,7 @@ REGISTER_OPTYPE_DEFINE(RECV, "Recv"); | |||||
| REGISTER_OPTYPE_DEFINE(LABELSET, "LabelSet"); | REGISTER_OPTYPE_DEFINE(LABELSET, "LabelSet"); | ||||
| REGISTER_OPTYPE_DEFINE(LABELGOTO, "LabelGoto"); | REGISTER_OPTYPE_DEFINE(LABELGOTO, "LabelGoto"); | ||||
| REGISTER_OPTYPE_DEFINE(LABELGOTOEX, "LabelGotoEx"); | |||||
| REGISTER_OPTYPE_DEFINE(LABELSWITCH, "LabelSwitch"); | REGISTER_OPTYPE_DEFINE(LABELSWITCH, "LabelSwitch"); | ||||
| REGISTER_OPTYPE_DEFINE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | REGISTER_OPTYPE_DEFINE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | ||||
| @@ -196,7 +196,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||||
| GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); | GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); | ||||
| auto dir_path_len = directory_path.length(); | auto dir_path_len = directory_path.length(); | ||||
| if (dir_path_len >= PATH_MAX) { | if (dir_path_len >= PATH_MAX) { | ||||
| GELOGE(ge::FAILED, "Directory path is too long."); | |||||
| GELOGW("Directory path is too long."); | |||||
| return -1; | return -1; | ||||
| } | } | ||||
| char tmp_dir_path[PATH_MAX] = {0}; | char tmp_dir_path[PATH_MAX] = {0}; | ||||
| @@ -207,7 +207,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||||
| int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 | int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| if (errno != EEXIST) { | if (errno != EEXIST) { | ||||
| GELOGE(ge::FAILED, "Cannot create directory %s. Make sure that the directory exists and writable.", | |||||
| GELOGW("Cannot create directory %s. Make sure that the directory exists and writable.", | |||||
| directory_path.c_str()); | directory_path.c_str()); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -218,8 +218,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||||
| int32_t ret = mmMkdir(const_cast<char *>(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 | int32_t ret = mmMkdir(const_cast<char *>(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 | ||||
| if (ret != 0) { | if (ret != 0) { | ||||
| if (errno != EEXIST) { | if (errno != EEXIST) { | ||||
| GELOGE(ge::FAILED, "Cannot create directory %s. Make sure that the directory exists and writable.", | |||||
| directory_path.c_str()); | |||||
| GELOGW("Cannot create directory %s. Make sure that the directory exists and writable.", directory_path.c_str()); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| } | } | ||||
| @@ -339,7 +338,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path) { | ||||
| // The specified path is empty | // The specified path is empty | ||||
| if (file_path.empty()) { | if (file_path.empty()) { | ||||
| GELOGE(ge::FAILED, "Path is empty."); | |||||
| GELOGW("Path is empty."); | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -358,23 +357,23 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const | |||||
| std::string real_path = RealPath(file_path.c_str()); | std::string real_path = RealPath(file_path.c_str()); | ||||
| // Unable to get absolute path (does not exist or does not have permission to access) | // Unable to get absolute path (does not exist or does not have permission to access) | ||||
| if (real_path.empty()) { | if (real_path.empty()) { | ||||
| GELOGE(ge::FAILED, "Can not get real path for %s, %s", file_path.c_str(), strerror(errno)); | |||||
| GELOGW("Can not get real path for %s, %s", file_path.c_str(), strerror(errno)); | |||||
| return false; | return false; | ||||
| } | } | ||||
| // The absolute path points to a file that is not readable | // The absolute path points to a file that is not readable | ||||
| if (access(real_path.c_str(), R_OK) != 0) { | if (access(real_path.c_str(), R_OK) != 0) { | ||||
| GELOGE(ge::FAILED, "Can not read file in %s, %s", file_path.c_str(), strerror(errno)); | |||||
| GELOGW("Can not read file in %s, %s", file_path.c_str(), strerror(errno)); | |||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| FMK_FUNC_HOST_VISIBILITY bool CheckOutputPathValid(const std::string &file_path) { | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const std::string &file_path) { | |||||
| // The specified path is empty | // The specified path is empty | ||||
| if (file_path.empty()) { | if (file_path.empty()) { | ||||
| GELOGE(ge::FAILED, "Path is empty."); | |||||
| GELOGW("Path is empty."); | |||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -394,8 +393,8 @@ FMK_FUNC_HOST_VISIBILITY bool CheckOutputPathValid(const std::string &file_path) | |||||
| // Can get absolute path (file exists) | // Can get absolute path (file exists) | ||||
| if (!real_path.empty()) { | if (!real_path.empty()) { | ||||
| // File is not readable or writable | // File is not readable or writable | ||||
| if (access(real_path.c_str(), R_OK | W_OK | F_OK) != 0) { | |||||
| GELOGE(ge::FAILED, "Path[ %s ] exists, but can not be write, %s", file_path.c_str(), strerror(errno)); | |||||
| if (access(real_path.c_str(), W_OK | F_OK) != 0) { | |||||
| GELOGW("Path[ %s ] exists, but can not be write, %s", file_path.c_str(), strerror(errno)); | |||||
| return false; | return false; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -413,7 +412,7 @@ FMK_FUNC_HOST_VISIBILITY bool CheckOutputPathValid(const std::string &file_path) | |||||
| std::string prefix_path = std::string(file_path).substr(0, static_cast<size_t>(path_split_pos)); | std::string prefix_path = std::string(file_path).substr(0, static_cast<size_t>(path_split_pos)); | ||||
| // Determine whether the specified path is valid by creating the path | // Determine whether the specified path is valid by creating the path | ||||
| if (CreateDirectory(prefix_path) != 0) { | if (CreateDirectory(prefix_path) != 0) { | ||||
| GELOGE(ge::FAILED, "Can not create prefix path for path[ %s ].", file_path.c_str()); | |||||
| GELOGW("Can not create prefix path for path[ %s ].", file_path.c_str()); | |||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| @@ -47,6 +47,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
| "../graph/load/new_model_manager/task_info/kernel_task_info.cc" | "../graph/load/new_model_manager/task_info/kernel_task_info.cc" | ||||
| "../graph/load/new_model_manager/task_info/label_goto_task_info.cc" | "../graph/load/new_model_manager/task_info/label_goto_task_info.cc" | ||||
| "../graph/load/new_model_manager/task_info/label_set_task_info.cc" | "../graph/load/new_model_manager/task_info/label_set_task_info.cc" | ||||
| "../graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||||
| "../graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | "../graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | ||||
| "../graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "../graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
| "../graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "../graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| #include "graph/operator.h" | #include "graph/operator.h" | ||||
| #include "register/register.h" | |||||
| #include "inc/register/register.h" | |||||
| namespace ge { | namespace ge { | ||||
| class HostCpuEngine { | class HostCpuEngine { | ||||
| @@ -76,7 +76,7 @@ bool Output::CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_inde | |||||
| DataBuffer data_buf = rslt->blobs[data_begin + data_count]; | DataBuffer data_buf = rslt->blobs[data_begin + data_count]; | ||||
| bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share); | bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share); | ||||
| if (!ret) { | if (!ret) { | ||||
| GELOGE(FAILED, "Copy data to host error. index: %lu", i); | |||||
| GELOGE(FAILED, "Copy data to host error. index: %lu, addr: %p", i, v_input_data_addr_[i]); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| data_index = data_begin + data_count; | data_index = data_begin + data_count; | ||||
| @@ -96,6 +96,7 @@ bool RuntimeModel::InitStream(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
| GELOGE(RT_FAILED, "Call rt api rtModelBindStream failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api rtModelBindStream failed, ret: 0x%X", rt_ret); | ||||
| return false; | return false; | ||||
| } | } | ||||
| GELOGI("stream index:%u, stream:%p.", i, stream); | |||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -446,8 +447,11 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||||
| /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | ||||
| /// and that of unknown shape is zero too. | /// and that of unknown shape is zero too. | ||||
| /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | ||||
| int64_t elem_num = | |||||
| (constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); | |||||
| int64_t elem_num = constant->weight_tensors[0].GetShapeSize(); | |||||
| if (elem_num == 0 && constant->weight_tensors[0].size == 0) { | |||||
| elem_num = 1; | |||||
| } | |||||
| if (constant->weight_data.size() < sizeof(uint64_t)) { | if (constant->weight_data.size() < sizeof(uint64_t)) { | ||||
| GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | ||||
| return false; | return false; | ||||
| @@ -82,6 +82,7 @@ bool CceTask::Distribute() { | |||||
| stub_func_ = nullptr; | stub_func_ = nullptr; | ||||
| return false; | return false; | ||||
| } | } | ||||
| GELOGI("CCETask: stub_func = %s [%p].", task_info_->stub_func().c_str(), stub_func_); | |||||
| // Flowtable | // Flowtable | ||||
| if (is_flowtable_) { | if (is_flowtable_) { | ||||
| @@ -43,6 +43,8 @@ EventRecordTask::EventRecordTask(const ModelContext &model_context, | |||||
| EventRecordTask::~EventRecordTask() {} | EventRecordTask::~EventRecordTask() {} | ||||
| bool EventRecordTask::Distribute() { | bool EventRecordTask::Distribute() { | ||||
| GELOGI("EventRecordTask Distribute start, stream: %p, event: %p, stream_id: %u, event_id: %u.", stream_, event_, | |||||
| task_info_->stream_id(), task_info_->event_id()); | |||||
| rtError_t rt_ret = rtEventRecord(event_, stream_); | rtError_t rt_ret = rtEventRecord(event_, stream_); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
| @@ -42,6 +42,9 @@ EventWaitTask::EventWaitTask(const ModelContext &model_context, const std::share | |||||
| EventWaitTask::~EventWaitTask() {} | EventWaitTask::~EventWaitTask() {} | ||||
| bool EventWaitTask::Distribute() { | bool EventWaitTask::Distribute() { | ||||
| GELOGI("EventWaitTask Distribute start, stream: %p, event: %p, stream_id: %u, event_id: %u.", stream_, event_, | |||||
| task_info_->stream_id(), task_info_->event_id()); | |||||
| rtError_t rt_ret = rtStreamWaitEvent(stream_, event_); | rtError_t rt_ret = rtStreamWaitEvent(stream_, event_); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Call rt api rtStreamWaitEvent failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api rtStreamWaitEvent failed, ret: 0x%X", rt_ret); | ||||
| @@ -101,6 +101,7 @@ bool HcclTask::Distribute() { | |||||
| char *private_def = reinterpret_cast<char *>(const_cast<char unsigned *>(task_info_->private_def().data())); | char *private_def = reinterpret_cast<char *>(const_cast<char unsigned *>(task_info_->private_def().data())); | ||||
| auto private_def_len = static_cast<uint32_t>(task_info_->private_def().size()); | auto private_def_len = static_cast<uint32_t>(task_info_->private_def().size()); | ||||
| GELOGI("the first address of the custom info, privateDef=%p", private_def); | |||||
| GELOGI("hcclStreamNum =%ld", task_info_->hccl_stream_num()); | GELOGI("hcclStreamNum =%ld", task_info_->hccl_stream_num()); | ||||
| for (int64_t i = 0; i < task_info_->hccl_stream_num(); ++i) { | for (int64_t i = 0; i < task_info_->hccl_stream_num(); ++i) { | ||||
| @@ -117,6 +118,7 @@ bool HcclTask::Distribute() { | |||||
| return false; | return false; | ||||
| } | } | ||||
| GELOGI("hccl_stream addr is=%p", stream); | |||||
| slave_stream_list_.push_back(stream); | slave_stream_list_.push_back(stream); | ||||
| } | } | ||||
| @@ -62,6 +62,9 @@ bool StreamSwitchTask::Distribute() { | |||||
| rtStream_t true_stream = stream_list_[task_info_->true_stream_id()]; | rtStream_t true_stream = stream_list_[task_info_->true_stream_id()]; | ||||
| rtSwitchDataType_t data_type = static_cast<rtSwitchDataType_t>(task_info_->data_type()); | rtSwitchDataType_t data_type = static_cast<rtSwitchDataType_t>(task_info_->data_type()); | ||||
| GELOGI("InitStreamSwitchTask, cond:%d, trueStream:%p, trueStreamID:%ld, datatype:%ld.", cond, true_stream, | |||||
| task_info_->true_stream_id(), task_info_->data_type()); | |||||
| GELOGI("StreamSwitchTask Distribute Start."); | GELOGI("StreamSwitchTask Distribute Start."); | ||||
| rtError_t rt_ret = rtStreamSwitchEx(input, cond, value, true_stream, stream_, data_type); | rtError_t rt_ret = rtStreamSwitchEx(input, cond, value, true_stream, stream_, data_type); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| @@ -69,6 +72,7 @@ bool StreamSwitchTask::Distribute() { | |||||
| return false; | return false; | ||||
| } | } | ||||
| GELOGI("Distribute StreamSwitch, cond:%d, trueStream:%p, datatype:%ld.", cond, true_stream, task_info_->data_type()); | |||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -69,6 +69,7 @@ bool TbeTask::Distribute() { | |||||
| stub_func_ = nullptr; | stub_func_ = nullptr; | ||||
| return false; | return false; | ||||
| } | } | ||||
| GELOGI("TbeTask: stub_func = %s [%p].", task_info_->stub_func().c_str(), stub_func_); | |||||
| // Get args | // Get args | ||||
| std::vector<void *> tensor_device_addrs; | std::vector<void *> tensor_device_addrs; | ||||
| @@ -18,8 +18,8 @@ | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "common/helper/model_helper.h" | #include "common/helper/model_helper.h" | ||||
| #include "common/opskernel/ops_kernel_info_types.h" | #include "common/opskernel/ops_kernel_info_types.h" | ||||
| #include "graph/build/stream_graph_optimizer.h" | |||||
| #include "graph/build/run_context.h" | #include "graph/build/run_context.h" | ||||
| #include "graph/build/stream_graph_optimizer.h" | |||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| @@ -98,8 +98,10 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfo | |||||
| Status ret = SecondPartition(comp_graph, subgraph_ptr_list); | Status ret = SecondPartition(comp_graph, subgraph_ptr_list); | ||||
| GE_CHK_STATUS_RET(ret, "Graph second partition Failed."); | GE_CHK_STATUS_RET(ret, "Graph second partition Failed."); | ||||
| auto subgraph_map = graph_partitioner_.GetSubGraphMap(); | |||||
| GE_TIMESTAMP_START(BuildSubgraph); | GE_TIMESTAMP_START(BuildSubgraph); | ||||
| ge::ModelBuilder builder(comp_graph, subgraph_ptr_list, stream_max_parallel_num_, hcom_parallel_, build_mode_); | |||||
| ge::ModelBuilder builder(comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); | |||||
| GELOGI("[Build] invoke the other opskernel to generate task."); | GELOGI("[Build] invoke the other opskernel to generate task."); | ||||
| @@ -135,7 +137,7 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfo | |||||
| } | } | ||||
| GE_TIMESTAMP_START(GetTaskInfo); | GE_TIMESTAMP_START(GetTaskInfo); | ||||
| ret = GetTaskInfo(builder, model_ptr, comp_graph, subgraph_ptr_list, session_id); | |||||
| ret = GetTaskInfo(builder, model_ptr, comp_graph, subgraph_map, session_id); | |||||
| GE_TIMESTAMP_END(GetTaskInfo, "GraphBuilder::GetTaskInfo"); | GE_TIMESTAMP_END(GetTaskInfo, "GraphBuilder::GetTaskInfo"); | ||||
| GraphUtils::DumpGEGraph(comp_graph, "AfterGetTask"); | GraphUtils::DumpGEGraph(comp_graph, "AfterGetTask"); | ||||
| @@ -155,7 +157,7 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfo | |||||
| } | } | ||||
| Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr &model_ptr, | Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr &model_ptr, | ||||
| ComputeGraphPtr &comp_graph, std::vector<SubGraphInfoPtr> &subgraph_ptr_list, | |||||
| ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map, | |||||
| uint64_t session_id) { | uint64_t session_id) { | ||||
| GE_CHECK_NOTNULL(model_ptr); | GE_CHECK_NOTNULL(model_ptr); | ||||
| GE_CHECK_NOTNULL(comp_graph); | GE_CHECK_NOTNULL(comp_graph); | ||||
| @@ -190,7 +192,7 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr | |||||
| } | } | ||||
| StreamGraphOptimizer stream_optimizer; | StreamGraphOptimizer stream_optimizer; | ||||
| ret = stream_optimizer.OptimizeStreamedSubGraph(comp_graph, subgraph_ptr_list, run_context.GetRunContext()); | |||||
| ret = stream_optimizer.OptimizeStreamedSubGraph(comp_graph, subgraph_map, run_context.GetRunContext()); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Optimize streamed subGraph fail."); | GELOGE(ret, "Optimize streamed subGraph fail."); | ||||
| return ret; | return ret; | ||||
| @@ -53,7 +53,7 @@ class GraphBuilder { | |||||
| private: | private: | ||||
| Status CalcOpParam(const ge::ComputeGraphPtr &graph); | Status CalcOpParam(const ge::ComputeGraphPtr &graph); | ||||
| Status GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr &model_ptr, ComputeGraphPtr &comp_graph, | Status GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr &model_ptr, ComputeGraphPtr &comp_graph, | ||||
| std::vector<SubGraphInfoPtr> &subgraph_ptr_list, uint64_t session_id = INVALID_SESSION_ID); | |||||
| Graph2SubGraphInfoList &subgraph_map, uint64_t session_id = INVALID_SESSION_ID); | |||||
| Status SetInputSize(const ge::NodePtr &node_ptr); | Status SetInputSize(const ge::NodePtr &node_ptr); | ||||
| Status UpdateDataInputSize(const ge::NodePtr &node_ptr); | Status UpdateDataInputSize(const ge::NodePtr &node_ptr); | ||||
| Status SecondPartition(ge::ComputeGraphPtr &comp_graph, vector<ge::SubGraphInfoPtr> &subgraph_ptr_list); | Status SecondPartition(ge::ComputeGraphPtr &comp_graph, vector<ge::SubGraphInfoPtr> &subgraph_ptr_list); | ||||
| @@ -70,7 +70,7 @@ bool LogicalStreamPass::HasNonConstInputNode(const Subgraph &subgraph) const { | |||||
| return false; | return false; | ||||
| } | } | ||||
| Status AssignByLabelPass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
| Status AssignByLabelPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
| bool changed = false; | bool changed = false; | ||||
| int64_t &next_stream = context.next_stream; | int64_t &next_stream = context.next_stream; | ||||
| map<string, int64_t> label_streams; | map<string, int64_t> label_streams; | ||||
| @@ -97,7 +97,7 @@ Status AssignByLabelPass::Run(ComputeGraphPtr whole_graph, const vector<Subgraph | |||||
| return changed ? SUCCESS : NOT_CHANGED; | return changed ? SUCCESS : NOT_CHANGED; | ||||
| } | } | ||||
| Status IndependentStreamPass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
| Status IndependentStreamPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
| bool changed = false; | bool changed = false; | ||||
| int64_t &next_stream = context.next_stream; | int64_t &next_stream = context.next_stream; | ||||
| @@ -129,8 +129,7 @@ Status IndependentStreamPass::Run(ComputeGraphPtr whole_graph, const vector<Subg | |||||
| return changed ? SUCCESS : NOT_CHANGED; | return changed ? SUCCESS : NOT_CHANGED; | ||||
| } | } | ||||
| Status AssignByDependencyPass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, | |||||
| Context &context) { | |||||
| Status AssignByDependencyPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
| bool changed = false; | bool changed = false; | ||||
| if (IsHeadNodeExceeded(subgraphs)) { | if (IsHeadNodeExceeded(subgraphs)) { | ||||
| int64_t &next_stream = context.next_stream; | int64_t &next_stream = context.next_stream; | ||||
| @@ -298,7 +297,7 @@ int64_t AssignByDependencyPass::AssignNewStream(SubgraphPtr subgraph) { | |||||
| subgraph->stream_id = stream_id; | subgraph->stream_id = stream_id; | ||||
| engine_next_streams_[engine_name] = stream_id + 1; | engine_next_streams_[engine_name] = stream_id + 1; | ||||
| assigned_subgraphs_.emplace(subgraph); | |||||
| assigned_subgraphs_.emplace_back(subgraph); | |||||
| if ((stream_id + 1) > engine_stream_num_[engine_name]) { | if ((stream_id + 1) > engine_stream_num_[engine_name]) { | ||||
| engine_stream_num_[engine_name] = stream_id + 1; | engine_stream_num_[engine_name] = stream_id + 1; | ||||
| @@ -311,6 +310,15 @@ int64_t AssignByDependencyPass::AssignNewStream(SubgraphPtr subgraph) { | |||||
| } | } | ||||
| void AssignByDependencyPass::UpdateAssignedSubgraphs(Context &context) { | void AssignByDependencyPass::UpdateAssignedSubgraphs(Context &context) { | ||||
| // If the parent stream is valid, the first assigned stream will reuse the parent stream id | |||||
| // and other streams use new id. To ensure that the id of the new stream is continuous, | |||||
| // we first subtract one from next_stream. | |||||
| int64_t to_be_updated_stream = kInvalidStream; | |||||
| if (context.parent_stream != kInvalidStream) { | |||||
| context.next_stream--; | |||||
| to_be_updated_stream = context.next_stream; | |||||
| } | |||||
| // Update the starting stream id for each engine. | // Update the starting stream id for each engine. | ||||
| int64_t &next_stream = context.next_stream; | int64_t &next_stream = context.next_stream; | ||||
| map<string, int64_t> engine_start_streams; | map<string, int64_t> engine_start_streams; | ||||
| @@ -320,10 +328,16 @@ void AssignByDependencyPass::UpdateAssignedSubgraphs(Context &context) { | |||||
| next_stream += stream_count; | next_stream += stream_count; | ||||
| } | } | ||||
| // Update the subgraphs assigned by the engine. | |||||
| // Update the subgraph streams assigned by engine. | |||||
| for (auto &subgraph : assigned_subgraphs_) { | for (auto &subgraph : assigned_subgraphs_) { | ||||
| subgraph->stream_id += engine_start_streams[subgraph->engine_conf.id]; | subgraph->stream_id += engine_start_streams[subgraph->engine_conf.id]; | ||||
| GELOGI("Stream of subgraph %s has been updated to %ld.", subgraph->name.c_str(), subgraph->stream_id); | |||||
| if (subgraph->stream_id == to_be_updated_stream) { | |||||
| subgraph->stream_id = context.parent_stream; | |||||
| GELOGI("Subgraph %s of engine %s reuses parent stream %ld.", subgraph->name.c_str(), | |||||
| subgraph->engine_conf.id.c_str(), context.parent_stream); | |||||
| } else { | |||||
| GELOGI("Stream of subgraph %s has been updated to %ld.", subgraph->name.c_str(), subgraph->stream_id); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| @@ -337,7 +351,7 @@ void AssignByDependencyPass::UpdateReusedSubgraphs() { | |||||
| } | } | ||||
| } | } | ||||
| Status NodeStreamUpdatePass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
| Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
| // Check if all subgraphs have been assigned a stream. | // Check if all subgraphs have been assigned a stream. | ||||
| for (const SubgraphPtr &subgraph : subgraphs) { | for (const SubgraphPtr &subgraph : subgraphs) { | ||||
| const string &engine_name = subgraph->engine_conf.id; | const string &engine_name = subgraph->engine_conf.id; | ||||
| @@ -353,7 +367,7 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr whole_graph, const vector<Subgr | |||||
| } | } | ||||
| // Init the stream id of node. | // Init the stream id of node. | ||||
| for (NodePtr &node : whole_graph->GetDirectNode()) { | |||||
| for (NodePtr &node : graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
| node->GetOpDesc()->SetStreamId(kInvalidStream); | node->GetOpDesc()->SetStreamId(kInvalidStream); | ||||
| } | } | ||||
| @@ -375,76 +389,11 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr whole_graph, const vector<Subgr | |||||
| } | } | ||||
| // Update stream id for nodes belong to skipped engine subgraph | // Update stream id for nodes belong to skipped engine subgraph | ||||
| GE_CHK_STATUS_RET(UpdateForSkippedEngine(whole_graph, subgraphs)); | |||||
| RefreshContinuousStreams(whole_graph, context); | |||||
| GE_CHK_STATUS_RET(UpdateForSkippedEngine(graph, subgraphs)); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status AllReduceParallelPass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
| if (!context.hcom_parallel) { | |||||
| return NOT_CHANGED; | |||||
| } | |||||
| GELOGI("AllReduceParallelPass is enabled."); | |||||
| GraphUtils::DumpGEGraph(whole_graph, "BeforeAllReduceParallel"); | |||||
| // All successors of HcomAllReduce. | |||||
| set<NodePtr> all_reduce_succs; | |||||
| for (const NodePtr &node : whole_graph->GetDirectNode()) { | |||||
| if (node->GetType() != HCOMALLREDUCE || node->GetInDataNodes().size() <= 1) { | |||||
| continue; | |||||
| } | |||||
| string reduce_stream_label; | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| // ATTR_NAME_STREAM_LABEL is optional. | |||||
| (void)AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, reduce_stream_label); | |||||
| set<NodePtr> cur_nodes = {node}; | |||||
| while (!cur_nodes.empty()) { | |||||
| set<NodePtr> all_out_data_nodes; | |||||
| for (auto &curr_node : cur_nodes) { | |||||
| for (const NodePtr &out_node : curr_node->GetOutDataNodes()) { | |||||
| string out_stream_label; | |||||
| GE_CHECK_NOTNULL(out_node->GetOpDesc()); | |||||
| // ATTR_NAME_STREAM_LABEL is optional. | |||||
| (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, out_stream_label); | |||||
| if (out_stream_label == reduce_stream_label) { | |||||
| all_reduce_succs.emplace(out_node); | |||||
| all_out_data_nodes.emplace(out_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| cur_nodes = all_out_data_nodes; | |||||
| } | |||||
| } | |||||
| map<int64_t, int64_t> old_stream_to_new; | |||||
| for (const NodePtr &node : all_reduce_succs) { | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| auto old_stream = node->GetOpDesc()->GetStreamId(); | |||||
| if (old_stream != kInvalidStream) { | |||||
| int64_t new_stream = kInvalidStream; | |||||
| auto iter = old_stream_to_new.find(old_stream); | |||||
| if (iter != old_stream_to_new.end()) { | |||||
| new_stream = iter->second; | |||||
| } else { | |||||
| new_stream = context.next_stream; | |||||
| context.next_stream++; | |||||
| old_stream_to_new.emplace(old_stream, new_stream); | |||||
| } | |||||
| GELOGI("Stream of node %s has been updated from %ld to %ld.", node->GetName().c_str(), old_stream, new_stream); | |||||
| node->GetOpDesc()->SetStreamId(new_stream); | |||||
| } | |||||
| } | |||||
| return !all_reduce_succs.empty() ? SUCCESS : NOT_CHANGED; | |||||
| } | |||||
| int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { | int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { | ||||
| set<int64_t> stream_ids; | set<int64_t> stream_ids; | ||||
| @@ -472,11 +421,11 @@ int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { | |||||
| return kInvalidStream; | return kInvalidStream; | ||||
| } | } | ||||
| Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &whole_graph, | |||||
| Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph, | |||||
| const vector<SubgraphPtr> &subgraphs) { | const vector<SubgraphPtr> &subgraphs) { | ||||
| set<OpDescPtr> nodes_to_be_updated; | set<OpDescPtr> nodes_to_be_updated; | ||||
| // Check if sub graph is engine skipped and without stream label or not | |||||
| // Check if subgraph is engine skipped and without stream label or not | |||||
| for (const SubgraphPtr &subgraph : subgraphs) { | for (const SubgraphPtr &subgraph : subgraphs) { | ||||
| if (IsEngineSkip(*subgraph) && !HasStreamLabel(*subgraph)) { | if (IsEngineSkip(*subgraph) && !HasStreamLabel(*subgraph)) { | ||||
| auto graph = subgraph->subgraph_info.GetSubGraph(); | auto graph = subgraph->subgraph_info.GetSubGraph(); | ||||
| @@ -492,7 +441,7 @@ Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &whole | |||||
| } | } | ||||
| // Try reassign the stream id | // Try reassign the stream id | ||||
| for (ge::NodePtr &node : whole_graph->GetDirectNode()) { | |||||
| for (ge::NodePtr &node : graph->GetDirectNode()) { | |||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| int64_t stream_id = op_desc->GetStreamId(); | int64_t stream_id = op_desc->GetStreamId(); | ||||
| @@ -509,6 +458,7 @@ Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &whole | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -525,40 +475,65 @@ bool NodeStreamUpdatePass::AreAllPredStreamsInvalid(const NodePtr &node) const { | |||||
| return true; | return true; | ||||
| } | } | ||||
| void NodeStreamUpdatePass::RefreshContinuousStreams(ComputeGraphPtr whole_graph, Context &context) const { | |||||
| int64_t stream_num = context.next_stream; | |||||
| vector<bool> stream_has_node(stream_num); | |||||
| Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
| if (!context.hcom_parallel) { | |||||
| return NOT_CHANGED; | |||||
| } | |||||
| for (const NodePtr &node : whole_graph->GetDirectNode()) { | |||||
| if (node != nullptr) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| if (op_desc != nullptr) { | |||||
| int64_t stream_id = op_desc->GetStreamId(); | |||||
| if (stream_id != kInvalidStream && stream_id < stream_num) { | |||||
| stream_has_node[stream_id] = true; | |||||
| } | |||||
| } | |||||
| GELOGI("AllReduceParallelPass is enabled."); | |||||
| GraphUtils::DumpGEGraph(graph, "BeforeAllReduceParallel"); | |||||
| // All successors of HcomAllReduce. | |||||
| set<NodePtr> all_reduce_succs; | |||||
| for (const NodePtr &node : graph->GetDirectNode()) { | |||||
| if (node->GetType() != HCOMALLREDUCE || node->GetInDataNodes().size() <= 1) { | |||||
| continue; | |||||
| } | } | ||||
| } | |||||
| context.next_stream = 0; | |||||
| vector<int64_t> old_to_new_streams(stream_num, kInvalidStream); | |||||
| for (size_t old_stream = 0; old_stream < stream_has_node.size(); ++old_stream) { | |||||
| if (stream_has_node[old_stream]) { | |||||
| old_to_new_streams[old_stream] = context.next_stream; | |||||
| ++context.next_stream; | |||||
| string reduce_stream_label; | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| (void)AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, reduce_stream_label); | |||||
| set<NodePtr> cur_nodes = {node}; | |||||
| while (!cur_nodes.empty()) { | |||||
| set<NodePtr> all_out_data_nodes; | |||||
| for (auto &curr_node : cur_nodes) { | |||||
| for (const NodePtr &out_node : curr_node->GetOutDataNodes()) { | |||||
| string out_stream_label; | |||||
| GE_CHECK_NOTNULL(out_node->GetOpDesc()); | |||||
| (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, out_stream_label); | |||||
| if (out_stream_label == reduce_stream_label) { | |||||
| all_reduce_succs.emplace(out_node); | |||||
| all_out_data_nodes.emplace(out_node); | |||||
| } | |||||
| } | |||||
| } | |||||
| cur_nodes = all_out_data_nodes; | |||||
| } | } | ||||
| } | } | ||||
| for (const NodePtr &node : whole_graph->GetDirectNode()) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| if (op_desc != nullptr) { | |||||
| int64_t stream_id = op_desc->GetStreamId(); | |||||
| if (stream_id != kInvalidStream && stream_id < stream_num) { | |||||
| op_desc->SetStreamId(old_to_new_streams[stream_id]); | |||||
| map<int64_t, int64_t> old_stream_to_new; | |||||
| for (const NodePtr &node : all_reduce_succs) { | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| auto old_stream = node->GetOpDesc()->GetStreamId(); | |||||
| if (old_stream != kInvalidStream) { | |||||
| int64_t new_stream = kInvalidStream; | |||||
| auto iter = old_stream_to_new.find(old_stream); | |||||
| if (iter != old_stream_to_new.end()) { | |||||
| new_stream = iter->second; | |||||
| } else { | |||||
| new_stream = context.next_stream; | |||||
| context.next_stream++; | |||||
| old_stream_to_new.emplace(old_stream, new_stream); | |||||
| } | } | ||||
| GELOGI("Stream of node %s has been updated from %ld to %ld.", node->GetName().c_str(), old_stream, new_stream); | |||||
| node->GetOpDesc()->SetStreamId(new_stream); | |||||
| } | } | ||||
| } | } | ||||
| return !all_reduce_succs.empty() ? SUCCESS : NOT_CHANGED; | |||||
| } | } | ||||
| LogicalStreamAllocator::LogicalStreamAllocator(const map<string, SchedulerConf> &scheduler_confs, | LogicalStreamAllocator::LogicalStreamAllocator(const map<string, SchedulerConf> &scheduler_confs, | ||||
| @@ -567,9 +542,10 @@ LogicalStreamAllocator::LogicalStreamAllocator(const map<string, SchedulerConf> | |||||
| context_.hcom_parallel = hcom_parallel; | context_.hcom_parallel = hcom_parallel; | ||||
| } | } | ||||
| Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &whole_graph, const vector<SubGraphInfoPtr> &subgraph_infos, | |||||
| Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &whole_graph, const Graph2SubGraphInfoList &subgraph_map, | |||||
| int64_t &stream_num) { | int64_t &stream_num) { | ||||
| GE_CHECK_NOTNULL(whole_graph); | GE_CHECK_NOTNULL(whole_graph); | ||||
| map<string, EngineConfPtr> engine_confs; | map<string, EngineConfPtr> engine_confs; | ||||
| GE_TIMESTAMP_START(InitEngineConfs); | GE_TIMESTAMP_START(InitEngineConfs); | ||||
| for (const auto &item : scheduler_confs_) { | for (const auto &item : scheduler_confs_) { | ||||
| @@ -583,16 +559,64 @@ Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &whole_graph, const | |||||
| } | } | ||||
| GE_TIMESTAMP_END(InitEngineConfs, "GraphBuilder::AssignStreamInitEngineConfs"); | GE_TIMESTAMP_END(InitEngineConfs, "GraphBuilder::AssignStreamInitEngineConfs"); | ||||
| Status status = DoAssign(whole_graph, subgraph_map, engine_confs); | |||||
| if (status != SUCCESS) { | |||||
| GELOGE(status, "Assign streams failed."); | |||||
| return status; | |||||
| } | |||||
| vector<ComputeGraphPtr> subgraphs = whole_graph->GetAllSubgraphs(); | |||||
| for (const ComputeGraphPtr &subgraph : subgraphs) { | |||||
| Status status = DoAssign(subgraph, subgraph_map, engine_confs); | |||||
| if (status != SUCCESS) { | |||||
| GELOGE(status, "Assign streams failed."); | |||||
| return status; | |||||
| } | |||||
| } | |||||
| RefreshContinuousStreams(whole_graph); | |||||
| stream_num = context_.next_stream; | |||||
| GELOGI("Assigned logical stream num: %ld.", stream_num); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status LogicalStreamAllocator::DoAssign(const ComputeGraphPtr &graph, const Graph2SubGraphInfoList &subgraph_map, | |||||
| const map<string, EngineConfPtr> &engine_confs) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| NodePtr parent_node = graph->GetParentNode(); | |||||
| if (parent_node == nullptr || parent_node->GetOpDesc() == nullptr) { | |||||
| context_.parent_stream = kInvalidStream; | |||||
| } else { | |||||
| context_.parent_stream = parent_node->GetOpDesc()->GetStreamId(); | |||||
| } | |||||
| auto iter = subgraph_map.find(graph); | |||||
| if (iter == subgraph_map.end()) { | |||||
| GELOGE(FAILED, "Graph %s not found.", graph->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| const vector<SubGraphInfoPtr> &subgraph_info_list = iter->second; | |||||
| vector<SubgraphPtr> subgraphs; | vector<SubgraphPtr> subgraphs; | ||||
| GE_TIMESTAMP_START(ConvertSubgraphs); | GE_TIMESTAMP_START(ConvertSubgraphs); | ||||
| Status status = ConvertSubgraphs(subgraph_infos, engine_confs, subgraphs); | |||||
| Status status = ConvertSubgraphs(subgraph_info_list, engine_confs, subgraphs); | |||||
| GE_TIMESTAMP_END(ConvertSubgraphs, "GraphBuilder::AssignStreamConvertSubgraphs"); | GE_TIMESTAMP_END(ConvertSubgraphs, "GraphBuilder::AssignStreamConvertSubgraphs"); | ||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| GELOGE(status, "Create subgraphs failed."); | GELOGE(status, "Create subgraphs failed."); | ||||
| return status; | return status; | ||||
| } | } | ||||
| return RunPasses(whole_graph, subgraphs, stream_num); | |||||
| GELOGI("Subgraphs of graph %s:", graph->GetName().c_str()); | |||||
| for (const auto &subgraph : subgraphs) { | |||||
| if (subgraph != nullptr) { | |||||
| GELOGI("subgraph: %s", subgraph->name.c_str()); | |||||
| } | |||||
| } | |||||
| return RunPasses(graph, subgraphs); | |||||
| } | } | ||||
| Status LogicalStreamAllocator::ConvertSubgraphs(const vector<SubGraphInfoPtr> &subgraph_infos, | Status LogicalStreamAllocator::ConvertSubgraphs(const vector<SubGraphInfoPtr> &subgraph_infos, | ||||
| @@ -631,8 +655,7 @@ Status LogicalStreamAllocator::ConvertSubgraphs(const vector<SubGraphInfoPtr> &s | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &whole_graph, const vector<SubgraphPtr> &subgraphs, | |||||
| int64_t &stream_num) { | |||||
| Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vector<SubgraphPtr> &subgraphs) { | |||||
| vector<LogicalStreamPassPtr> passes; | vector<LogicalStreamPassPtr> passes; | ||||
| passes.emplace_back(MakeShared<AssignByLabelPass>()); | passes.emplace_back(MakeShared<AssignByLabelPass>()); | ||||
| passes.emplace_back(MakeShared<IndependentStreamPass>()); | passes.emplace_back(MakeShared<IndependentStreamPass>()); | ||||
| @@ -643,7 +666,7 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &whole_graph, con | |||||
| for (auto &pass : passes) { | for (auto &pass : passes) { | ||||
| GE_CHECK_NOTNULL(pass); | GE_CHECK_NOTNULL(pass); | ||||
| Status status = pass->Run(whole_graph, subgraphs, context_); | |||||
| Status status = pass->Run(graph, subgraphs, context_); | |||||
| if (status == SUCCESS) { | if (status == SUCCESS) { | ||||
| GELOGI("Stream pass %s return SUCCESS.", pass->GetName().c_str()); | GELOGI("Stream pass %s return SUCCESS.", pass->GetName().c_str()); | ||||
| } else if (status == NOT_CHANGED) { | } else if (status == NOT_CHANGED) { | ||||
| @@ -654,9 +677,42 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &whole_graph, con | |||||
| } | } | ||||
| } | } | ||||
| stream_num = context_.next_stream; | |||||
| GELOGI("Assigned logical stream num: %ld.", stream_num); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void LogicalStreamAllocator::RefreshContinuousStreams(const ComputeGraphPtr &graph) { | |||||
| int64_t stream_num = context_.next_stream; | |||||
| vector<bool> stream_has_node(stream_num); | |||||
| for (const NodePtr &node : graph->GetAllNodes()) { | |||||
| if (node != nullptr) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| if (op_desc != nullptr) { | |||||
| int64_t stream_id = op_desc->GetStreamId(); | |||||
| if (stream_id != kInvalidStream && stream_id < stream_num) { | |||||
| stream_has_node[stream_id] = true; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| context_.next_stream = 0; | |||||
| vector<int64_t> old_to_new_streams(stream_num, kInvalidStream); | |||||
| for (size_t old_stream = 0; old_stream < stream_has_node.size(); ++old_stream) { | |||||
| if (stream_has_node[old_stream]) { | |||||
| old_to_new_streams[old_stream] = context_.next_stream; | |||||
| ++context_.next_stream; | |||||
| } | |||||
| } | |||||
| for (const NodePtr &node : graph->GetAllNodes()) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| if (op_desc != nullptr) { | |||||
| int64_t stream_id = op_desc->GetStreamId(); | |||||
| if (stream_id != kInvalidStream && stream_id < stream_num) { | |||||
| op_desc->SetStreamId(old_to_new_streams[stream_id]); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -60,7 +60,7 @@ class LogicalStreamPass { | |||||
| }; | }; | ||||
| struct Context { | struct Context { | ||||
| // Next stream id. | |||||
| int64_t parent_stream = kInvalidStream; | |||||
| int64_t next_stream = 0; | int64_t next_stream = 0; | ||||
| bool hcom_parallel = false; | bool hcom_parallel = false; | ||||
| }; | }; | ||||
| @@ -71,7 +71,7 @@ class LogicalStreamPass { | |||||
| virtual ~LogicalStreamPass() = default; | virtual ~LogicalStreamPass() = default; | ||||
| const std::string &GetName() const; | const std::string &GetName() const; | ||||
| virtual Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) = 0; | |||||
| virtual Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) = 0; | |||||
| protected: | protected: | ||||
| bool IsEngineSkip(const Subgraph &subgraph) const; | bool IsEngineSkip(const Subgraph &subgraph) const; | ||||
| @@ -93,21 +93,21 @@ using LogicalStreamPassPtr = std::shared_ptr<LogicalStreamPass>; | |||||
| class AssignByLabelPass : public LogicalStreamPass { | class AssignByLabelPass : public LogicalStreamPass { | ||||
| public: | public: | ||||
| STREAM_PASS_DEFAULT_FUNC(AssignByLabelPass); | STREAM_PASS_DEFAULT_FUNC(AssignByLabelPass); | ||||
| Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
| Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
| }; | }; | ||||
| // Engines such as hccl require independent Stream. | // Engines such as hccl require independent Stream. | ||||
| class IndependentStreamPass : public LogicalStreamPass { | class IndependentStreamPass : public LogicalStreamPass { | ||||
| public: | public: | ||||
| STREAM_PASS_DEFAULT_FUNC(IndependentStreamPass); | STREAM_PASS_DEFAULT_FUNC(IndependentStreamPass); | ||||
| Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
| Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
| }; | }; | ||||
| // Reuse streams or assign new streams based on dependencies. | // Reuse streams or assign new streams based on dependencies. | ||||
| class AssignByDependencyPass : public LogicalStreamPass { | class AssignByDependencyPass : public LogicalStreamPass { | ||||
| public: | public: | ||||
| STREAM_PASS_DEFAULT_FUNC(AssignByDependencyPass); | STREAM_PASS_DEFAULT_FUNC(AssignByDependencyPass); | ||||
| Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
| Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
| private: | private: | ||||
| void InitEndSubgraphMap(const std::vector<SubgraphPtr> &subgraphs, std::map<NodePtr, SubgraphPtr> &end_subgraph_map); | void InitEndSubgraphMap(const std::vector<SubgraphPtr> &subgraphs, std::map<NodePtr, SubgraphPtr> &end_subgraph_map); | ||||
| @@ -132,7 +132,7 @@ class AssignByDependencyPass : public LogicalStreamPass { | |||||
| std::map<std::string, int64_t> engine_stream_num_; | std::map<std::string, int64_t> engine_stream_num_; | ||||
| // Subgraphs of assign stream by engine | // Subgraphs of assign stream by engine | ||||
| std::set<SubgraphPtr> assigned_subgraphs_; | |||||
| std::vector<SubgraphPtr> assigned_subgraphs_; | |||||
| // <current subgraph, reused subgraph> | // <current subgraph, reused subgraph> | ||||
| std::vector<std::pair<SubgraphPtr, SubgraphPtr>> reused_subgraphs_; | std::vector<std::pair<SubgraphPtr, SubgraphPtr>> reused_subgraphs_; | ||||
| @@ -142,7 +142,7 @@ class AssignByDependencyPass : public LogicalStreamPass { | |||||
| class NodeStreamUpdatePass : public LogicalStreamPass { | class NodeStreamUpdatePass : public LogicalStreamPass { | ||||
| public: | public: | ||||
| STREAM_PASS_DEFAULT_FUNC(NodeStreamUpdatePass); | STREAM_PASS_DEFAULT_FUNC(NodeStreamUpdatePass); | ||||
| Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
| Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
| private: | private: | ||||
| /// Optimize for case like: | /// Optimize for case like: | ||||
| @@ -150,19 +150,18 @@ class NodeStreamUpdatePass : public LogicalStreamPass { | |||||
| /// To case: | /// To case: | ||||
| /// NodeA(stream1) -> Const(stream1) -> NodeB(stream1) | /// NodeA(stream1) -> Const(stream1) -> NodeB(stream1) | ||||
| /// Which could reduce event number (Const could be other type which belong to skipped engine subgraph) | /// Which could reduce event number (Const could be other type which belong to skipped engine subgraph) | ||||
| Status UpdateForSkippedEngine(const ComputeGraphPtr &whole_graph, const std::vector<SubgraphPtr> &subgraphs); | |||||
| Status UpdateForSkippedEngine(const ComputeGraphPtr &graph, const std::vector<SubgraphPtr> &subgraphs); | |||||
| int64_t GetSingleInoutStream(const NodePtr &node) const; | int64_t GetSingleInoutStream(const NodePtr &node) const; | ||||
| // Judge if all predecessors' streams of node are INVALID_STREAM | // Judge if all predecessors' streams of node are INVALID_STREAM | ||||
| bool AreAllPredStreamsInvalid(const NodePtr &node) const; | bool AreAllPredStreamsInvalid(const NodePtr &node) const; | ||||
| void RefreshContinuousStreams(ComputeGraphPtr whole_graph, Context &context) const; | |||||
| }; | }; | ||||
| // AllReduce and backward operators execute in parallel. | // AllReduce and backward operators execute in parallel. | ||||
| class AllReduceParallelPass : public LogicalStreamPass { | class AllReduceParallelPass : public LogicalStreamPass { | ||||
| public: | public: | ||||
| STREAM_PASS_DEFAULT_FUNC(AllReduceParallelPass); | STREAM_PASS_DEFAULT_FUNC(AllReduceParallelPass); | ||||
| Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
| Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
| }; | }; | ||||
| // Assign logical streams which is not limited by the number of tasks. | // Assign logical streams which is not limited by the number of tasks. | ||||
| @@ -178,13 +177,16 @@ class LogicalStreamAllocator { | |||||
| LogicalStreamAllocator &operator=(const LogicalStreamAllocator &) = delete; | LogicalStreamAllocator &operator=(const LogicalStreamAllocator &) = delete; | ||||
| ~LogicalStreamAllocator() = default; | ~LogicalStreamAllocator() = default; | ||||
| Status Assign(const ComputeGraphPtr &whole_graph, const std::vector<SubGraphInfoPtr> &subgraphs, int64_t &stream_num); | |||||
| Status Assign(const ComputeGraphPtr &whole_graph, const Graph2SubGraphInfoList &subgraph_map, int64_t &stream_num); | |||||
| private: | private: | ||||
| Status DoAssign(const ComputeGraphPtr &graph, const Graph2SubGraphInfoList &subgraph_map, | |||||
| const map<string, EngineConfPtr> &engine_confs); | |||||
| Status ConvertSubgraphs(const std::vector<SubGraphInfoPtr> &subgraph_infos, | Status ConvertSubgraphs(const std::vector<SubGraphInfoPtr> &subgraph_infos, | ||||
| const std::map<std::string, EngineConfPtr> &engine_confs, | const std::map<std::string, EngineConfPtr> &engine_confs, | ||||
| std::vector<SubgraphPtr> &subgraphs); | std::vector<SubgraphPtr> &subgraphs); | ||||
| Status RunPasses(const ComputeGraphPtr &whole_graph, const std::vector<SubgraphPtr> &subgraphs, int64_t &stream_num); | |||||
| Status RunPasses(const ComputeGraphPtr &graph, const std::vector<SubgraphPtr> &subgraphs); | |||||
| void RefreshContinuousStreams(const ComputeGraphPtr &graph); | |||||
| const std::map<std::string, SchedulerConf> &scheduler_confs_; | const std::map<std::string, SchedulerConf> &scheduler_confs_; | ||||
| const std::map<std::string, int> &max_parallel_num_; | const std::map<std::string, int> &max_parallel_num_; | ||||
| @@ -805,6 +805,9 @@ void SetOffsetSize(const NodeTypeIndex &node_type_index, int64_t offset, size_t | |||||
| } | } | ||||
| } | } | ||||
| op_desc->SetOutputOffset(output_list); | op_desc->SetOutputOffset(output_list); | ||||
| GELOGI("[IMAS]Set %s name[%s] output[%d] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu].", | |||||
| graph_name.c_str(), op_desc->GetName().c_str(), node_type_index.index, offset, op_desc->GetStreamId(), size, | |||||
| real_size); | |||||
| } else if (node_type_index.mem_type == kWorkspace) { | } else if (node_type_index.mem_type == kWorkspace) { | ||||
| vector<int64_t> workspace_list; | vector<int64_t> workspace_list; | ||||
| workspace_list = op_desc->GetWorkspace(); | workspace_list = op_desc->GetWorkspace(); | ||||
| @@ -821,6 +824,9 @@ void SetOffsetSize(const NodeTypeIndex &node_type_index, int64_t offset, size_t | |||||
| workspace_list.at(node_type_index.index) = offset; | workspace_list.at(node_type_index.index) = offset; | ||||
| } | } | ||||
| op_desc->SetWorkspace(workspace_list); | op_desc->SetWorkspace(workspace_list); | ||||
| GELOGI("[IMAS]Set %s name[%s] workspace[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu].", | |||||
| graph_name.c_str(), op_desc->GetName().c_str(), node_type_index.index, offset, op_desc->GetStreamId(), size, | |||||
| real_size); | |||||
| } | } | ||||
| } | } | ||||
| @@ -310,6 +310,11 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node) | |||||
| if (is_tensor_actual_size == 0) { | if (is_tensor_actual_size == 0) { | ||||
| AlignMemOffset(MEM_ALIGN_SIZE); | AlignMemOffset(MEM_ALIGN_SIZE); | ||||
| } | } | ||||
| GELOGI( | |||||
| "[IMAS]Continuous input : Set %s name[%s] output[%d] offset to [%zu] stream_id[%ld] size[%zu] " | |||||
| "real_size[%ld].", | |||||
| node->GetOwnerComputeGraph()->GetName().c_str(), peer_op_desc->GetName().c_str(), peer_out_data_anchor->GetIdx(), | |||||
| pre_mem_offset, peer_op_desc->GetStreamId(), (memory_offset_[0].mem_offset_ - pre_mem_offset), tensor_desc_size); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -340,6 +345,11 @@ Status GraphMemoryAssigner::AssignContinuousOutputMemory(const ge::NodePtr &node | |||||
| memory_offset_[0].mem_offset_ += tensor_desc_size; | memory_offset_[0].mem_offset_ += tensor_desc_size; | ||||
| AlignMemOffset(MEM_ALIGN_SIZE); | AlignMemOffset(MEM_ALIGN_SIZE); | ||||
| GELOGI( | |||||
| "[IMAS]Continuous output : Set %s name[%s] output[%d] offset to [%zu] stream_id[%ld] size[%zu] " | |||||
| "real_size[%ld].", | |||||
| node->GetOwnerComputeGraph()->GetName().c_str(), out_op_desc->GetName().c_str(), out_data_anchor->GetIdx(), | |||||
| pre_mem_offset, out_op_desc->GetStreamId(), (memory_offset_[0].mem_offset_ - pre_mem_offset), tensor_desc_size); | |||||
| } | } | ||||
| out_op_desc->SetOutputOffset(output_list); | out_op_desc->SetOutputOffset(output_list); | ||||
| @@ -413,8 +423,10 @@ Status GraphMemoryAssigner::ReAssignReuseAndNoPaddingContinuousInputMemory() { | |||||
| pre_mem_offset, peer_op_desc->GetStreamId(), out_size, output_mem_size); | pre_mem_offset, peer_op_desc->GetStreamId(), out_size, output_mem_size); | ||||
| } | } | ||||
| memory_offset_[0].mem_offset_ += extra_memory_size; | memory_offset_[0].mem_offset_ += extra_memory_size; | ||||
| GELOGI("After reassign virtual input node[name:%s, type:%s] memory, memory offset = %zu.", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), memory_offset_[0].mem_offset_); | |||||
| size_t after_mem_offset = memory_offset_[0].mem_offset_; | |||||
| AlignMemOffset(MEM_ALIGN_SIZE); | |||||
| GELOGI("After reassign virtual input node[name:%s, type:%s] memory, memory offset = %zu, align memory = %zu.", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), after_mem_offset, memory_offset_[0].mem_offset_); | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -499,8 +511,10 @@ Status GraphMemoryAssigner::ReAssignReuseAndNoPaddingContinuousOutputMemory() { | |||||
| } | } | ||||
| op_desc->SetOutputOffset(output_list); | op_desc->SetOutputOffset(output_list); | ||||
| memory_offset_[0].mem_offset_ += extra_memory_size; | memory_offset_[0].mem_offset_ += extra_memory_size; | ||||
| GELOGI("After reassign virtual output node[name:%s, type:%s] memory, memory offset = %zu.", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), memory_offset_[0].mem_offset_); | |||||
| size_t after_mem_offset = memory_offset_[0].mem_offset_; | |||||
| AlignMemOffset(MEM_ALIGN_SIZE); | |||||
| GELOGI("After reassign virtual output node[name:%s, type:%s] memory, memory offset = %zu, align memory = %zu.", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), after_mem_offset, memory_offset_[0].mem_offset_); | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -567,6 +581,11 @@ Status GraphMemoryAssigner::ReAssignMergeMemory() { | |||||
| output_list[index] = data_output_offset; | output_list[index] = data_output_offset; | ||||
| src_node->GetOpDesc()->SetOutputOffset(output_list); | src_node->GetOpDesc()->SetOutputOffset(output_list); | ||||
| GELOGI( | |||||
| "[IMAS]ReAssignMergeMemory : Set %s name[%s] output[%d] offset to [%ld] stream_id[%ld] size[%ld] " | |||||
| "real_size[%ld].", | |||||
| n->GetOwnerComputeGraph()->GetName().c_str(), src_node->GetOpDesc()->GetName().c_str(), index, | |||||
| data_output_offset, src_node->GetOpDesc()->GetStreamId(), max_output_size, max_output_size); | |||||
| input_list.emplace_back(data_output_offset); | input_list.emplace_back(data_output_offset); | ||||
| } | } | ||||
| @@ -897,6 +916,9 @@ Status GraphMemoryAssigner::AssignAtomicOutputMemory(const ge::NodePtr &node) { | |||||
| } | } | ||||
| output_list[output_index] = memory_offset_[0].mem_offset_; | output_list[output_index] = memory_offset_[0].mem_offset_; | ||||
| GELOGI("[IMAS]Atomic output : Set %s name[%s] output[%ld] offset to [%zu] stream_id[%ld] size[%ld] real_size[%ld].", | |||||
| compute_graph_->GetName().c_str(), op_desc->GetName().c_str(), output_index, memory_offset_[0].mem_offset_, | |||||
| op_desc->GetStreamId(), size, size); | |||||
| memory_offset_[0].mem_offset_ += size; | memory_offset_[0].mem_offset_ += size; | ||||
| AlignMemOffset(MEM_ALIGN_SIZE); | AlignMemOffset(MEM_ALIGN_SIZE); | ||||
| @@ -933,6 +955,11 @@ Status GraphMemoryAssigner::AssignOrdinaryAtomicWorkspaceMemory(const ge::OpDesc | |||||
| } | } | ||||
| workspace_vector[workspace_index] = memory_offset_[0].mem_offset_; | workspace_vector[workspace_index] = memory_offset_[0].mem_offset_; | ||||
| GELOGI( | |||||
| "[IMAS]Atomic ordinary workspace : Set %s name[%s] workspace[%lu] offset to [%zu] stream_id[%ld] " | |||||
| "size[%ld] real_size[%ld].", | |||||
| compute_graph_->GetName().c_str(), op_desc->GetName().c_str(), workspace_index, memory_offset_[0].mem_offset_, | |||||
| op_desc->GetStreamId(), workspace_size, workspace_size); | |||||
| memory_offset_[0].mem_offset_ += workspace_size; | memory_offset_[0].mem_offset_ += workspace_size; | ||||
| } | } | ||||
| @@ -958,6 +985,11 @@ Status GraphMemoryAssigner::AssignFusionAtomicWorkspaceMemory(const ge::OpDescPt | |||||
| auto workspace_size = info_iter.second; | auto workspace_size = info_iter.second; | ||||
| size_t workspace_offset = memory_offset_[0].mem_offset_; | size_t workspace_offset = memory_offset_[0].mem_offset_; | ||||
| GELOGI( | |||||
| "[IMAS]Atomic fusion workspace : Set %s name[%s] workspace[%lu] offset to [%zu] stream_id[%ld] size[%ld] " | |||||
| "real_size[%ld].", | |||||
| compute_graph_->GetName().c_str(), op_desc->GetName().c_str(), workspace_index, memory_offset_[0].mem_offset_, | |||||
| op_desc->GetStreamId(), workspace_size, workspace_size); | |||||
| memory_offset_[0].mem_offset_ += workspace_size; | memory_offset_[0].mem_offset_ += workspace_size; | ||||
| index_offset.insert(std::make_pair(workspace_index, workspace_offset)); | index_offset.insert(std::make_pair(workspace_index, workspace_offset)); | ||||
| @@ -1005,7 +1037,8 @@ ge::Status GraphMemoryAssigner::SetInputOffset() { | |||||
| GELOGE(FAILED, "memory_offset_ is empty."); | GELOGE(FAILED, "memory_offset_ is empty."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| GEEVENT("[IMAS]AfterAssignMemory : %s", compute_graph_->GetName().c_str()); | |||||
| GEEVENT("[IMAS]AfterAssignMemory : %s memoffset[%zu]", compute_graph_->GetName().c_str(), | |||||
| memory_offset_[0].mem_offset_); | |||||
| for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { | for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { | ||||
| if (UpdateOpInputOffset(node) != ge::SUCCESS) { | if (UpdateOpInputOffset(node) != ge::SUCCESS) { | ||||
| GELOGE(ge::FAILED, "Update op input offset failed"); | GELOGE(ge::FAILED, "Update op input offset failed"); | ||||
| @@ -1166,6 +1199,12 @@ ge::Status GraphMemoryAssigner::SetAtomicCleanAttr(const NodePtr &n, int64_t ato | |||||
| GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(node_op_desc, ATTR_NAME_AUTOMIC_ADD_MEM_SIZE, mem_size_vector), | GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(node_op_desc, ATTR_NAME_AUTOMIC_ADD_MEM_SIZE, mem_size_vector), | ||||
| GELOGE(FAILED, "SetListInt failed."); | GELOGE(FAILED, "SetListInt failed."); | ||||
| return FAILED); | return FAILED); | ||||
| GELOGI( | |||||
| "[IMAS]SetAtomicCleanAttr : Set %s name[%s] output[%d] offset to [%ld] streamid[%ld] size[%ld] " | |||||
| "realsize[%ld].", | |||||
| node->GetOwnerComputeGraph()->GetName().c_str(), node_op_desc->GetName().c_str(), 0, atomic_mem_start, | |||||
| node->GetOpDesc()->GetStreamId(), atomic_mem_size, atomic_mem_size); | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -28,6 +28,7 @@ | |||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
| #include "graph/ge_context.h" | |||||
| #include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
| #include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| @@ -39,7 +40,6 @@ | |||||
| #include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "graph/ge_context.h" | |||||
| #include "init/gelib.h" | #include "init/gelib.h" | ||||
| #include "memory/memory_assigner.h" | #include "memory/memory_assigner.h" | ||||
| #include "omg/version.h" | #include "omg/version.h" | ||||
| @@ -78,15 +78,16 @@ bool IsGeLocalOp(const ge::ConstOpDescPtr &op_desc) { | |||||
| ge::GeTensorDesc output_desc = op_desc->GetOutputDesc(0); | ge::GeTensorDesc output_desc = op_desc->GetOutputDesc(0); | ||||
| return !(output_desc.GetDataType() == ge::DT_STRING); | return !(output_desc.GetDataType() == ge::DT_STRING); | ||||
| } | } | ||||
| const set<string> ge_local_set = { | |||||
| ge::STREAMMERGE, ge::MEMCPYASYNC, ge::STREAMACTIVE, ge::STREAMSWITCH, ge::VARIABLE, ge::NOOP, ge::CONSTANT, | |||||
| ge::ENTER, ge::REFENTER, ge::LOOPCOND, ge::NEXTITERATION, ge::REFNEXTITERATION, ge::EXIT, ge::REFEXIT}; | |||||
| const set<string> ge_local_set = {ge::STREAMMERGE, ge::MEMCPYASYNC, ge::STREAMACTIVE, ge::STREAMSWITCH, | |||||
| ge::VARIABLE, ge::NOOP, ge::CONSTANT, ge::ENTER, | |||||
| ge::REFENTER, ge::LOOPCOND, ge::NEXTITERATION, ge::REFNEXTITERATION, | |||||
| ge::EXIT, ge::REFEXIT, ge::MEMCPYADDRASYNC}; | |||||
| return (ge_local_set.find(type) != ge_local_set.end()); | return (ge_local_set.find(type) != ge_local_set.end()); | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| ModelBuilder::ModelBuilder(ge::ComputeGraphPtr compute_graph, const vector<SubGraphInfoPtr> &subgraphs, | |||||
| ModelBuilder::ModelBuilder(ge::ComputeGraphPtr compute_graph, const Graph2SubGraphInfoList &subgraphs, | |||||
| const map<string, int> &stream_max_parallel_num, bool hcom_parallel, int mode) | const map<string, int> &stream_max_parallel_num, bool hcom_parallel, int mode) | ||||
| : mem_offset_(0), | : mem_offset_(0), | ||||
| weight_offset_(kWeightsStartOffset), | weight_offset_(kWeightsStartOffset), | ||||
| @@ -225,6 +226,25 @@ Status ModelBuilder::SetInputOutputDesc() { | |||||
| if (!is_loop_graph_ && node_op_desc->GetType() == LOOPCOND) { | if (!is_loop_graph_ && node_op_desc->GetType() == LOOPCOND) { | ||||
| is_loop_graph_ = true; | is_loop_graph_ = true; | ||||
| } | } | ||||
| // if user set input node format ND, the expected node for data and netoutput format is ND in | |||||
| // final graph. | |||||
| if ((domi::GetContext().format == domi::DOMI_TENSOR_ND) && | |||||
| ((node_op_desc->GetType() == DATA_TYPE) || (node_op_desc->GetType() == NETOUTPUT))) { | |||||
| GELOGI("The node [%s] format should be set ND.", node_op_desc->GetName().c_str()); | |||||
| auto inputDescsPtr = node_op_desc->GetAllInputsDescPtr(); | |||||
| auto outputDescsPtr = node_op_desc->GetAllOutputsDescPtr(); | |||||
| ge::Format format = ge::FORMAT_ND; | |||||
| for (auto &inputDescPtr : inputDescsPtr) { | |||||
| GE_CHECK_NOTNULL(inputDescPtr); | |||||
| inputDescPtr->SetFormat(format); | |||||
| inputDescPtr->SetOriginFormat(format); | |||||
| } | |||||
| for (auto &outputDescPtr : outputDescsPtr) { | |||||
| GE_CHECK_NOTNULL(outputDescPtr); | |||||
| outputDescPtr->SetFormat(format); | |||||
| outputDescPtr->SetOriginFormat(format); | |||||
| } | |||||
| } | |||||
| if (node_op_desc->GetType() == DATA_TYPE || node_op_desc->GetType() == AIPP_DATA_TYPE) { | if (node_op_desc->GetType() == DATA_TYPE || node_op_desc->GetType() == AIPP_DATA_TYPE) { | ||||
| GELOGD("Data node: %s.", n->GetName().c_str()); | GELOGD("Data node: %s.", n->GetName().c_str()); | ||||
| @@ -37,7 +37,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| class ModelBuilder { | class ModelBuilder { | ||||
| public: | public: | ||||
| ModelBuilder(ge::ComputeGraphPtr whole_graph, const std::vector<SubGraphInfoPtr> &subgraphs, | |||||
| ModelBuilder(ge::ComputeGraphPtr whole_graph, const Graph2SubGraphInfoList &subgraphs, | |||||
| const std::map<std::string, int> &stream_max_parallel_num, bool hcom_parallel, | const std::map<std::string, int> &stream_max_parallel_num, bool hcom_parallel, | ||||
| int mode = static_cast<int>(domi::BuildMode::GEN_TASK_WITHOUT_FUSION)); | int mode = static_cast<int>(domi::BuildMode::GEN_TASK_WITHOUT_FUSION)); | ||||
| @@ -85,7 +85,7 @@ class ModelBuilder { | |||||
| ge::ComputeGraphPtr compute_graph_; | ge::ComputeGraphPtr compute_graph_; | ||||
| const std::vector<SubGraphInfoPtr> &subgraphs_; | |||||
| const Graph2SubGraphInfoList &subgraphs_; | |||||
| int64_t stream_num_; | int64_t stream_num_; | ||||
| @@ -164,6 +164,9 @@ Status RunContextUtil::CreateRunContext(Model &model, const ComputeGraphPtr &gra | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| GELOGI("CreateRunContext: data_mem_base_ = %p, weight_mem_base_ = %p, memory_size = %lu, weight_size = %lu", | |||||
| data_mem_base_, weight_mem_base_, data_mem_size_, weight_mem_size_); | |||||
| run_context_ = {rt_model_, nullptr, session_id, data_mem_size_, data_mem_base_, weight_mem_size_, | run_context_ = {rt_model_, nullptr, session_id, data_mem_size_, data_mem_base_, weight_mem_size_, | ||||
| weight_mem_base_, buffer, stream_list_, event_list_, label_list_}; | weight_mem_base_, buffer, stream_list_, event_list_, label_list_}; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -40,7 +40,7 @@ const uint32_t kMaxSwitchStreamNum = 1; | |||||
| namespace ge { | namespace ge { | ||||
| Status StreamAllocator::AssignLogicalStreams(const std::map<std::string, int> &max_parallel_num, bool hcom_parallel) { | Status StreamAllocator::AssignLogicalStreams(const std::map<std::string, int> &max_parallel_num, bool hcom_parallel) { | ||||
| GELOGI("AssignLogicalStreams start."); | |||||
| GELOGI("Assign logical streams start."); | |||||
| GE_CHECK_NOTNULL(whole_graph_); | GE_CHECK_NOTNULL(whole_graph_); | ||||
| GraphUtils::DumpGEGraph(whole_graph_, "BeforeAssignedLogicalStreams"); | GraphUtils::DumpGEGraph(whole_graph_, "BeforeAssignedLogicalStreams"); | ||||
| GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "BeforeAssignedLogicalStreams"); | GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "BeforeAssignedLogicalStreams"); | ||||
| @@ -52,7 +52,6 @@ Status StreamAllocator::AssignLogicalStreams(const std::map<std::string, int> &m | |||||
| } | } | ||||
| const map<string, SchedulerConf> &scheduler_confs = gelib->DNNEngineManagerObj().GetSchedulers(); | const map<string, SchedulerConf> &scheduler_confs = gelib->DNNEngineManagerObj().GetSchedulers(); | ||||
| LogicalStreamAllocator logical_allocator(scheduler_confs, max_parallel_num, hcom_parallel); | LogicalStreamAllocator logical_allocator(scheduler_confs, max_parallel_num, hcom_parallel); | ||||
| Status status = logical_allocator.Assign(whole_graph_, subgraphs_, stream_num_); | Status status = logical_allocator.Assign(whole_graph_, subgraphs_, stream_num_); | ||||
| if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
| @@ -62,7 +61,7 @@ Status StreamAllocator::AssignLogicalStreams(const std::map<std::string, int> &m | |||||
| GraphUtils::DumpGEGraph(whole_graph_, "AfterAssignedLogicalStreams"); | GraphUtils::DumpGEGraph(whole_graph_, "AfterAssignedLogicalStreams"); | ||||
| GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "AfterAssignedLogicalStreams"); | GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "AfterAssignedLogicalStreams"); | ||||
| GELOGI("AssignLogicalStreams success."); | |||||
| GELOGI("Assign logical streams success."); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -136,7 +135,7 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu | |||||
| GELOGI("None of nodes need to assign stream, stream num is 0, it will cause error, so change it to 1"); | GELOGI("None of nodes need to assign stream, stream num is 0, it will cause error, so change it to 1"); | ||||
| stream_num_ = 1; | stream_num_ = 1; | ||||
| } | } | ||||
| GELOGI("stream_num_: %ld, event_num_: %u.", stream_num_, event_num_); | |||||
| GELOGI("stream num: %ld, event num: %u.", stream_num_, event_num_); | |||||
| GELOGI("RefreshRealStream successfully."); | GELOGI("RefreshRealStream successfully."); | ||||
| stream_num = stream_num_; | stream_num = stream_num_; | ||||
| @@ -148,7 +147,7 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu | |||||
| // Split the stream according to the maximum number of nodes in the stream. | // Split the stream according to the maximum number of nodes in the stream. | ||||
| Status StreamAllocator::SplitStreams() { | Status StreamAllocator::SplitStreams() { | ||||
| if (stream_num_ == 0) { | if (stream_num_ == 0) { | ||||
| GELOGI("stream_num_ is 0"); | |||||
| GELOGI("The number of streams is 0 and no need to split."); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -30,7 +30,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| class StreamAllocator { | class StreamAllocator { | ||||
| public: | public: | ||||
| StreamAllocator(ComputeGraphPtr whole_graph, const std::vector<SubGraphInfoPtr> &subgraphs) | |||||
| StreamAllocator(ComputeGraphPtr whole_graph, const Graph2SubGraphInfoList &subgraphs) | |||||
| : whole_graph_(std::move(whole_graph)), subgraphs_(subgraphs) {} | : whole_graph_(std::move(whole_graph)), subgraphs_(subgraphs) {} | ||||
| StreamAllocator(const StreamAllocator &) = delete; | StreamAllocator(const StreamAllocator &) = delete; | ||||
| StreamAllocator &operator=(const StreamAllocator &) = delete; | StreamAllocator &operator=(const StreamAllocator &) = delete; | ||||
| @@ -75,7 +75,7 @@ class StreamAllocator { | |||||
| bool IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr, const NodePtr &recv_node_ptr) const; | bool IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr, const NodePtr &recv_node_ptr) const; | ||||
| ComputeGraphPtr whole_graph_; | ComputeGraphPtr whole_graph_; | ||||
| const std::vector<SubGraphInfoPtr> &subgraphs_; | |||||
| const Graph2SubGraphInfoList &subgraphs_; | |||||
| int64_t stream_num_{0}; | int64_t stream_num_{0}; | ||||
| uint32_t event_num_{0}; | uint32_t event_num_{0}; | ||||
| @@ -29,19 +29,21 @@ static const int64_t kInvalidStream = -1; | |||||
| namespace ge { | namespace ge { | ||||
| StreamGraphOptimizer::~StreamGraphOptimizer() {} | StreamGraphOptimizer::~StreamGraphOptimizer() {} | ||||
| void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, vector<SubGraphInfoPtr> &subgraph_infos) { | |||||
| void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map) { | |||||
| size_t node_size = comp_graph->GetDirectNodesSize(); | size_t node_size = comp_graph->GetDirectNodesSize(); | ||||
| GELOGI("Refresh placeholder and end nodeId start from node num: %zu", node_size); | GELOGI("Refresh placeholder and end nodeId start from node num: %zu", node_size); | ||||
| for (const auto &sub_graph_info : subgraph_infos) { | |||||
| ComputeGraphPtr sub_graph = sub_graph_info->GetSubGraph(); | |||||
| if (sub_graph == nullptr) { | |||||
| continue; | |||||
| } | |||||
| for (ge::NodePtr &node : sub_graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL_EXEC(node->GetOpDesc(), return ); | |||||
| if ((node->GetType() == END) || (node->GetType() == PLACEHOLDER)) { | |||||
| node->GetOpDesc()->SetId(static_cast<int64_t>(node_size)); | |||||
| node_size++; | |||||
| for (const auto &subgraph_pair : subgraph_map) { | |||||
| for (const auto &subgraph_info : subgraph_pair.second) { | |||||
| ComputeGraphPtr subgraph = subgraph_info->GetSubGraph(); | |||||
| if (subgraph == nullptr) { | |||||
| continue; | |||||
| } | |||||
| for (ge::NodePtr &node : subgraph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL_EXEC(node->GetOpDesc(), return ); | |||||
| if ((node->GetType() == END) || (node->GetType() == PLACEHOLDER)) { | |||||
| node->GetOpDesc()->SetId(static_cast<int64_t>(node_size)); | |||||
| node_size++; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -71,67 +73,71 @@ bool StreamGraphOptimizer::IsSameStreamId(const ComputeGraphPtr &comp_graph) { | |||||
| } | } | ||||
| Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &comp_graph, | Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &comp_graph, | ||||
| vector<SubGraphInfoPtr> &subgraph_infos, | |||||
| Graph2SubGraphInfoList &subgraph_map, | |||||
| struct RunContext &run_context) { | struct RunContext &run_context) { | ||||
| Status ret = SUCCESS; | |||||
| GELOGI("Begin to Get optimize streamed subgraph."); | |||||
| GELOGI("Optimize streamed subgraph start."); | |||||
| RefreshNodeId(comp_graph, subgraph_infos); | |||||
| RefreshNodeId(comp_graph, subgraph_map); | |||||
| std::shared_ptr<GELib> instance = ge::GELib::GetInstance(); | std::shared_ptr<GELib> instance = ge::GELib::GetInstance(); | ||||
| GE_CHECK_NOTNULL(instance); | GE_CHECK_NOTNULL(instance); | ||||
| for (auto &sub_graph_info : subgraph_infos) { | |||||
| ComputeGraphPtr sub_graph = sub_graph_info->GetSubGraph(); | |||||
| if (sub_graph == nullptr) { | |||||
| continue; | |||||
| } | |||||
| for (const auto &subgraph_pair : subgraph_map) { | |||||
| for (const auto &subgraph_info : subgraph_pair.second) { | |||||
| ComputeGraphPtr subgraph = subgraph_info->GetSubGraph(); | |||||
| GE_CHECK_NOTNULL(subgraph); | |||||
| std::string engine_name = sub_graph_info->GetEngineName(); | |||||
| GELOGI("Optimize subgraph %s", subgraph->GetName().c_str()); | |||||
| vector<GraphOptimizerPtr> graph_optimizers; | |||||
| if (instance->DNNEngineManagerObj().IsEngineRegistered(engine_name)) { | |||||
| instance->OpsKernelManagerObj().GetGraphOptimizerByEngine(engine_name, graph_optimizers); | |||||
| GELOGI("Subgraph: %s start optimize streamed graph. engineName: %s, subgraph num: %zu, graph Optimizer num: %zu.", | |||||
| sub_graph->GetName().c_str(), engine_name.c_str(), subgraph_infos.size(), graph_optimizers.size()); | |||||
| std::string engine_name = subgraph_info->GetEngineName(); | |||||
| auto nodes = sub_graph->GetDirectNode(); | |||||
| if (nodes.empty()) { | |||||
| continue; | |||||
| } | |||||
| if (!IsSameStreamId(sub_graph)) { | |||||
| GELOGI("There are more than one stream in subgraph %s", sub_graph->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| OpDescPtr op_desc = nodes.at(0)->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| int64_t stream_id = op_desc->GetStreamId(); | |||||
| if (static_cast<size_t>(stream_id) >= run_context.graphStreamList.size()) { | |||||
| GELOGE(FAILED, "stream_id is bigger than run_context.graphStreamList.size()"); | |||||
| return FAILED; | |||||
| } | |||||
| run_context.stream = run_context.graphStreamList[stream_id]; | |||||
| GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu.", | |||||
| sub_graph->GetName().c_str(), engine_name.c_str(), stream_id, | |||||
| static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream))); | |||||
| for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { | |||||
| GE_CHECK_NOTNULL(*iter); | |||||
| ret = (*iter)->OptimizeStreamGraph(*sub_graph, run_context); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, | |||||
| "[optimizeStreamedSubGraph]: optimize streamed subgraph failed, subgraph: %s, engine_name: %s, graph " | |||||
| "Optimizer num: %zu, ret: %u", | |||||
| sub_graph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size(), ret); | |||||
| return ret; | |||||
| vector<GraphOptimizerPtr> graph_optimizers; | |||||
| if (instance->DNNEngineManagerObj().IsEngineRegistered(engine_name)) { | |||||
| instance->OpsKernelManagerObj().GetGraphOptimizerByEngine(engine_name, graph_optimizers); | |||||
| GELOGI("Subgraph: %s start optimize streamed graph. engineName: %s, graph Optimizer num: %zu.", | |||||
| subgraph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size()); | |||||
| auto nodes = subgraph->GetDirectNode(); | |||||
| if (nodes.empty()) { | |||||
| continue; | |||||
| } | |||||
| if (!IsSameStreamId(subgraph)) { | |||||
| GELOGI("There are more than one stream in subgraph %s", subgraph->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| OpDescPtr op_desc = nodes.at(0)->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| int64_t stream_id = op_desc->GetStreamId(); | |||||
| if (static_cast<size_t>(stream_id) >= run_context.graphStreamList.size()) { | |||||
| GELOGE(FAILED, "stream_id %ld is bigger than run_context.graphStreamList.size() %zu", stream_id, | |||||
| run_context.graphStreamList.size()); | |||||
| return FAILED; | |||||
| } | |||||
| run_context.stream = run_context.graphStreamList[stream_id]; | |||||
| GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu.", | |||||
| subgraph->GetName().c_str(), engine_name.c_str(), stream_id, | |||||
| static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream))); | |||||
| for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { | |||||
| GE_CHECK_NOTNULL(*iter); | |||||
| Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE( | |||||
| ret, | |||||
| "[optimizeStreamedSubGraph]: optimize streamed subgraph failed, subgraph: %s, engine_name: %s, graph " | |||||
| "Optimizer num: %zu, ret: %u", | |||||
| subgraph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size(), ret); | |||||
| return ret; | |||||
| } | |||||
| GELOGI( | |||||
| "[optimizeStreamedSubGraph]: optimize streamed subgraph success, subgraph: %s, engine_name: %s, graph " | |||||
| "Optimizer num: %zu!", | |||||
| subgraph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size()); | |||||
| } | } | ||||
| GELOGI( | |||||
| "[optimizeStreamedSubGraph]: optimize streamed subgraph success, subgraph: %s, engine_name: %s, graph " | |||||
| "Optimizer num: %zu!", | |||||
| sub_graph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size()); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| return ret; | |||||
| GELOGI("Optimize streamed subgraph success."); | |||||
| return SUCCESS; | |||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -35,11 +35,11 @@ class StreamGraphOptimizer { | |||||
| virtual ~StreamGraphOptimizer(); | virtual ~StreamGraphOptimizer(); | ||||
| Status OptimizeStreamedSubGraph(const ComputeGraphPtr &comp_graph, std::vector<SubGraphInfoPtr> &subgraph_ptr_list, | |||||
| Status OptimizeStreamedSubGraph(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map, | |||||
| struct RunContext &run_context); | struct RunContext &run_context); | ||||
| private: | private: | ||||
| void RefreshNodeId(const ComputeGraphPtr &comp_graph, std::vector<SubGraphInfoPtr> &subgraph_ptr_list); | |||||
| void RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map); | |||||
| bool IsSameStreamId(const ComputeGraphPtr &comp_graph); | bool IsSameStreamId(const ComputeGraphPtr &comp_graph); | ||||
| }; | }; | ||||
| @@ -221,10 +221,8 @@ Status TaskGenerator::SaveL1fusionNodes(map<int64_t, std::vector<NodePtr>> &l1_f | |||||
| if (call_check) { | if (call_check) { | ||||
| auto input_group_id = *input_group_ids.begin(); | auto input_group_id = *input_group_ids.begin(); | ||||
| if (group_id != input_group_id) { | if (group_id != input_group_id) { | ||||
| GELOGE(INTERNAL_ERROR, | |||||
| "L1Fusion: node[name:%s(%s) with group id:%ld and diff from it's input nodes's group id:%ld ", | |||||
| GELOGW("L1Fusion: node[name:%s(%s) with group id:%ld and diff from it's input nodes's group id:%ld ", | |||||
| name.c_str(), type.c_str(), group_id, input_group_id); | name.c_str(), type.c_str(), group_id, input_group_id); | ||||
| return INTERNAL_ERROR; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -172,7 +172,7 @@ NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::st | |||||
| GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); | ||||
| (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | ||||
| NodePtr label_set = graph->AddNodeFront(op_desc); | |||||
| NodePtr label_set = graph->AddNode(op_desc); | |||||
| GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); | GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); | ||||
| // Link control edge to graph tail. | // Link control edge to graph tail. | ||||
| @@ -202,7 +202,7 @@ NodePtr LabelMaker::AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::s | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTO); | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX); | |||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
| SetStreamIdEnter(graph, op_desc); | SetStreamIdEnter(graph, op_desc); | ||||
| @@ -238,7 +238,7 @@ NodePtr LabelMaker::AddLabelGotoLeave(const ComputeGraphPtr &graph, const std::s | |||||
| const NodePtr &node = *it; | const NodePtr &node = *it; | ||||
| GE_CHECK_NOTNULL_EXEC(node, return nullptr); | GE_CHECK_NOTNULL_EXEC(node, return nullptr); | ||||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTO); | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX); | |||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
| SetStreamIdLeave(graph, op_desc); | SetStreamIdLeave(graph, op_desc); | ||||
| @@ -366,6 +366,7 @@ NodePtr LabelMaker::AddLabelSwitchIndex(const ComputeGraphPtr &graph, const std: | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(name, DATA); | OpDescPtr op_desc = MakeShared<OpDesc>(name, DATA); | ||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
| op_desc->SetStreamId(kInvalidStreamId); | |||||
| GELOGI("Data: Create node %s.", op_desc->GetName().c_str()); | GELOGI("Data: Create node %s.", op_desc->GetName().c_str()); | ||||
| if (op_desc->AddOutputDesc(desc) != GRAPH_SUCCESS) { | if (op_desc->AddOutputDesc(desc) != GRAPH_SUCCESS) { | ||||
| @@ -20,11 +20,11 @@ | |||||
| namespace { | namespace { | ||||
| const uint32_t kCoreDim = 1; // for rtCpuKernelLaunch | const uint32_t kCoreDim = 1; // for rtCpuKernelLaunch | ||||
| const char *const kCpuTaskModelEnqueue = "modelEnqueue"; | const char *const kCpuTaskModelEnqueue = "modelEnqueue"; | ||||
| const char *const kCpuTaskPrepareInput = "modelPrepareInput"; | |||||
| const char *const kCpuTaskWaitEndGraph = "modelWaitEndGraph"; | const char *const kCpuTaskWaitEndGraph = "modelWaitEndGraph"; | ||||
| const char *const kCpuTaskPrepareOutput = "modelPrepareOutput"; | |||||
| const char *const kCpuTaskPrepareOutput = "bufferPrepareOutput"; | |||||
| const char *const kCpuTaskModelDequeue = "modelDequeue"; | const char *const kCpuTaskModelDequeue = "modelDequeue"; | ||||
| const char *const kCpuTaskModelRepeat = "modelRepeat"; | const char *const kCpuTaskModelRepeat = "modelRepeat"; | ||||
| const char *const kCpuTaskZeroCopy = "zeroCpy"; | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| @@ -93,19 +93,19 @@ Status CpuTaskModelDequeue::Distribute() { | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief definiteness queue schedule, bind output queue to task. | |||||
| /// @param [in] addr: NetOutput Op input tensor address. | |||||
| /// @param [in] size: NetOutput Op input tensor size. | |||||
| /// @param [in] in_mbuf: input mbuf addr for input data. | |||||
| /// @brief definiteness queue schedule, zero copy. | |||||
| /// @param [in] mbuf_list: input/output mbuf addr list for input/output data. | |||||
| /// @param [in] outside_addrs: model input/output memory addr | |||||
| /// @return: 0 for success / others for failed | /// @return: 0 for success / others for failed | ||||
| /// | /// | ||||
| Status CpuTaskPrepareInput::Init(uintptr_t addr, uint32_t size, uintptr_t in_mbuf) { | |||||
| Status CpuTaskZeroCopy::Init(std::vector<uintptr_t> &mbuf_list, | |||||
| std::map<const void *, std::vector<void *>> &outside_addrs) { | |||||
| if ((args_ != nullptr) || (args_size_ > 0)) { | if ((args_ != nullptr) || (args_size_ > 0)) { | ||||
| GELOGE(FAILED, "Task already initialized, size: %u", args_size_); | GELOGE(FAILED, "Task already initialized, size: %u", args_size_); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| args_size_ = sizeof(PrepareInputInfo); | |||||
| args_size_ = sizeof(AddrMapInfo); | |||||
| rtError_t status = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); | rtError_t status = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); | ||||
| if (status != RT_ERROR_NONE) { | if (status != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); | GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); | ||||
| @@ -113,36 +113,99 @@ Status CpuTaskPrepareInput::Init(uintptr_t addr, uint32_t size, uintptr_t in_mbu | |||||
| } | } | ||||
| GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "args data.", args_size_) | GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "args data.", args_size_) | ||||
| PrepareInputInfo prepare; | |||||
| prepare.in_mbuf = in_mbuf; | |||||
| prepare.mbuf_offset = 0; | |||||
| prepare.data_size = size; | |||||
| prepare.data_addr = addr; | |||||
| status = rtMemcpy(args_, args_size_, &prepare, args_size_, RT_MEMCPY_HOST_TO_DEVICE); | |||||
| AddrMapInfo addr_map_info; | |||||
| for (const auto &addrs : outside_addrs) { | |||||
| addr_map_info.addr_num += addrs.second.size(); | |||||
| } | |||||
| GELOGI("addr_map_info.addr_num is %zu", addr_map_info.addr_num); | |||||
| // init src_addrs/dst_addrs | |||||
| size_t index = 0; | |||||
| vector<uint64_t> src_addrs; | |||||
| vector<uint64_t> dst_addrs; | |||||
| for (const auto &addrs : outside_addrs) { | |||||
| for (size_t i = 0; i < addrs.second.size(); ++i) { | |||||
| src_addrs.push_back(mbuf_list.at(index)); | |||||
| dst_addrs.push_back(reinterpret_cast<uint64_t>(addrs.second.at(i))); | |||||
| } | |||||
| index++; | |||||
| } | |||||
| // malloc mem for src_addrs/dst_addrs, and copy data of src_addrs/dst_addrs | |||||
| status = rtMalloc(&src_addr_, src_addrs.size() * sizeof(uint64_t), RT_MEMORY_HBM); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); | |||||
| return RT_FAILED; | |||||
| } | |||||
| status = rtMemcpy(src_addr_, src_addrs.size() * sizeof(uint64_t), src_addrs.data(), | |||||
| src_addrs.size() * sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (status != RT_ERROR_NONE) { | if (status != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); | GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); | ||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| status = rtMalloc(&dst_addr_, dst_addrs.size() * sizeof(uint64_t), RT_MEMORY_HBM); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); | |||||
| return RT_FAILED; | |||||
| } | |||||
| status = rtMemcpy(dst_addr_, dst_addrs.size() * sizeof(uint64_t), dst_addrs.data(), | |||||
| dst_addrs.size() * sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); | |||||
| return RT_FAILED; | |||||
| } | |||||
| // src_addr_list is init to src_addr, which is the point to src_addrs | |||||
| if (!src_addrs.empty() && !dst_addrs.empty()) { | |||||
| addr_map_info.src_addr_list = reinterpret_cast<uint64_t>(src_addr_); | |||||
| addr_map_info.dst_addr_list = reinterpret_cast<uint64_t>(dst_addr_); | |||||
| GELOGI("src_addr_list is %lu, dst_addr_list is %lu", addr_map_info.src_addr_list, addr_map_info.dst_addr_list); | |||||
| } | |||||
| status = rtMemcpy(args_, args_size_, &addr_map_info, sizeof(AddrMapInfo), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); | |||||
| return RT_FAILED; | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status CpuTaskPrepareInput::Distribute() { | |||||
| Status CpuTaskZeroCopy::Distribute() { | |||||
| if ((args_ == nullptr) || (args_size_ == 0) || (stream_ == nullptr)) { | if ((args_ == nullptr) || (args_size_ == 0) || (stream_ == nullptr)) { | ||||
| GELOGE(FAILED, "Task not initialized, distribute failed, size: %u", args_size_); | GELOGE(FAILED, "Task not initialized, distribute failed, size: %u", args_size_); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| rtError_t status = rtCpuKernelLaunch(nullptr, kCpuTaskPrepareInput, kCoreDim, args_, args_size_, nullptr, stream_); | |||||
| rtError_t status = rtCpuKernelLaunch(nullptr, kCpuTaskZeroCopy, kCoreDim, args_, args_size_, nullptr, stream_); | |||||
| if (status != RT_ERROR_NONE) { | if (status != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Call rt CpuKernelLaunch PrepareInput failed, status: 0x%X", status); | |||||
| GELOGE(RT_FAILED, "Call rt CpuKernelLaunch ZeroCopy failed, status: 0x%X", status); | |||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| GELOGI("Cpu kernel launch prepare input task success."); | |||||
| GELOGI("Cpu kernel launch zero copy task success."); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| CpuTaskZeroCopy::~CpuTaskZeroCopy() { | |||||
| if (src_addr_ == nullptr && dst_addr_ == nullptr) { | |||||
| return; | |||||
| } | |||||
| if (src_addr_ != nullptr) { | |||||
| rtError_t status = rtFree(src_addr_); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGW("Call rt free failed, status: 0x%x", status); | |||||
| } | |||||
| } | |||||
| if (dst_addr_ != nullptr) { | |||||
| rtError_t status = rtFree(dst_addr_); | |||||
| if (status != RT_ERROR_NONE) { | |||||
| GELOGW("Call rt free failed, status: 0x%x", status); | |||||
| } | |||||
| } | |||||
| src_addr_ = nullptr; | |||||
| dst_addr_ = nullptr; | |||||
| } | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief definiteness queue schedule, bind output queue to task. | /// @brief definiteness queue schedule, bind output queue to task. | ||||
| @@ -47,6 +47,13 @@ struct PrepareOutputInfo { | |||||
| uintptr_t out_mbuf; // output mbuf addr | uintptr_t out_mbuf; // output mbuf addr | ||||
| }; | }; | ||||
| // For AICPU task "modelZeroCopy" | |||||
| struct AddrMapInfo { | |||||
| uint32_t addr_num = 0; | |||||
| uint64_t src_addr_list; | |||||
| uint64_t dst_addr_list; | |||||
| }; | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief CpuTask base, inherit from TaskInfo used for manage. | /// @brief CpuTask base, inherit from TaskInfo used for manage. | ||||
| @@ -78,17 +85,21 @@ class CpuTaskModelDequeue : public CpuTaskInfo { | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief definiteness queue schedule, bind output queue to task. | |||||
| /// @brief definiteness queue schedule, zero copy. | |||||
| /// | /// | ||||
| class CpuTaskPrepareInput : public CpuTaskInfo { | |||||
| class CpuTaskZeroCopy : public CpuTaskInfo { | |||||
| public: | public: | ||||
| explicit CpuTaskPrepareInput(rtStream_t stream) : CpuTaskInfo(stream) {} | |||||
| ~CpuTaskPrepareInput() override {} | |||||
| explicit CpuTaskZeroCopy(rtStream_t stream) : CpuTaskInfo(stream) {} | |||||
| ~CpuTaskZeroCopy() override; | |||||
| Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override { return SUCCESS; } | Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override { return SUCCESS; } | ||||
| Status Init(uintptr_t addr, uint32_t size, uintptr_t in_mbuf); | |||||
| Status Init(std::vector<uintptr_t> &mbuf_list, std::map<const void *, std::vector<void *>> &outside_addrs); | |||||
| Status Distribute() override; | Status Distribute() override; | ||||
| private: | |||||
| void *src_addr_ = nullptr; | |||||
| void *dst_addr_ = nullptr; | |||||
| }; | }; | ||||
| /// | /// | ||||
| @@ -340,13 +340,6 @@ class DavinciModel { | |||||
| vector<InputOutputDescInfo> &output_desc, | vector<InputOutputDescInfo> &output_desc, | ||||
| std::vector<uint32_t> &inputFormats, std::vector<uint32_t> &output_formats); | std::vector<uint32_t> &inputFormats, std::vector<uint32_t> &output_formats); | ||||
| /// | |||||
| /// @ingroup domi_ome | |||||
| /// @brief copy input data to model | |||||
| /// @return Status | |||||
| /// | |||||
| Status CopyInputDataToModel(const std::vector<DataBuffer> &data, uint32_t data_op_index, bool device_data); | |||||
| Status ReturnResult(uint32_t data_id, const bool rslt_flg, const bool seq_end_flg, OutputData *output_data); | Status ReturnResult(uint32_t data_id, const bool rslt_flg, const bool seq_end_flg, OutputData *output_data); | ||||
| Status ReturnNoOutput(uint32_t data_id); | Status ReturnNoOutput(uint32_t data_id); | ||||
| @@ -413,20 +406,6 @@ class DavinciModel { | |||||
| /// | /// | ||||
| uint32_t GetDeviceId() const { return device_id_; } | uint32_t GetDeviceId() const { return device_id_; } | ||||
| /// | |||||
| /// @ingroup domi_ome | |||||
| /// @brief Set Train Mode | |||||
| /// @return void | |||||
| /// | |||||
| void SetTrainMode(bool mode) { is_train_mode_ = mode; } | |||||
| /// | |||||
| /// @ingroup domi_ome | |||||
| /// @brief Get Train Mode | |||||
| /// @return bool true | |||||
| /// | |||||
| bool GetTrainMode() { return is_train_mode_; } | |||||
| GeModelPtr GetGeModel() { return ge_model_; } | GeModelPtr GetGeModel() { return ge_model_; } | ||||
| const RuntimeParam &GetRuntimeParam() { return runtime_param_; } | const RuntimeParam &GetRuntimeParam() { return runtime_param_; } | ||||
| @@ -519,15 +498,14 @@ class DavinciModel { | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief Copy Data addr to model for direct use. | /// @brief Copy Data addr to model for direct use. | ||||
| /// @param [in] const vector<void *> &addrs: model input memory addr list. | |||||
| /// @param [in] const vector<uint32_t> &sizes: model input memory size list. | |||||
| /// @param [in] const std::map<uint32_t, std::pair<int64_t, void *>> &data_info: model memory addr/size list. | |||||
| /// @param [in] const std::vector<DataBuffer> &blobs: user input data list. | /// @param [in] const std::vector<DataBuffer> &blobs: user input data list. | ||||
| /// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input | /// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input | ||||
| /// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy | /// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy | ||||
| /// @param [in] string batch_label: batch label for multi-batch scenes | /// @param [in] string batch_label: batch label for multi-batch scenes | ||||
| /// @return SUCCESS handle successfully / others handle failed | /// @return SUCCESS handle successfully / others handle failed | ||||
| /// | /// | ||||
| Status ZeroCopyBlobs(const std::vector<void *> &addr_list, const std::vector<int64_t> &size_list, | |||||
| Status ZeroCopyBlobs(const std::map<uint32_t, std::pair<int64_t, void *>> &data_info, | |||||
| const std::vector<DataBuffer> &blobs, bool is_dynamic_input, ZeroCopyMode zero_copy_mode, | const std::vector<DataBuffer> &blobs, bool is_dynamic_input, ZeroCopyMode zero_copy_mode, | ||||
| string batch_label); | string batch_label); | ||||
| @@ -610,11 +588,9 @@ class DavinciModel { | |||||
| /// @brief Data Op Initialize. | /// @brief Data Op Initialize. | ||||
| /// @param [in] NodePtr: Data Op. | /// @param [in] NodePtr: Data Op. | ||||
| /// @param [in/out] data_op_index: NetOutput addr size info. | /// @param [in/out] data_op_index: NetOutput addr size info. | ||||
| /// @param [in/out] input_data_info: Data index and addr info {index, {size, addr}}. | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status InitDataOp(const NodePtr &node, uint32_t &data_op_index, | |||||
| std::map<uint32_t, std::pair<int64_t, void *>> &input_data_info); | |||||
| Status InitDataOp(const NodePtr &node, uint32_t &data_op_index); | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| @@ -632,20 +608,28 @@ class DavinciModel { | |||||
| /// | /// | ||||
| Status InitNetOutput(const OpDescPtr &op_desc); | Status InitNetOutput(const OpDescPtr &op_desc); | ||||
| /// | |||||
| /// @ingroup ge | |||||
| /// @brief Make Input and Output addr for feature use. | |||||
| /// @param [in] input_data_info: Data index and addr info {index, {size, addr}}. | |||||
| /// @return Status | |||||
| /// | |||||
| Status CombineDataInfo(const std::map<uint32_t, std::pair<int64_t, void *>> &input_data_info); | |||||
| /// | /// | ||||
| /// @ingroup domi_ome | /// @ingroup domi_ome | ||||
| /// @brief Constant Op Init. | /// @brief Constant Op Init. | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status InitConstant(const ConstOpDescPtr &op_desc) const; | |||||
| Status InitConstant(const OpDescPtr &op_desc); | |||||
| Status InitVariable(const OpDescPtr &op_desc); | |||||
| Status InitEndGraph(const OpDescPtr &op_desc); | |||||
| /// @ingroup ge | |||||
| /// @brief LabelSet Op Initialize. | |||||
| /// @param [in] op_desc: LabelSet Op descriptor. | |||||
| /// @return Status | |||||
| Status InitLabelSet(const OpDescPtr &op_desc); | |||||
| Status InitStreamSwitch(const OpDescPtr &op_desc); | |||||
| Status InitStreamActive(const OpDescPtr &op_desc); | |||||
| Status InitStreamSwitchN(const OpDescPtr &op_desc); | |||||
| /// | /// | ||||
| /// @ingroup domi_ome | /// @ingroup domi_ome | ||||
| @@ -662,7 +646,7 @@ class DavinciModel { | |||||
| /// @brief Init model stream for NN model. | /// @brief Init model stream for NN model. | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status InitModelStream(rtStream_t stream, bool async_mode); | |||||
| Status InitModelStream(rtStream_t stream); | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| @@ -678,12 +662,16 @@ class DavinciModel { | |||||
| /// | /// | ||||
| Status BindInputQueue(); | Status BindInputQueue(); | ||||
| Status CpuTaskModelZeroCopy(std::vector<uintptr_t> &mbuf_list, | |||||
| std::map<const void *, std::vector<void *>> &outside_addrs); | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief ACL, Bind NetOutput Op addr to output queue. | /// @brief ACL, Bind NetOutput Op addr to output queue. | ||||
| /// @return: 0 for success / others for fail | /// @return: 0 for success / others for fail | ||||
| /// | /// | ||||
| Status BindOutputQueue(); | Status BindOutputQueue(); | ||||
| Status CpuModelPrepareOutput(uintptr_t addr, uint32_t size); | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| @@ -692,13 +680,6 @@ class DavinciModel { | |||||
| /// | /// | ||||
| Status BindActiveStream(); | Status BindActiveStream(); | ||||
| /// | |||||
| /// @ingroup domi_ome | |||||
| /// @brief insert active_stream_indication_ | |||||
| /// @return Status | |||||
| /// | |||||
| Status MarkActiveStream(const OpDescPtr &op_desc); | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief definiteness queue schedule, bind input queue to task. | /// @brief definiteness queue schedule, bind input queue to task. | ||||
| @@ -707,7 +688,7 @@ class DavinciModel { | |||||
| /// @param [in] size: Data Op output tensor size. | /// @param [in] size: Data Op output tensor size. | ||||
| /// @return: 0 for success / others for fail | /// @return: 0 for success / others for fail | ||||
| /// | /// | ||||
| Status CpuModelDequeue(uint32_t queue_id, uintptr_t addr, uint32_t size); | |||||
| Status CpuModelDequeue(uint32_t queue_id); | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| @@ -734,6 +715,8 @@ class DavinciModel { | |||||
| /// | /// | ||||
| Status CpuWaitEndGraph(); | Status CpuWaitEndGraph(); | ||||
| Status BindEnqueue(); | |||||
| Status CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf); | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief definiteness queue schedule, repeat run model. | /// @brief definiteness queue schedule, repeat run model. | ||||
| @@ -783,10 +766,8 @@ class DavinciModel { | |||||
| vector<OpDescPtr> variable_op_list_; | vector<OpDescPtr> variable_op_list_; | ||||
| vector<int64_t> output_size_list_; // Init by NetOutput Input Tensor | |||||
| vector<void *> output_addr_list_; // Init by NetOutput Input Tensor | |||||
| vector<int64_t> input_size_list_; // Init by Data Output Tensor | |||||
| vector<void *> input_addr_list_; // Init by Data Output Tensor | |||||
| std::map<uint32_t, std::pair<int64_t, void *>> input_data_info_; // Init by Data Output Tensor | |||||
| std::map<uint32_t, std::pair<int64_t, void *>> output_data_info_; // Init by NetOutput Input Tensor | |||||
| // output op: save cce op actual needed memory size | // output op: save cce op actual needed memory size | ||||
| vector<int64_t> output_memory_size_list_; | vector<int64_t> output_memory_size_list_; | ||||
| @@ -813,6 +794,7 @@ class DavinciModel { | |||||
| vector<rtEvent_t> event_list_; | vector<rtEvent_t> event_list_; | ||||
| vector<rtLabel_t> label_list_; | vector<rtLabel_t> label_list_; | ||||
| set<uint32_t> label_id_indication_; | |||||
| std::mutex outside_addrs_mutex_; | std::mutex outside_addrs_mutex_; | ||||
| std::map<const void *, std::vector<void *>> input_outside_addrs_; | std::map<const void *, std::vector<void *>> input_outside_addrs_; | ||||
| @@ -830,6 +812,8 @@ class DavinciModel { | |||||
| bool is_inner_model_stream_; | bool is_inner_model_stream_; | ||||
| bool is_async_mode_; // For NN execute, Async mode use rtMemcpyAsync on rt_model_stream_. | |||||
| // ACL queue schedule, save queue ids for Init. | // ACL queue schedule, save queue ids for Init. | ||||
| std::vector<TaskInfoPtr> cpu_task_list_; | std::vector<TaskInfoPtr> cpu_task_list_; | ||||
| std::vector<uint32_t> input_queue_ids_; // input queue ids created by caller. | std::vector<uint32_t> input_queue_ids_; // input queue ids created by caller. | ||||
| @@ -847,8 +831,6 @@ class DavinciModel { | |||||
| uint32_t device_id_; | uint32_t device_id_; | ||||
| bool is_train_mode_; | |||||
| std::mutex flowctrl_op_index_internal_map_mutex_; | std::mutex flowctrl_op_index_internal_map_mutex_; | ||||
| std::map<uint32_t, uint32_t> flowctrl_op_index_internal_map_; | std::map<uint32_t, uint32_t> flowctrl_op_index_internal_map_; | ||||
| std::set<uint32_t> active_stream_indication_; | std::set<uint32_t> active_stream_indication_; | ||||
| @@ -358,26 +358,17 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<Tensor | |||||
| input_data.timestamp = 0; | input_data.timestamp = 0; | ||||
| input_data.index = 0; | input_data.index = 0; | ||||
| std::size_t index = 0; | |||||
| for (const auto &op : model->GetDataList()) { | |||||
| GE_CHECK_NOTNULL(op); | |||||
| GE_CHECK_GE(inputs.size(), 1); | |||||
| GE_CHECK_GE(inputs.size() - 1, index); | |||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| DataBuffer data; | DataBuffer data; | ||||
| data.data = inputs[index].data.data; | |||||
| data.length = inputs[index].data.length; | |||||
| data.data = inputs[i].data.data; | |||||
| data.length = inputs[i].data.length; | |||||
| input_data.blobs.push_back(data); | input_data.blobs.push_back(data); | ||||
| index++; | |||||
| } | } | ||||
| CHECK_FALSE_EXEC(input_data.blobs.size() >= inputs.size(), | |||||
| GELOGW("cur_inputs size = %zu, inputs size = %zu.", input_data.blobs.size(), inputs.size());); | |||||
| OutputData output_data; | OutputData output_data; | ||||
| output_data.model_id = model_id; | output_data.model_id = model_id; | ||||
| output_data.index = 0; | output_data.index = 0; | ||||
| for (size_t i = 0; i < outputs.size(); i++) { | |||||
| for (size_t i = 0; i < outputs.size(); ++i) { | |||||
| DataBuffer data; | DataBuffer data; | ||||
| data.data = outputs[i].data.data; | data.data = outputs[i].data.data; | ||||
| data.length = outputs[i].data.length; | data.length = outputs[i].data.length; | ||||
| @@ -675,6 +666,15 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model | |||||
| break; | break; | ||||
| } | } | ||||
| davinci_model->SetId(model_id); | davinci_model->SetId(model_id); | ||||
| int32_t device_id = 0; | |||||
| rtError_t rt_ret = rtGetDevice(&device_id); | |||||
| if (rt_ret != RT_ERROR_NONE || device_id < 0) { | |||||
| GELOGE(RT_FAILED, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id); | |||||
| return FAILED; | |||||
| } | |||||
| davinci_model->SetDeviceId(device_id); | |||||
| ret = davinci_model->Init(dev_ptr, mem_size, weight_ptr, weight_size); | ret = davinci_model->Init(dev_ptr, mem_size, weight_ptr, weight_size); | ||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, break, "DavinciInit failed."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, break, "DavinciInit failed."); | ||||
| @@ -51,27 +51,6 @@ bool ModelUtils::IsOutput(ConstOpDescPtr op_desc) { | |||||
| return false; | return false; | ||||
| } | } | ||||
| /// | |||||
| /// @ingroup domi_ome | |||||
| /// @brief Check is the Input need trans code. | |||||
| /// @return bool | |||||
| /// | |||||
| bool ModelUtils::IsInputTensorNeedTrans(ConstOpDescPtr op_desc, size_t tensor_index) { | |||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return false); | |||||
| const auto &input_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(tensor_index)); | |||||
| const auto &output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(tensor_index)); | |||||
| GE_CHECK_NOTNULL_EXEC(input_desc, return false); | |||||
| GE_CHECK_NOTNULL_EXEC(output_desc, return false); | |||||
| if ((output_desc->GetFormat() == FORMAT_NC1HWC0) && (output_desc->GetDataType() == DT_INT8)) { | |||||
| // AIPP input, add attribute in data op to tag aipp | |||||
| return false; | |||||
| } | |||||
| return (input_desc->GetFormat() != output_desc->GetFormat()) || | |||||
| (input_desc->GetDataType() != output_desc->GetDataType()); | |||||
| } | |||||
| /// | /// | ||||
| /// @ingroup domi_ome | /// @ingroup domi_ome | ||||
| /// @brief Get input size. | /// @brief Get input size. | ||||
| @@ -398,6 +377,8 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co | |||||
| GE_CHK_STATUS(TensorUtils::GetDataOffset(tensor_desc, data_offset)); | GE_CHK_STATUS(TensorUtils::GetDataOffset(tensor_desc, data_offset)); | ||||
| uint8_t *weight_addr = static_cast<uint8_t *>(weight_base + data_offset - logic_weight_base); | uint8_t *weight_addr = static_cast<uint8_t *>(weight_base + data_offset - logic_weight_base); | ||||
| v_input_data_addr.push_back(weight_addr); | v_input_data_addr.push_back(weight_addr); | ||||
| GELOGI("[IMAS]GetInputDataAddrs graph_%u type[C] name[%s] input[%zu] memaddr[%p]", model_param.graph_id, | |||||
| op_desc->GetName().c_str(), i, weight_addr); | |||||
| }); | }); | ||||
| non_const_index++; | non_const_index++; | ||||
| continue; | continue; | ||||
| @@ -411,7 +392,10 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co | |||||
| non_const_index++; | non_const_index++; | ||||
| GE_IF_BOOL_EXEC(var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(input_offset), | GE_IF_BOOL_EXEC(var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(input_offset), | ||||
| uint8_t *variable_addr = var_base + input_offset - logic_var_base; | uint8_t *variable_addr = var_base + input_offset - logic_var_base; | ||||
| v_input_data_addr.push_back(variable_addr); continue;); | |||||
| v_input_data_addr.push_back(variable_addr); | |||||
| GELOGI("[IMAS]GetInputDataAddrs graph_%u type[V] name[%s] input[%lu] memaddr[%p]", | |||||
| model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr); | |||||
| continue;); | |||||
| bool input_tensor = false; | bool input_tensor = false; | ||||
| GE_IF_BOOL_EXEC(TensorUtils::GetInputTensor(op_desc->GetOutputDesc(i), input_tensor) != GRAPH_SUCCESS, | GE_IF_BOOL_EXEC(TensorUtils::GetInputTensor(op_desc->GetOutputDesc(i), input_tensor) != GRAPH_SUCCESS, | ||||
| @@ -421,12 +405,14 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co | |||||
| uint8_t *mem_addr = nullptr; | uint8_t *mem_addr = nullptr; | ||||
| // l1 fusion | // l1 fusion | ||||
| if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { | if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { | ||||
| mem_addr = reinterpret_cast<uint8_t *>(input_offset); | |||||
| mem_addr = reinterpret_cast<uint8_t *>(reinterpret_cast<intptr_t>(input_offset)); | |||||
| v_input_data_addr.push_back(mem_addr); | v_input_data_addr.push_back(mem_addr); | ||||
| } else { | } else { | ||||
| mem_addr = static_cast<uint8_t *>(mem_base + input_offset - logic_mem_base); | mem_addr = static_cast<uint8_t *>(mem_base + input_offset - logic_mem_base); | ||||
| v_input_data_addr.push_back(mem_addr); | v_input_data_addr.push_back(mem_addr); | ||||
| } | } | ||||
| GELOGI("[IMAS]GetInputDataAddrs graph_%u type[F] name[%s] input[%zu] memaddr[%p]", model_param.graph_id, | |||||
| op_desc->GetName().c_str(), i, mem_addr); | |||||
| } | } | ||||
| return v_input_data_addr; | return v_input_data_addr; | ||||
| @@ -487,12 +473,14 @@ vector<void *> ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C | |||||
| uint8_t *mem_addr = nullptr; | uint8_t *mem_addr = nullptr; | ||||
| // l1 fusion | // l1 fusion | ||||
| if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { | if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { | ||||
| mem_addr = reinterpret_cast<uint8_t *>(v_output_offset[i]); | |||||
| mem_addr = reinterpret_cast<uint8_t *>(reinterpret_cast<intptr_t>(v_output_offset[i])); | |||||
| v_output_data_addr.push_back(mem_addr); | v_output_data_addr.push_back(mem_addr); | ||||
| } else { | } else { | ||||
| mem_addr = static_cast<uint8_t *>(mem_base + v_output_offset[i] - logic_mem_base); | mem_addr = static_cast<uint8_t *>(mem_base + v_output_offset[i] - logic_mem_base); | ||||
| v_output_data_addr.push_back(mem_addr); | v_output_data_addr.push_back(mem_addr); | ||||
| } | } | ||||
| GELOGI("[IMAS]GetOutputDataAddrs graph_%u type[F] name[%s] output[%zu] memaddr[%p]", model_param.graph_id, | |||||
| op_desc->GetName().c_str(), i, mem_addr); | |||||
| } | } | ||||
| return v_output_data_addr; | return v_output_data_addr; | ||||
| } | } | ||||
| @@ -530,7 +518,7 @@ vector<void *> ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param | |||||
| if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { | if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { | ||||
| v_workspace_data_addr.push_back(reinterpret_cast<uint8_t *>(v_workspace_offset[i])); | v_workspace_data_addr.push_back(reinterpret_cast<uint8_t *>(v_workspace_offset[i])); | ||||
| GELOGI("L1Fusion: op: %s, GetWorkspaceDataAddrs mem_addr[workspace index %zu]:%p", op_desc->GetName().c_str(), i, | GELOGI("L1Fusion: op: %s, GetWorkspaceDataAddrs mem_addr[workspace index %zu]:%p", op_desc->GetName().c_str(), i, | ||||
| reinterpret_cast<uint8_t *>(v_workspace_offset[i])); | |||||
| reinterpret_cast<uint8_t *>(reinterpret_cast<intptr_t>(v_workspace_offset[i]))); | |||||
| } else { | } else { | ||||
| int64_t workspace_offset = v_workspace_offset[i]; | int64_t workspace_offset = v_workspace_offset[i]; | ||||
| int64_t workspace_bytes = v_workspace_bytes[i]; | int64_t workspace_bytes = v_workspace_bytes[i]; | ||||
| @@ -558,6 +546,7 @@ Status ModelUtils::ConvertVirtualAddressToPhysical(uint8_t *virtual_address, uin | |||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| GELOGD("virtual_address=%p, physical_address=%p", virtual_address, physical_address); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -40,13 +40,6 @@ class ModelUtils { | |||||
| /// | /// | ||||
| static bool IsOutput(ConstOpDescPtr op_desc); | static bool IsOutput(ConstOpDescPtr op_desc); | ||||
| /// | |||||
| /// @ingroup domi_ome | |||||
| /// @brief Check is the Input need trans code. | |||||
| /// @return bool | |||||
| /// | |||||
| static bool IsInputTensorNeedTrans(ConstOpDescPtr op_desc, size_t tensor_index); | |||||
| /// | /// | ||||
| /// @ingroup domi_ome | /// @ingroup domi_ome | ||||
| /// @brief Get input size. | /// @brief Get input size. | ||||
| @@ -38,6 +38,7 @@ Status EndGraphTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
| } | } | ||||
| model_ = davinci_model->GetRtModelHandle(); | model_ = davinci_model->GetRtModelHandle(); | ||||
| GELOGI("InitEndGraphTaskInfo Init Success, model:%p, stream:%p", model_, stream_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -125,6 +125,7 @@ Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_m | |||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| GELOGD("hccl_stream addr is=%p", stream); | |||||
| hccl_stream_list_.push_back(stream); | hccl_stream_list_.push_back(stream); | ||||
| davinci_model->PushHcclStream(stream); | davinci_model->PushHcclStream(stream); | ||||
| } | } | ||||
| @@ -245,6 +246,8 @@ void HcclTaskInfo::GetPrivateDefByTaskDef(const domi::TaskDef &task) { | |||||
| GELOGE(RT_FAILED, "Call rtMemcpy Fail, ret = 0x%X.", ret); | GELOGE(RT_FAILED, "Call rtMemcpy Fail, ret = 0x%X.", ret); | ||||
| return; | return; | ||||
| } | } | ||||
| GELOGI("The first address of the custom info, privateDef=%p.", private_def_); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -41,6 +41,7 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
| } | } | ||||
| auto kernel_ex_def = task_def.kernel_ex(); | auto kernel_ex_def = task_def.kernel_ex(); | ||||
| const RuntimeParam &rts_param = davinci_model->GetRuntimeParam(); | |||||
| // 1. Copy context from kernelExDef.private to workspace | // 1. Copy context from kernelExDef.private to workspace | ||||
| uint32_t op_index = kernel_ex_def.op_index(); | uint32_t op_index = kernel_ex_def.op_index(); | ||||
| @@ -50,12 +51,12 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| if (CopyTaskInfo(kernel_ex_def, davinci_model->GetRuntimeParam(), op_desc) != SUCCESS) { | |||||
| if (CopyTaskInfo(kernel_ex_def, rts_param, op_desc) != SUCCESS) { | |||||
| GELOGE(FAILED, "copy task info to workspace failed."); | GELOGE(FAILED, "copy task info to workspace failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| vector<void *> workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(davinci_model->GetRuntimeParam(), op_desc); | |||||
| const vector<void *> workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); | |||||
| if (workspace_data_addrs.empty()) { | if (workspace_data_addrs.empty()) { | ||||
| GELOGE(FAILED, "workspace_data_addrs is empty."); | GELOGE(FAILED, "workspace_data_addrs is empty."); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -79,16 +80,16 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
| uint64_t step_id_addr = 0; | uint64_t step_id_addr = 0; | ||||
| OpDescPtr step_id_node = davinci_model->GetVariableOp(NODE_NAME_GLOBAL_STEP); | OpDescPtr step_id_node = davinci_model->GetVariableOp(NODE_NAME_GLOBAL_STEP); | ||||
| if (step_id_node != nullptr) { | if (step_id_node != nullptr) { | ||||
| vector<void *> v_step_id_addr = ModelUtils::GetOutputDataAddrs(davinci_model->GetRuntimeParam(), step_id_node); | |||||
| vector<void *> v_step_id_addr = ModelUtils::GetOutputDataAddrs(rts_param, step_id_node); | |||||
| if (!v_step_id_addr.empty()) { | if (!v_step_id_addr.empty()) { | ||||
| step_id_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(v_step_id_addr[0])); | step_id_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(v_step_id_addr[0])); | ||||
| } | } | ||||
| } | } | ||||
| // 3. Set workspaceaddr, inputOutputDataAddr | // 3. Set workspaceaddr, inputOutputDataAddr | ||||
| uint64_t workspace_base_addr = reinterpret_cast<uint64_t>(workspace_data_addrs[0]); | |||||
| vector<void *> input_addrs = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); | |||||
| vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); | |||||
| uint64_t workspace_base_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(workspace_data_addrs[0])); | |||||
| const vector<void *> input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); | |||||
| const vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); | |||||
| vector<void *> io_addrs; | vector<void *> io_addrs; | ||||
| io_addrs.insert(io_addrs.end(), input_addrs.begin(), input_addrs.end()); | io_addrs.insert(io_addrs.end(), input_addrs.begin(), input_addrs.end()); | ||||
| io_addrs.insert(io_addrs.end(), output_addrs.begin(), output_addrs.end()); | io_addrs.insert(io_addrs.end(), output_addrs.begin(), output_addrs.end()); | ||||
| @@ -132,7 +133,13 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
| rt_ret = rtMemcpy(kernel_buf_, sizeof(STR_FWK_OP_KERNEL), static_cast<void *>(&fwk_op_kernel), | rt_ret = rtMemcpy(kernel_buf_, sizeof(STR_FWK_OP_KERNEL), static_cast<void *>(&fwk_op_kernel), | ||||
| sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_DEVICE); | sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_DEVICE); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) | ||||
| davinci_model->SetZeroCopyAddr(op_desc, io_addrs, input_output_addr_); | |||||
| vector<void *> virtual_io_addrs; // use virtual address for zero copy key. | |||||
| const vector<void *> virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); | |||||
| const vector<void *> virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); | |||||
| virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); | |||||
| virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); | |||||
| davinci_model->SetZeroCopyAddr(op_desc, virtual_io_addrs, input_output_addr_); | |||||
| kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); | kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); | ||||
| davinci_model_ = davinci_model; | davinci_model_ = davinci_model; | ||||
| @@ -25,6 +25,7 @@ class KernelExTaskInfo : public TaskInfo { | |||||
| public: | public: | ||||
| KernelExTaskInfo() | KernelExTaskInfo() | ||||
| : task_id_(0), | : task_id_(0), | ||||
| stream_id_(0), | |||||
| dump_flag_(RT_KERNEL_DEFAULT), | dump_flag_(RT_KERNEL_DEFAULT), | ||||
| kernel_buf_size_(0), | kernel_buf_size_(0), | ||||
| davinci_model_(nullptr), | davinci_model_(nullptr), | ||||
| @@ -221,13 +221,13 @@ Status KernelTaskInfo::SuperKernelLaunch() { | |||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| // Call the fuse API | // Call the fuse API | ||||
| skt::SuperKernel *superKernel; | |||||
| skt::SuperKernel *superKernel = nullptr; | |||||
| if (factory->FuseKernels(skt_kernel_list, skt_arg_list, skt_info_.last_block_dim, superKernel) != SUCCESS) { | if (factory->FuseKernels(skt_kernel_list, skt_arg_list, skt_info_.last_block_dim, superKernel) != SUCCESS) { | ||||
| GELOGE(RT_FAILED, "SuperKernelLaunch: fuse call failed"); | GELOGE(RT_FAILED, "SuperKernelLaunch: fuse call failed"); | ||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| // Launch a super kernel | // Launch a super kernel | ||||
| if (superKernel->Launch(skt_info_.last_stream, true) != SUCCESS) { | |||||
| if (superKernel->Launch(skt_info_.last_stream, RT_KERNEL_DUMPFLAG) != SUCCESS) { | |||||
| GELOGE(RT_FAILED, "SuperKernelLaunch: launch failed"); | GELOGE(RT_FAILED, "SuperKernelLaunch: launch failed"); | ||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| @@ -341,6 +341,7 @@ Status KernelTaskInfo::Distribute() { | |||||
| rtError_t rt_ret = RT_ERROR_NONE; | rtError_t rt_ret = RT_ERROR_NONE; | ||||
| char *skt_enable_env = getenv("SKT_ENABLE"); | char *skt_enable_env = getenv("SKT_ENABLE"); | ||||
| int64_t env_flag = (skt_enable_env != nullptr) ? strtol(skt_enable_env, nullptr, 10) : 0; | int64_t env_flag = (skt_enable_env != nullptr) ? strtol(skt_enable_env, nullptr, 10) : 0; | ||||
| bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); | |||||
| if (kernel_type_ == cce::ccKernelType::AI_CPU) { | if (kernel_type_ == cce::ccKernelType::AI_CPU) { | ||||
| // blockDim is reserved parameter, set to 1 | // blockDim is reserved parameter, set to 1 | ||||
| rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(so_name_.c_str()), | rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(so_name_.c_str()), | ||||
| @@ -348,11 +349,10 @@ Status KernelTaskInfo::Distribute() { | |||||
| nullptr, stream_, dump_flag_); | nullptr, stream_, dump_flag_); | ||||
| } else { | } else { | ||||
| /* default: not skt launch */ | /* default: not skt launch */ | ||||
| bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); | |||||
| GELOGI( | GELOGI( | ||||
| "KernelTaskInfo Distribute Start, sktenable:%ld taskid:%u sktid:%u last_sktid:%u stubfunc_name:%s " | |||||
| "KernelTaskInfo Distribute Start, sktenable:%d taskid:%u sktid:%u last_sktid:%u stubfunc_name:%s " | |||||
| "stubfunc:%p blockdim:%u stream:%p", | "stubfunc:%p blockdim:%u stream:%p", | ||||
| env_flag, task_id_, skt_id_, skt_info_.last_task_id, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); | |||||
| call_skt, task_id_, skt_id_, skt_info_.last_task_id, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); | |||||
| // l1 fusion enable and env flag open (kCloseSkt for skt debug) | // l1 fusion enable and env flag open (kCloseSkt for skt debug) | ||||
| if (call_skt && (env_flag != kCloseSkt)) { | if (call_skt && (env_flag != kCloseSkt)) { | ||||
| GE_RETURN_IF_ERROR(SuperKernelDistribute()); | GE_RETURN_IF_ERROR(SuperKernelDistribute()); | ||||
| @@ -371,7 +371,7 @@ Status KernelTaskInfo::Distribute() { | |||||
| GELOGI( | GELOGI( | ||||
| "KernelTaskInfo Distribute Success. sktenable:%d taskid:%d sktid:%d stubfunc_name:%s stubfunc:%p " | "KernelTaskInfo Distribute Success. sktenable:%d taskid:%d sktid:%d stubfunc_name:%s stubfunc:%p " | ||||
| "blockdim:%d stream:%p", | "blockdim:%d stream:%p", | ||||
| env_flag, task_id_, skt_id_, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); | |||||
| call_skt, task_id_, skt_id_, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -423,12 +423,12 @@ Status KernelTaskInfo::InitTVMTask(DavinciModel *davinci_model, uint16_t offset, | |||||
| stub_func_ = const_cast<char *>(bin_file_key); | stub_func_ = const_cast<char *>(bin_file_key); | ||||
| } | } | ||||
| const vector<void *> input_data_addrs = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); | |||||
| const vector<void *> output_data_addrs = ModelUtils::GetOutputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); | |||||
| const vector<void *> workspace_data_addrs = | |||||
| ModelUtils::GetWorkspaceDataAddrs(davinci_model->GetRuntimeParam(), op_desc); | |||||
| vector<void *> tensor_device_addrs; | |||||
| const RuntimeParam &rts_param = davinci_model->GetRuntimeParam(); | |||||
| const vector<void *> input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); | |||||
| const vector<void *> output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); | |||||
| const vector<void *> workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); | |||||
| vector<void *> tensor_device_addrs; | |||||
| tensor_device_addrs.insert(tensor_device_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); | tensor_device_addrs.insert(tensor_device_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); | ||||
| tensor_device_addrs.insert(tensor_device_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); | tensor_device_addrs.insert(tensor_device_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); | ||||
| tensor_device_addrs.insert(tensor_device_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); | tensor_device_addrs.insert(tensor_device_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); | ||||
| @@ -468,7 +468,13 @@ Status KernelTaskInfo::InitTVMTask(DavinciModel *davinci_model, uint16_t offset, | |||||
| reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(args_) + offset + sizeof(void *) * input_data_addrs.size()); | reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(args_) + offset + sizeof(void *) * input_data_addrs.size()); | ||||
| } | } | ||||
| davinci_model_->SetZeroCopyAddr(op_desc, tensor_device_addrs, static_cast<char *>(args_) + offset); | |||||
| vector<void *> virtual_io_addrs; // use virtual address for zero copy key. | |||||
| const vector<void *> virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); | |||||
| const vector<void *> virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); | |||||
| virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); | |||||
| virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); | |||||
| davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, static_cast<char *>(args_) + offset); | |||||
| // update origin l2 data | // update origin l2 data | ||||
| string sm_desc = kernel_def.sm_desc(); | string sm_desc = kernel_def.sm_desc(); | ||||
| char *sm_contrl = nullptr; | char *sm_contrl = nullptr; | ||||
| @@ -516,6 +522,7 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::map<uint32_t, std::shared_ | |||||
| } | } | ||||
| auto op_desc = iter->second; | auto op_desc = iter->second; | ||||
| const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); | |||||
| const domi::KernelContext &context = kernel_def.context(); | const domi::KernelContext &context = kernel_def.context(); | ||||
| const uint32_t kCustomAicpuArgsLen = 5; | const uint32_t kCustomAicpuArgsLen = 5; | ||||
| @@ -534,11 +541,8 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::map<uint32_t, std::shared_ | |||||
| ctx_.argsOffset[i] = (reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data())))[i]; | ctx_.argsOffset[i] = (reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data())))[i]; | ||||
| } | } | ||||
| const std::vector<void *> input_data_addrs = | |||||
| ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); | |||||
| const std::vector<void *> output_data_addrs = | |||||
| ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); | |||||
| const std::vector<void *> input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); | |||||
| const std::vector<void *> output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); | |||||
| Status ret = StoreInputOutputTensor(input_data_addrs, output_data_addrs, ModelUtils::GetInputDescs(op_desc), | Status ret = StoreInputOutputTensor(input_data_addrs, output_data_addrs, ModelUtils::GetInputDescs(op_desc), | ||||
| ModelUtils::GetOutputDescs(op_desc)); | ModelUtils::GetOutputDescs(op_desc)); | ||||
| @@ -583,15 +587,15 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::map<uint32_t, std::shared_ | |||||
| } | } | ||||
| } | } | ||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[0])) = | *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[0])) = | ||||
| reinterpret_cast<uint64_t>(custom_info_.input_descs); // arg 0 | |||||
| reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_descs)); // arg 0 | |||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[1])) = | *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[1])) = | ||||
| reinterpret_cast<uint64_t>(custom_info_.input_addrs); // arg 1 | |||||
| reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_addrs)); // arg 1 | |||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[2])) = | *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[2])) = | ||||
| reinterpret_cast<uint64_t>(custom_info_.output_descs); // arg 2 | |||||
| reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_descs)); // arg 2 | |||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[3])) = | *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[3])) = | ||||
| reinterpret_cast<uint64_t>(custom_info_.output_addrs); // arg 3 | |||||
| reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_addrs)); // arg 3 | |||||
| *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[4])) = | *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[4])) = | ||||
| reinterpret_cast<uint64_t>(custom_info_.attr_handle); // arg 4 | |||||
| reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.attr_handle)); // arg 4 | |||||
| rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); | rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| @@ -606,8 +610,10 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::map<uint32_t, std::shared_ | |||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| davinci_model_->SetZeroCopyAddr(op_desc, input_data_addrs, custom_info_.input_addrs); | |||||
| davinci_model_->SetZeroCopyAddr(op_desc, output_data_addrs, custom_info_.output_addrs); | |||||
| const vector<void *> virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); | |||||
| const vector<void *> virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); | |||||
| davinci_model_->SetZeroCopyAddr(op_desc, virtual_in_addrs, custom_info_.input_addrs); | |||||
| davinci_model_->SetZeroCopyAddr(op_desc, virtual_out_addrs, custom_info_.output_addrs); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -714,8 +720,10 @@ Status KernelTaskInfo::InitAicpuTask(const std::map<uint32_t, OpDescPtr> &op_lis | |||||
| } | } | ||||
| OpDescPtr op_desc = iter->second; | OpDescPtr op_desc = iter->second; | ||||
| vector<void *> input_addrs = ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); | |||||
| vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); | |||||
| const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); | |||||
| vector<void *> input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); | |||||
| vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); | |||||
| vector<void *> io_addrs; | vector<void *> io_addrs; | ||||
| io_addrs.insert(io_addrs.end(), input_addrs.begin(), input_addrs.end()); | io_addrs.insert(io_addrs.end(), input_addrs.begin(), input_addrs.end()); | ||||
| io_addrs.insert(io_addrs.end(), output_addrs.begin(), output_addrs.end()); | io_addrs.insert(io_addrs.end(), output_addrs.begin(), output_addrs.end()); | ||||
| @@ -752,7 +760,13 @@ Status KernelTaskInfo::InitAicpuTask(const std::map<uint32_t, OpDescPtr> &op_lis | |||||
| sizeof(void *) * input_addrs.size()); | sizeof(void *) * input_addrs.size()); | ||||
| } | } | ||||
| davinci_model_->SetZeroCopyAddr(op_desc, io_addrs, static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead)); | |||||
| vector<void *> virtual_io_addrs; // use virtual address for zero copy key. | |||||
| const vector<void *> virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); | |||||
| const vector<void *> virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); | |||||
| virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); | |||||
| virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); | |||||
| davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, | |||||
| static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead)); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -977,7 +991,7 @@ Status KernelTaskInfo::SetFlowtable(std::string &flowtable, const domi::KernelDe | |||||
| *(reinterpret_cast<uint64_t *>( | *(reinterpret_cast<uint64_t *>( | ||||
| args + (reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data())))[0])) = | args + (reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data())))[0])) = | ||||
| reinterpret_cast<uint64_t>(flowtable_); | |||||
| reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(flowtable_)); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -0,0 +1,149 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/load/new_model_manager/davinci_model.h" | |||||
| namespace ge { | |||||
| Status MemcpyAddrAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||||
| GELOGI("MemcpyAddrAsyncTaskInfo Init Start."); | |||||
| if (davinci_model == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "davinci_model is null!"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| auto memcpy_async_def = task_def.memcpy_async(); | |||||
| uint64_t logic_dst = memcpy_async_def.dst(); | |||||
| uint64_t logic_src = memcpy_async_def.src(); | |||||
| dst_max_ = memcpy_async_def.dst_max(); | |||||
| uint64_t update_base_addr = 0; | |||||
| ret = GetUpdateBaseAddr(davinci_model, logic_src, update_base_addr); | |||||
| if (ret != SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| src_ = reinterpret_cast<uint8_t *>(update_base_addr + logic_src); | |||||
| if (src_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "src_ is null!"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| uint64_t mem_base = reinterpret_cast<uint64_t>(davinci_model->MemBase()); | |||||
| uint64_t logic_mem_base = davinci_model->GetRtBaseAddr(); | |||||
| dst_ = reinterpret_cast<uint8_t *>(mem_base + (logic_dst - logic_mem_base)); | |||||
| if (dst_ == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "dst_ is null!"); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| count_ = memcpy_async_def.count(); | |||||
| kind_ = memcpy_async_def.kind(); | |||||
| // malloc args memory | |||||
| size_t args_size = sizeof(void *); | |||||
| rtError_t rt_ret = rtMalloc(&args_, args_size * 2, RT_MEMORY_HBM); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| // copy orign src | |||||
| GELOGI("src_args:%p, destMax:%zu, src_:%p, count=%zu, kind=%u", args_, args_size, src_, args_size, | |||||
| RT_MEMCPY_HOST_TO_DEVICE); | |||||
| rt_ret = rtMemcpy(args_, args_size, &src_, args_size, RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api for src failed, ret: 0x%X", rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| // copy orign dst | |||||
| GELOGI("dst_args:%p, destMax:%zu, dst_:%p, count=%zu, kind=%u", | |||||
| reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(args_) + args_size), args_size, dst_, args_size, | |||||
| RT_MEMCPY_HOST_TO_DEVICE); | |||||
| rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(args_) + args_size), args_size, &dst_, | |||||
| args_size, RT_MEMCPY_HOST_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api for dst failed, ret: 0x%X", rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| GELOGI("InitMemcpyAddrAsyncTaskInfo, logic_src:%p, logic_dst:%p, src:%p, dst:%p, src_args:%p, dst_args:%p", | |||||
| reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(logic_src)), | |||||
| reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(logic_dst)), src_, dst_, args_, | |||||
| reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(args_) + args_size)); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MemcpyAddrAsyncTaskInfo::Distribute() { | |||||
| GELOGI("MemcpyAddrAsyncTaskInfo Distribute Start."); | |||||
| GELOGI("Distribute MemcpyAddrAsync, dst_max:%lu, count:%lu, kind:%u.", dst_max_, count_, kind_); | |||||
| rtError_t rt_ret = rtMemcpyAsync(reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(args_) + sizeof(void *)), | |||||
| dst_max_, args_, count_, static_cast<rtMemcpyKind_t>(kind_), stream_); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MemcpyAddrAsyncTaskInfo::GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, | |||||
| uint64_t &base_addr) { | |||||
| GE_CHECK_NOTNULL(davinci_model); | |||||
| uint64_t data_base_addr = | |||||
| reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(davinci_model->MemBase())) - davinci_model->GetRtBaseAddr(); | |||||
| uint64_t weight_base_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(davinci_model->WeightsMemBase())) - | |||||
| davinci_model->GetRtWeightAddr(); | |||||
| uint64_t var_base_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(davinci_model->VarMemBase())) - | |||||
| davinci_model->GetRtVarAddr(); | |||||
| uint64_t data_base_addr_start = davinci_model->GetRtBaseAddr(); | |||||
| uint64_t data_base_addr_end = davinci_model->GetRtBaseAddr() + davinci_model->TotalMemSize(); | |||||
| uint64_t wight_base_addr_start = davinci_model->GetRtWeightAddr(); | |||||
| uint64_t wight_base_addr_end = davinci_model->GetRtWeightAddr() + davinci_model->TotalWeightsMemSize(); | |||||
| uint64_t varible_base_addr_start = davinci_model->GetRtVarAddr(); | |||||
| uint64_t varible_base_addr_end = davinci_model->GetRtVarAddr() + davinci_model->TotalVarMemSize(); | |||||
| if ((data_base_addr_start <= update_addr) && (update_addr <= data_base_addr_end)) { | |||||
| base_addr = data_base_addr; | |||||
| GELOGI("The update_addr is data address."); | |||||
| } else if ((wight_base_addr_start <= update_addr) && (update_addr <= wight_base_addr_end)) { | |||||
| base_addr = weight_base_addr; | |||||
| GELOGI("The update_addr is weight address."); | |||||
| } else if ((varible_base_addr_start <= update_addr) && (update_addr <= varible_base_addr_end)) { | |||||
| base_addr = var_base_addr; | |||||
| GELOGI("The update_addr is variable address."); | |||||
| } else if (update_addr != 0) { | |||||
| base_addr = 0; | |||||
| GELOGE(PARAM_INVALID, "The update_addr is abnormal."); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| REGISTER_TASK_INFO(RT_MODEL_TASK_MEMCPY_ADDR_ASYNC, MemcpyAddrAsyncTaskInfo); | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,55 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ | |||||
| #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ | |||||
| #include "graph/load/new_model_manager/task_info/task_info.h" | |||||
| namespace ge { | |||||
| class MemcpyAddrAsyncTaskInfo : public TaskInfo { | |||||
| public: | |||||
| MemcpyAddrAsyncTaskInfo() : dst_(nullptr), dst_max_(0), src_(nullptr), args_(nullptr), count_(0), kind_(0) {} | |||||
| ~MemcpyAddrAsyncTaskInfo() override { | |||||
| src_ = nullptr; | |||||
| dst_ = nullptr; | |||||
| if (args_ != nullptr) { | |||||
| rtError_t ret = rtFree(args_); | |||||
| if (ret != RT_ERROR_NONE) { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); | |||||
| } | |||||
| } | |||||
| args_ = nullptr; | |||||
| } | |||||
| Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||||
| Status Distribute() override; | |||||
| private: | |||||
| Status GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, uint64_t &base_addr); | |||||
| void *dst_; | |||||
| uint64_t dst_max_; | |||||
| void *src_; | |||||
| void *args_; | |||||
| uint64_t count_; | |||||
| uint32_t kind_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ | |||||
| @@ -51,6 +51,9 @@ Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da | |||||
| count_ = memcpy_async_def.count(); | count_ = memcpy_async_def.count(); | ||||
| kind_ = memcpy_async_def.kind(); | kind_ = memcpy_async_def.kind(); | ||||
| GELOGI("MemcpyAsyncTaskInfo Init Success, logic_src:%p, logic_dst:%p, src:%p, dst:%p", | |||||
| reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(logic_src)), | |||||
| reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(logic_dst)), src_, dst_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -63,6 +63,8 @@ Status StreamActiveTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *d | |||||
| active_stream_ = davinci_model->GetStreamList()[active_stream_index_list[internal_index]]; | active_stream_ = davinci_model->GetStreamList()[active_stream_index_list[internal_index]]; | ||||
| active_stream_id_ = stream_active_def.active_stream_id(); | active_stream_id_ = stream_active_def.active_stream_id(); | ||||
| GELOGI("InitStreamActiveTaskInfo Init Success, index:%u, activeStream:%p, activeStreamID:%u.", internal_index, | |||||
| active_stream_, active_stream_id_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -74,6 +76,8 @@ Status StreamActiveTaskInfo::Distribute() { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| GELOGI("StreamActiveTaskInfo Distribute Success. activeStreamID:%p.", active_stream_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -95,6 +95,10 @@ Status StreamSwitchTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *d | |||||
| } | } | ||||
| data_type_ = static_cast<rtSwitchDataType_t>(data_type); | data_type_ = static_cast<rtSwitchDataType_t>(data_type); | ||||
| } | } | ||||
| GELOGI("InitStreamSwitchTaskInfo Init Success, cond:%d, trueStream:%p, trueStreamID:%u, datatype:%d.", cond_, | |||||
| true_stream_, true_stream_id_, data_type_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -105,6 +109,8 @@ Status StreamSwitchTaskInfo::Distribute() { | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| GELOGI("StreamSwitchTaskInfo Distribute Success. cond:%d, stream:%p, datatype:%d.", cond_, true_stream_, data_type_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -19,17 +19,17 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace skt { | namespace skt { | ||||
| Status SuperKernel::Launch(rtStream_t stream, bool dump_flag) { | |||||
| Status SuperKernel::Launch(rtStream_t stream, uint32_t dump_flag) { | |||||
| const void *func_stub_ = this->GetFuncStub(); | const void *func_stub_ = this->GetFuncStub(); | ||||
| const void *args[] = {this->GetNavTablePtr(), (const void *)this->GetNavTableSize()}; | |||||
| const void *args[] = {this->GetNavTablePtr(), | |||||
| reinterpret_cast<const void *>(reinterpret_cast<uintptr_t>(this->GetNavTableSize()))}; | |||||
| void *device_args_addr = nullptr; | |||||
| rtError_t rt_ret = rtMalloc((void **)&(device_args_addr), sizeof(args), RT_MEMORY_HBM); | |||||
| rtError_t rt_ret = rtMalloc((void **)&(device_args_addr_), sizeof(args), RT_MEMORY_HBM); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) | ||||
| rt_ret = rtMemcpy((void *)device_args_addr, sizeof(args), (void *)args, sizeof(args), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| rt_ret = rtMemcpy((void *)device_args_addr_, sizeof(args), (void *)args, sizeof(args), RT_MEMCPY_HOST_TO_DEVICE); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) | ||||
| rt_ret = rtKernelLaunchWithFlag((void *const)func_stub_, block_dim_, device_args_addr, sizeof(args), NULL, stream, | |||||
| rt_ret = rtKernelLaunchWithFlag((void *const)func_stub_, block_dim_, device_args_addr_, sizeof(args), NULL, stream, | |||||
| dump_flag); | dump_flag); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelLaunchWithFlag failied. error: 0x%X", rt_ret); | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelLaunchWithFlag failied. error: 0x%X", rt_ret); | ||||
| return FAILED;) | return FAILED;) | ||||
| @@ -25,6 +25,7 @@ namespace ge { | |||||
| namespace skt { | namespace skt { | ||||
| class SuperKernel { | class SuperKernel { | ||||
| private: | private: | ||||
| void *device_args_addr_ = nullptr; | |||||
| const void *func_stub_; | const void *func_stub_; | ||||
| void *dev_nav_table_; | void *dev_nav_table_; | ||||
| uint64_t nav_table_size_; | uint64_t nav_table_size_; | ||||
| @@ -33,8 +34,18 @@ class SuperKernel { | |||||
| public: | public: | ||||
| SuperKernel(const void *stub, void *ptr, uint64_t sz, uint32_t dim) | SuperKernel(const void *stub, void *ptr, uint64_t sz, uint32_t dim) | ||||
| : func_stub_(stub), dev_nav_table_(ptr), nav_table_size_(sz), block_dim_(dim) {} | : func_stub_(stub), dev_nav_table_(ptr), nav_table_size_(sz), block_dim_(dim) {} | ||||
| ~SuperKernel() {} | |||||
| Status Launch(rtStream_t stream, bool dump_flag); | |||||
| ~SuperKernel() { | |||||
| // free memory when all releasing | |||||
| if (device_args_addr_ != nullptr) { | |||||
| GE_CHK_RT(rtFree(device_args_addr_)); | |||||
| GELOGI("SKT: super_kernel args addr free."); | |||||
| } | |||||
| if (dev_nav_table_ != nullptr) { | |||||
| GE_CHK_RT(rtFree(dev_nav_table_)); | |||||
| GELOGI("SKT: super_kernel args addr free."); | |||||
| } | |||||
| } | |||||
| Status Launch(rtStream_t stream, uint32_t dump_flag); | |||||
| const void *GetFuncStub() const { return func_stub_; } | const void *GetFuncStub() const { return func_stub_; } | ||||
| const void *GetNavTablePtr() const { return dev_nav_table_; } | const void *GetNavTablePtr() const { return dev_nav_table_; } | ||||
| uint64_t GetNavTableSize() const { return nav_table_size_; } | uint64_t GetNavTableSize() const { return nav_table_size_; } | ||||
| @@ -30,26 +30,26 @@ Status SuperKernelFactory::Init() { | |||||
| rt_ret = rtGetFunctionByName(this->sk_stub_name_.c_str(), &this->func_stub_); | rt_ret = rtGetFunctionByName(this->sk_stub_name_.c_str(), &this->func_stub_); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, | ||||
| "rtGetFunctionByName " | "rtGetFunctionByName " | ||||
| "failied. stub_func: %s", | |||||
| "failed. stub_func: %s", | |||||
| this->sk_stub_name_.c_str()); | this->sk_stub_name_.c_str()); | ||||
| return FAILED;) | return FAILED;) | ||||
| rt_ret = rtGetAddrByFun(this->func_stub_, &this->func_ptr_); | rt_ret = rtGetAddrByFun(this->func_stub_, &this->func_ptr_); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failied. error: 0x%X", rt_ret); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); | |||||
| return FAILED;) | return FAILED;) | ||||
| if (this->use_physical_address_ != nullptr) { | if (this->use_physical_address_ != nullptr) { | ||||
| void *skt_func = nullptr; | void *skt_func = nullptr; | ||||
| rt_ret = rtKernelConfigTransArg(this->func_ptr_, sizeof(uint64_t), 0, &skt_func); | rt_ret = rtKernelConfigTransArg(this->func_ptr_, sizeof(uint64_t), 0, &skt_func); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); | |||||
| return FAILED;) | return FAILED;) | ||||
| GELOGD( | GELOGD( | ||||
| "SKT: fuseKernels super_kernel_template subFunc %p, device func " | "SKT: fuseKernels super_kernel_template subFunc %p, device func " | ||||
| "address %p, device physic PC %p", | "address %p, device physic PC %p", | ||||
| (uint64_t)this->func_stub_, (uint64_t)this->func_ptr_, (uint64_t)skt_func); | |||||
| this->func_stub_, this->func_ptr_, skt_func); | |||||
| } else { | } else { | ||||
| GELOGD( | GELOGD( | ||||
| "SKT: fuseKernels super_kernel_template subFunc %p, device func " | "SKT: fuseKernels super_kernel_template subFunc %p, device func " | ||||
| "address %p", | "address %p", | ||||
| (uint64_t)this->func_stub_, (uint64_t)this->func_ptr_); | |||||
| this->func_stub_, this->func_ptr_); | |||||
| } | } | ||||
| } | } | ||||
| is_init_ = true; | is_init_ = true; | ||||
| @@ -94,63 +94,66 @@ Status SuperKernelFactory::FuseKernels(const std::vector<void *> &stub_func_list | |||||
| uint64_t nav_table_size = 2 * stub_func_list.size() * sizeof(int64_t); | uint64_t nav_table_size = 2 * stub_func_list.size() * sizeof(int64_t); | ||||
| rtError_t rt_ret; | rtError_t rt_ret; | ||||
| void *hbm_nav_table_addr = nullptr; | |||||
| if (this->use_physical_address_ != nullptr) { | if (this->use_physical_address_ != nullptr) { | ||||
| for (unsigned i = 0; i < stub_func_list.size(); i++) { | for (unsigned i = 0; i < stub_func_list.size(); i++) { | ||||
| void *sub_device_func = nullptr; | void *sub_device_func = nullptr; | ||||
| rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); | rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failied. error: 0x%X", rt_ret); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); | |||||
| return FAILED;) | return FAILED;) | ||||
| void *sub_device_func_pys = nullptr; | void *sub_device_func_pys = nullptr; | ||||
| void *args_addr_pys = nullptr; | void *args_addr_pys = nullptr; | ||||
| rt_ret = rtKernelConfigTransArg(sub_device_func, sizeof(uint64_t), 0, &sub_device_func_pys); | rt_ret = rtKernelConfigTransArg(sub_device_func, sizeof(uint64_t), 0, &sub_device_func_pys); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); | |||||
| return FAILED;) | return FAILED;) | ||||
| rt_ret = rtKernelConfigTransArg(args_addr_list[i], sizeof(uint64_t), 0, &args_addr_pys); | rt_ret = rtKernelConfigTransArg(args_addr_list[i], sizeof(uint64_t), 0, &args_addr_pys); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); | |||||
| return FAILED;) | return FAILED;) | ||||
| GELOGD( | GELOGD( | ||||
| "SKT: fuseKernels subFunc %p, device func address %p, device " | "SKT: fuseKernels subFunc %p, device func address %p, device " | ||||
| "physic func address %p", | "physic func address %p", | ||||
| stub_func_list[i], (uint64_t)sub_device_func, (uint64_t)sub_device_func_pys); | |||||
| nav_table[i * 2] = (uint64_t)sub_device_func_pys / 4; | |||||
| GELOGD("SKT: CALL offet %p", nav_table[i * 2]); | |||||
| nav_table[i * 2 + 1] = (uint64_t)args_addr_pys; | |||||
| stub_func_list[i], sub_device_func, sub_device_func_pys); | |||||
| // store two uint64_t address | |||||
| // address divided by 4 because of 32bits encoding, call offset will *4 when calculating | |||||
| nav_table[i * 2] = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(sub_device_func_pys)) / 4; | |||||
| GELOGD("SKT: CALL offset %p", nav_table[i * 2]); | |||||
| nav_table[i * 2 + 1] = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(args_addr_pys)); | |||||
| GELOGD("SKT: fuseKernels args base address %p", nav_table[i * 2 + 1]); | GELOGD("SKT: fuseKernels args base address %p", nav_table[i * 2 + 1]); | ||||
| } | } | ||||
| void *hbm_nav_table_addr = nullptr; | |||||
| void *hbm_nav_table_addr_pys = nullptr; | void *hbm_nav_table_addr_pys = nullptr; | ||||
| rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); | rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) | |||||
| rt_ret = | rt_ret = | ||||
| rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); | rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); return FAILED;) | |||||
| rt_ret = rtKernelConfigTransArg(hbm_nav_table_addr, sizeof(uint64_t), 0, &hbm_nav_table_addr_pys); | rt_ret = rtKernelConfigTransArg(hbm_nav_table_addr, sizeof(uint64_t), 0, &hbm_nav_table_addr_pys); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); | |||||
| return FAILED;) | return FAILED;) | ||||
| GELOGD("SKT: hbm_nav_table_addr %p, hbm_nav_table_addr_pys %p", (uint64_t)hbm_nav_table_addr, | |||||
| (uint64_t)hbm_nav_table_addr_pys); | |||||
| GELOGD("SKT: hbm_nav_table_addr %p, hbm_nav_table_addr_pys %p", hbm_nav_table_addr, hbm_nav_table_addr_pys); | |||||
| // Create the necessary metadata for the super kernel | // Create the necessary metadata for the super kernel | ||||
| h = new SuperKernel(this->func_stub_, hbm_nav_table_addr_pys, nav_table_size, block_dim); | h = new SuperKernel(this->func_stub_, hbm_nav_table_addr_pys, nav_table_size, block_dim); | ||||
| } else { | } else { | ||||
| for (unsigned i = 0; i < stub_func_list.size(); i++) { | for (unsigned i = 0; i < stub_func_list.size(); i++) { | ||||
| void *sub_device_func = nullptr; | void *sub_device_func = nullptr; | ||||
| rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); | rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failied. error: 0x%X", rt_ret); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); | |||||
| return FAILED;) | return FAILED;) | ||||
| GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], (uint64_t)sub_device_func); | |||||
| nav_table[i * 2] = (uint64_t)sub_device_func / 4; | |||||
| GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); | |||||
| // store two uint64_t address | |||||
| // address divided by 4 because of 32bits encoding, call offset will *4 when calculating | |||||
| nav_table[i * 2] = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(sub_device_func)) / 4; | |||||
| GELOGD("SKT: CALL offet %p", nav_table[i * 2]); | GELOGD("SKT: CALL offet %p", nav_table[i * 2]); | ||||
| nav_table[i * 2 + 1] = (uint64_t)args_addr_list[i]; | |||||
| nav_table[i * 2 + 1] = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(args_addr_list[i])); | |||||
| GELOGD("SKT: fuseKernels args base address %p", nav_table[i * 2 + 1]); | GELOGD("SKT: fuseKernels args base address %p", nav_table[i * 2 + 1]); | ||||
| } | } | ||||
| void *hbm_nav_table_addr = nullptr; | |||||
| rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); | rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) | |||||
| rt_ret = | rt_ret = | ||||
| rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); | rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); return FAILED;) | |||||
| // Create the necessary metadata for the super kernel | // Create the necessary metadata for the super kernel | ||||
| h = new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim); | h = new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim); | ||||
| } | } | ||||
| @@ -31,12 +31,12 @@ class SuperKernelFactory { | |||||
| const char *use_physical_address_ = getenv("GE_USE_PHYSICAL_ADDRESS"); | const char *use_physical_address_ = getenv("GE_USE_PHYSICAL_ADDRESS"); | ||||
| bool is_init_ = false; | bool is_init_ = false; | ||||
| SuperKernelFactory(){}; | SuperKernelFactory(){}; | ||||
| ~SuperKernelFactory(){}; | |||||
| public: | public: | ||||
| SuperKernelFactory(SuperKernelFactory const &) = delete; | SuperKernelFactory(SuperKernelFactory const &) = delete; | ||||
| void operator=(SuperKernelFactory const &) = delete; | void operator=(SuperKernelFactory const &) = delete; | ||||
| static SuperKernelFactory &GetInstance(); | static SuperKernelFactory &GetInstance(); | ||||
| SuperKernelFactory(const std::string &sk_stub_name_, const std::string &bin_file); | |||||
| Status Init(); | Status Init(); | ||||
| Status Uninitialize(); | Status Uninitialize(); | ||||
| Status FuseKernels(const std::vector<void *> &stub_func_list, const std::vector<void *> &args_addr_list, | Status FuseKernels(const std::vector<void *> &stub_func_list, const std::vector<void *> &args_addr_list, | ||||
| @@ -33,6 +33,7 @@ | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "framework/common/ge_types.h" | #include "framework/common/ge_types.h" | ||||
| #include "graph/manager/util/rt_context_util.h" | |||||
| #include "graph/common/transop_util.h" | #include "graph/common/transop_util.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| @@ -117,6 +118,7 @@ Status GraphManager::Initialize(const std::map<string, string> &options) { | |||||
| } | } | ||||
| graph_map_.clear(); | graph_map_.clear(); | ||||
| cache_helper_map_.clear(); | |||||
| init_flag_ = true; | init_flag_ = true; | ||||
| thread_run_flag_ = true; | thread_run_flag_ = true; | ||||
| @@ -180,6 +182,7 @@ Status GraphManager::Finalize() { | |||||
| } | } | ||||
| } | } | ||||
| graph_map_.clear(); | graph_map_.clear(); | ||||
| cache_helper_map_.clear(); | |||||
| // graph context | // graph context | ||||
| if (graph_context_ != nullptr) { | if (graph_context_ != nullptr) { | ||||
| @@ -426,6 +429,13 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<Ge | |||||
| sub_graph_list[0]->SetSubGraph(merged_compute_graph); | sub_graph_list[0]->SetSubGraph(merged_compute_graph); | ||||
| // set subgraphlist to graphnode | // set subgraphlist to graphnode | ||||
| graph_node->SetSubGraph(sub_graph_list); | graph_node->SetSubGraph(sub_graph_list); | ||||
| // when set incre build, save om model and var manager | |||||
| auto save_ret = SaveCacheAfterBuild(graph_node->GetGraphId(), merged_compute_graph, ge_model); | |||||
| if (save_ret != SUCCESS) { | |||||
| GELOGW("Fail to save cache."); | |||||
| } | |||||
| // release rts generate context | |||||
| RtContextUtil::GetInstance().DestroyrtContexts(); | |||||
| GE_TIMESTAMP_END(PreRun, "GraphManager::PreRun"); | GE_TIMESTAMP_END(PreRun, "GraphManager::PreRun"); | ||||
| GEEVENT("[GEPERFTRACE] GE PreRun End"); | GEEVENT("[GEPERFTRACE] GE PreRun End"); | ||||
| return ret; | return ret; | ||||
| @@ -444,10 +454,14 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| GeModelPtr ge_model = nullptr; | GeModelPtr ge_model = nullptr; | ||||
| ret = PreRun(graph_node, inputs, ge_models, ge_model, session_id); | |||||
| // check need incre build. | |||||
| ret = IncreBuild(graph_node, ge_model); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "PreRun Failed."); | |||||
| return ret; | |||||
| ret = PreRun(graph_node, inputs, ge_models, ge_model, session_id); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "PreRun Failed."); | |||||
| return ret; | |||||
| } | |||||
| } | } | ||||
| ret = LoadGraph(ge_model, graph_node); | ret = LoadGraph(ge_model, graph_node); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -492,6 +506,90 @@ Status GraphManager::LoadGraph(const GeModelPtr &ge_model, const GraphNodePtr &g | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphManager::LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, | |||||
| GeModelPtr &ge_model) { | |||||
| auto graph_id = graph_node->GetGraphId(); | |||||
| auto ret = cache_helper->LoadOmModelFromCache(ge_model); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGW("Fail to load om model from cache."); | |||||
| if (cache_helper->ClearCache(graph_id) != SUCCESS) { | |||||
| GELOGW("Fail to clear cache of graph %u.", graph_id); | |||||
| } | |||||
| return FAILED; | |||||
| } | |||||
| ret = cache_helper->RecoverVarManagerFromCache(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGW("Fail to recover VarManager from cache."); | |||||
| if (cache_helper->ClearCache(graph_id) != SUCCESS) { | |||||
| GELOGW("Fail to clear cache of graph %u.", graph_id); | |||||
| } | |||||
| return FAILED; | |||||
| } | |||||
| ComputeGraphPtr compute_graph_in_model = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | |||||
| if (compute_graph_in_model == nullptr) { | |||||
| GELOGW("Error occurred when get compute graph from om, abandon."); | |||||
| return FAILED; | |||||
| } else { | |||||
| graph_node->SetComputeGraph(compute_graph_in_model); | |||||
| graph_node->SetGeModel(ge_model); | |||||
| GELOGI("Load model and graph form cache om file."); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphManager::SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper) { | |||||
| auto ret = cache_helper->SaveCacheInfoToCache(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGW("Fail to save cache info of graph[%d] to cache.", graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| ret = cache_helper->SaveVarManagerToCache(true); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGW("Fail to save var manager to cache."); | |||||
| cache_helper->ClearCache(graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| GELOGI("Cache files have been saved."); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphManager::SaveCacheAfterBuild(uint32_t graph_id, ge::ComputeGraphPtr graph, GeModelPtr &ge_model) { | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if ((instance_ptr == nullptr) || !instance_ptr->InitFlag()) { | |||||
| GELOGW("GELib not initialized."); | |||||
| return FAILED; | |||||
| } | |||||
| if (instance_ptr->IsIncreBuild()) { | |||||
| auto iter = cache_helper_map_.find(graph_id); | |||||
| if (iter == cache_helper_map_.end()) { | |||||
| GELOGW("Can not find ModelCacheHelper of graph[%u]", graph_id); | |||||
| return FAILED; | |||||
| } else { | |||||
| ModelCacheHelperPtr cache_helper = iter->second; | |||||
| auto ret = cache_helper->RefreshComputeGraph(graph); | |||||
| if (ret != SUCCESS) { | |||||
| cache_helper->ClearCache(graph_id); | |||||
| GELOGW("Fail to refresh cache helper's compute graph"); | |||||
| return FAILED; | |||||
| } | |||||
| ret = cache_helper->SaveVarManagerToCache(false); | |||||
| if (ret != SUCCESS) { | |||||
| cache_helper->ClearCache(graph_id); | |||||
| GELOGW("Fail to save VarManager to cache"); | |||||
| return FAILED; | |||||
| } | |||||
| ret = cache_helper->SaveOmModelToCache(ge_model); | |||||
| if (ret != SUCCESS) { | |||||
| cache_helper->ClearCache(graph_id); | |||||
| GELOGW("Fail to save om model to cache"); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, | Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, | ||||
| const std::vector<GeTensor> &inputs, std::vector<GeTensor> &outputs) { | const std::vector<GeTensor> &inputs, std::vector<GeTensor> &outputs) { | ||||
| Status ret = graph_executor_.SetCondition(&sync_run_mutex_, &condition_, graph_run_listener_); | Status ret = graph_executor_.SetCondition(&sync_run_mutex_, &condition_, graph_run_listener_); | ||||
| @@ -551,6 +649,9 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vector<GeTenso | |||||
| GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "[RunGraph] compute_graph_tmp is NULL, graph id = %u.", graph_id); | GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "[RunGraph] compute_graph_tmp is NULL, graph id = %u.", graph_id); | ||||
| return GE_GRAPH_GRAPH_NODE_NULL;)) | return GE_GRAPH_GRAPH_NODE_NULL;)) | ||||
| // when set incre build, add cache helper map | |||||
| AddModelCacheHelperToMap(graph_id, session_id, compute_graph_tmp); | |||||
| std::vector<GeModelPtr> ge_models; | std::vector<GeModelPtr> ge_models; | ||||
| if (options_.local_fmk_op_flag) { | if (options_.local_fmk_op_flag) { | ||||
| @@ -583,7 +684,7 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vector<GeTenso | |||||
| if (!all_sub_graph.empty()) { | if (!all_sub_graph.empty()) { | ||||
| auto checkPointGraph = all_sub_graph[0]->GetSubGraph(); | auto checkPointGraph = all_sub_graph[0]->GetSubGraph(); | ||||
| if (IsCheckpointGraph(checkPointGraph)) { | if (IsCheckpointGraph(checkPointGraph)) { | ||||
| ret = CheckpointHandle(graph_id, outputs); | |||||
| ret = CheckpointHandle(graph_id, checkPointGraph, outputs); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "[RunGraph] CheckpointHandle failed!"); | GELOGE(ret, "[RunGraph] CheckpointHandle failed!"); | ||||
| } | } | ||||
| @@ -667,6 +768,15 @@ Status GraphManager::SaveParams(ge::GeModel &model, const std::string &type, con | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void GraphManager::RemoveModelCacheHelper(const GraphId &graph_id) { | |||||
| auto iter = cache_helper_map_.find(graph_id); | |||||
| if (iter != cache_helper_map_.end()) { | |||||
| cache_helper_map_.erase(iter); | |||||
| } else { | |||||
| GELOGW("[GraphManager] cache helper does not exist, graph_id = %u", graph_id); | |||||
| } | |||||
| } | |||||
| Status GraphManager::RemoveGraph(const GraphId &graph_id) { | Status GraphManager::RemoveGraph(const GraphId &graph_id) { | ||||
| auto it = graph_map_.find(graph_id); | auto it = graph_map_.find(graph_id); | ||||
| if (it == graph_map_.end()) { | if (it == graph_map_.end()) { | ||||
| @@ -716,6 +826,9 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { | |||||
| } | } | ||||
| var_acc_ctrl_.RemoveGraph(graph_id); | var_acc_ctrl_.RemoveGraph(graph_id); | ||||
| graph_map_.erase(it); | graph_map_.erase(it); | ||||
| RemoveModelCacheHelper(graph_id); | |||||
| auto ge_model = graph_node->GetGeModel(); | auto ge_model = graph_node->GetGeModel(); | ||||
| if (ge_model != nullptr) { | if (ge_model != nullptr) { | ||||
| GELOGI("Unload model %u.", ge_model->GetModelId()); | GELOGI("Unload model %u.", ge_model->GetModelId()); | ||||
| @@ -1106,21 +1219,15 @@ Status GraphManager::SummaryHandle(const GraphId &graph_id, std::vector<GeTensor | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphManager::CheckpointHandle(const GraphId &graph_id, const std::vector<GeTensor> &outputs) { | |||||
| Status GraphManager::CheckpointHandle(const GraphId &graph_id, const ComputeGraphPtr &compute_graph, | |||||
| const std::vector<GeTensor> &outputs) { | |||||
| GELOGI("[GraphManager] CheckpointHandle, outputsSize=%zu.", outputs.size()); | GELOGI("[GraphManager] CheckpointHandle, outputsSize=%zu.", outputs.size()); | ||||
| std::vector<InputOutputDescInfo> outputs_desc = graph_executor_.GetOutputsDesc(); | std::vector<InputOutputDescInfo> outputs_desc = graph_executor_.GetOutputsDesc(); | ||||
| GELOGI("[GraphManager] CheckpointHandle, outputsDescSize=%zu.", outputs_desc.size()); | GELOGI("[GraphManager] CheckpointHandle, outputsDescSize=%zu.", outputs_desc.size()); | ||||
| // find graph | |||||
| GraphNodePtr graph_node = nullptr; | |||||
| Status ret = GetGraphNode(graph_id, graph_node); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[CheckpointHandle] graph not exist, graph_id = %u.", graph_id); | |||||
| return ret; | |||||
| } | |||||
| ComputeGraphPtr compute_graph_ptr = GraphUtils::GetComputeGraph(*(graph_node->GetGraph())); | |||||
| std::map<string, Tensor> save_results; | std::map<string, Tensor> save_results; | ||||
| NodePtr netoutput = nullptr; | NodePtr netoutput = nullptr; | ||||
| for (const auto &node : compute_graph_ptr->GetDirectNode()) { | |||||
| for (const auto &node : compute_graph->GetDirectNode()) { | |||||
| if (node->GetType() == kNetOutput) { | if (node->GetType() == kNetOutput) { | ||||
| netoutput = node; | netoutput = node; | ||||
| break; | break; | ||||
| @@ -1248,6 +1355,8 @@ bool GraphManager::CheckTransOpForCheckpointGraph(NodePtr &node) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| static inline bool CheckConstanOpForCheckpointGraph(NodePtr &node) { return node->GetOutDataNodes().empty(); } | |||||
| bool GraphManager::IsCheckpointGraph(ComputeGraphPtr &compute_graph) { | bool GraphManager::IsCheckpointGraph(ComputeGraphPtr &compute_graph) { | ||||
| if (compute_graph == nullptr) { | if (compute_graph == nullptr) { | ||||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "[IsCheckpointGraph] computeGraph is nullptr."); | GELOGE(GE_GRAPH_PARAM_NULLPTR, "[IsCheckpointGraph] computeGraph is nullptr."); | ||||
| @@ -1268,6 +1377,10 @@ bool GraphManager::IsCheckpointGraph(ComputeGraphPtr &compute_graph) { | |||||
| if (!CheckTransOpForCheckpointGraph(node)) { | if (!CheckTransOpForCheckpointGraph(node)) { | ||||
| return false; | return false; | ||||
| } | } | ||||
| } else if (op->GetType() == CONSTANTOP) { | |||||
| if (!CheckConstanOpForCheckpointGraph(node)) { | |||||
| return false; | |||||
| } | |||||
| } else if (op->GetType() != kSend && op->GetType() != kRecv) { | } else if (op->GetType() != kSend && op->GetType() != kRecv) { | ||||
| GELOGI("this node is not allow in checkpoint sub graph, node_type: %s, node_name: %s.", op->GetType().c_str(), | GELOGI("this node is not allow in checkpoint sub graph, node_type: %s, node_name: %s.", op->GetType().c_str(), | ||||
| op->GetName().c_str()); | op->GetName().c_str()); | ||||
| @@ -1439,8 +1552,6 @@ Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_gra | |||||
| names_to_passes.emplace_back("ReshapeRemovePass", &trans_op_nearby_allreduce_fusion_pass); | names_to_passes.emplace_back("ReshapeRemovePass", &trans_op_nearby_allreduce_fusion_pass); | ||||
| ReshapeRemovePass reshape_remove_pass; | ReshapeRemovePass reshape_remove_pass; | ||||
| names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); | names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); | ||||
| ReplaceWithEmptyConstPass replace_with_empty_const_pass; | |||||
| names_to_passes.emplace_back("ReplaceWithEmptyConstPass", &replace_with_empty_const_pass); | |||||
| ConstantFoldingPass constant_folding_pass; | ConstantFoldingPass constant_folding_pass; | ||||
| names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | ||||
| DimensionAdjustPass dimension_adjust_pass; | DimensionAdjustPass dimension_adjust_pass; | ||||
| @@ -1632,6 +1743,51 @@ Status GraphManager::RunGraphAsync(const GraphId &graph_id, const std::vector<ge | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void GraphManager::AddModelCacheHelperToMap(const GraphId &graph_id, uint64_t session_id, | |||||
| ComputeGraphPtr &compute_graph) { | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr != nullptr && instance_ptr->IsIncreBuild()) { | |||||
| auto iter = cache_helper_map_.find(graph_id); | |||||
| if (iter == cache_helper_map_.end()) { | |||||
| ModelCacheHelperPtr cache_helper = MakeShared<ge::ModelCacheHelper>(session_id, graph_id, compute_graph); | |||||
| if (cache_helper != nullptr) { | |||||
| cache_helper_map_.emplace(std::make_pair(graph_id, cache_helper)); | |||||
| } else { | |||||
| GELOGW("Cache helper make shared failed, graph_id = %u.", graph_id); | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| Status GraphManager::IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model) { | |||||
| std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->IsIncreBuild()) { | |||||
| return FAILED; | |||||
| } | |||||
| const uint32_t graph_id = graph_node->GetGraphId(); | |||||
| auto iter = cache_helper_map_.find(graph_id); | |||||
| if (iter == cache_helper_map_.end()) { | |||||
| GELOGW("Can not find ModelCacheHelper of graph[%u]", graph_id); | |||||
| return FAILED; | |||||
| } | |||||
| ModelCacheHelperPtr cache_helper = iter->second; | |||||
| if (cache_helper->IsModelCacheHit()) { | |||||
| GEEVENT("Model cache hit."); | |||||
| Status ret = LoadFromCache(graph_node, cache_helper, ge_model); | |||||
| if (ret == SUCCESS) { | |||||
| return SUCCESS; | |||||
| } else { | |||||
| GELOGW("Error occurred when load from cache, abandon."); | |||||
| } | |||||
| } else { | |||||
| GEEVENT("Model cache miss."); | |||||
| } | |||||
| if (SaveCacheBeforeBuild(graph_node->GetGraphId(), cache_helper) != SUCCESS) { | |||||
| GELOGW("Error occurred when save cache."); | |||||
| } | |||||
| return FAILED; | |||||
| } | |||||
| void GraphManager::PreRunThread(GraphManager *graph_manager) { | void GraphManager::PreRunThread(GraphManager *graph_manager) { | ||||
| if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { | if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { | ||||
| GELOGW("Set thread name failed."); | GELOGW("Set thread name failed."); | ||||
| @@ -1685,6 +1841,8 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
| return; | return; | ||||
| } | } | ||||
| } | } | ||||
| // when set incre build, save cache helper. | |||||
| graph_manager->AddModelCacheHelperToMap(args.graph_id, args.session_id, compute_graph_tmp); | |||||
| std::vector<GeModelPtr> ge_models; | std::vector<GeModelPtr> ge_models; | ||||
| @@ -1707,12 +1865,15 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
| return; | return; | ||||
| } | } | ||||
| ret = graph_manager->PreRun(graph_node, ge_inputs, ge_models, ge_model, args.session_id); | |||||
| if (ret != SUCCESS) { | |||||
| graph_node->SetRunFlag(false); | |||||
| ReturnError(graph_manager, args.callback, ret, "PreRun failed, thread exit."); | |||||
| graph_node->Unlock(); | |||||
| return; | |||||
| // check need incre build. | |||||
| if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { | |||||
| ret = graph_manager->PreRun(graph_node, ge_inputs, ge_models, ge_model, args.session_id); | |||||
| if (ret != SUCCESS) { | |||||
| graph_node->SetRunFlag(false); | |||||
| ReturnError(graph_manager, args.callback, ret, "PreRun Failed, thread exit.."); | |||||
| graph_node->Unlock(); | |||||
| return; | |||||
| } | |||||
| } | } | ||||
| graph_node->SetBuildFlag(true); | graph_node->SetBuildFlag(true); | ||||
| graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); | graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); | ||||
| @@ -27,6 +27,7 @@ | |||||
| #include "common/blocking_queue.h" | #include "common/blocking_queue.h" | ||||
| #include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
| #include "common/helper/model_cache_helper.h" | |||||
| #include "external/graph/types.h" | #include "external/graph/types.h" | ||||
| #include "ge/ge_api_types.h" | #include "ge/ge_api_types.h" | ||||
| #include "graph/build/graph_builder.h" | #include "graph/build/graph_builder.h" | ||||
| @@ -211,7 +212,8 @@ class GraphManager { | |||||
| Status SummaryHandle(const GraphId &graph_id, std::vector<GeTensor> &outputs); | Status SummaryHandle(const GraphId &graph_id, std::vector<GeTensor> &outputs); | ||||
| Status CheckpointHandle(const GraphId &graph_id, const std::vector<GeTensor> &outputs); | |||||
| Status CheckpointHandle(const GraphId &graph_id, const ComputeGraphPtr &compute_graph, | |||||
| const std::vector<GeTensor> &outputs); | |||||
| // call the callback function of ME to push summary result data to ME | // call the callback function of ME to push summary result data to ME | ||||
| Status PushSummaryData2ME(const GraphId &graph_id, const std::map<std::string, ge::Tensor> &summary_data); | Status PushSummaryData2ME(const GraphId &graph_id, const std::map<std::string, ge::Tensor> &summary_data); | ||||
| @@ -260,6 +262,13 @@ class GraphManager { | |||||
| bool IsGraphNeedBuild(const GraphNodePtr &graph_node); | bool IsGraphNeedBuild(const GraphNodePtr &graph_node); | ||||
| Status LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, GeModelPtr &ge_model); | |||||
| Status SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper); | |||||
| Status SaveCacheAfterBuild(uint32_t graph_id, ComputeGraphPtr graph, GeModelPtr &ge_model); | |||||
| void AddModelCacheHelperToMap(const GraphId &graph_id, uint64_t session_id, ComputeGraphPtr &compute_graph); | |||||
| Status IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model); | |||||
| void RemoveModelCacheHelper(const GraphId &graph_id); | |||||
| static void PreRunThread(GraphManager *graph_manager); | static void PreRunThread(GraphManager *graph_manager); | ||||
| static void RunThread(GraphManager *graph_manager); | static void RunThread(GraphManager *graph_manager); | ||||
| static void StopQueue(GraphManager *graph_manager); | static void StopQueue(GraphManager *graph_manager); | ||||
| @@ -274,6 +283,8 @@ class GraphManager { | |||||
| std::map<GraphId, GraphNodePtr> graph_map_; | std::map<GraphId, GraphNodePtr> graph_map_; | ||||
| std::map<GraphId, ModelCacheHelperPtr> cache_helper_map_; | |||||
| // for run graph synchronous return | // for run graph synchronous return | ||||
| std::mutex sync_run_mutex_; | std::mutex sync_run_mutex_; | ||||
| std::condition_variable condition_; | std::condition_variable condition_; | ||||
| @@ -64,6 +64,10 @@ ge::Status VarResource::GetVarAddr(const std::string &var_name, const ge::GeTens | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void VarResource::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) { | |||||
| var_addr_mgr_map = var_addr_mgr_map_; | |||||
| } | |||||
| void VarResource::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, | void VarResource::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, | ||||
| rtMemType_t memory_type) { | rtMemType_t memory_type) { | ||||
| std::string var_key = VarKey(var_name, tensor_desc); | std::string var_key = VarKey(var_name, tensor_desc); | ||||
| @@ -170,6 +174,14 @@ void VarResource::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &b | |||||
| var_broad_cast_info_[graph_id][broad_cast_info.var_name] = broad_cast_info; | var_broad_cast_info_[graph_id][broad_cast_info.var_name] = broad_cast_info; | ||||
| } | } | ||||
| ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) { | |||||
| if (var_broad_cast_info_.count(graph_id) == 0 || var_broad_cast_info_[graph_id].count(var_name) == 0) { | |||||
| return FAILED; | |||||
| } | |||||
| broad_cast_info = var_broad_cast_info_[graph_id][var_name]; | |||||
| return SUCCESS; | |||||
| } | |||||
| ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | ||||
| const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr) { | const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr) { | ||||
| if (var_op_desc == nullptr) { | if (var_op_desc == nullptr) { | ||||
| @@ -282,11 +294,17 @@ Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uin | |||||
| // align 512 BYTE | // align 512 BYTE | ||||
| var_mem_size_ = var_mem_size_ + kSessionMemAlignSize; | var_mem_size_ = var_mem_size_ + kSessionMemAlignSize; | ||||
| GELOGI( | |||||
| "[IMAS]AssignVarMem Set session_%lu name[%s] output[%d]" | |||||
| "offset to [%zu] size[%lu] realsize[%lu].", | |||||
| session_id, var_name.c_str(), 0, mem_offset, (var_mem_size_ - mem_offset), real_size); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| int64_t MemResource::GetVarMemSize() const { return var_mem_size_; } | int64_t MemResource::GetVarMemSize() const { return var_mem_size_; } | ||||
| void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; }; | |||||
| VarManager::VarManager(uint64_t session_id) | VarManager::VarManager(uint64_t session_id) | ||||
| : version_(SessionVersion::OTHER_VERSION), | : version_(SessionVersion::OTHER_VERSION), | ||||
| session_id_(session_id), | session_id_(session_id), | ||||
| @@ -363,6 +381,21 @@ ge::Status VarManager::SetVarAddr(const std::string &var_name, const ge::GeTenso | |||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| ge::Status VarManager::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address, | |||||
| rtMemType_t memory_type) { | |||||
| GELOGI("VarManager::SaveVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(), | |||||
| ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(), | |||||
| ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str()); | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
| if (var_resource_ == nullptr) { | |||||
| GELOGW("VarManager has not been init."); | |||||
| return ge::INTERNAL_ERROR; | |||||
| } | |||||
| var_resource_->SaveVarAddr(var_name, tensor_desc, address, memory_type); | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, | ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, | ||||
| rtMemType_t &memory_type) { | rtMemType_t &memory_type) { | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| @@ -388,6 +421,10 @@ ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTenso | |||||
| return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type); | return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type); | ||||
| } | } | ||||
| void VarManager::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) { | |||||
| var_resource_->GetAllVarAddrMgr(var_addr_mgr_map); | |||||
| } | |||||
| int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { | int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| MemResource *mem_resource = nullptr; | MemResource *mem_resource = nullptr; | ||||
| @@ -405,14 +442,36 @@ int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { | |||||
| return mem_resource->GetVarMemSize(); | return mem_resource->GetVarMemSize(); | ||||
| } | } | ||||
| Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) { | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
| MemResource *mem_resource = nullptr; | |||||
| auto iter = mem_resource_map_.find(memory_type); | |||||
| if (iter == mem_resource_map_.end()) { | |||||
| mem_resource = new (std::nothrow) MemResource(); | |||||
| if (mem_resource == nullptr) { | |||||
| GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type); | |||||
| return ge::INTERNAL_ERROR; | |||||
| } else { | |||||
| mem_resource_map_[memory_type] = mem_resource; | |||||
| } | |||||
| } else { | |||||
| mem_resource = iter->second; | |||||
| } | |||||
| if (mem_resource == nullptr) { | |||||
| GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid."); | |||||
| return FAILED; | |||||
| } | |||||
| mem_resource->UpdateVarMemSize(mem_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, | ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, | ||||
| rtMemType_t memory_type) { | rtMemType_t memory_type) { | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| GELOGI( | |||||
| "VarManager::AssignVarMem var_name = %s, data_type = %s, data_format = " | |||||
| "%s.", | |||||
| var_name.c_str(), ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(), | |||||
| ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str()); | |||||
| GELOGI("VarManager::AssignVarMem var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(), | |||||
| ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(), | |||||
| ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str()); | |||||
| int64_t tensor_desc_size = 0; | int64_t tensor_desc_size = 0; | ||||
| size_t mem_offset = 0; | size_t mem_offset = 0; | ||||
| @@ -475,14 +534,13 @@ ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTen | |||||
| if (cur_tensor_desc.GetFormat() != tensor_desc.GetFormat() || | if (cur_tensor_desc.GetFormat() != tensor_desc.GetFormat() || | ||||
| cur_tensor_desc.GetDataType() != tensor_desc.GetDataType() || | cur_tensor_desc.GetDataType() != tensor_desc.GetDataType() || | ||||
| cur_tensor_desc.GetShape().GetDims() != tensor_desc.GetShape().GetDims()) { | cur_tensor_desc.GetShape().GetDims() != tensor_desc.GetShape().GetDims()) { | ||||
| GELOGI( | |||||
| "var %s assigned new memory (format, data type, shape) (%s, %s, " | |||||
| "%zu) from (%s, %s, %zu)", | |||||
| var_name.c_str(), ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(), | |||||
| ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(), tensor_desc.GetShape().GetDims().size(), | |||||
| ge::TypeUtils::DataTypeToSerialString(cur_tensor_desc.GetDataType()).c_str(), | |||||
| ge::TypeUtils::FormatToSerialString(cur_tensor_desc.GetFormat()).c_str(), | |||||
| cur_tensor_desc.GetShape().GetDims().size()); | |||||
| GELOGI("var %s assigned new memory (format, data type, shape) (%s, %s, %zu) from (%s, %s, %zu)", var_name.c_str(), | |||||
| ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(), | |||||
| ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(), | |||||
| tensor_desc.GetShape().GetDims().size(), | |||||
| ge::TypeUtils::DataTypeToSerialString(cur_tensor_desc.GetDataType()).c_str(), | |||||
| ge::TypeUtils::FormatToSerialString(cur_tensor_desc.GetFormat()).c_str(), | |||||
| cur_tensor_desc.GetShape().GetDims().size()); | |||||
| var_resource_->SetVarAddr(var_name, tensor_desc, | var_resource_->SetVarAddr(var_name, tensor_desc, | ||||
| reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(mem_offset)), memory_type); | reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(mem_offset)), memory_type); | ||||
| } | } | ||||
| @@ -550,6 +608,16 @@ ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastIn | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| ge::Status VarManager::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) { | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
| if (var_resource_ == nullptr) { | |||||
| GELOGW("VarManager has not been init."); | |||||
| return ge::INTERNAL_ERROR; | |||||
| } | |||||
| return var_resource_->GetBroadCastInfo(graph_id, var_name, broad_cast_info); | |||||
| } | |||||
| ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) { | ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) { | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str()); | GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str()); | ||||
| @@ -672,6 +740,7 @@ Status VarManager::SetMemoryMallocSize(const map<string, string> &options) { | |||||
| GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse graph memory manager malloc max size failed."); | GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse graph memory manager malloc max size failed."); | ||||
| return ge::GE_GRAPH_OPTIONS_INVALID; | return ge::GE_GRAPH_OPTIONS_INVALID; | ||||
| } | } | ||||
| GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_); | |||||
| } | } | ||||
| it = options.find(VARIABLE_MEMORY_MAX_SIZE); | it = options.find(VARIABLE_MEMORY_MAX_SIZE); | ||||
| @@ -101,6 +101,8 @@ class VarResource { | |||||
| ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, | ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, | ||||
| rtMemType_t &memory_type); | rtMemType_t &memory_type); | ||||
| void GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map); | |||||
| void SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, | void SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, | ||||
| rtMemType_t rtMemType_t); | rtMemType_t rtMemType_t); | ||||
| @@ -113,6 +115,8 @@ class VarResource { | |||||
| void SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | void SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | ||||
| ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | |||||
| ge::Status SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | ge::Status SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | ||||
| const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr); | const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr); | ||||
| @@ -175,6 +179,8 @@ class MemResource { | |||||
| int64_t GetVarMemSize() const; | int64_t GetVarMemSize() const; | ||||
| void UpdateVarMemSize(int64_t mem_size); | |||||
| private: | private: | ||||
| uint64_t total_size_; | uint64_t total_size_; | ||||
| uint64_t var_mem_size_; | uint64_t var_mem_size_; | ||||
| @@ -196,9 +202,14 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
| ge::Status SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, | ge::Status SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, | ||||
| rtMemType_t memory_type); | rtMemType_t memory_type); | ||||
| ge::Status SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address, | |||||
| rtMemType_t memory_type); | |||||
| ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, | ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, | ||||
| rtMemType_t &memory_type); | rtMemType_t &memory_type); | ||||
| void GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map); | |||||
| ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); | ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); | ||||
| ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, ge::ConstOpDescPtr var_op_desc, | ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, ge::ConstOpDescPtr var_op_desc, | ||||
| @@ -206,6 +217,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
| ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | ||||
| ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | |||||
| ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, ge::ConstOpDescPtr var_op_desc, | ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, ge::ConstOpDescPtr var_op_desc, | ||||
| uint8_t *base_ptr); | uint8_t *base_ptr); | ||||
| @@ -251,6 +264,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
| int64_t GetVarMemSize(rtMemType_t memory_type); | int64_t GetVarMemSize(rtMemType_t memory_type); | ||||
| Status UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size); | |||||
| bool IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc); | bool IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc); | ||||
| bool IsVarExist(const std::string &var_name); | bool IsVarExist(const std::string &var_name); | ||||
| @@ -238,6 +238,14 @@ Status ge::GraphPartitioner::MergeSubGraph(ge::ComputeGraphPtr &output_merged_co | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| GE_TIMESTAMP_END(MergeGraphTopologicalSorting, "GraphPartitioner::MergeGraphTopologicalSorting"); | GE_TIMESTAMP_END(MergeGraphTopologicalSorting, "GraphPartitioner::MergeGraphTopologicalSorting"); | ||||
| // flush all nodes' engine of merged graph | |||||
| GE_TIMESTAMP_START(MergeGraphEnginePlacerRun); | |||||
| graph_info_.engine_placer_.SetComputeGraph(output_merged_compute_graph); | |||||
| if (graph_info_.engine_placer_.Run() != SUCCESS) { | |||||
| GELOGE(GE_GRAPH_INIT_FAILED, "[GraphPartitioner]: engine_placer run failed"); | |||||
| return FAILED; | |||||
| } | |||||
| GE_TIMESTAMP_END(MergeGraphEnginePlacerRun, "GraphPartitioner::MergeGraphEnginePlacerRun"); | |||||
| GELOGI("Graph merge ends."); | GELOGI("Graph merge ends."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -200,7 +200,18 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { | |||||
| vector<OpInfo> op_info_vec = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); | vector<OpInfo> op_info_vec = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); | ||||
| for (const auto &op_info : op_info_vec) { | for (const auto &op_info : op_info_vec) { | ||||
| if (op_info.isAtomic) { | if (op_info.isAtomic) { | ||||
| GELOGI("Recognized atomic op %s from HCCL engine.", op_desc->GetName().c_str()); | |||||
| GELOGI("Recognized atomic op %s from DNN_HCCL engine.", op_desc->GetName().c_str()); | |||||
| // check peer input is DATA | |||||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
| if (in_data_anchor->GetPeerOutAnchor() != nullptr && | |||||
| in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() != nullptr) { | |||||
| auto peer_in_node = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); | |||||
| if (peer_in_node->GetType() == DATA) { | |||||
| GELOGI("Recognized atomic op %s from DNN_HCCL engine and input is DATA.", op_desc->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| } | |||||
| hcom_node_vec_.push_back(node); | hcom_node_vec_.push_back(node); | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -49,9 +49,11 @@ Status CastKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstG | |||||
| GELOGE(PARAM_INVALID, "Input const_weight_ptr is nullptr."); | GELOGE(PARAM_INVALID, "Input const_weight_ptr is nullptr."); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| const uint8_t *src_data = const_weight_ptr->GetData().data(); | const uint8_t *src_data = const_weight_ptr->GetData().data(); | ||||
| if (op_desc_ptr == nullptr || src_data == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr or src_data is nullptr."); | |||||
| // src_data == nullptr is supported | |||||
| if (op_desc_ptr == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr is nullptr."); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| GeTensorDesc op_desc = op_desc_ptr->GetOutputDesc(0); | GeTensorDesc op_desc = op_desc_ptr->GetOutputDesc(0); | ||||
| @@ -73,7 +75,7 @@ Status CastKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstG | |||||
| TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(), | TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(), | ||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
| GE_CHECK_SIZE(const_weight_ptr->GetData().GetSize()); | |||||
| // const_weight_ptr->GetData().GetSize() == 0 is supported | |||||
| auto src_data_size = src_shape.GetShapeSize(); | auto src_data_size = src_shape.GetShapeSize(); | ||||
| if (src_data_size == 0 && | if (src_data_size == 0 && | ||||
| static_cast<int>(const_weight_ptr->GetData().GetSize()) == GetSizeByDataType(src_data_type)) { | static_cast<int>(const_weight_ptr->GetData().GetSize()) == GetSizeByDataType(src_data_type)) { | ||||
| @@ -113,7 +115,6 @@ Status CastKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstG | |||||
| } | } | ||||
| if (output_ptr->SetData(trans_result.data.get(), trans_result.length) != SUCCESS) { | if (output_ptr->SetData(trans_result.data.get(), trans_result.length) != SUCCESS) { | ||||
| GELOGW("Compute: SetData failed"); | GELOGW("Compute: SetData failed"); | ||||
| return FAILED; | |||||
| } | } | ||||
| v_output.push_back(output_ptr); | v_output.push_back(output_ptr); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -113,12 +113,26 @@ bool KernelUtils::CheckSizeForTransOp(const ge::ConstGeTensorPtr &const_weight_p | |||||
| GELOGI("Const real value Size:%zu, op_desc Shape Size:%ld, data_type:%s.", data_size, cal_size, | GELOGI("Const real value Size:%zu, op_desc Shape Size:%ld, data_type:%s.", data_size, cal_size, | ||||
| TypeUtils::DataTypeToSerialString(data_type).c_str()); | TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
| if ((shape_size != 0) || (length != 0 && (data_size / static_cast<size_t>(length) != 1))) { | |||||
| if (!(data_size == static_cast<size_t>(cal_size) && data_size != 0)) { | |||||
| if (shape_size != 0) { | |||||
| // Standard tensor | |||||
| if (data_size != static_cast<size_t>(cal_size) || data_size == 0) { | |||||
| GELOGW("Const input data size is not equal with tensor desc shape"); | |||||
| return false; | |||||
| } | |||||
| } else if (data_shape.GetDimNum() != 0) { | |||||
| // Empty tensor, has zero in shape vector | |||||
| if (data_size != 0) { | |||||
| GELOGW("Const input data size is not equal with tensor desc shape"); | |||||
| return false; | |||||
| } | |||||
| } else { | |||||
| // Scalar tensor, has only one element in tensor | |||||
| if (length != 0 && (data_size / static_cast<size_t>(length) != 1)) { | |||||
| GELOGW("Const input data size is not equal with tensor desc shape"); | GELOGW("Const input data size is not equal with tensor desc shape"); | ||||
| return false; | return false; | ||||
| } | } | ||||
| } | } | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -29,6 +29,7 @@ namespace ge { | |||||
| class KernelUtils { | class KernelUtils { | ||||
| public: | public: | ||||
| KernelUtils() = delete; | KernelUtils() = delete; | ||||
| ~KernelUtils() = delete; | |||||
| static Status CheckDimensionNodeInfo(const NodePtr &node_ptr); | static Status CheckDimensionNodeInfo(const NodePtr &node_ptr); | ||||
| static bool CheckFormatSupported(const NodePtr &node_ptr); | static bool CheckFormatSupported(const NodePtr &node_ptr); | ||||
| static bool CheckSizeForTransOp(const ConstGeTensorPtr &const_weight_ptr, const OpDescPtr &op_desc_ptr); | static bool CheckSizeForTransOp(const ConstGeTensorPtr &const_weight_ptr, const OpDescPtr &op_desc_ptr); | ||||
| @@ -41,7 +42,7 @@ class KernelUtils { | |||||
| * @param [out] output the tensor for save sequence of numbers | * @param [out] output the tensor for save sequence of numbers | ||||
| * @author | * @author | ||||
| */ | */ | ||||
| template<typename T> | |||||
| template <typename T> | |||||
| static Status GenData(const int64_t data_num, const T value, const GeTensorPtr &output) { | static Status GenData(const int64_t data_num, const T value, const GeTensorPtr &output) { | ||||
| if (data_num > 0) { | if (data_num > 0) { | ||||
| if (!CheckInt64MulOverflow(data_num, static_cast<int64_t>(sizeof(T)))) { | if (!CheckInt64MulOverflow(data_num, static_cast<int64_t>(sizeof(T)))) { | ||||
| @@ -69,12 +70,12 @@ class KernelUtils { | |||||
| } | } | ||||
| /** | /** | ||||
| * Calculate dimension | |||||
| * @param [in] dims save the tensor of the dimension | |||||
| * @param [in] vec_dim results of each dimension | |||||
| * @param [out] data_num total size of data | |||||
| * @author | |||||
| */ | |||||
| * Calculate dimension | |||||
| * @param [in] dims save the tensor of the dimension | |||||
| * @param [in] vec_dim results of each dimension | |||||
| * @param [out] data_num total size of data | |||||
| * @author | |||||
| */ | |||||
| template <typename T> | template <typename T> | ||||
| static Status CalcDims(const ConstGeTensorPtr dims, std::vector<int64_t> &vec_dim, int64_t &data_num) { | static Status CalcDims(const ConstGeTensorPtr dims, std::vector<int64_t> &vec_dim, int64_t &data_num) { | ||||
| data_num = 1; | data_num = 1; | ||||
| @@ -67,8 +67,8 @@ Status PackKernel::ValidateKernelParams(const ge::OpDescPtr &op_desc_ptr, | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| if (!(AttrUtils::GetInt(op_desc_ptr, PACK_ATTR_NAME_NUM, n_))) { | if (!(AttrUtils::GetInt(op_desc_ptr, PACK_ATTR_NAME_NUM, n_))) { | ||||
| GELOGE(PARAM_INVALID, "Attr %s is not exist.", PACK_ATTR_NAME_NUM.c_str()); | |||||
| return PARAM_INVALID; | |||||
| n_ = 0; | |||||
| GELOGD("Attr %s is not set, default value %ld is used.", PACK_ATTR_NAME_NUM.c_str(), n_); | |||||
| } | } | ||||
| if (!(AttrUtils::GetInt(op_desc_ptr, ATTR_NAME_AXIS, axis_))) { | if (!(AttrUtils::GetInt(op_desc_ptr, ATTR_NAME_AXIS, axis_))) { | ||||
| GELOGE(PARAM_INVALID, "Attr %s is not exist.", ATTR_NAME_AXIS.c_str()); | GELOGE(PARAM_INVALID, "Attr %s is not exist.", ATTR_NAME_AXIS.c_str()); | ||||
| @@ -105,11 +105,7 @@ Status PackKernel::ValidateInputs(const ge::OpDescPtr &op_desc_ptr, const std::v | |||||
| GELOGW("Input %ld of pack kernel %s is null.", i, op_desc_ptr->GetName().c_str()); | GELOGW("Input %ld of pack kernel %s is null.", i, op_desc_ptr->GetName().c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| // check if tensor contains data | |||||
| if (input[i]->GetData().size() == 0) { | |||||
| GELOGW("Inputs %ld do not have value.", i); | |||||
| return NOT_CHANGED; | |||||
| } | |||||
| if (i == 0) { | if (i == 0) { | ||||
| // get first input shape | // get first input shape | ||||
| shape = input[0]->GetTensorDesc().GetShape(); | shape = input[0]->GetTensorDesc().GetShape(); | ||||
| @@ -127,8 +123,8 @@ Status PackKernel::ValidateInputs(const ge::OpDescPtr &op_desc_ptr, const std::v | |||||
| auto dst_shape = tensor_desc.GetShape(); | auto dst_shape = tensor_desc.GetShape(); | ||||
| int64_t num = 1; | int64_t num = 1; | ||||
| for (auto dim : dst_shape.GetDims()) { | for (auto dim : dst_shape.GetDims()) { | ||||
| if (dim < 1) { | |||||
| GELOGW("Invalid zero dim in the shape %s", formats::ShapeToString(shape).c_str()); | |||||
| if (dim < 0) { | |||||
| GELOGW("Invalid dim ld% in the shape %s", dim, formats::ShapeToString(shape).c_str()); | |||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| } | } | ||||
| num *= dim; | num *= dim; | ||||
| @@ -141,6 +137,12 @@ Status PackKernel::ValidateInputs(const ge::OpDescPtr &op_desc_ptr, const std::v | |||||
| GELOGW("Shape of input %ld is not equal wiht input 0.", i); | GELOGW("Shape of input %ld is not equal wiht input 0.", i); | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| } | } | ||||
| // check tensor data size is zero ot not | |||||
| if (input[i]->GetData().size() == 0 && num != 0) { | |||||
| GELOGW("Inputs %ld do not have value.", i); | |||||
| return NOT_CHANGED; | |||||
| } | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -167,6 +169,13 @@ void PackKernel::ExpandDims(const int64_t axis, const std::vector<ge::ConstGeTen | |||||
| Status PackKernel::CopyOutputData(const GeShape &final_shape, const std::vector<ge::ConstGeTensorPtr> &input, | Status PackKernel::CopyOutputData(const GeShape &final_shape, const std::vector<ge::ConstGeTensorPtr> &input, | ||||
| ge::GeTensorPtr &output_ptr) { | ge::GeTensorPtr &output_ptr) { | ||||
| output_ptr->MutableTensorDesc().SetShape(final_shape); | |||||
| output_ptr->MutableTensorDesc().SetDataType(DataType(data_type_)); | |||||
| if (final_shape.GetShapeSize() == 0 && final_shape.GetDims().size() != 0) { | |||||
| // means has zero in shape list, output tnesor data is []. | |||||
| return SUCCESS; | |||||
| } | |||||
| int64_t times = 1; | int64_t times = 1; | ||||
| int64_t unit = 1; | int64_t unit = 1; | ||||
| // calculate data unit | // calculate data unit | ||||
| @@ -210,8 +219,6 @@ Status PackKernel::CopyOutputData(const GeShape &final_shape, const std::vector< | |||||
| if (output_ptr->SetData(buf.get(), static_cast<size_t>(output_size * data_size)) != GRAPH_SUCCESS) { | if (output_ptr->SetData(buf.get(), static_cast<size_t>(output_size * data_size)) != GRAPH_SUCCESS) { | ||||
| GELOGW("CopyOutputData: SetData failed"); | GELOGW("CopyOutputData: SetData failed"); | ||||
| } | } | ||||
| output_ptr->MutableTensorDesc().SetShape(final_shape); | |||||
| output_ptr->MutableTensorDesc().SetDataType(DataType(data_type_)); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -63,10 +63,7 @@ Status ReduceProdKernel::ReduceProdCheck(const ge::OpDescPtr &op_desc_ptr, | |||||
| GELOGE(PARAM_INVALID, "Axis must be at most rank 1, node node: %s", op_desc_ptr->GetName().c_str()); | GELOGE(PARAM_INVALID, "Axis must be at most rank 1, node node: %s", op_desc_ptr->GetName().c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| if (data_tensor->GetData().size() == 0 || axis_tensor->GetData().size() == 0) { | |||||
| GELOGE(PARAM_INVALID, "ReduceProdKernel data size of inputs is 0, node node: %s", op_desc_ptr->GetName().c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| DataType data_type = data_tensor->GetTensorDesc().GetDataType(); | DataType data_type = data_tensor->GetTensorDesc().GetDataType(); | ||||
| if (kReduceProdSupportedType.find(data_type) == kReduceProdSupportedType.end()) { | if (kReduceProdSupportedType.find(data_type) == kReduceProdSupportedType.end()) { | ||||
| GELOGE(PARAM_INVALID, "ReduceProdKernel data type %s not support, node name: %s", | GELOGE(PARAM_INVALID, "ReduceProdKernel data type %s not support, node name: %s", | ||||
| @@ -151,7 +148,6 @@ Status ReduceProdKernel::DataCal(const std::vector<ge::ConstGeTensorPtr> &input, | |||||
| static_cast<size_t>(head_dim_ * end_dim_ * sizeof(int32_t))) != GRAPH_SUCCESS, | static_cast<size_t>(head_dim_ * end_dim_ * sizeof(int32_t))) != GRAPH_SUCCESS, | ||||
| GELOGW("set data failed"); | GELOGW("set data failed"); | ||||
| return INTERNAL_ERROR); | return INTERNAL_ERROR); | ||||
| output_ptr->MutableTensorDesc().SetDataType(data_dtype); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -260,19 +256,32 @@ Status ReduceProdKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vec | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| } | } | ||||
| } else if (input.at(kReduceProdAxisIndex)->GetData().size() == 0) { | |||||
| // axis tensor value is [], means no process for input | |||||
| output_ptr->MutableTensorDesc().SetShape(input.at(kReduceProdDataIndex)->GetTensorDesc().GetShape()); | |||||
| output_ptr->MutableTensorDesc().SetDataType(input.at(kReduceProdDataIndex)->GetTensorDesc().GetDataType()); | |||||
| if (output_ptr->SetData(input.at(kReduceProdDataIndex)->GetData()) != GRAPH_SUCCESS) { | |||||
| GELOGW("Compute: SetData failed"); | |||||
| } | |||||
| } else { | } else { | ||||
| // calculate axis to reduce | // calculate axis to reduce | ||||
| ret = AxisCal(input); | ret = AxisCal(input); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| return NOT_CHANGED; | return NOT_CHANGED; | ||||
| } | } | ||||
| // calculate data and data type | |||||
| ret = DataCal(input, output_ptr); | |||||
| if (ret != SUCCESS) { | |||||
| return NOT_CHANGED; | |||||
| } | |||||
| // calculate shape | |||||
| // calculate and set shape | |||||
| ShapeCal(op_desc_ptr, input, output_ptr); | ShapeCal(op_desc_ptr, input, output_ptr); | ||||
| // set data type | |||||
| output_ptr->MutableTensorDesc().SetDataType(input.at(kReduceProdDataIndex)->GetTensorDesc().GetDataType()); | |||||
| // data size == 0 means input tensor has zero in shape, and tensor value is []. | |||||
| if (input.at(kReduceProdDataIndex)->GetData().size() != 0) { | |||||
| // calculate data and data type | |||||
| ret = DataCal(input, output_ptr); | |||||
| if (ret != SUCCESS) { | |||||
| return NOT_CHANGED; | |||||
| } | |||||
| } | |||||
| } | } | ||||
| // print output tensor information, and will be deleted | // print output tensor information, and will be deleted | ||||
| @@ -48,8 +48,9 @@ Status TransdataKernel::ValidateInput(const OpDescPtr &op_desc_ptr, const std::v | |||||
| GELOGE(PARAM_INVALID, "Input const_weight_ptr is nullptr."); | GELOGE(PARAM_INVALID, "Input const_weight_ptr is nullptr."); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| const uint8_t *src_data = const_weight_ptr->GetData().data(); | |||||
| if (op_desc_ptr == nullptr || src_data == nullptr) { | |||||
| // src_data == nullptr is supported | |||||
| if (op_desc_ptr == nullptr) { | |||||
| GELOGE(PARAM_INVALID, "Input opDescPtr is nullptr."); | GELOGE(PARAM_INVALID, "Input opDescPtr is nullptr."); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -26,6 +26,7 @@ namespace ge { | |||||
| class PassUtils { | class PassUtils { | ||||
| public: | public: | ||||
| PassUtils() = delete; | PassUtils() = delete; | ||||
| ~PassUtils() = delete; | |||||
| static NodePtr GetInDataNode(const ConstNodePtr &node, int index); | static NodePtr GetInDataNode(const ConstNodePtr &node, int index); | ||||
| @@ -137,7 +137,7 @@ Status SwitchOpPass::ReplaceSwitchNode(ComputeGraphPtr &graph, NodePtr &switch_n | |||||
| NodePtr out_node = peer_in_anchor->GetOwnerNode(); | NodePtr out_node = peer_in_anchor->GetOwnerNode(); | ||||
| GE_CHK_STATUS_RET(GetOriginalType(out_node, type), "Get node type fail."); | GE_CHK_STATUS_RET(GetOriginalType(out_node, type), "Get node type fail."); | ||||
| if ((type == MERGE) || (type == REFMERGE)) { | if ((type == MERGE) || (type == REFMERGE)) { | ||||
| NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, peer_data_anchor); | |||||
| NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, peer_data_anchor, false); | |||||
| GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return FAILED, "Create memcpy_async node fail."); | GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return FAILED, "Create memcpy_async node fail."); | ||||
| GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, memcpy_node->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, memcpy_node->GetInDataAnchor(0)), | ||||
| "MemcpyAsync node add edge fail."); | "MemcpyAsync node add edge fail."); | ||||
| @@ -234,16 +234,18 @@ Status SwitchOpPass::ReplaceMergeNode(ComputeGraphPtr &graph, NodePtr &merge_nod | |||||
| need_label_nodes_.emplace_back(stream_merge); | need_label_nodes_.emplace_back(stream_merge); | ||||
| } | } | ||||
| bool multi_batch_flag = false; | |||||
| if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { | if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { | ||||
| if (!ge::AttrUtils::SetBool(op_desc, ATTR_INSERT_BY_MBATCH, true)) { | if (!ge::AttrUtils::SetBool(op_desc, ATTR_INSERT_BY_MBATCH, true)) { | ||||
| GELOGE(FAILED, "Set attr ATTR_INSERT_BY_MBATCH fail, StreamMerge:%s.", node_name.c_str()); | GELOGE(FAILED, "Set attr ATTR_INSERT_BY_MBATCH fail, StreamMerge:%s.", node_name.c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| multi_batch_flag = true; | |||||
| } | } | ||||
| (void)bypass_nodes_.insert(merge_node); | (void)bypass_nodes_.insert(merge_node); | ||||
| GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, stream_merge), "StreamMerge add memcpy node fail."); | |||||
| GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, stream_merge, multi_batch_flag), "StreamMerge add memcpy node fail."); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -302,17 +304,20 @@ NodePtr SwitchOpPass::CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodeP | |||||
| /// @brief Add MemcpyAsync Node | /// @brief Add MemcpyAsync Node | ||||
| /// @param [in] graph | /// @param [in] graph | ||||
| /// @param [in] in_node | /// @param [in] in_node | ||||
| /// @param [in] multi_batch_flag | |||||
| /// @return ge::NodePtr | /// @return ge::NodePtr | ||||
| /// | /// | ||||
| NodePtr SwitchOpPass::CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) { | |||||
| NodePtr SwitchOpPass::CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, | |||||
| bool multi_batch_flag) { | |||||
| GE_CHK_BOOL_EXEC(out_data_anchor != nullptr, return nullptr, "Param of input node is null."); | GE_CHK_BOOL_EXEC(out_data_anchor != nullptr, return nullptr, "Param of input node is null."); | ||||
| OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); | OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); | ||||
| GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); | GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); | ||||
| std::string node_name = pre_op_desc->GetName() + "_" + MEMCPYASYNC; | |||||
| std::string memcpy_type = multi_batch_flag ? MEMCPYADDRASYNC : MEMCPYASYNC; | |||||
| std::string node_name = pre_op_desc->GetName() + "_" + memcpy_type; | |||||
| node_name = CheckDuplicateName(node_name); | node_name = CheckDuplicateName(node_name); | ||||
| GELOGI("Create MemcpyAsync op:%s.", node_name.c_str()); | GELOGI("Create MemcpyAsync op:%s.", node_name.c_str()); | ||||
| OpDescPtr op_desc = MakeShared<OpDesc>(node_name, MEMCPYASYNC); | |||||
| OpDescPtr op_desc = MakeShared<OpDesc>(node_name, memcpy_type); | |||||
| if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
| GELOGE(FAILED, "Create op_desc fail, MemcpyAsync:%s.", node_name.c_str()); | GELOGE(FAILED, "Create op_desc fail, MemcpyAsync:%s.", node_name.c_str()); | ||||
| return nullptr; | return nullptr; | ||||
| @@ -432,9 +437,10 @@ NodePtr SwitchOpPass::CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node) { | |||||
| /// @brief Add MemcpyAsync Op as StreamMerge in_node | /// @brief Add MemcpyAsync Op as StreamMerge in_node | ||||
| /// @param [in] graph | /// @param [in] graph | ||||
| /// @param [in] node | /// @param [in] node | ||||
| /// @param [in] multi_batch_flag | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node) { | |||||
| Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node, bool multi_batch_flag) { | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); | GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); | ||||
| for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
| @@ -447,7 +453,7 @@ Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node) | |||||
| continue); | continue); | ||||
| GE_IF_BOOL_EXEC(type != MEMCPYASYNC, { | GE_IF_BOOL_EXEC(type != MEMCPYASYNC, { | ||||
| in_node = CreateMemcpyAsyncNode(graph, peer_out_anchor); | |||||
| in_node = CreateMemcpyAsyncNode(graph, peer_out_anchor, multi_batch_flag); | |||||
| GE_CHK_BOOL_EXEC(in_node != nullptr, return FAILED, "Create MemcpyAsync node fail."); | GE_CHK_BOOL_EXEC(in_node != nullptr, return FAILED, "Create MemcpyAsync node fail."); | ||||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "MemcpyAsync node remove edge fail."); | GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "MemcpyAsync node remove edge fail."); | ||||
| GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, in_node->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, in_node->GetInDataAnchor(0)), | ||||
| @@ -103,13 +103,13 @@ class SwitchOpPass : public GraphPass { | |||||
| NodePtr CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix, | NodePtr CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix, | ||||
| OutDataAnchorPtr &peer_cond_anchor); | OutDataAnchorPtr &peer_cond_anchor); | ||||
| NodePtr CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor); | |||||
| NodePtr CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); | |||||
| Status CombineSwitchNode(ComputeGraphPtr &graph); | Status CombineSwitchNode(ComputeGraphPtr &graph); | ||||
| NodePtr CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node); | NodePtr CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node); | ||||
| Status AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &stream_merge_node); | |||||
| Status AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &stream_merge_node, bool multi_batch_flag); | |||||
| Status BypassSwitchNode(NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, OutDataAnchorPtr &peer_cond_anchor); | Status BypassSwitchNode(NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, OutDataAnchorPtr &peer_cond_anchor); | ||||
| @@ -22,11 +22,14 @@ | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "external/graph/graph.h" | #include "external/graph/graph.h" | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/common/omg_util.h" | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| std::map<std::string, std::map<int, int>> VariablePrepareOpPass::ref_node_without_prototype_map_{ | |||||
| {REFSWITCH, {{0, 0}, {0, 1}}}}; | |||||
| Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { | Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { | ||||
| GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
| for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
| @@ -43,9 +46,7 @@ Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { | |||||
| for (auto &node : graph->GetDirectNode()) { | for (auto &node : graph->GetDirectNode()) { | ||||
| GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | ||||
| bool is_variable = node->GetOpDesc()->GetType() == VARIABLE; | |||||
| bool is_deal = has_dealed_variable_.find(node->GetName()) == has_dealed_variable_.end(); | |||||
| if (is_variable && is_deal) { | |||||
| if (node->GetOpDesc()->GetType() == VARIABLE) { | |||||
| Status ret = DealVariableNode(node); | Status ret = DealVariableNode(node); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "variable add back edge failed"); | GELOGE(ret, "variable add back edge failed"); | ||||
| @@ -149,7 +150,7 @@ NodePtr VariablePrepareOpPass::GetFinalWritableNode(ge::NodePtr &writable_node, | |||||
| } | } | ||||
| } | } | ||||
| if (!found_writeable_node) { | if (!found_writeable_node) { | ||||
| GELOGI("final writable node is %s", current_node->GetName().c_str()); | |||||
| GELOGD("final writable node is %s", current_node->GetName().c_str()); | |||||
| return current_node; | return current_node; | ||||
| } | } | ||||
| } | } | ||||
| @@ -159,53 +160,54 @@ Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, g | |||||
| GE_CHECK_NOTNULL(final_writable_node); | GE_CHECK_NOTNULL(final_writable_node); | ||||
| GE_CHECK_NOTNULL(var_node); | GE_CHECK_NOTNULL(var_node); | ||||
| NodePtr var_ref_node = CreatVariableRef(final_writable_node, var_node); | |||||
| GE_CHECK_NOTNULL(var_ref_node); | |||||
| // add control anchor between var_ref_node and final peer node | |||||
| // var_ref_node need to execute before other nodes | |||||
| if (final_writable_node->GetType() == FRAMEWORKOP) { | |||||
| GELOGD("No need to add variable_ref for frameworkop"); | |||||
| return SUCCESS; | |||||
| } | |||||
| std::stringstream variable_ref_name; | |||||
| variable_ref_name << "_TO_" << final_writable_node->GetName() << "_REF_" << index; | |||||
| ge::NodePtr find_node = var_node->GetOwnerComputeGraph()->FindNode(var_node->GetName() + variable_ref_name.str()); | |||||
| if (find_node != nullptr) { | |||||
| GELOGD("The corresponding variable_ref [%s] has been added to this connection.", find_node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| NodePtr variable_ref_node = CreatVariableRef(var_node->GetName() + variable_ref_name.str(), var_node); | |||||
| GELOGI("Add variable_ref between [%s] and [%s]", var_node->GetName().c_str(), variable_ref_node->GetName().c_str()); | |||||
| GE_CHECK_NOTNULL(variable_ref_node); | |||||
| // add control anchor between variable_ref and final peer node | |||||
| // variable_ref_node need to execute before other nodes | |||||
| auto final_writable_outAnchors = final_writable_node->GetAllOutAnchors(); | auto final_writable_outAnchors = final_writable_node->GetAllOutAnchors(); | ||||
| for (auto &final_writable_outAnchor : final_writable_outAnchors) { | for (auto &final_writable_outAnchor : final_writable_outAnchors) { | ||||
| GE_CHECK_NOTNULL(final_writable_outAnchor); | GE_CHECK_NOTNULL(final_writable_outAnchor); | ||||
| for (auto &final_writable_peerAnchor : final_writable_outAnchor->GetPeerAnchors()) { | for (auto &final_writable_peerAnchor : final_writable_outAnchor->GetPeerAnchors()) { | ||||
| GE_CHECK_NOTNULL(final_writable_peerAnchor); | GE_CHECK_NOTNULL(final_writable_peerAnchor); | ||||
| NodePtr peer_node = final_writable_peerAnchor->GetOwnerNode(); | NodePtr peer_node = final_writable_peerAnchor->GetOwnerNode(); | ||||
| graphStatus ret = ge::GraphUtils::AddEdge(var_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()); | |||||
| graphStatus ret = | |||||
| ge::GraphUtils::AddEdge(variable_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()); | |||||
| if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
| GELOGE(FAILED, "add control anchor between var_ref_node and final_writable peer_node failed"); | |||||
| GELOGE(FAILED, "add control anchor between variable_ref and final_writable peer node failed"); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| // add edge final node:index ---> var_ref_node:0 | |||||
| graphStatus ret = | graphStatus ret = | ||||
| ge::GraphUtils::AddEdge(final_writable_node->GetOutDataAnchor(index), var_ref_node->GetInDataAnchor(0)); | |||||
| ge::GraphUtils::AddEdge(final_writable_node->GetOutDataAnchor(index), variable_ref_node->GetInDataAnchor(0)); | |||||
| if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
| GELOGE(FAILED, "add data anchor between var_ref_node and final_writable peer_node failed"); | |||||
| GELOGE(FAILED, "add data anchor between variable_ref and final_writable peer node failed"); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| ge::NodePtr VariablePrepareOpPass::CreatVariableRef(ge::NodePtr &final_writable_node, ge::NodePtr &var_node) { | |||||
| if ((final_writable_node == nullptr) || (var_node == nullptr) || (var_node->GetOwnerComputeGraph() == nullptr)) { | |||||
| GELOGE(FAILED, "parameter ptr is null."); | |||||
| return nullptr; | |||||
| } | |||||
| GELOGD("Create VarRef Op: final_writable_node: [%s] var_node: [%s]>>>>", final_writable_node->GetName().c_str(), | |||||
| var_node->GetName().c_str()); | |||||
| static uint32_t var_ref_count = 0; | |||||
| std::stringstream var_ref_name; | |||||
| var_ref_name << "_to_" << final_writable_node->GetName() << "_REF_" << var_ref_count++; | |||||
| ge::NodePtr VariablePrepareOpPass::CreatVariableRef(const std::string &variable_ref_name, ge::NodePtr &var_node) { | |||||
| OpDescPtr var_op_desc = var_node->GetOpDesc(); | OpDescPtr var_op_desc = var_node->GetOpDesc(); | ||||
| if (var_op_desc == nullptr) { | if (var_op_desc == nullptr) { | ||||
| GELOGE(FAILED, "get var opdesc is nullptr"); | GELOGE(FAILED, "get var opdesc is nullptr"); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| OpDescPtr var_ref_op_desc = | |||||
| MakeShared<OpDesc>(var_node->GetName() + var_ref_name.str().c_str(), var_op_desc->GetType()); | |||||
| OpDescPtr var_ref_op_desc = MakeShared<OpDesc>(variable_ref_name.c_str(), var_op_desc->GetType()); | |||||
| if (var_ref_op_desc == nullptr) { | if (var_ref_op_desc == nullptr) { | ||||
| GELOGE(FAILED, "var_ref opdesc is nullptr"); | GELOGE(FAILED, "var_ref opdesc is nullptr"); | ||||
| return nullptr; | return nullptr; | ||||
| @@ -217,15 +219,15 @@ ge::NodePtr VariablePrepareOpPass::CreatVariableRef(ge::NodePtr &final_writable_ | |||||
| GE_IF_BOOL_EXEC(var_ref_op_desc->AddInputDesc(var_op_desc->GetOutputDesc(0)) != SUCCESS, | GE_IF_BOOL_EXEC(var_ref_op_desc->AddInputDesc(var_op_desc->GetOutputDesc(0)) != SUCCESS, | ||||
| GELOGW("add input desc edge failed"); | GELOGW("add input desc edge failed"); | ||||
| return nullptr); | return nullptr); | ||||
| NodePtr var_ref_node = var_node->GetOwnerComputeGraph()->AddNode(var_ref_op_desc); | |||||
| GE_IF_BOOL_EXEC(var_ref_node == nullptr, GELOGW("var_ref_node is null"); return nullptr); | |||||
| has_dealed_variable_.insert(var_node->GetName()); | |||||
| NodePtr variable_ref_node = var_node->GetOwnerComputeGraph()->AddNode(var_ref_op_desc); | |||||
| GE_IF_BOOL_EXEC(variable_ref_node == nullptr, GELOGW("variable_ref_node is null"); return nullptr); | |||||
| bool is_set_str = ge::AttrUtils::SetStr(var_ref_op_desc, REF_VAR_SRC_VAR_NAME, var_op_desc->GetName()); | bool is_set_str = ge::AttrUtils::SetStr(var_ref_op_desc, REF_VAR_SRC_VAR_NAME, var_op_desc->GetName()); | ||||
| if (is_set_str) { | if (is_set_str) { | ||||
| GELOGD("Set node [%s] REF_VAR_SRC_VAR_NAME [%s]", var_ref_node->GetName().c_str(), var_op_desc->GetName().c_str()); | |||||
| GELOGD("Set node [%s] REF_VAR_SRC_VAR_NAME [%s]", variable_ref_node->GetName().c_str(), | |||||
| var_op_desc->GetName().c_str()); | |||||
| } | } | ||||
| return var_ref_node; | |||||
| return variable_ref_node; | |||||
| } | } | ||||
| int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int input_index) { | int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int input_index) { | ||||
| @@ -240,16 +242,13 @@ int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int inpu | |||||
| } | } | ||||
| } | } | ||||
| auto node_iter = ref_input_output_map_.find(node_type); | |||||
| if (node_iter == ref_input_output_map_.end()) { | |||||
| return -1; | |||||
| } | |||||
| auto index_iter = node_iter->second.find(input_index); | |||||
| if (index_iter == node_iter->second.end()) { | |||||
| return -1; | |||||
| if (node_type == FRAMEWORKOP) { | |||||
| std::string original_type; | |||||
| GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, GELOGW("Get node original type fail")); | |||||
| GELOGI("find frameworkop: [%s], original type is %s", node->GetName().c_str(), original_type.c_str()); | |||||
| return FindRefOutIndex(original_type, input_index, ref_node_without_prototype_map_); | |||||
| } | } | ||||
| return index_iter->second; | |||||
| return FindRefOutIndex(node_type, input_index, ref_input_output_map_); | |||||
| } | } | ||||
| void VariablePrepareOpPass::GenerateRefTypeAndInputOutputMap(const NodePtr &node) { | void VariablePrepareOpPass::GenerateRefTypeAndInputOutputMap(const NodePtr &node) { | ||||
| @@ -301,4 +300,18 @@ Status VariablePrepareOpPass::UpdateAssignOpDesc(const ge::NodePtr &node) { | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| int VariablePrepareOpPass::FindRefOutIndex(const std::string &node_type, int input_index, | |||||
| const std::map<std::string, std::map<int, int>> &ref_map) { | |||||
| auto node_iter = ref_map.find(node_type); | |||||
| if (node_iter == ref_map.end()) { | |||||
| return -1; | |||||
| } | |||||
| auto index_iter = node_iter->second.find(input_index); | |||||
| if (index_iter == node_iter->second.end()) { | |||||
| return -1; | |||||
| } | |||||
| return index_iter->second; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||