@@ -176,7 +176,7 @@ cd ${BASEPATH} | |||||
mkdir -p output/plugin/nnengine/ge_config/ | mkdir -p output/plugin/nnengine/ge_config/ | ||||
find output/ -name graphengine_lib.tar -exec rm {} \; | find output/ -name graphengine_lib.tar -exec rm {} \; | ||||
cp src/ge/engine_manager/engine_conf.json output/plugin/nnengine/ge_config/ | cp src/ge/engine_manager/engine_conf.json output/plugin/nnengine/ge_config/ | ||||
find output/ -maxdepth 1 -name libengine.so -exec mv {} output/plugin/nnengine/ \; | |||||
find output/ -maxdepth 1 -name libengine.so -exec mv -f {} output/plugin/nnengine/ \; | |||||
tar -cf graphengine_lib.tar output/* | tar -cf graphengine_lib.tar output/* | ||||
mv -f graphengine_lib.tar output | mv -f graphengine_lib.tar output | ||||
echo "---------------- GraphEngine package archive generated ----------------" | echo "---------------- GraphEngine package archive generated ----------------" |
@@ -0,0 +1,36 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef COMPRESS_H | |||||
#define COMPRESS_H | |||||
#include <uchar.h> | |||||
enum CmpStatus { RET_SUCCESS = 0, RET_ERROR = -1 }; | |||||
struct CompressConfig { | |||||
size_t inputSize; // length of data to compress | |||||
size_t engineNum; // how many decompress engines | |||||
size_t maxRatio; // how much size of a basic compression block, only 64 supported now (8x: 64 4x: 32) | |||||
size_t channel; // channels of L2 or DDR. For load balance | |||||
size_t fractalSize; // size of compressing block | |||||
bool isTight; // whether compose compressed data tightly | |||||
}; | |||||
CmpStatus CompressWeights(char* input, const CompressConfig& compressConfig, char* indexs, char* output, | |||||
size_t& compressedLength); | |||||
#endif // COMPRESS_H |
@@ -0,0 +1,97 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef PLATFORM_INFO_H | |||||
#define PLATFORM_INFO_H | |||||
#include <map> | |||||
#include <string> | |||||
#include <vector> | |||||
#include "platform_info_def.h" | |||||
using std::map; | |||||
using std::string; | |||||
using std::vector; | |||||
namespace fe { | |||||
class PlatformInfoManager { | |||||
public: | |||||
PlatformInfoManager(const PlatformInfoManager &) = delete; | |||||
PlatformInfoManager &operator=(const PlatformInfoManager &) = delete; | |||||
static PlatformInfoManager &Instance(); | |||||
uint32_t InitializePlatformInfo(); | |||||
uint32_t Finalize(); | |||||
uint32_t GetPlatformInfo(const string SoCVersion, PlatformInfo &platformInfo, OptionalInfo &optiCompilationInfo); | |||||
void SetOptionalCompilationInfo(OptionalInfo &optiCompilationInfo); | |||||
private: | |||||
PlatformInfoManager(); | |||||
~PlatformInfoManager(); | |||||
uint32_t LoadIniFile(string iniFileRealPath); | |||||
void Trim(string &str); | |||||
uint32_t LoadConfigFile(string realPath); | |||||
string RealPath(const std::string &path); | |||||
string GetSoFilePath(); | |||||
void ParseVersion(map<string, string> &versionMap, string &socVersion, PlatformInfo &platformInfoTemp); | |||||
void ParseSocInfo(map<string, string> &socInfoMap, PlatformInfo &platformInfoTemp); | |||||
void ParseCubeOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
void ParseBufferOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
void ParseUBOfAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
void ParseAICoreSpec(map<string, string> &aiCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
void ParseBufferOfAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||||
void ParseAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||||
void ParseUBOfAICoreMemoryRates(map<string, string> &aiCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||||
void ParseAICoreintrinsicDtypeMap(map<string, string> &aiCoreintrinsicDtypeMap, PlatformInfo &platformInfoTemp); | |||||
void ParseVectorCoreSpec(map<string, string> &vectorCoreSpecMap, PlatformInfo &platformInfoTemp); | |||||
void ParseVectorCoreMemoryRates(map<string, string> &vectorCoreMemoryRatesMap, PlatformInfo &platformInfoTemp); | |||||
void ParseVectorCoreintrinsicDtypeMap(map<string, string> &vectorCoreintrinsicDtypeMap, | |||||
PlatformInfo &platformInfoTemp); | |||||
uint32_t ParsePlatformInfoFromStrToStruct(map<string, map<string, string>> &contentInfoMap, string &socVersion, | |||||
PlatformInfo &platformInfoTemp); | |||||
uint32_t AssemblePlatformInfoVector(map<string, map<string, string>> &contentInfoMap); | |||||
private: | |||||
bool initFlag_; | |||||
map<string, PlatformInfo> platformInfoMap_; | |||||
OptionalInfo optiCompilationInfo_; | |||||
}; | |||||
} // namespace fe | |||||
#endif |
@@ -0,0 +1,122 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef PLATFORM_INFO_DEF_H | |||||
#define PLATFORM_INFO_DEF_H | |||||
#include <map> | |||||
#include <string> | |||||
#include <vector> | |||||
using std::map; | |||||
using std::string; | |||||
using std::vector; | |||||
namespace fe { | |||||
enum MemoryType { DDR = 0, HBM }; | |||||
enum L2Type { Cache = 0, Buff }; | |||||
typedef struct tagStrInfo { | |||||
string aicVersion; | |||||
string ccecAICVersion; | |||||
string ccecAIVVersion; | |||||
string isSupportAIcpuCompiler; | |||||
} StrInfo; | |||||
typedef struct tagSoCInfo { | |||||
uint32_t aiCoreCnt; | |||||
uint32_t vectorCoreCnt; | |||||
uint32_t aiCpuCnt; | |||||
MemoryType memoryType; | |||||
uint64_t memorySize; | |||||
L2Type l2Type; | |||||
uint64_t l2Size; | |||||
uint32_t l2PageNum; | |||||
} SoCInfo; | |||||
typedef struct tagAiCoreSpec { | |||||
double cubeFreq; | |||||
uint64_t cubeMSize; | |||||
uint64_t cubeNSize; | |||||
uint64_t cubeKSize; | |||||
uint64_t vecCalcSize; | |||||
uint64_t l0ASize; | |||||
uint64_t l0BSize; | |||||
uint64_t l0CSize; | |||||
uint64_t l1Size; | |||||
uint64_t smaskBuffer; | |||||
uint64_t ubSize; | |||||
uint64_t ubblockSize; | |||||
uint64_t ubbankSize; | |||||
uint64_t ubbankNum; | |||||
uint64_t ubburstInOneBlock; | |||||
uint64_t ubbankGroupNum; | |||||
} AiCoreSpec; | |||||
typedef struct tagAiCoreMemoryRates { | |||||
double ddrRate; | |||||
double l2Rate; | |||||
double l2ReadRate; | |||||
double l2WriteRate; | |||||
double l1ToL0ARate; | |||||
double l1ToL0BRate; | |||||
double l1ToUBRate; | |||||
double l0CToUBRate; | |||||
double ubToL2Rate; | |||||
double ubToDdrRate; | |||||
double ubToL1Rate; | |||||
} AiCoreMemoryRates; | |||||
typedef struct tagVectorCoreSpec { | |||||
uint64_t vecCalcSize; | |||||
uint64_t smaskBuffer; | |||||
uint64_t ubSize; | |||||
uint64_t ubblockSize; | |||||
uint64_t ubbankSize; | |||||
uint64_t ubbankNum; | |||||
uint64_t ubburstInOneBlock; | |||||
uint64_t ubbankGroupNum; | |||||
} VectorCoreSpec; | |||||
typedef struct tagVectorCoreMemoryRates { | |||||
double ddrRate; | |||||
double l2Rate; | |||||
double l2ReadRate; | |||||
double l2WriteRate; | |||||
double ubToL2Rate; | |||||
double ubToDdrRate; | |||||
} VectorCoreMemoryRates; | |||||
typedef struct tagPlatformInfo { | |||||
StrInfo strInfo; | |||||
SoCInfo socInfo; | |||||
AiCoreSpec aiCoreSpec; | |||||
AiCoreMemoryRates aiCoreMemoryRates; | |||||
map<string, vector<string>> aiCoreIntrinsicDtypeMap; | |||||
VectorCoreSpec vectorCoreSpec; | |||||
VectorCoreMemoryRates vectorCoreMemoryRates; | |||||
map<string, vector<string>> vectorCoreIntrinsicDtypeMap; | |||||
} PlatformInfo; | |||||
typedef struct tagOptionalInfo { | |||||
string socVersion; | |||||
string coreType; | |||||
uint32_t aiCoreNum; | |||||
string l1FusionFlag; | |||||
} OptionalInfo; | |||||
} // namespace fe | |||||
#endif |
@@ -40,6 +40,8 @@ const char *const OPTION_EXEC_EXTERN_PLUGIN_PATH = "ge.soLoadPath"; | |||||
const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | const char *const OPTION_EXEC_ENABLE_DUMP = "ge.exec.enableDump"; | ||||
const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | const char *const OPTION_EXEC_DUMP_PATH = "ge.exec.dumpPath"; | ||||
const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | const char *const OPTION_EXEC_DUMP_STEP = "ge.exec.dumpStep"; | ||||
const char *const OPTION_EXEC_ENABLE_INCRE_BUILD = "ge.exec.enableIncreBuild"; | |||||
const char *const OPTION_EXEC_INCRE_BUILD_CACHE_PATH = "ge.exec.increBuildCachePath"; | |||||
// Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | // Hccl flag, if ge.exec.hcclFlag =1, it means load plugin for opskernel, else:ge.exec.hcclFlag =0 | ||||
const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | ||||
const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | ||||
@@ -116,27 +116,5 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY OpReceiver { | |||||
namespace ge { | namespace ge { | ||||
using OpRegistrationData = domi::OpRegistrationData; | using OpRegistrationData = domi::OpRegistrationData; | ||||
using OpReceiver = domi::OpReceiver; | using OpReceiver = domi::OpReceiver; | ||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOp { | |||||
public: | |||||
HostCpuOp() = default; | |||||
virtual ~HostCpuOp() = default; | |||||
virtual graphStatus Compute(Operator &op, const std::map<std::string, const Tensor> &inputs, | |||||
std::map<std::string, Tensor> &outputs) = 0; | |||||
}; | |||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY HostCpuOpRegistrar { | |||||
public: | |||||
HostCpuOpRegistrar(const char *op_type, HostCpuOp *(*create_fn)()); | |||||
}; | |||||
#define REGISTER_HOST_CPU_OP_BUILDER(name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(__COUNTER__, name, op) | |||||
#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ_HELPER(ctr, name, op) REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) | |||||
#define REGISTER_HOST_CPU_OP_BUILDER_UNIQ(ctr, name, op) \ | |||||
static ::ge::HostCpuOpRegistrar register_host_cpu_op##ctr __attribute__((unused)) = \ | |||||
::ge::HostCpuOpRegistrar(name, []() -> ::ge::HostCpuOp * { return new (std::nothrow) op(); }) | |||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_EXTERNAL_REGISTER_REGISTER_H_ | #endif // INC_EXTERNAL_REGISTER_REGISTER_H_ |
@@ -434,6 +434,7 @@ REGISTER_OPTYPE_DECLARE(STREAMSWITCH, "StreamSwitch"); | |||||
REGISTER_OPTYPE_DECLARE(STREAMSWITCHN, "StreamSwitchN"); | REGISTER_OPTYPE_DECLARE(STREAMSWITCHN, "StreamSwitchN"); | ||||
REGISTER_OPTYPE_DECLARE(STREAMACTIVE, "StreamActive"); | REGISTER_OPTYPE_DECLARE(STREAMACTIVE, "StreamActive"); | ||||
REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); | REGISTER_OPTYPE_DECLARE(MEMCPYASYNC, "MemcpyAsync"); | ||||
REGISTER_OPTYPE_DECLARE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | |||||
REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | ||||
REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | ||||
REGISTER_OPTYPE_DECLARE(SEND, "Send"); | REGISTER_OPTYPE_DECLARE(SEND, "Send"); | ||||
@@ -441,6 +442,7 @@ REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | |||||
REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | ||||
REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | ||||
REGISTER_OPTYPE_DECLARE(LABELGOTOEX, "LabelGotoEx"); | |||||
REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); | REGISTER_OPTYPE_DECLARE(LABELSWITCH, "LabelSwitch"); | ||||
REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | REGISTER_OPTYPE_DECLARE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | ||||
@@ -979,9 +979,14 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_N_BATCH_SPILT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NO_TASK_AND_DUMP_NEEDED; | ||||
// functional ops attr | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; | |||||
// used for label switch | // used for label switch | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | ||||
// Varible | // Varible | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | ||||
@@ -22,7 +22,7 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/anchor.h" | #include "graph/anchor.h" | ||||
#include "detail/attributes_holder.h" | |||||
#include "graph/detail/attributes_holder.h" | |||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
@@ -262,6 +262,8 @@ class GraphUtils { | |||||
static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | static graphStatus MoveOutCtrlEdges(NodePtr &src_node, NodePtr &dst_node); | ||||
static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | static ComputeGraphPtr FindRootGraph(ComputeGraphPtr graph); | ||||
static graphStatus TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec); | |||||
}; | }; | ||||
class ComputeGraphBuilder { | class ComputeGraphBuilder { | ||||
@@ -54,17 +54,34 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY size_t ComputeGraph::GetAllNodesS | |||||
return s; | return s; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const { | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::Vistor<NodePtr> ComputeGraph::GetAllNodes() const { | ||||
vector<NodePtr> all_nodes(nodes_.size()); | |||||
(void)std::copy(nodes_.begin(), nodes_.end(), all_nodes.begin()); | |||||
for (const auto &sub_graph : sub_graph_) { | |||||
if (sub_graph == nullptr) { | |||||
GELOGW("sub graph is nullptr"); | |||||
if (sub_graph_.empty()) { | |||||
return Vistor<NodePtr>(shared_from_this(), nodes_); | |||||
} | |||||
std::vector<NodePtr> all_nodes; | |||||
std::deque<NodePtr> candidates; | |||||
candidates.insert(candidates.begin(), nodes_.begin(), nodes_.end()); | |||||
while (!candidates.empty()) { | |||||
NodePtr node = candidates.front(); | |||||
all_nodes.emplace_back(node); | |||||
candidates.pop_front(); | |||||
OpDescPtr op_desc = node->GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
continue; | continue; | ||||
} | } | ||||
for (const auto &node : sub_graph->GetAllNodes()) { | |||||
all_nodes.push_back(node); | |||||
const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||||
for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { | |||||
auto subgraph = GetSubgraph(*name_iter); | |||||
if (subgraph != nullptr) { | |||||
candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); | |||||
} | |||||
} | } | ||||
} | } | ||||
return Vistor<NodePtr>(shared_from_this(), all_nodes); | return Vistor<NodePtr>(shared_from_this(), all_nodes); | ||||
} | } | ||||
size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } | size_t ComputeGraph::GetDirectNodesSize() const { return nodes_.size(); } | ||||
@@ -602,7 +619,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::InsertE | |||||
graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | ||||
std::map<NodePtr, uint32_t> &map_in_edge_num, | std::map<NodePtr, uint32_t> &map_in_edge_num, | ||||
std::vector<NodePtr> &stack) { | std::vector<NodePtr> &stack) { | ||||
GELOGI("Runing_Dfs_Sort"); | |||||
GELOGI("Runing_Dfs_Sort: %s", name_.c_str()); | |||||
// Record the number of non data nodes but no input nodes | // Record the number of non data nodes but no input nodes | ||||
GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); | GE_CHK_BOOL_EXEC(SortNodes(stack, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); | ||||
@@ -647,7 +664,7 @@ graphStatus ComputeGraph::DFSTopologicalSorting(std::vector<NodePtr> &node_vec, | |||||
graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec, | graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec, | ||||
std::map<NodePtr, uint32_t> &map_in_edge_num, | std::map<NodePtr, uint32_t> &map_in_edge_num, | ||||
std::deque<NodePtr> &stack) { | std::deque<NodePtr> &stack) { | ||||
GELOGI("Runing_Bfs_Sort"); | |||||
GELOGI("Runing_Bfs_Sort: %s", name_.c_str()); | |||||
std::vector<NodePtr> stack_input; | std::vector<NodePtr> stack_input; | ||||
std::map<string, NodePtr> breadth_node_map; | std::map<string, NodePtr> breadth_node_map; | ||||
// Record the number of non data nodes but no input nodes | // Record the number of non data nodes but no input nodes | ||||
@@ -735,7 +752,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog | |||||
use_BFS = true; | use_BFS = true; | ||||
} | } | ||||
} else { | } else { | ||||
GELOGW("Get OPTION_GRAPH_RUN_MODE failed, use BFSTopologicalSorting by default."); | |||||
GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); | |||||
} | } | ||||
if (use_BFS) { | if (use_BFS) { | ||||
@@ -955,11 +955,8 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | |||||
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | ||||
// functional ops attr | // functional ops attr | ||||
const std::string ATTR_NAME_TCOND = "Tcond"; | |||||
const std::string ATTR_NAME_TIN = "Tin"; | |||||
const std::string ATTR_NAME_TOUT = "Tout"; | |||||
const std::string ATTR_NAME_THEN_BRANCH = "then_branch"; | |||||
const std::string ATTR_NAME_ELSE_BRANCH = "else_branch"; | |||||
const std::string ATTR_NAME_WHILE_COND = "cond"; | |||||
const std::string ATTR_NAME_WHILE_BODY = "body"; | |||||
// used for label switch | // used for label switch | ||||
const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | ||||
@@ -28,6 +28,7 @@ | |||||
#include <cstring> | #include <cstring> | ||||
#include <fstream> | #include <fstream> | ||||
#include <iomanip> | #include <iomanip> | ||||
#include <queue> | |||||
#include "./ge_context.h" | #include "./ge_context.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
@@ -1999,4 +2000,60 @@ void PartialGraphBuilder::BuildExistNodes(graphStatus &error_code, std::string & | |||||
GELOGD("Build exist nodes succ."); | GELOGD("Build exist nodes succ."); | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
GraphUtils::TopologicalSortingByName(const ge::ComputeGraphPtr &compute_graph, vector<NodePtr> &node_vec) { | |||||
std::vector<NodePtr> stack_input; | |||||
std::map<NodePtr, uint32_t> map_in_edge_num; | |||||
graphStatus ret = compute_graph->SortNodes(stack_input, map_in_edge_num); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Sort nodes failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
const size_t non_user_input_index = stack_input.size() - compute_graph->inputs_order_.size() - 1; | |||||
std::sort(stack_input.begin(), stack_input.begin() + non_user_input_index, | |||||
[](const NodePtr &a, const NodePtr &b) -> bool { return (a->GetName() > b->GetName()); }); | |||||
std::queue<NodePtr> stack; | |||||
NodePtr cur_node = nullptr; | |||||
std::map<string, NodePtr> name_node_map; | |||||
vector<string> nodes_name; | |||||
while (!stack_input.empty() || !stack.empty()) { | |||||
if (!stack.empty()) { | |||||
cur_node = stack.front(); | |||||
stack.pop(); | |||||
} else { | |||||
cur_node = stack_input.back(); | |||||
stack_input.pop_back(); | |||||
} | |||||
node_vec.emplace_back(cur_node); | |||||
compute_graph->CollectBreadthOutNode(cur_node, map_in_edge_num, name_node_map); | |||||
for (const auto &iter : name_node_map) { | |||||
nodes_name.emplace_back(iter.first); | |||||
} | |||||
std::sort(nodes_name.begin(), nodes_name.end()); | |||||
for (const auto &iter : nodes_name) { | |||||
stack.push(name_node_map[iter]); | |||||
} | |||||
name_node_map.clear(); | |||||
nodes_name.clear(); | |||||
} | |||||
// If they are not equal, there is a closed loop | |||||
if (node_vec.size() != compute_graph->nodes_.size()) { | |||||
std::set<Node *> itered_nodes_set; | |||||
for (auto &node : node_vec) { | |||||
itered_nodes_set.insert(node.get()); | |||||
} | |||||
GE_LOGE("Failed to do topo sorting total %zu, itered %zu, exist closed loop in graph.", | |||||
compute_graph->nodes_.size(), node_vec.size()); | |||||
for (auto &node : compute_graph->nodes_) { | |||||
if (itered_nodes_set.count(node.get()) == 0) { | |||||
GE_LOGE("The node %s does not itered when topological sorting", node->GetName().c_str()); | |||||
} | |||||
} | |||||
return GRAPH_FAILED; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -41,6 +41,7 @@ include_directories(${GE_SOURCE_DIR}/inc/external/graph) | |||||
include_directories(${GE_SOURCE_DIR}/inc/framework) | include_directories(${GE_SOURCE_DIR}/inc/framework) | ||||
include_directories(${GE_SOURCE_DIR}/inc/framework/common) | include_directories(${GE_SOURCE_DIR}/inc/framework/common) | ||||
include_directories(${GE_SOURCE_DIR}/inc/runtime) | include_directories(${GE_SOURCE_DIR}/inc/runtime) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib) | |||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | include_directories(${GE_SOURCE_DIR}/third_party/fwkacllib/inc/cce) | ||||
include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | include_directories(${GE_SOURCE_DIR}/third_party/securec/include) | ||||
@@ -55,6 +56,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"common/formats/utils/formats_trans_utils.cc" | "common/formats/utils/formats_trans_utils.cc" | ||||
"common/fp16_t.cc" | "common/fp16_t.cc" | ||||
"common/ge/plugin_manager.cc" | "common/ge/plugin_manager.cc" | ||||
"common/helper/model_cache_helper.cc" | |||||
"common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
"engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
"ge_local_engine/engine/host_cpu_engine.cc" | "ge_local_engine/engine/host_cpu_engine.cc" | ||||
@@ -92,6 +94,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/load/new_model_manager/task_info/kernel_task_info.cc" | "graph/load/new_model_manager/task_info/kernel_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/label_goto_task_info.cc" | "graph/load/new_model_manager/task_info/label_goto_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/label_set_task_info.cc" | "graph/load/new_model_manager/task_info/label_set_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
@@ -269,6 +272,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"common/formats/utils/formats_trans_utils.cc" | "common/formats/utils/formats_trans_utils.cc" | ||||
"common/fp16_t.cc" | "common/fp16_t.cc" | ||||
"common/ge/plugin_manager.cc" | "common/ge/plugin_manager.cc" | ||||
"common/helper/model_cache_helper.cc" | |||||
"common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
"engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
"ge_local_engine/engine/host_cpu_engine.cc" | "ge_local_engine/engine/host_cpu_engine.cc" | ||||
@@ -305,6 +309,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/load/new_model_manager/task_info/kernel_task_info.cc" | "graph/load/new_model_manager/task_info/kernel_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/label_goto_task_info.cc" | "graph/load/new_model_manager/task_info/label_goto_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/label_set_task_info.cc" | "graph/load/new_model_manager/task_info/label_set_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||||
"graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | "graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
"graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
@@ -470,7 +475,7 @@ target_link_libraries(ge_compiler | |||||
${slog} | ${slog} | ||||
${mmpa} | ${mmpa} | ||||
${msprof} | ${msprof} | ||||
${runtime} | |||||
${runtime_compiler} | |||||
${resouce} | ${resouce} | ||||
rt | rt | ||||
dl) | dl) |
@@ -134,10 +134,6 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
} | } | ||||
auto trans_mode = iter->second; | auto trans_mode = iter->second; | ||||
if (args.src_data_size == 0) { | |||||
GELOGE(PARAM_INVALID, "Invalid src data size %zu", args.src_data_size); | |||||
return PARAM_INVALID; | |||||
} | |||||
int size = GetSizeByDataType(args.dst_data_type); | int size = GetSizeByDataType(args.dst_data_type); | ||||
if (size <= 0) { | if (size <= 0) { | ||||
GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", | GELOGE(PARAM_INVALID, "Failed to calc size from data type %s", | ||||
@@ -149,6 +145,12 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
size_t total_size = static_cast<size_t>(args.src_data_size * size); | size_t total_size = static_cast<size_t>(args.src_data_size * size); | ||||
result.length = total_size; | |||||
if (total_size == 0) { | |||||
GELOGI("In TransDataType, total_size is zero, has no data."); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); | GELOGE(OUT_OF_MEMORY, "Failed to alloc the memory for dst buf %zu, data size %zu", total_size, args.src_data_size); | ||||
@@ -162,7 +164,6 @@ Status DataTypeTransfer::TransDataType(const CastArgs &args, TransResult &result | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
result.data = dst; | result.data = dst; | ||||
result.length = total_size; | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -134,6 +134,11 @@ Status FormatTransferC1hwncoc0Hwcn::TransFormat(const TransArgs &args, TransResu | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t total_size = GetItemNumByShape(args.dst_shape) * size; | int64_t total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -88,6 +88,11 @@ Status TransFormatDhwckToFz3D(const TransArgs &args, TransResult &result) { | |||||
dst_size *= dim; | dst_size *= dim; | ||||
} | } | ||||
dst_size *= data_size; | dst_size *= data_size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -89,6 +89,11 @@ Status TransFormatDhwncToFz3DTranspose(const TransArgs &args, TransResult &resul | |||||
dst_size *= dim; | dst_size *= dim; | ||||
} | } | ||||
dst_size *= data_size; | dst_size *= data_size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -116,6 +116,11 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -184,6 +189,11 @@ Status TransFormatFromNdToFracNz(const TransArgs &args, TransResult &result, con | |||||
Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | Status TransFormatFromFracNzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -119,6 +119,11 @@ Status TransFormatFromNchwToFz(const TransArgs &args, TransResult &result) { | |||||
int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | int64_t total_ele_cnt = hf_cnt * vf_cnt * fractal_ele_cnt; | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = total_ele_cnt * size; | int64_t dst_size = total_ele_cnt * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -194,6 +199,11 @@ Status TransFormatHwcnToFz(const TransArgs &args, TransResult &result) { | |||||
dst_size *= dim; | dst_size *= dim; | ||||
} | } | ||||
dst_size *= data_size; | dst_size *= data_size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -259,6 +269,11 @@ Status TransFormatNhwcToFz(const TransArgs &args, TransResult &result) { | |||||
dst_size *= dim; | dst_size *= dim; | ||||
} | } | ||||
dst_size *= data_size; | dst_size *= data_size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -117,6 +117,11 @@ Status CheckShapeRelation(const TransArgs &args, ShapeVector &hw_shape) { | |||||
Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, const ShapeVector &hw_shape) { | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -153,8 +158,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||||
auto src_offset = (src_h_head + w1_idx * w0) * size; | auto src_offset = (src_h_head + w1_idx * w0) * size; | ||||
auto dst_offset = (h0_head + w1_idx * h0w0) * size; | auto dst_offset = (h0_head + w1_idx * h0w0) * size; | ||||
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size * w0)); | static_cast<size_t>(size * w0)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
@@ -169,8 +174,8 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||||
auto src_offset = (src_h_head + src_w_idx) * size; | auto src_offset = (src_h_head + src_w_idx) * size; | ||||
auto dst_offset = (w0_head + w0_idx) * size; | auto dst_offset = (w0_head + w0_idx) * size; | ||||
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
@@ -189,6 +194,11 @@ Status TransFormatFromNdToFracZz(const TransArgs &args, TransResult &result, con | |||||
Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, const ShapeVector &dst_hw_shape) { | ||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | int64_t dst_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (dst_size == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size](), std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -226,8 +236,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||||
auto src_offset = (h0_head + w1_idx * h0w0) * size; | auto src_offset = (h0_head + w1_idx * h0w0) * size; | ||||
auto dst_offset = (dst_h_head + w1_idx * w0) * size; | auto dst_offset = (dst_h_head + w1_idx * w0) * size; | ||||
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size * w0)); | static_cast<size_t>(size * w0)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
@@ -242,8 +252,8 @@ Status TransFormatFromFracZzToNd(const TransArgs &args, TransResult &result, con | |||||
auto dst_w_idx = w1_head + w0_idx; | auto dst_w_idx = w1_head + w0_idx; | ||||
auto dst_offset = (dst_h_head + dst_w_idx) * size; | auto dst_offset = (dst_h_head + dst_w_idx) * size; | ||||
auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? dst_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
@@ -133,6 +133,12 @@ Status FormatTransferFracZHwcn::TransFormat(const TransArgs &args, TransResult & | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -133,6 +133,12 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -140,6 +146,7 @@ Status FormatTransferFracZNchw::TransFormat(const TransArgs &args, TransResult & | |||||
GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | GELOGD("Begin to trans format from FracZ to NCHW, src shape %s, data type %s, dst shape %s, memory size %ld", | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
ShapeToString(args.dst_shape).c_str(), total_size); | ShapeToString(args.dst_shape).c_str(), total_size); | ||||
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | ||||
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | ||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | ||||
@@ -132,6 +132,12 @@ Status FormatTransferFracZNhwc::TransFormat(const TransArgs &args, TransResult & | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -35,7 +35,7 @@ Status TransShapeHwcnToC1hwncoc0(const DataType &data_type, const std::vector<in | |||||
std::vector<int64_t> &dst_shape) { | std::vector<int64_t> &dst_shape) { | ||||
auto cube_size = GetCubeSizeByDataType(data_type); | auto cube_size = GetCubeSizeByDataType(data_type); | ||||
dst_shape.clear(); | dst_shape.clear(); | ||||
dst_shape.push_back((src_shape.at(kHwcnC) - 1) / cube_size + 1); | |||||
dst_shape.push_back(Ceil(src_shape.at(kHwcnC), static_cast<int64_t>(cube_size))); | |||||
dst_shape.push_back(src_shape.at(kHwcnH)); | dst_shape.push_back(src_shape.at(kHwcnH)); | ||||
dst_shape.push_back(src_shape.at(kHwcnW)); | dst_shape.push_back(src_shape.at(kHwcnW)); | ||||
dst_shape.push_back(src_shape.at(kHwcnN)); | dst_shape.push_back(src_shape.at(kHwcnN)); | ||||
@@ -169,6 +169,12 @@ Status FormatTransferHwcnC1hwncoc0::TransFormat(const TransArgs &args, TransResu | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -58,7 +58,7 @@ Status CheckArgsForNc1hwc0ToNchw(const TransArgs &args) { | |||||
} | } | ||||
if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNchwH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNchwW) || | if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNchwH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNchwW) || | ||||
src_shape.at(kNc1hwc0N) != dst_shape.at(kNchwN) || src_shape.at(kNc1hwc0C0) != c0 || | src_shape.at(kNc1hwc0N) != dst_shape.at(kNchwN) || src_shape.at(kNc1hwc0C0) != c0 || | ||||
src_shape.at(kNc1hwc0C1) != (dst_shape.at(kNchwC) - 1) / c0 + 1) { | |||||
src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNchwC), c0))) { | |||||
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | ||||
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -102,8 +102,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
auto src_offset = src_idx * size; | auto src_offset = src_idx * size; | ||||
auto dst_offset = dst_idx * size; | auto dst_offset = dst_idx * size; | ||||
auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? total_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? total_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
@@ -130,6 +130,12 @@ Status FormatTransferNc1hwc0Nchw::TransFormat(const TransArgs &args, TransResult | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -58,7 +58,7 @@ Status CheckArgsForNc1hwc0ToNhwc(const TransArgs &args) { | |||||
} | } | ||||
if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNhwcH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNhwcW) || | if (src_shape.at(kNc1hwc0H) != dst_shape.at(kNhwcH) || src_shape.at(kNc1hwc0W) != dst_shape.at(kNhwcW) || | ||||
src_shape.at(kNc1hwc0N) != dst_shape.at(kNhwcN) || src_shape.at(kNc1hwc0C0) != c0 || | src_shape.at(kNc1hwc0N) != dst_shape.at(kNhwcN) || src_shape.at(kNc1hwc0C0) != c0 || | ||||
src_shape.at(kNc1hwc0C1) != (dst_shape.at(kNhwcC) - 1) / c0 + 1) { | |||||
src_shape.at(kNc1hwc0C1) != (Ceil(dst_shape.at(kNhwcC), c0))) { | |||||
GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | GELOGE(PARAM_INVALID, "Failed to check relationship between src and dst shape, src shape %s, dst shape %s", | ||||
ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ShapeToString(src_shape).c_str(), ShapeToString(dst_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -102,8 +102,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
auto src_offset = src_idx * size; | auto src_offset = src_idx * size; | ||||
auto dst_offset = dst_idx * size; | auto dst_offset = dst_idx * size; | ||||
auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? total_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? total_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), args.data + src_offset, | ||||
static_cast<size_t>(size)); | static_cast<size_t>(size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
@@ -130,6 +130,12 @@ Status FormatTransferNc1hwc0Nhwc::TransFormat(const TransArgs &args, TransResult | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -134,6 +134,10 @@ Status TransFormatFromNchwToFzC04(const TransArgs &args, TransResult &result) { | |||||
GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", total_ele_cnt, size); | GELOGE(INTERNAL_ERROR, "int64 mul overflow.A[%lld], B[%lld]", total_ele_cnt, size); | ||||
return INTERNAL_ERROR); | return INTERNAL_ERROR); | ||||
int64_t dst_size = total_ele_cnt * size; | int64_t dst_size = total_ele_cnt * size; | ||||
if (dst_size == 0) { | |||||
result.length = 0; | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
@@ -219,6 +223,10 @@ Status PaddingNC(const TransArgs &args, TransArgs &args_tmp, std::shared_ptr<uin | |||||
return INTERNAL_ERROR); | return INTERNAL_ERROR); | ||||
int64_t dst_size = total_ele_cnt * size; | int64_t dst_size = total_ele_cnt * size; | ||||
if (dst_size == 0) { | |||||
return SUCCESS; | |||||
} | |||||
dst.reset(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | dst.reset(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | GELOGE(OUT_OF_MEMORY, "Failed to trans format from %s to %s, can not alloc the memory for dst buf %ld", | ||||
@@ -40,7 +40,7 @@ Status TransShapeNchwToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||||
} | } | ||||
dst_shape.clear(); | dst_shape.clear(); | ||||
dst_shape.push_back(src_shape.at(kNchwN)); | dst_shape.push_back(src_shape.at(kNchwN)); | ||||
dst_shape.push_back((src_shape.at(kNchwC) - 1) / c0 + 1); | |||||
dst_shape.push_back(Ceil(src_shape.at(kNchwC), c0)); | |||||
dst_shape.push_back(src_shape.at(kNchwH)); | dst_shape.push_back(src_shape.at(kNchwH)); | ||||
dst_shape.push_back(src_shape.at(kNchwW)); | dst_shape.push_back(src_shape.at(kNchwW)); | ||||
dst_shape.push_back(c0); | dst_shape.push_back(c0); | ||||
@@ -74,25 +74,8 @@ Status CheckArgsForNchwToNc1hwc0(const TransArgs &args) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // namespace | |||||
Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { | |||||
if (CheckArgsForNchwToNc1hwc0(args) != SUCCESS) { | |||||
return PARAM_INVALID; | |||||
} | |||||
// Guarantee the validity of parameters in check function | |||||
int size = GetSizeByDataType(args.src_data_type); | |||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
if (total_size <= 0) { | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
GELOGD( | |||||
"Begin to trans format from NCHW to NC1HWC0, src shape %s, data type " | |||||
"%s, dst shape %s memory size %ld", | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
ShapeToString(args.dst_shape).c_str(), total_size); | |||||
Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const int size, const int64_t total_size) { | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[total_size], std::default_delete<uint8_t[]>()); | ||||
if (dst == nullptr) { | if (dst == nullptr) { | ||||
GELOGE(OUT_OF_MEMORY, | GELOGE(OUT_OF_MEMORY, | ||||
@@ -132,8 +115,8 @@ Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult | |||||
int64_t dst_index = c0_idx + w_head_addr; | int64_t dst_index = c0_idx + w_head_addr; | ||||
int64_t dst_offset = dst_index * size; | int64_t dst_offset = dst_index * size; | ||||
auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? total_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? total_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
int64_t cIdx = c0_idx + c1_idx * c0; | int64_t cIdx = c0_idx + c1_idx * c0; | ||||
int64_t srcIdx = n_idx * chw + cIdx * hw + h_idx * w + w_idx; | int64_t srcIdx = n_idx * chw + cIdx * hw + h_idx * w + w_idx; | ||||
auto src_offset = srcIdx * size; | auto src_offset = srcIdx * size; | ||||
@@ -150,7 +133,7 @@ Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult | |||||
} | } | ||||
} else { | } else { | ||||
auto ret = | auto ret = | ||||
memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | |||||
memset_s(dst.get() + dst_offset, static_cast<size_t>(protected_size), 0, static_cast<size_t>(size)); | |||||
if (ret != EOK) { | if (ret != EOK) { | ||||
GELOGE(INTERNAL_ERROR, | GELOGE(INTERNAL_ERROR, | ||||
"Failed to set to 0 to " | "Failed to set to 0 to " | ||||
@@ -169,6 +152,39 @@ Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult | |||||
result.length = static_cast<size_t>(total_size); | result.length = static_cast<size_t>(total_size); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // namespace | |||||
Status FormatTransferNchwNc1hwc0::TransFormat(const TransArgs &args, TransResult &result) { | |||||
if (CheckArgsForNchwToNc1hwc0(args) != SUCCESS) { | |||||
return PARAM_INVALID; | |||||
} | |||||
// Guarantee the validity of parameters in check function | |||||
int size = GetSizeByDataType(args.src_data_type); | |||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | |||||
if (total_size <= 0) { | |||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | |||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
GELOGD( | |||||
"Begin to trans format from NCHW to NC1HWC0, src shape %s, data type " | |||||
"%s, dst shape %s memory size %ld", | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
ShapeToString(args.dst_shape).c_str(), total_size); | |||||
if (GetDstDataAfterTrans(args, result, size, total_size) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to get data after trans, src shape %s, data type %s, dst shape %s, memory size %ld", | |||||
ShapeToString(args.src_shape).c_str(), TypeUtils::DataTypeToSerialString(args.src_data_type).c_str(), | |||||
ShapeToString(args.dst_shape).c_str(), total_size); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status FormatTransferNchwNc1hwc0::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | Status FormatTransferNchwNc1hwc0::TransShape(Format src_format, const std::vector<int64_t> &src_shape, | ||||
DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | DataType data_type, Format dst_format, std::vector<int64_t> &dst_shape) { | ||||
@@ -38,7 +38,7 @@ Status TransShapeNhwcToNc1hwc0(const std::vector<int64_t> &src_shape, DataType d | |||||
} | } | ||||
dst_shape.clear(); | dst_shape.clear(); | ||||
dst_shape.push_back(src_shape.at(kNhwcN)); | dst_shape.push_back(src_shape.at(kNhwcN)); | ||||
dst_shape.push_back((src_shape.at(kNhwcC) - 1) / c0 + 1); | |||||
dst_shape.push_back(Ceil(src_shape.at(kNhwcC), c0)); | |||||
dst_shape.push_back(src_shape.at(kNhwcH)); | dst_shape.push_back(src_shape.at(kNhwcH)); | ||||
dst_shape.push_back(src_shape.at(kNhwcW)); | dst_shape.push_back(src_shape.at(kNhwcW)); | ||||
dst_shape.push_back(c0); | dst_shape.push_back(c0); | ||||
@@ -119,8 +119,8 @@ Status GetDstDataAfterTrans(const TransArgs &args, TransResult &result, const in | |||||
int64_t dst_idx = c0_idx + w_head_addr; | int64_t dst_idx = c0_idx + w_head_addr; | ||||
int64_t dst_offset = dst_idx * size; | int64_t dst_offset = dst_idx * size; | ||||
auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = total_size - dst_offset < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? total_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? total_size - dst_offset | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
int64_t c_idx = c0_idx + c1_idx * c0; | int64_t c_idx = c0_idx + c1_idx * c0; | ||||
int64_t src_idx = n_idx * hwc + h_idx * wc + w_idx * c + c_idx; | int64_t src_idx = n_idx * hwc + h_idx * wc + w_idx * c + c_idx; | ||||
auto src_offset = src_idx * size; | auto src_offset = src_idx * size; | ||||
@@ -161,6 +161,12 @@ Status FormatTransferNhwcNc1hwc0::TransFormat(const TransArgs &args, TransResult | |||||
int size = GetSizeByDataType(args.src_data_type); | int size = GetSizeByDataType(args.src_data_type); | ||||
auto total_size = GetItemNumByShape(args.dst_shape) * size; | auto total_size = GetItemNumByShape(args.dst_shape) * size; | ||||
if (total_size <= 0) { | if (total_size <= 0) { | ||||
int64_t src_size = GetItemNumByShape(args.src_shape); | |||||
if (total_size == 0 && src_size == 0) { | |||||
result.length = static_cast<size_t>(total_size); | |||||
return SUCCESS; | |||||
} | |||||
GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | GELOGE(INTERNAL_ERROR, "Get %ld total size from dst shape %s, src shape %s", total_size, | ||||
ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ShapeToString(args.dst_shape).c_str(), ShapeToString(args.src_shape).c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
@@ -27,22 +27,22 @@ namespace ge { | |||||
namespace formats { | namespace formats { | ||||
namespace { | namespace { | ||||
std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{ | std::map<Format, std::map<Format, std::vector<int64_t>>> perm_args{ | ||||
{FORMAT_NCHW, | |||||
{{FORMAT_NHWC, std::vector<int64_t>({0, 2, 3, 1})}, | |||||
{FORMAT_HWCN, std::vector<int64_t>({2, 3, 1, 0})}, | |||||
{FORMAT_CHWN, std::vector<int64_t>({1, 2, 3, 0})}}}, | |||||
{FORMAT_NHWC, | |||||
{{FORMAT_NCHW, std::vector<int64_t>({0, 3, 1, 2})}, | |||||
{FORMAT_CHWN, std::vector<int64_t>({3, 1, 2, 0})}, | |||||
{FORMAT_HWCN, std::vector<int64_t>({1, 2, 3, 0})}}}, | |||||
{FORMAT_HWCN, | |||||
{{FORMAT_NCHW, std::vector<int64_t>({3, 2, 0, 1})}, | |||||
{FORMAT_NHWC, std::vector<int64_t>({3, 0, 1, 2})}, | |||||
{FORMAT_CHWN, std::vector<int64_t>({2, 0, 1, 3})}}}, | |||||
{FORMAT_CHWN, | |||||
{{FORMAT_NCHW, std::vector<int64_t>({3, 0, 1, 2})}, | |||||
{FORMAT_NHWC, std::vector<int64_t>({3, 1, 2, 0})}, | |||||
{FORMAT_HWCN, std::vector<int64_t>({1, 2, 0, 3})}}}, | |||||
{FORMAT_NCHW, | |||||
{{FORMAT_NHWC, std::vector<int64_t>({0, 2, 3, 1})}, | |||||
{FORMAT_HWCN, std::vector<int64_t>({2, 3, 1, 0})}, | |||||
{FORMAT_CHWN, std::vector<int64_t>({1, 2, 3, 0})}}}, | |||||
{FORMAT_NHWC, | |||||
{{FORMAT_NCHW, std::vector<int64_t>({0, 3, 1, 2})}, | |||||
{FORMAT_CHWN, std::vector<int64_t>({3, 1, 2, 0})}, | |||||
{FORMAT_HWCN, std::vector<int64_t>({1, 2, 3, 0})}}}, | |||||
{FORMAT_HWCN, | |||||
{{FORMAT_NCHW, std::vector<int64_t>({3, 2, 0, 1})}, | |||||
{FORMAT_NHWC, std::vector<int64_t>({3, 0, 1, 2})}, | |||||
{FORMAT_CHWN, std::vector<int64_t>({2, 0, 1, 3})}}}, | |||||
{FORMAT_CHWN, | |||||
{{FORMAT_NCHW, std::vector<int64_t>({3, 0, 1, 2})}, | |||||
{FORMAT_NHWC, std::vector<int64_t>({3, 1, 2, 0})}, | |||||
{FORMAT_HWCN, std::vector<int64_t>({1, 2, 0, 3})}}}, | |||||
}; | }; | ||||
bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<int64_t> &perm_arg) { | ||||
@@ -51,8 +51,8 @@ bool IsShapeArgValid(const std::vector<int64_t> &src_shape, const std::vector<in | |||||
return false; | return false; | ||||
} | } | ||||
for (auto dim : src_shape) { | for (auto dim : src_shape) { | ||||
if (dim <= 0) { | |||||
GELOGE(PARAM_INVALID, "Failed to transpose, zero dim in src shape %s", ShapeToString(src_shape).c_str()); | |||||
if (dim < 0) { | |||||
GELOGE(PARAM_INVALID, "Failed to transpose, negative dim in src shape %s", ShapeToString(src_shape).c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -146,20 +146,24 @@ Status Transpose(const uint8_t *src, const std::vector<int64_t> &src_shape, Data | |||||
int64_t dst_ele_num = GetItemNumByShape(dst_shape); | int64_t dst_ele_num = GetItemNumByShape(dst_shape); | ||||
int64_t data_size = GetSizeByDataType(src_data_type); | int64_t data_size = GetSizeByDataType(src_data_type); | ||||
int64_t dst_size = data_size * dst_ele_num; | int64_t dst_size = data_size * dst_ele_num; | ||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
GELOGD("Begin to transpose, src shape %s, perm arg %s, dst shape %s, data type %s", JoinToString(src_shape).c_str(), | GELOGD("Begin to transpose, src shape %s, perm arg %s, dst shape %s, data type %s", JoinToString(src_shape).c_str(), | ||||
JoinToString(perm_arg).c_str(), JoinToString(dst_shape).c_str(), | JoinToString(perm_arg).c_str(), JoinToString(dst_shape).c_str(), | ||||
TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | ||||
if (dst_ele_num == 0) { | |||||
result.length = static_cast<size_t>(dst_size); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<uint8_t> dst(new (std::nothrow) uint8_t[dst_size], std::default_delete<uint8_t[]>()); | |||||
int64_t dst_index = 0; | int64_t dst_index = 0; | ||||
std::vector<int64_t> dst_indexes(dst_shape.size()); | std::vector<int64_t> dst_indexes(dst_shape.size()); | ||||
while (dst_index < dst_ele_num) { | while (dst_index < dst_ele_num) { | ||||
auto src_offset = GenOffset(src_heads, dst_indexes) * data_size; | auto src_offset = GenOffset(src_heads, dst_indexes) * data_size; | ||||
auto dst_offset_bytes = dst_index * data_size; | auto dst_offset_bytes = dst_index * data_size; | ||||
auto protected_size = dst_size - dst_offset_bytes < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | auto protected_size = dst_size - dst_offset_bytes < static_cast<int64_t>(SECUREC_MEM_MAX_LEN) | ||||
? dst_size - dst_offset_bytes | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
? dst_size - dst_offset_bytes | |||||
: static_cast<int64_t>(SECUREC_MEM_MAX_LEN); | |||||
auto ret = memcpy_s(dst.get() + dst_offset_bytes, static_cast<size_t>(protected_size), src + src_offset, | auto ret = memcpy_s(dst.get() + dst_offset_bytes, static_cast<size_t>(protected_size), src + src_offset, | ||||
static_cast<size_t>(data_size)); | static_cast<size_t>(data_size)); | ||||
if (ret != EOK) { | if (ret != EOK) { | ||||
@@ -24,6 +24,7 @@ | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include "common/formats/utils/formats_trans_utils.h" | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
@@ -38,10 +39,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransFormat(const TransArg | |||||
TypeUtils::FormatToSerialString(args.dst_format).c_str()); | TypeUtils::FormatToSerialString(args.dst_format).c_str()); | ||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (args.data == nullptr) { | |||||
auto src_shape_size = GetItemNumByShape(args.src_shape); | |||||
if (args.data == nullptr && src_shape_size != 0) { | |||||
GELOGE(PARAM_INVALID, "Invalid input null data"); | GELOGE(PARAM_INVALID, "Invalid input null data"); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
return transfer->TransFormat(args, result); | return transfer->TransFormat(args, result); | ||||
} | } | ||||
@@ -71,6 +75,12 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Status TransDataType(const CastAr | |||||
TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | TypeUtils::DataTypeToSerialString(args.dst_data_type).c_str()); | ||||
return UNSUPPORTED; | return UNSUPPORTED; | ||||
} | } | ||||
if (args.data == nullptr && args.src_data_size != 0) { | |||||
GELOGE(PARAM_INVALID, "Invalid input null data"); | |||||
return PARAM_INVALID; | |||||
} | |||||
return transfer->TransDataType(args, result); | return transfer->TransDataType(args, result); | ||||
} | } | ||||
@@ -69,11 +69,11 @@ bool IsShapeValid(const std::vector<int64_t> &shape) { | |||||
} | } | ||||
int64_t num = 1; | int64_t num = 1; | ||||
for (auto dim : shape) { | for (auto dim : shape) { | ||||
if (dim < 1) { | |||||
GELOGE(PARAM_INVALID, "Invalid zero dim in the shape %s", ShapeToString(shape).c_str()); | |||||
if (dim < 0) { | |||||
GELOGE(PARAM_INVALID, "Invalid negative dim in the shape %s", ShapeToString(shape).c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
if (kShapeItemNumMAX / dim < num) { | |||||
if (dim != 0 && kShapeItemNumMAX / dim < num) { | |||||
GELOGE(PARAM_INVALID, "Shape overflow, the total count should be less than %ld!", kShapeItemNumMAX); | GELOGE(PARAM_INVALID, "Shape overflow, the total count should be less than %ld!", kShapeItemNumMAX); | ||||
return false; | return false; | ||||
} | } | ||||
@@ -64,6 +64,9 @@ bool IsShapeEqual(const GeShape &src, const GeShape &dst); | |||||
template <typename T> | template <typename T> | ||||
T Ceil(T n1, T n2) { | T Ceil(T n1, T n2) { | ||||
if (n1 == 0) { | |||||
return 0; | |||||
} | |||||
return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; | return (n2 != 0) ? (n1 - 1) / n2 + 1 : 0; | ||||
} | } | ||||
@@ -0,0 +1,121 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ | |||||
#define GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ | |||||
#include <nlohmann/json.hpp> | |||||
#include <set> | |||||
#include <string> | |||||
#include "ge/ge_api_error_codes.h" | |||||
#include "graph/compute_graph.h" | |||||
#include "graph/manager/graph_var_manager.h" | |||||
#include "model/ge_model.h" | |||||
namespace ge { | |||||
using Json = nlohmann::json; | |||||
struct CacheInfo { | |||||
size_t node_num; | |||||
size_t edge_num; | |||||
size_t graph_hash; | |||||
map<std::string, size_t> nodes_hash; | |||||
CacheInfo() : node_num(0), edge_num(0), graph_hash(0) {} | |||||
}; | |||||
class ModelCacheHelper { | |||||
public: | |||||
ModelCacheHelper(uint64_t session_id, uint32_t graph_id, ComputeGraphPtr &compute_graph); | |||||
Status SaveCacheInfoToCache() const; | |||||
Status SaveVarManagerToCache(bool before_build) const; | |||||
Status SaveOmModelToCache(const GeModelPtr &ge_model) const; | |||||
bool IsModelCacheHit() const; | |||||
Status RecoverVarManagerFromCache() const; | |||||
Status LoadOmModelFromCache(GeModelPtr &ge_model) const; | |||||
Status RefreshComputeGraph(const ComputeGraphPtr &compute_graph); | |||||
Status ClearCache(uint32_t graph_id) const; | |||||
private: | |||||
Status GetComputeGraphHash(size_t &hash) const; | |||||
Status GetNodesHash(map<std::string, size_t> &hash_map) const; | |||||
Status GetCacheInfo(CacheInfo &cache_info) const; | |||||
Status RecoverMemResource(const Json &json) const; | |||||
Status RecoverAllocatedGraphId(const Json &json) const; | |||||
Status RecoverChangedGraphId(const Json &json) const; | |||||
Status RecoverVarAddrAndTensorDesc(const Json &json) const; | |||||
Status RecoverBroadcastInfo(const Json &json) const; | |||||
Status RecoverTransRoads(const Json &json) const; | |||||
static Status RecompileNodes(GeModelPtr &ge_model); | |||||
bool IsNodeHashSameAsCache(const map<std::string, size_t> &hash_map) const; | |||||
bool IsMemResourceSameAsCache(Json &json) const; | |||||
bool IsChangedGraphIdSameAsCache(Json &json) const; | |||||
bool IsAllocatedGraphIdSameAsCache(Json &json) const; | |||||
bool IsCurVarTensorDescSameAsCache(Json &json) const; | |||||
bool IsVarAddrMgrMapSameAsCache(Json &json) const; | |||||
bool IsBroadcastInfoSameAsCache(Json &json) const; | |||||
bool IsTransRoadsSameAsCache(Json &json) const; | |||||
bool IsVarManagerSameAsCache(Json &json) const; | |||||
bool IsVarManagerParamSameAsCache(Json &json) const; | |||||
Status SaveJsonToFile(const string &file_name, const Json &json) const; | |||||
Status LoadJsonFromFile(const string &file_name, Json &json) const; | |||||
Status GetNodesHashMapJson(Json &json) const; | |||||
Status GetMemResourceMap(Json &json) const; | |||||
Status GetVarAddrMgrMapJson(Json &json) const; | |||||
Status GetCurVarTensorDescMapJson(Json &json) const; | |||||
Status GetTransRoadsJson(Json &json) const; | |||||
Status GetChangedGraphIdJson(Json &json) const; | |||||
Status GetAllocatedGraphIdJson(Json &json) const; | |||||
Status GetBroadcastInfoJson(Json &json) const; | |||||
Status GetVarResourceJson(Json &json) const; | |||||
Status GetVarManagerJson(Json &json) const; | |||||
static Status TensorDescToJson(const GeTensorDesc &ge_tensor_desc, Json &json); | |||||
static Status JsonToTensorDesc(const Json &json, GeTensorDesc &ge_tensor_desc); | |||||
static Status ParseMemResourceFromJson(const Json &json, map<rtMemType_t, int64_t> &mem_resource); | |||||
static Status ParseVarAddrMgrMapFromJson(const Json &json, | |||||
std::vector<std::pair<std::string, VarAddrMgr>> &var_addr_mgr_vector, | |||||
std::unordered_set<uint64_t> &var_offset_set); | |||||
static Status ParseCurVarTensorDescMapFromJson( | |||||
const Json &json, std::unordered_map<std::string, ge::GeTensorDesc> &cur_var_tensor_desc_map); | |||||
static Status ParseTransRoadsFromJson(const Json &json, | |||||
std::unordered_map<std::string, std::vector<TransNodeInfo>> &trans_roads); | |||||
static Status ParseChangedGraphIdFromJson(const Json &json, | |||||
std::unordered_map<std::string, uint32_t> &changed_graph_id); | |||||
static Status ParseAllocatedGraphIdFromJson(const Json &json, | |||||
std::unordered_map<std::string, uint32_t> &allocated_graph_id); | |||||
static Status ParseBroadcastInfoFromJson(const Json &json, | |||||
std::unordered_map<std::string, VarBroadCastInfo> &var_broadcast_info); | |||||
static Status GetVarNameFromVarKey(const string &var_key, const GeTensorDesc &tensor_desc, string &var_name); | |||||
uint64_t session_id_; | |||||
uint32_t graph_id_; | |||||
string cache_path_; | |||||
ComputeGraphPtr compute_graph_; | |||||
std::set<string> var_names_; | |||||
bool is_cache_path_valid_for_output; | |||||
static map<uint32_t, uint32_t> graph_id_run_times_; | |||||
}; | |||||
using ModelCacheHelperPtr = std::shared_ptr<ModelCacheHelper>; | |||||
} // namespace ge | |||||
#endif // GE_COMMON_HELPER_MODEL_CACHE_HELPER_H_ |
@@ -385,6 +385,7 @@ REGISTER_OPTYPE_DEFINE(STREAMSWITCH, "StreamSwitch"); | |||||
REGISTER_OPTYPE_DEFINE(STREAMSWITCHN, "StreamSwitchN"); | REGISTER_OPTYPE_DEFINE(STREAMSWITCHN, "StreamSwitchN"); | ||||
REGISTER_OPTYPE_DEFINE(STREAMACTIVE, "StreamActive"); | REGISTER_OPTYPE_DEFINE(STREAMACTIVE, "StreamActive"); | ||||
REGISTER_OPTYPE_DEFINE(MEMCPYASYNC, "MemcpyAsync"); | REGISTER_OPTYPE_DEFINE(MEMCPYASYNC, "MemcpyAsync"); | ||||
REGISTER_OPTYPE_DEFINE(MEMCPYADDRASYNC, "MemcpyAddrAsync"); | |||||
REGISTER_OPTYPE_DEFINE(STREAMMERGE, "StreamMerge"); | REGISTER_OPTYPE_DEFINE(STREAMMERGE, "StreamMerge"); | ||||
REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); | REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); | ||||
REGISTER_OPTYPE_DEFINE(SEND, "Send"); | REGISTER_OPTYPE_DEFINE(SEND, "Send"); | ||||
@@ -392,6 +393,7 @@ REGISTER_OPTYPE_DEFINE(RECV, "Recv"); | |||||
REGISTER_OPTYPE_DEFINE(LABELSET, "LabelSet"); | REGISTER_OPTYPE_DEFINE(LABELSET, "LabelSet"); | ||||
REGISTER_OPTYPE_DEFINE(LABELGOTO, "LabelGoto"); | REGISTER_OPTYPE_DEFINE(LABELGOTO, "LabelGoto"); | ||||
REGISTER_OPTYPE_DEFINE(LABELGOTOEX, "LabelGotoEx"); | |||||
REGISTER_OPTYPE_DEFINE(LABELSWITCH, "LabelSwitch"); | REGISTER_OPTYPE_DEFINE(LABELSWITCH, "LabelSwitch"); | ||||
REGISTER_OPTYPE_DEFINE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | REGISTER_OPTYPE_DEFINE(LABELSWITCHBYINDEX, "LabelSwitchByIndex"); | ||||
@@ -196,7 +196,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||||
GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); | GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); | ||||
auto dir_path_len = directory_path.length(); | auto dir_path_len = directory_path.length(); | ||||
if (dir_path_len >= PATH_MAX) { | if (dir_path_len >= PATH_MAX) { | ||||
GELOGE(ge::FAILED, "Directory path is too long."); | |||||
GELOGW("Directory path is too long."); | |||||
return -1; | return -1; | ||||
} | } | ||||
char tmp_dir_path[PATH_MAX] = {0}; | char tmp_dir_path[PATH_MAX] = {0}; | ||||
@@ -207,7 +207,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||||
int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 | int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 | ||||
if (ret != 0) { | if (ret != 0) { | ||||
if (errno != EEXIST) { | if (errno != EEXIST) { | ||||
GELOGE(ge::FAILED, "Cannot create directory %s. Make sure that the directory exists and writable.", | |||||
GELOGW("Cannot create directory %s. Make sure that the directory exists and writable.", | |||||
directory_path.c_str()); | directory_path.c_str()); | ||||
return ret; | return ret; | ||||
} | } | ||||
@@ -218,8 +218,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||||
int32_t ret = mmMkdir(const_cast<char *>(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 | int32_t ret = mmMkdir(const_cast<char *>(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 | ||||
if (ret != 0) { | if (ret != 0) { | ||||
if (errno != EEXIST) { | if (errno != EEXIST) { | ||||
GELOGE(ge::FAILED, "Cannot create directory %s. Make sure that the directory exists and writable.", | |||||
directory_path.c_str()); | |||||
GELOGW("Cannot create directory %s. Make sure that the directory exists and writable.", directory_path.c_str()); | |||||
return ret; | return ret; | ||||
} | } | ||||
} | } | ||||
@@ -339,7 +338,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const std::string &file_path) { | ||||
// The specified path is empty | // The specified path is empty | ||||
if (file_path.empty()) { | if (file_path.empty()) { | ||||
GELOGE(ge::FAILED, "Path is empty."); | |||||
GELOGW("Path is empty."); | |||||
return false; | return false; | ||||
} | } | ||||
@@ -358,23 +357,23 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const | |||||
std::string real_path = RealPath(file_path.c_str()); | std::string real_path = RealPath(file_path.c_str()); | ||||
// Unable to get absolute path (does not exist or does not have permission to access) | // Unable to get absolute path (does not exist or does not have permission to access) | ||||
if (real_path.empty()) { | if (real_path.empty()) { | ||||
GELOGE(ge::FAILED, "Can not get real path for %s, %s", file_path.c_str(), strerror(errno)); | |||||
GELOGW("Can not get real path for %s, %s", file_path.c_str(), strerror(errno)); | |||||
return false; | return false; | ||||
} | } | ||||
// The absolute path points to a file that is not readable | // The absolute path points to a file that is not readable | ||||
if (access(real_path.c_str(), R_OK) != 0) { | if (access(real_path.c_str(), R_OK) != 0) { | ||||
GELOGE(ge::FAILED, "Can not read file in %s, %s", file_path.c_str(), strerror(errno)); | |||||
GELOGW("Can not read file in %s, %s", file_path.c_str(), strerror(errno)); | |||||
return false; | return false; | ||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY bool CheckOutputPathValid(const std::string &file_path) { | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const std::string &file_path) { | |||||
// The specified path is empty | // The specified path is empty | ||||
if (file_path.empty()) { | if (file_path.empty()) { | ||||
GELOGE(ge::FAILED, "Path is empty."); | |||||
GELOGW("Path is empty."); | |||||
return false; | return false; | ||||
} | } | ||||
@@ -394,8 +393,8 @@ FMK_FUNC_HOST_VISIBILITY bool CheckOutputPathValid(const std::string &file_path) | |||||
// Can get absolute path (file exists) | // Can get absolute path (file exists) | ||||
if (!real_path.empty()) { | if (!real_path.empty()) { | ||||
// File is not readable or writable | // File is not readable or writable | ||||
if (access(real_path.c_str(), R_OK | W_OK | F_OK) != 0) { | |||||
GELOGE(ge::FAILED, "Path[ %s ] exists, but can not be write, %s", file_path.c_str(), strerror(errno)); | |||||
if (access(real_path.c_str(), W_OK | F_OK) != 0) { | |||||
GELOGW("Path[ %s ] exists, but can not be write, %s", file_path.c_str(), strerror(errno)); | |||||
return false; | return false; | ||||
} | } | ||||
} else { | } else { | ||||
@@ -413,7 +412,7 @@ FMK_FUNC_HOST_VISIBILITY bool CheckOutputPathValid(const std::string &file_path) | |||||
std::string prefix_path = std::string(file_path).substr(0, static_cast<size_t>(path_split_pos)); | std::string prefix_path = std::string(file_path).substr(0, static_cast<size_t>(path_split_pos)); | ||||
// Determine whether the specified path is valid by creating the path | // Determine whether the specified path is valid by creating the path | ||||
if (CreateDirectory(prefix_path) != 0) { | if (CreateDirectory(prefix_path) != 0) { | ||||
GELOGE(ge::FAILED, "Can not create prefix path for path[ %s ].", file_path.c_str()); | |||||
GELOGW("Can not create prefix path for path[ %s ].", file_path.c_str()); | |||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
@@ -47,6 +47,7 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"../graph/load/new_model_manager/task_info/kernel_task_info.cc" | "../graph/load/new_model_manager/task_info/kernel_task_info.cc" | ||||
"../graph/load/new_model_manager/task_info/label_goto_task_info.cc" | "../graph/load/new_model_manager/task_info/label_goto_task_info.cc" | ||||
"../graph/load/new_model_manager/task_info/label_set_task_info.cc" | "../graph/load/new_model_manager/task_info/label_set_task_info.cc" | ||||
"../graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.cc" | |||||
"../graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | "../graph/load/new_model_manager/task_info/memcpy_async_task_info.cc" | ||||
"../graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | "../graph/load/new_model_manager/task_info/profiler_trace_task_info.cc" | ||||
"../graph/load/new_model_manager/task_info/stream_active_task_info.cc" | "../graph/load/new_model_manager/task_info/stream_active_task_info.cc" | ||||
@@ -21,7 +21,7 @@ | |||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
#include "graph/operator.h" | #include "graph/operator.h" | ||||
#include "register/register.h" | |||||
#include "inc/register/register.h" | |||||
namespace ge { | namespace ge { | ||||
class HostCpuEngine { | class HostCpuEngine { | ||||
@@ -76,7 +76,7 @@ bool Output::CopyRslt(OutputData *rslt, uint32_t data_begin, uint32_t &data_inde | |||||
DataBuffer data_buf = rslt->blobs[data_begin + data_count]; | DataBuffer data_buf = rslt->blobs[data_begin + data_count]; | ||||
bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share); | bool ret = SetDataBuf(data_buf, data_begin, data_count, i, support_mem_share); | ||||
if (!ret) { | if (!ret) { | ||||
GELOGE(FAILED, "Copy data to host error. index: %lu", i); | |||||
GELOGE(FAILED, "Copy data to host error. index: %lu, addr: %p", i, v_input_data_addr_[i]); | |||||
return ret; | return ret; | ||||
} | } | ||||
data_index = data_begin + data_count; | data_index = data_begin + data_count; | ||||
@@ -96,6 +96,7 @@ bool RuntimeModel::InitStream(std::shared_ptr<DavinciModel> &davinci_model) { | |||||
GELOGE(RT_FAILED, "Call rt api rtModelBindStream failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api rtModelBindStream failed, ret: 0x%X", rt_ret); | ||||
return false; | return false; | ||||
} | } | ||||
GELOGI("stream index:%u, stream:%p.", i, stream); | |||||
} | } | ||||
return true; | return true; | ||||
@@ -446,8 +447,11 @@ bool RuntimeModel::InitConstantInfo(std::shared_ptr<DavinciModel> &davinci_model | |||||
/// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | /// The logic of GetShapeSize is wrong, the scaler tensor's GetShapeSize is zero | ||||
/// and that of unknown shape is zero too. | /// and that of unknown shape is zero too. | ||||
/// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | /// Unknown shape will not appear here, so we can use zero judge a tensor is scaler or not. | ||||
int64_t elem_num = | |||||
(constant->weight_tensors[0].GetShapeSize() == 0) ? 1 : constant->weight_tensors[0].GetShapeSize(); | |||||
int64_t elem_num = constant->weight_tensors[0].GetShapeSize(); | |||||
if (elem_num == 0 && constant->weight_tensors[0].size == 0) { | |||||
elem_num = 1; | |||||
} | |||||
if (constant->weight_data.size() < sizeof(uint64_t)) { | if (constant->weight_data.size() < sizeof(uint64_t)) { | ||||
GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | GELOGE(FAILED, "weight_data size is smaller than sizeof(uint64_t)"); | ||||
return false; | return false; | ||||
@@ -82,6 +82,7 @@ bool CceTask::Distribute() { | |||||
stub_func_ = nullptr; | stub_func_ = nullptr; | ||||
return false; | return false; | ||||
} | } | ||||
GELOGI("CCETask: stub_func = %s [%p].", task_info_->stub_func().c_str(), stub_func_); | |||||
// Flowtable | // Flowtable | ||||
if (is_flowtable_) { | if (is_flowtable_) { | ||||
@@ -43,6 +43,8 @@ EventRecordTask::EventRecordTask(const ModelContext &model_context, | |||||
EventRecordTask::~EventRecordTask() {} | EventRecordTask::~EventRecordTask() {} | ||||
bool EventRecordTask::Distribute() { | bool EventRecordTask::Distribute() { | ||||
GELOGI("EventRecordTask Distribute start, stream: %p, event: %p, stream_id: %u, event_id: %u.", stream_, event_, | |||||
task_info_->stream_id(), task_info_->event_id()); | |||||
rtError_t rt_ret = rtEventRecord(event_, stream_); | rtError_t rt_ret = rtEventRecord(event_, stream_); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
@@ -42,6 +42,9 @@ EventWaitTask::EventWaitTask(const ModelContext &model_context, const std::share | |||||
EventWaitTask::~EventWaitTask() {} | EventWaitTask::~EventWaitTask() {} | ||||
bool EventWaitTask::Distribute() { | bool EventWaitTask::Distribute() { | ||||
GELOGI("EventWaitTask Distribute start, stream: %p, event: %p, stream_id: %u, event_id: %u.", stream_, event_, | |||||
task_info_->stream_id(), task_info_->event_id()); | |||||
rtError_t rt_ret = rtStreamWaitEvent(stream_, event_); | rtError_t rt_ret = rtStreamWaitEvent(stream_, event_); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt api rtStreamWaitEvent failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api rtStreamWaitEvent failed, ret: 0x%X", rt_ret); | ||||
@@ -101,6 +101,7 @@ bool HcclTask::Distribute() { | |||||
char *private_def = reinterpret_cast<char *>(const_cast<char unsigned *>(task_info_->private_def().data())); | char *private_def = reinterpret_cast<char *>(const_cast<char unsigned *>(task_info_->private_def().data())); | ||||
auto private_def_len = static_cast<uint32_t>(task_info_->private_def().size()); | auto private_def_len = static_cast<uint32_t>(task_info_->private_def().size()); | ||||
GELOGI("the first address of the custom info, privateDef=%p", private_def); | |||||
GELOGI("hcclStreamNum =%ld", task_info_->hccl_stream_num()); | GELOGI("hcclStreamNum =%ld", task_info_->hccl_stream_num()); | ||||
for (int64_t i = 0; i < task_info_->hccl_stream_num(); ++i) { | for (int64_t i = 0; i < task_info_->hccl_stream_num(); ++i) { | ||||
@@ -117,6 +118,7 @@ bool HcclTask::Distribute() { | |||||
return false; | return false; | ||||
} | } | ||||
GELOGI("hccl_stream addr is=%p", stream); | |||||
slave_stream_list_.push_back(stream); | slave_stream_list_.push_back(stream); | ||||
} | } | ||||
@@ -62,6 +62,9 @@ bool StreamSwitchTask::Distribute() { | |||||
rtStream_t true_stream = stream_list_[task_info_->true_stream_id()]; | rtStream_t true_stream = stream_list_[task_info_->true_stream_id()]; | ||||
rtSwitchDataType_t data_type = static_cast<rtSwitchDataType_t>(task_info_->data_type()); | rtSwitchDataType_t data_type = static_cast<rtSwitchDataType_t>(task_info_->data_type()); | ||||
GELOGI("InitStreamSwitchTask, cond:%d, trueStream:%p, trueStreamID:%ld, datatype:%ld.", cond, true_stream, | |||||
task_info_->true_stream_id(), task_info_->data_type()); | |||||
GELOGI("StreamSwitchTask Distribute Start."); | GELOGI("StreamSwitchTask Distribute Start."); | ||||
rtError_t rt_ret = rtStreamSwitchEx(input, cond, value, true_stream, stream_, data_type); | rtError_t rt_ret = rtStreamSwitchEx(input, cond, value, true_stream, stream_, data_type); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
@@ -69,6 +72,7 @@ bool StreamSwitchTask::Distribute() { | |||||
return false; | return false; | ||||
} | } | ||||
GELOGI("Distribute StreamSwitch, cond:%d, trueStream:%p, datatype:%ld.", cond, true_stream, task_info_->data_type()); | |||||
return true; | return true; | ||||
} | } | ||||
@@ -69,6 +69,7 @@ bool TbeTask::Distribute() { | |||||
stub_func_ = nullptr; | stub_func_ = nullptr; | ||||
return false; | return false; | ||||
} | } | ||||
GELOGI("TbeTask: stub_func = %s [%p].", task_info_->stub_func().c_str(), stub_func_); | |||||
// Get args | // Get args | ||||
std::vector<void *> tensor_device_addrs; | std::vector<void *> tensor_device_addrs; | ||||
@@ -18,8 +18,8 @@ | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/helper/model_helper.h" | #include "common/helper/model_helper.h" | ||||
#include "common/opskernel/ops_kernel_info_types.h" | #include "common/opskernel/ops_kernel_info_types.h" | ||||
#include "graph/build/stream_graph_optimizer.h" | |||||
#include "graph/build/run_context.h" | #include "graph/build/run_context.h" | ||||
#include "graph/build/stream_graph_optimizer.h" | |||||
#include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
#include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
@@ -98,8 +98,10 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfo | |||||
Status ret = SecondPartition(comp_graph, subgraph_ptr_list); | Status ret = SecondPartition(comp_graph, subgraph_ptr_list); | ||||
GE_CHK_STATUS_RET(ret, "Graph second partition Failed."); | GE_CHK_STATUS_RET(ret, "Graph second partition Failed."); | ||||
auto subgraph_map = graph_partitioner_.GetSubGraphMap(); | |||||
GE_TIMESTAMP_START(BuildSubgraph); | GE_TIMESTAMP_START(BuildSubgraph); | ||||
ge::ModelBuilder builder(comp_graph, subgraph_ptr_list, stream_max_parallel_num_, hcom_parallel_, build_mode_); | |||||
ge::ModelBuilder builder(comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); | |||||
GELOGI("[Build] invoke the other opskernel to generate task."); | GELOGI("[Build] invoke the other opskernel to generate task."); | ||||
@@ -135,7 +137,7 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfo | |||||
} | } | ||||
GE_TIMESTAMP_START(GetTaskInfo); | GE_TIMESTAMP_START(GetTaskInfo); | ||||
ret = GetTaskInfo(builder, model_ptr, comp_graph, subgraph_ptr_list, session_id); | |||||
ret = GetTaskInfo(builder, model_ptr, comp_graph, subgraph_map, session_id); | |||||
GE_TIMESTAMP_END(GetTaskInfo, "GraphBuilder::GetTaskInfo"); | GE_TIMESTAMP_END(GetTaskInfo, "GraphBuilder::GetTaskInfo"); | ||||
GraphUtils::DumpGEGraph(comp_graph, "AfterGetTask"); | GraphUtils::DumpGEGraph(comp_graph, "AfterGetTask"); | ||||
@@ -155,7 +157,7 @@ Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfo | |||||
} | } | ||||
Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr &model_ptr, | Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr &model_ptr, | ||||
ComputeGraphPtr &comp_graph, std::vector<SubGraphInfoPtr> &subgraph_ptr_list, | |||||
ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map, | |||||
uint64_t session_id) { | uint64_t session_id) { | ||||
GE_CHECK_NOTNULL(model_ptr); | GE_CHECK_NOTNULL(model_ptr); | ||||
GE_CHECK_NOTNULL(comp_graph); | GE_CHECK_NOTNULL(comp_graph); | ||||
@@ -190,7 +192,7 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr | |||||
} | } | ||||
StreamGraphOptimizer stream_optimizer; | StreamGraphOptimizer stream_optimizer; | ||||
ret = stream_optimizer.OptimizeStreamedSubGraph(comp_graph, subgraph_ptr_list, run_context.GetRunContext()); | |||||
ret = stream_optimizer.OptimizeStreamedSubGraph(comp_graph, subgraph_map, run_context.GetRunContext()); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Optimize streamed subGraph fail."); | GELOGE(ret, "Optimize streamed subGraph fail."); | ||||
return ret; | return ret; | ||||
@@ -53,7 +53,7 @@ class GraphBuilder { | |||||
private: | private: | ||||
Status CalcOpParam(const ge::ComputeGraphPtr &graph); | Status CalcOpParam(const ge::ComputeGraphPtr &graph); | ||||
Status GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr &model_ptr, ComputeGraphPtr &comp_graph, | Status GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr &model_ptr, ComputeGraphPtr &comp_graph, | ||||
std::vector<SubGraphInfoPtr> &subgraph_ptr_list, uint64_t session_id = INVALID_SESSION_ID); | |||||
Graph2SubGraphInfoList &subgraph_map, uint64_t session_id = INVALID_SESSION_ID); | |||||
Status SetInputSize(const ge::NodePtr &node_ptr); | Status SetInputSize(const ge::NodePtr &node_ptr); | ||||
Status UpdateDataInputSize(const ge::NodePtr &node_ptr); | Status UpdateDataInputSize(const ge::NodePtr &node_ptr); | ||||
Status SecondPartition(ge::ComputeGraphPtr &comp_graph, vector<ge::SubGraphInfoPtr> &subgraph_ptr_list); | Status SecondPartition(ge::ComputeGraphPtr &comp_graph, vector<ge::SubGraphInfoPtr> &subgraph_ptr_list); | ||||
@@ -70,7 +70,7 @@ bool LogicalStreamPass::HasNonConstInputNode(const Subgraph &subgraph) const { | |||||
return false; | return false; | ||||
} | } | ||||
Status AssignByLabelPass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
Status AssignByLabelPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
bool changed = false; | bool changed = false; | ||||
int64_t &next_stream = context.next_stream; | int64_t &next_stream = context.next_stream; | ||||
map<string, int64_t> label_streams; | map<string, int64_t> label_streams; | ||||
@@ -97,7 +97,7 @@ Status AssignByLabelPass::Run(ComputeGraphPtr whole_graph, const vector<Subgraph | |||||
return changed ? SUCCESS : NOT_CHANGED; | return changed ? SUCCESS : NOT_CHANGED; | ||||
} | } | ||||
Status IndependentStreamPass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
Status IndependentStreamPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
bool changed = false; | bool changed = false; | ||||
int64_t &next_stream = context.next_stream; | int64_t &next_stream = context.next_stream; | ||||
@@ -129,8 +129,7 @@ Status IndependentStreamPass::Run(ComputeGraphPtr whole_graph, const vector<Subg | |||||
return changed ? SUCCESS : NOT_CHANGED; | return changed ? SUCCESS : NOT_CHANGED; | ||||
} | } | ||||
Status AssignByDependencyPass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, | |||||
Context &context) { | |||||
Status AssignByDependencyPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
bool changed = false; | bool changed = false; | ||||
if (IsHeadNodeExceeded(subgraphs)) { | if (IsHeadNodeExceeded(subgraphs)) { | ||||
int64_t &next_stream = context.next_stream; | int64_t &next_stream = context.next_stream; | ||||
@@ -298,7 +297,7 @@ int64_t AssignByDependencyPass::AssignNewStream(SubgraphPtr subgraph) { | |||||
subgraph->stream_id = stream_id; | subgraph->stream_id = stream_id; | ||||
engine_next_streams_[engine_name] = stream_id + 1; | engine_next_streams_[engine_name] = stream_id + 1; | ||||
assigned_subgraphs_.emplace(subgraph); | |||||
assigned_subgraphs_.emplace_back(subgraph); | |||||
if ((stream_id + 1) > engine_stream_num_[engine_name]) { | if ((stream_id + 1) > engine_stream_num_[engine_name]) { | ||||
engine_stream_num_[engine_name] = stream_id + 1; | engine_stream_num_[engine_name] = stream_id + 1; | ||||
@@ -311,6 +310,15 @@ int64_t AssignByDependencyPass::AssignNewStream(SubgraphPtr subgraph) { | |||||
} | } | ||||
void AssignByDependencyPass::UpdateAssignedSubgraphs(Context &context) { | void AssignByDependencyPass::UpdateAssignedSubgraphs(Context &context) { | ||||
// If the parent stream is valid, the first assigned stream will reuse the parent stream id | |||||
// and other streams use new id. To ensure that the id of the new stream is continuous, | |||||
// we first subtract one from next_stream. | |||||
int64_t to_be_updated_stream = kInvalidStream; | |||||
if (context.parent_stream != kInvalidStream) { | |||||
context.next_stream--; | |||||
to_be_updated_stream = context.next_stream; | |||||
} | |||||
// Update the starting stream id for each engine. | // Update the starting stream id for each engine. | ||||
int64_t &next_stream = context.next_stream; | int64_t &next_stream = context.next_stream; | ||||
map<string, int64_t> engine_start_streams; | map<string, int64_t> engine_start_streams; | ||||
@@ -320,10 +328,16 @@ void AssignByDependencyPass::UpdateAssignedSubgraphs(Context &context) { | |||||
next_stream += stream_count; | next_stream += stream_count; | ||||
} | } | ||||
// Update the subgraphs assigned by the engine. | |||||
// Update the subgraph streams assigned by engine. | |||||
for (auto &subgraph : assigned_subgraphs_) { | for (auto &subgraph : assigned_subgraphs_) { | ||||
subgraph->stream_id += engine_start_streams[subgraph->engine_conf.id]; | subgraph->stream_id += engine_start_streams[subgraph->engine_conf.id]; | ||||
GELOGI("Stream of subgraph %s has been updated to %ld.", subgraph->name.c_str(), subgraph->stream_id); | |||||
if (subgraph->stream_id == to_be_updated_stream) { | |||||
subgraph->stream_id = context.parent_stream; | |||||
GELOGI("Subgraph %s of engine %s reuses parent stream %ld.", subgraph->name.c_str(), | |||||
subgraph->engine_conf.id.c_str(), context.parent_stream); | |||||
} else { | |||||
GELOGI("Stream of subgraph %s has been updated to %ld.", subgraph->name.c_str(), subgraph->stream_id); | |||||
} | |||||
} | } | ||||
} | } | ||||
@@ -337,7 +351,7 @@ void AssignByDependencyPass::UpdateReusedSubgraphs() { | |||||
} | } | ||||
} | } | ||||
Status NodeStreamUpdatePass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
Status NodeStreamUpdatePass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
// Check if all subgraphs have been assigned a stream. | // Check if all subgraphs have been assigned a stream. | ||||
for (const SubgraphPtr &subgraph : subgraphs) { | for (const SubgraphPtr &subgraph : subgraphs) { | ||||
const string &engine_name = subgraph->engine_conf.id; | const string &engine_name = subgraph->engine_conf.id; | ||||
@@ -353,7 +367,7 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr whole_graph, const vector<Subgr | |||||
} | } | ||||
// Init the stream id of node. | // Init the stream id of node. | ||||
for (NodePtr &node : whole_graph->GetDirectNode()) { | |||||
for (NodePtr &node : graph->GetDirectNode()) { | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
node->GetOpDesc()->SetStreamId(kInvalidStream); | node->GetOpDesc()->SetStreamId(kInvalidStream); | ||||
} | } | ||||
@@ -375,76 +389,11 @@ Status NodeStreamUpdatePass::Run(ComputeGraphPtr whole_graph, const vector<Subgr | |||||
} | } | ||||
// Update stream id for nodes belong to skipped engine subgraph | // Update stream id for nodes belong to skipped engine subgraph | ||||
GE_CHK_STATUS_RET(UpdateForSkippedEngine(whole_graph, subgraphs)); | |||||
RefreshContinuousStreams(whole_graph, context); | |||||
GE_CHK_STATUS_RET(UpdateForSkippedEngine(graph, subgraphs)); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status AllReduceParallelPass::Run(ComputeGraphPtr whole_graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
if (!context.hcom_parallel) { | |||||
return NOT_CHANGED; | |||||
} | |||||
GELOGI("AllReduceParallelPass is enabled."); | |||||
GraphUtils::DumpGEGraph(whole_graph, "BeforeAllReduceParallel"); | |||||
// All successors of HcomAllReduce. | |||||
set<NodePtr> all_reduce_succs; | |||||
for (const NodePtr &node : whole_graph->GetDirectNode()) { | |||||
if (node->GetType() != HCOMALLREDUCE || node->GetInDataNodes().size() <= 1) { | |||||
continue; | |||||
} | |||||
string reduce_stream_label; | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
// ATTR_NAME_STREAM_LABEL is optional. | |||||
(void)AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, reduce_stream_label); | |||||
set<NodePtr> cur_nodes = {node}; | |||||
while (!cur_nodes.empty()) { | |||||
set<NodePtr> all_out_data_nodes; | |||||
for (auto &curr_node : cur_nodes) { | |||||
for (const NodePtr &out_node : curr_node->GetOutDataNodes()) { | |||||
string out_stream_label; | |||||
GE_CHECK_NOTNULL(out_node->GetOpDesc()); | |||||
// ATTR_NAME_STREAM_LABEL is optional. | |||||
(void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, out_stream_label); | |||||
if (out_stream_label == reduce_stream_label) { | |||||
all_reduce_succs.emplace(out_node); | |||||
all_out_data_nodes.emplace(out_node); | |||||
} | |||||
} | |||||
} | |||||
cur_nodes = all_out_data_nodes; | |||||
} | |||||
} | |||||
map<int64_t, int64_t> old_stream_to_new; | |||||
for (const NodePtr &node : all_reduce_succs) { | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
auto old_stream = node->GetOpDesc()->GetStreamId(); | |||||
if (old_stream != kInvalidStream) { | |||||
int64_t new_stream = kInvalidStream; | |||||
auto iter = old_stream_to_new.find(old_stream); | |||||
if (iter != old_stream_to_new.end()) { | |||||
new_stream = iter->second; | |||||
} else { | |||||
new_stream = context.next_stream; | |||||
context.next_stream++; | |||||
old_stream_to_new.emplace(old_stream, new_stream); | |||||
} | |||||
GELOGI("Stream of node %s has been updated from %ld to %ld.", node->GetName().c_str(), old_stream, new_stream); | |||||
node->GetOpDesc()->SetStreamId(new_stream); | |||||
} | |||||
} | |||||
return !all_reduce_succs.empty() ? SUCCESS : NOT_CHANGED; | |||||
} | |||||
int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { | int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { | ||||
set<int64_t> stream_ids; | set<int64_t> stream_ids; | ||||
@@ -472,11 +421,11 @@ int64_t NodeStreamUpdatePass::GetSingleInoutStream(const NodePtr &node) const { | |||||
return kInvalidStream; | return kInvalidStream; | ||||
} | } | ||||
Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &whole_graph, | |||||
Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &graph, | |||||
const vector<SubgraphPtr> &subgraphs) { | const vector<SubgraphPtr> &subgraphs) { | ||||
set<OpDescPtr> nodes_to_be_updated; | set<OpDescPtr> nodes_to_be_updated; | ||||
// Check if sub graph is engine skipped and without stream label or not | |||||
// Check if subgraph is engine skipped and without stream label or not | |||||
for (const SubgraphPtr &subgraph : subgraphs) { | for (const SubgraphPtr &subgraph : subgraphs) { | ||||
if (IsEngineSkip(*subgraph) && !HasStreamLabel(*subgraph)) { | if (IsEngineSkip(*subgraph) && !HasStreamLabel(*subgraph)) { | ||||
auto graph = subgraph->subgraph_info.GetSubGraph(); | auto graph = subgraph->subgraph_info.GetSubGraph(); | ||||
@@ -492,7 +441,7 @@ Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &whole | |||||
} | } | ||||
// Try reassign the stream id | // Try reassign the stream id | ||||
for (ge::NodePtr &node : whole_graph->GetDirectNode()) { | |||||
for (ge::NodePtr &node : graph->GetDirectNode()) { | |||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
int64_t stream_id = op_desc->GetStreamId(); | int64_t stream_id = op_desc->GetStreamId(); | ||||
@@ -509,6 +458,7 @@ Status NodeStreamUpdatePass::UpdateForSkippedEngine(const ComputeGraphPtr &whole | |||||
} | } | ||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -525,40 +475,65 @@ bool NodeStreamUpdatePass::AreAllPredStreamsInvalid(const NodePtr &node) const { | |||||
return true; | return true; | ||||
} | } | ||||
void NodeStreamUpdatePass::RefreshContinuousStreams(ComputeGraphPtr whole_graph, Context &context) const { | |||||
int64_t stream_num = context.next_stream; | |||||
vector<bool> stream_has_node(stream_num); | |||||
Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector<SubgraphPtr> &subgraphs, Context &context) { | |||||
if (!context.hcom_parallel) { | |||||
return NOT_CHANGED; | |||||
} | |||||
for (const NodePtr &node : whole_graph->GetDirectNode()) { | |||||
if (node != nullptr) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
if (op_desc != nullptr) { | |||||
int64_t stream_id = op_desc->GetStreamId(); | |||||
if (stream_id != kInvalidStream && stream_id < stream_num) { | |||||
stream_has_node[stream_id] = true; | |||||
} | |||||
} | |||||
GELOGI("AllReduceParallelPass is enabled."); | |||||
GraphUtils::DumpGEGraph(graph, "BeforeAllReduceParallel"); | |||||
// All successors of HcomAllReduce. | |||||
set<NodePtr> all_reduce_succs; | |||||
for (const NodePtr &node : graph->GetDirectNode()) { | |||||
if (node->GetType() != HCOMALLREDUCE || node->GetInDataNodes().size() <= 1) { | |||||
continue; | |||||
} | } | ||||
} | |||||
context.next_stream = 0; | |||||
vector<int64_t> old_to_new_streams(stream_num, kInvalidStream); | |||||
for (size_t old_stream = 0; old_stream < stream_has_node.size(); ++old_stream) { | |||||
if (stream_has_node[old_stream]) { | |||||
old_to_new_streams[old_stream] = context.next_stream; | |||||
++context.next_stream; | |||||
string reduce_stream_label; | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
(void)AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, reduce_stream_label); | |||||
set<NodePtr> cur_nodes = {node}; | |||||
while (!cur_nodes.empty()) { | |||||
set<NodePtr> all_out_data_nodes; | |||||
for (auto &curr_node : cur_nodes) { | |||||
for (const NodePtr &out_node : curr_node->GetOutDataNodes()) { | |||||
string out_stream_label; | |||||
GE_CHECK_NOTNULL(out_node->GetOpDesc()); | |||||
(void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, out_stream_label); | |||||
if (out_stream_label == reduce_stream_label) { | |||||
all_reduce_succs.emplace(out_node); | |||||
all_out_data_nodes.emplace(out_node); | |||||
} | |||||
} | |||||
} | |||||
cur_nodes = all_out_data_nodes; | |||||
} | } | ||||
} | } | ||||
for (const NodePtr &node : whole_graph->GetDirectNode()) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
if (op_desc != nullptr) { | |||||
int64_t stream_id = op_desc->GetStreamId(); | |||||
if (stream_id != kInvalidStream && stream_id < stream_num) { | |||||
op_desc->SetStreamId(old_to_new_streams[stream_id]); | |||||
map<int64_t, int64_t> old_stream_to_new; | |||||
for (const NodePtr &node : all_reduce_succs) { | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
auto old_stream = node->GetOpDesc()->GetStreamId(); | |||||
if (old_stream != kInvalidStream) { | |||||
int64_t new_stream = kInvalidStream; | |||||
auto iter = old_stream_to_new.find(old_stream); | |||||
if (iter != old_stream_to_new.end()) { | |||||
new_stream = iter->second; | |||||
} else { | |||||
new_stream = context.next_stream; | |||||
context.next_stream++; | |||||
old_stream_to_new.emplace(old_stream, new_stream); | |||||
} | } | ||||
GELOGI("Stream of node %s has been updated from %ld to %ld.", node->GetName().c_str(), old_stream, new_stream); | |||||
node->GetOpDesc()->SetStreamId(new_stream); | |||||
} | } | ||||
} | } | ||||
return !all_reduce_succs.empty() ? SUCCESS : NOT_CHANGED; | |||||
} | } | ||||
LogicalStreamAllocator::LogicalStreamAllocator(const map<string, SchedulerConf> &scheduler_confs, | LogicalStreamAllocator::LogicalStreamAllocator(const map<string, SchedulerConf> &scheduler_confs, | ||||
@@ -567,9 +542,10 @@ LogicalStreamAllocator::LogicalStreamAllocator(const map<string, SchedulerConf> | |||||
context_.hcom_parallel = hcom_parallel; | context_.hcom_parallel = hcom_parallel; | ||||
} | } | ||||
Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &whole_graph, const vector<SubGraphInfoPtr> &subgraph_infos, | |||||
Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &whole_graph, const Graph2SubGraphInfoList &subgraph_map, | |||||
int64_t &stream_num) { | int64_t &stream_num) { | ||||
GE_CHECK_NOTNULL(whole_graph); | GE_CHECK_NOTNULL(whole_graph); | ||||
map<string, EngineConfPtr> engine_confs; | map<string, EngineConfPtr> engine_confs; | ||||
GE_TIMESTAMP_START(InitEngineConfs); | GE_TIMESTAMP_START(InitEngineConfs); | ||||
for (const auto &item : scheduler_confs_) { | for (const auto &item : scheduler_confs_) { | ||||
@@ -583,16 +559,64 @@ Status LogicalStreamAllocator::Assign(const ComputeGraphPtr &whole_graph, const | |||||
} | } | ||||
GE_TIMESTAMP_END(InitEngineConfs, "GraphBuilder::AssignStreamInitEngineConfs"); | GE_TIMESTAMP_END(InitEngineConfs, "GraphBuilder::AssignStreamInitEngineConfs"); | ||||
Status status = DoAssign(whole_graph, subgraph_map, engine_confs); | |||||
if (status != SUCCESS) { | |||||
GELOGE(status, "Assign streams failed."); | |||||
return status; | |||||
} | |||||
vector<ComputeGraphPtr> subgraphs = whole_graph->GetAllSubgraphs(); | |||||
for (const ComputeGraphPtr &subgraph : subgraphs) { | |||||
Status status = DoAssign(subgraph, subgraph_map, engine_confs); | |||||
if (status != SUCCESS) { | |||||
GELOGE(status, "Assign streams failed."); | |||||
return status; | |||||
} | |||||
} | |||||
RefreshContinuousStreams(whole_graph); | |||||
stream_num = context_.next_stream; | |||||
GELOGI("Assigned logical stream num: %ld.", stream_num); | |||||
return SUCCESS; | |||||
} | |||||
Status LogicalStreamAllocator::DoAssign(const ComputeGraphPtr &graph, const Graph2SubGraphInfoList &subgraph_map, | |||||
const map<string, EngineConfPtr> &engine_confs) { | |||||
GE_CHECK_NOTNULL(graph); | |||||
NodePtr parent_node = graph->GetParentNode(); | |||||
if (parent_node == nullptr || parent_node->GetOpDesc() == nullptr) { | |||||
context_.parent_stream = kInvalidStream; | |||||
} else { | |||||
context_.parent_stream = parent_node->GetOpDesc()->GetStreamId(); | |||||
} | |||||
auto iter = subgraph_map.find(graph); | |||||
if (iter == subgraph_map.end()) { | |||||
GELOGE(FAILED, "Graph %s not found.", graph->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
const vector<SubGraphInfoPtr> &subgraph_info_list = iter->second; | |||||
vector<SubgraphPtr> subgraphs; | vector<SubgraphPtr> subgraphs; | ||||
GE_TIMESTAMP_START(ConvertSubgraphs); | GE_TIMESTAMP_START(ConvertSubgraphs); | ||||
Status status = ConvertSubgraphs(subgraph_infos, engine_confs, subgraphs); | |||||
Status status = ConvertSubgraphs(subgraph_info_list, engine_confs, subgraphs); | |||||
GE_TIMESTAMP_END(ConvertSubgraphs, "GraphBuilder::AssignStreamConvertSubgraphs"); | GE_TIMESTAMP_END(ConvertSubgraphs, "GraphBuilder::AssignStreamConvertSubgraphs"); | ||||
if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
GELOGE(status, "Create subgraphs failed."); | GELOGE(status, "Create subgraphs failed."); | ||||
return status; | return status; | ||||
} | } | ||||
return RunPasses(whole_graph, subgraphs, stream_num); | |||||
GELOGI("Subgraphs of graph %s:", graph->GetName().c_str()); | |||||
for (const auto &subgraph : subgraphs) { | |||||
if (subgraph != nullptr) { | |||||
GELOGI("subgraph: %s", subgraph->name.c_str()); | |||||
} | |||||
} | |||||
return RunPasses(graph, subgraphs); | |||||
} | } | ||||
Status LogicalStreamAllocator::ConvertSubgraphs(const vector<SubGraphInfoPtr> &subgraph_infos, | Status LogicalStreamAllocator::ConvertSubgraphs(const vector<SubGraphInfoPtr> &subgraph_infos, | ||||
@@ -631,8 +655,7 @@ Status LogicalStreamAllocator::ConvertSubgraphs(const vector<SubGraphInfoPtr> &s | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &whole_graph, const vector<SubgraphPtr> &subgraphs, | |||||
int64_t &stream_num) { | |||||
Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &graph, const vector<SubgraphPtr> &subgraphs) { | |||||
vector<LogicalStreamPassPtr> passes; | vector<LogicalStreamPassPtr> passes; | ||||
passes.emplace_back(MakeShared<AssignByLabelPass>()); | passes.emplace_back(MakeShared<AssignByLabelPass>()); | ||||
passes.emplace_back(MakeShared<IndependentStreamPass>()); | passes.emplace_back(MakeShared<IndependentStreamPass>()); | ||||
@@ -643,7 +666,7 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &whole_graph, con | |||||
for (auto &pass : passes) { | for (auto &pass : passes) { | ||||
GE_CHECK_NOTNULL(pass); | GE_CHECK_NOTNULL(pass); | ||||
Status status = pass->Run(whole_graph, subgraphs, context_); | |||||
Status status = pass->Run(graph, subgraphs, context_); | |||||
if (status == SUCCESS) { | if (status == SUCCESS) { | ||||
GELOGI("Stream pass %s return SUCCESS.", pass->GetName().c_str()); | GELOGI("Stream pass %s return SUCCESS.", pass->GetName().c_str()); | ||||
} else if (status == NOT_CHANGED) { | } else if (status == NOT_CHANGED) { | ||||
@@ -654,9 +677,42 @@ Status LogicalStreamAllocator::RunPasses(const ComputeGraphPtr &whole_graph, con | |||||
} | } | ||||
} | } | ||||
stream_num = context_.next_stream; | |||||
GELOGI("Assigned logical stream num: %ld.", stream_num); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void LogicalStreamAllocator::RefreshContinuousStreams(const ComputeGraphPtr &graph) { | |||||
int64_t stream_num = context_.next_stream; | |||||
vector<bool> stream_has_node(stream_num); | |||||
for (const NodePtr &node : graph->GetAllNodes()) { | |||||
if (node != nullptr) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
if (op_desc != nullptr) { | |||||
int64_t stream_id = op_desc->GetStreamId(); | |||||
if (stream_id != kInvalidStream && stream_id < stream_num) { | |||||
stream_has_node[stream_id] = true; | |||||
} | |||||
} | |||||
} | |||||
} | |||||
context_.next_stream = 0; | |||||
vector<int64_t> old_to_new_streams(stream_num, kInvalidStream); | |||||
for (size_t old_stream = 0; old_stream < stream_has_node.size(); ++old_stream) { | |||||
if (stream_has_node[old_stream]) { | |||||
old_to_new_streams[old_stream] = context_.next_stream; | |||||
++context_.next_stream; | |||||
} | |||||
} | |||||
for (const NodePtr &node : graph->GetAllNodes()) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
if (op_desc != nullptr) { | |||||
int64_t stream_id = op_desc->GetStreamId(); | |||||
if (stream_id != kInvalidStream && stream_id < stream_num) { | |||||
op_desc->SetStreamId(old_to_new_streams[stream_id]); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -60,7 +60,7 @@ class LogicalStreamPass { | |||||
}; | }; | ||||
struct Context { | struct Context { | ||||
// Next stream id. | |||||
int64_t parent_stream = kInvalidStream; | |||||
int64_t next_stream = 0; | int64_t next_stream = 0; | ||||
bool hcom_parallel = false; | bool hcom_parallel = false; | ||||
}; | }; | ||||
@@ -71,7 +71,7 @@ class LogicalStreamPass { | |||||
virtual ~LogicalStreamPass() = default; | virtual ~LogicalStreamPass() = default; | ||||
const std::string &GetName() const; | const std::string &GetName() const; | ||||
virtual Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) = 0; | |||||
virtual Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) = 0; | |||||
protected: | protected: | ||||
bool IsEngineSkip(const Subgraph &subgraph) const; | bool IsEngineSkip(const Subgraph &subgraph) const; | ||||
@@ -93,21 +93,21 @@ using LogicalStreamPassPtr = std::shared_ptr<LogicalStreamPass>; | |||||
class AssignByLabelPass : public LogicalStreamPass { | class AssignByLabelPass : public LogicalStreamPass { | ||||
public: | public: | ||||
STREAM_PASS_DEFAULT_FUNC(AssignByLabelPass); | STREAM_PASS_DEFAULT_FUNC(AssignByLabelPass); | ||||
Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
}; | }; | ||||
// Engines such as hccl require independent Stream. | // Engines such as hccl require independent Stream. | ||||
class IndependentStreamPass : public LogicalStreamPass { | class IndependentStreamPass : public LogicalStreamPass { | ||||
public: | public: | ||||
STREAM_PASS_DEFAULT_FUNC(IndependentStreamPass); | STREAM_PASS_DEFAULT_FUNC(IndependentStreamPass); | ||||
Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
}; | }; | ||||
// Reuse streams or assign new streams based on dependencies. | // Reuse streams or assign new streams based on dependencies. | ||||
class AssignByDependencyPass : public LogicalStreamPass { | class AssignByDependencyPass : public LogicalStreamPass { | ||||
public: | public: | ||||
STREAM_PASS_DEFAULT_FUNC(AssignByDependencyPass); | STREAM_PASS_DEFAULT_FUNC(AssignByDependencyPass); | ||||
Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
private: | private: | ||||
void InitEndSubgraphMap(const std::vector<SubgraphPtr> &subgraphs, std::map<NodePtr, SubgraphPtr> &end_subgraph_map); | void InitEndSubgraphMap(const std::vector<SubgraphPtr> &subgraphs, std::map<NodePtr, SubgraphPtr> &end_subgraph_map); | ||||
@@ -132,7 +132,7 @@ class AssignByDependencyPass : public LogicalStreamPass { | |||||
std::map<std::string, int64_t> engine_stream_num_; | std::map<std::string, int64_t> engine_stream_num_; | ||||
// Subgraphs of assign stream by engine | // Subgraphs of assign stream by engine | ||||
std::set<SubgraphPtr> assigned_subgraphs_; | |||||
std::vector<SubgraphPtr> assigned_subgraphs_; | |||||
// <current subgraph, reused subgraph> | // <current subgraph, reused subgraph> | ||||
std::vector<std::pair<SubgraphPtr, SubgraphPtr>> reused_subgraphs_; | std::vector<std::pair<SubgraphPtr, SubgraphPtr>> reused_subgraphs_; | ||||
@@ -142,7 +142,7 @@ class AssignByDependencyPass : public LogicalStreamPass { | |||||
class NodeStreamUpdatePass : public LogicalStreamPass { | class NodeStreamUpdatePass : public LogicalStreamPass { | ||||
public: | public: | ||||
STREAM_PASS_DEFAULT_FUNC(NodeStreamUpdatePass); | STREAM_PASS_DEFAULT_FUNC(NodeStreamUpdatePass); | ||||
Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
private: | private: | ||||
/// Optimize for case like: | /// Optimize for case like: | ||||
@@ -150,19 +150,18 @@ class NodeStreamUpdatePass : public LogicalStreamPass { | |||||
/// To case: | /// To case: | ||||
/// NodeA(stream1) -> Const(stream1) -> NodeB(stream1) | /// NodeA(stream1) -> Const(stream1) -> NodeB(stream1) | ||||
/// Which could reduce event number (Const could be other type which belong to skipped engine subgraph) | /// Which could reduce event number (Const could be other type which belong to skipped engine subgraph) | ||||
Status UpdateForSkippedEngine(const ComputeGraphPtr &whole_graph, const std::vector<SubgraphPtr> &subgraphs); | |||||
Status UpdateForSkippedEngine(const ComputeGraphPtr &graph, const std::vector<SubgraphPtr> &subgraphs); | |||||
int64_t GetSingleInoutStream(const NodePtr &node) const; | int64_t GetSingleInoutStream(const NodePtr &node) const; | ||||
// Judge if all predecessors' streams of node are INVALID_STREAM | // Judge if all predecessors' streams of node are INVALID_STREAM | ||||
bool AreAllPredStreamsInvalid(const NodePtr &node) const; | bool AreAllPredStreamsInvalid(const NodePtr &node) const; | ||||
void RefreshContinuousStreams(ComputeGraphPtr whole_graph, Context &context) const; | |||||
}; | }; | ||||
// AllReduce and backward operators execute in parallel. | // AllReduce and backward operators execute in parallel. | ||||
class AllReduceParallelPass : public LogicalStreamPass { | class AllReduceParallelPass : public LogicalStreamPass { | ||||
public: | public: | ||||
STREAM_PASS_DEFAULT_FUNC(AllReduceParallelPass); | STREAM_PASS_DEFAULT_FUNC(AllReduceParallelPass); | ||||
Status Run(ComputeGraphPtr whole_graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
Status Run(ComputeGraphPtr graph, const std::vector<SubgraphPtr> &subgraphs, Context &context) override; | |||||
}; | }; | ||||
// Assign logical streams which is not limited by the number of tasks. | // Assign logical streams which is not limited by the number of tasks. | ||||
@@ -178,13 +177,16 @@ class LogicalStreamAllocator { | |||||
LogicalStreamAllocator &operator=(const LogicalStreamAllocator &) = delete; | LogicalStreamAllocator &operator=(const LogicalStreamAllocator &) = delete; | ||||
~LogicalStreamAllocator() = default; | ~LogicalStreamAllocator() = default; | ||||
Status Assign(const ComputeGraphPtr &whole_graph, const std::vector<SubGraphInfoPtr> &subgraphs, int64_t &stream_num); | |||||
Status Assign(const ComputeGraphPtr &whole_graph, const Graph2SubGraphInfoList &subgraph_map, int64_t &stream_num); | |||||
private: | private: | ||||
Status DoAssign(const ComputeGraphPtr &graph, const Graph2SubGraphInfoList &subgraph_map, | |||||
const map<string, EngineConfPtr> &engine_confs); | |||||
Status ConvertSubgraphs(const std::vector<SubGraphInfoPtr> &subgraph_infos, | Status ConvertSubgraphs(const std::vector<SubGraphInfoPtr> &subgraph_infos, | ||||
const std::map<std::string, EngineConfPtr> &engine_confs, | const std::map<std::string, EngineConfPtr> &engine_confs, | ||||
std::vector<SubgraphPtr> &subgraphs); | std::vector<SubgraphPtr> &subgraphs); | ||||
Status RunPasses(const ComputeGraphPtr &whole_graph, const std::vector<SubgraphPtr> &subgraphs, int64_t &stream_num); | |||||
Status RunPasses(const ComputeGraphPtr &graph, const std::vector<SubgraphPtr> &subgraphs); | |||||
void RefreshContinuousStreams(const ComputeGraphPtr &graph); | |||||
const std::map<std::string, SchedulerConf> &scheduler_confs_; | const std::map<std::string, SchedulerConf> &scheduler_confs_; | ||||
const std::map<std::string, int> &max_parallel_num_; | const std::map<std::string, int> &max_parallel_num_; | ||||
@@ -805,6 +805,9 @@ void SetOffsetSize(const NodeTypeIndex &node_type_index, int64_t offset, size_t | |||||
} | } | ||||
} | } | ||||
op_desc->SetOutputOffset(output_list); | op_desc->SetOutputOffset(output_list); | ||||
GELOGI("[IMAS]Set %s name[%s] output[%d] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu].", | |||||
graph_name.c_str(), op_desc->GetName().c_str(), node_type_index.index, offset, op_desc->GetStreamId(), size, | |||||
real_size); | |||||
} else if (node_type_index.mem_type == kWorkspace) { | } else if (node_type_index.mem_type == kWorkspace) { | ||||
vector<int64_t> workspace_list; | vector<int64_t> workspace_list; | ||||
workspace_list = op_desc->GetWorkspace(); | workspace_list = op_desc->GetWorkspace(); | ||||
@@ -821,6 +824,9 @@ void SetOffsetSize(const NodeTypeIndex &node_type_index, int64_t offset, size_t | |||||
workspace_list.at(node_type_index.index) = offset; | workspace_list.at(node_type_index.index) = offset; | ||||
} | } | ||||
op_desc->SetWorkspace(workspace_list); | op_desc->SetWorkspace(workspace_list); | ||||
GELOGI("[IMAS]Set %s name[%s] workspace[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu].", | |||||
graph_name.c_str(), op_desc->GetName().c_str(), node_type_index.index, offset, op_desc->GetStreamId(), size, | |||||
real_size); | |||||
} | } | ||||
} | } | ||||
@@ -310,6 +310,11 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node) | |||||
if (is_tensor_actual_size == 0) { | if (is_tensor_actual_size == 0) { | ||||
AlignMemOffset(MEM_ALIGN_SIZE); | AlignMemOffset(MEM_ALIGN_SIZE); | ||||
} | } | ||||
GELOGI( | |||||
"[IMAS]Continuous input : Set %s name[%s] output[%d] offset to [%zu] stream_id[%ld] size[%zu] " | |||||
"real_size[%ld].", | |||||
node->GetOwnerComputeGraph()->GetName().c_str(), peer_op_desc->GetName().c_str(), peer_out_data_anchor->GetIdx(), | |||||
pre_mem_offset, peer_op_desc->GetStreamId(), (memory_offset_[0].mem_offset_ - pre_mem_offset), tensor_desc_size); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -340,6 +345,11 @@ Status GraphMemoryAssigner::AssignContinuousOutputMemory(const ge::NodePtr &node | |||||
memory_offset_[0].mem_offset_ += tensor_desc_size; | memory_offset_[0].mem_offset_ += tensor_desc_size; | ||||
AlignMemOffset(MEM_ALIGN_SIZE); | AlignMemOffset(MEM_ALIGN_SIZE); | ||||
GELOGI( | |||||
"[IMAS]Continuous output : Set %s name[%s] output[%d] offset to [%zu] stream_id[%ld] size[%zu] " | |||||
"real_size[%ld].", | |||||
node->GetOwnerComputeGraph()->GetName().c_str(), out_op_desc->GetName().c_str(), out_data_anchor->GetIdx(), | |||||
pre_mem_offset, out_op_desc->GetStreamId(), (memory_offset_[0].mem_offset_ - pre_mem_offset), tensor_desc_size); | |||||
} | } | ||||
out_op_desc->SetOutputOffset(output_list); | out_op_desc->SetOutputOffset(output_list); | ||||
@@ -413,8 +423,10 @@ Status GraphMemoryAssigner::ReAssignReuseAndNoPaddingContinuousInputMemory() { | |||||
pre_mem_offset, peer_op_desc->GetStreamId(), out_size, output_mem_size); | pre_mem_offset, peer_op_desc->GetStreamId(), out_size, output_mem_size); | ||||
} | } | ||||
memory_offset_[0].mem_offset_ += extra_memory_size; | memory_offset_[0].mem_offset_ += extra_memory_size; | ||||
GELOGI("After reassign virtual input node[name:%s, type:%s] memory, memory offset = %zu.", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), memory_offset_[0].mem_offset_); | |||||
size_t after_mem_offset = memory_offset_[0].mem_offset_; | |||||
AlignMemOffset(MEM_ALIGN_SIZE); | |||||
GELOGI("After reassign virtual input node[name:%s, type:%s] memory, memory offset = %zu, align memory = %zu.", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), after_mem_offset, memory_offset_[0].mem_offset_); | |||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -499,8 +511,10 @@ Status GraphMemoryAssigner::ReAssignReuseAndNoPaddingContinuousOutputMemory() { | |||||
} | } | ||||
op_desc->SetOutputOffset(output_list); | op_desc->SetOutputOffset(output_list); | ||||
memory_offset_[0].mem_offset_ += extra_memory_size; | memory_offset_[0].mem_offset_ += extra_memory_size; | ||||
GELOGI("After reassign virtual output node[name:%s, type:%s] memory, memory offset = %zu.", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), memory_offset_[0].mem_offset_); | |||||
size_t after_mem_offset = memory_offset_[0].mem_offset_; | |||||
AlignMemOffset(MEM_ALIGN_SIZE); | |||||
GELOGI("After reassign virtual output node[name:%s, type:%s] memory, memory offset = %zu, align memory = %zu.", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), after_mem_offset, memory_offset_[0].mem_offset_); | |||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -567,6 +581,11 @@ Status GraphMemoryAssigner::ReAssignMergeMemory() { | |||||
output_list[index] = data_output_offset; | output_list[index] = data_output_offset; | ||||
src_node->GetOpDesc()->SetOutputOffset(output_list); | src_node->GetOpDesc()->SetOutputOffset(output_list); | ||||
GELOGI( | |||||
"[IMAS]ReAssignMergeMemory : Set %s name[%s] output[%d] offset to [%ld] stream_id[%ld] size[%ld] " | |||||
"real_size[%ld].", | |||||
n->GetOwnerComputeGraph()->GetName().c_str(), src_node->GetOpDesc()->GetName().c_str(), index, | |||||
data_output_offset, src_node->GetOpDesc()->GetStreamId(), max_output_size, max_output_size); | |||||
input_list.emplace_back(data_output_offset); | input_list.emplace_back(data_output_offset); | ||||
} | } | ||||
@@ -897,6 +916,9 @@ Status GraphMemoryAssigner::AssignAtomicOutputMemory(const ge::NodePtr &node) { | |||||
} | } | ||||
output_list[output_index] = memory_offset_[0].mem_offset_; | output_list[output_index] = memory_offset_[0].mem_offset_; | ||||
GELOGI("[IMAS]Atomic output : Set %s name[%s] output[%ld] offset to [%zu] stream_id[%ld] size[%ld] real_size[%ld].", | |||||
compute_graph_->GetName().c_str(), op_desc->GetName().c_str(), output_index, memory_offset_[0].mem_offset_, | |||||
op_desc->GetStreamId(), size, size); | |||||
memory_offset_[0].mem_offset_ += size; | memory_offset_[0].mem_offset_ += size; | ||||
AlignMemOffset(MEM_ALIGN_SIZE); | AlignMemOffset(MEM_ALIGN_SIZE); | ||||
@@ -933,6 +955,11 @@ Status GraphMemoryAssigner::AssignOrdinaryAtomicWorkspaceMemory(const ge::OpDesc | |||||
} | } | ||||
workspace_vector[workspace_index] = memory_offset_[0].mem_offset_; | workspace_vector[workspace_index] = memory_offset_[0].mem_offset_; | ||||
GELOGI( | |||||
"[IMAS]Atomic ordinary workspace : Set %s name[%s] workspace[%lu] offset to [%zu] stream_id[%ld] " | |||||
"size[%ld] real_size[%ld].", | |||||
compute_graph_->GetName().c_str(), op_desc->GetName().c_str(), workspace_index, memory_offset_[0].mem_offset_, | |||||
op_desc->GetStreamId(), workspace_size, workspace_size); | |||||
memory_offset_[0].mem_offset_ += workspace_size; | memory_offset_[0].mem_offset_ += workspace_size; | ||||
} | } | ||||
@@ -958,6 +985,11 @@ Status GraphMemoryAssigner::AssignFusionAtomicWorkspaceMemory(const ge::OpDescPt | |||||
auto workspace_size = info_iter.second; | auto workspace_size = info_iter.second; | ||||
size_t workspace_offset = memory_offset_[0].mem_offset_; | size_t workspace_offset = memory_offset_[0].mem_offset_; | ||||
GELOGI( | |||||
"[IMAS]Atomic fusion workspace : Set %s name[%s] workspace[%lu] offset to [%zu] stream_id[%ld] size[%ld] " | |||||
"real_size[%ld].", | |||||
compute_graph_->GetName().c_str(), op_desc->GetName().c_str(), workspace_index, memory_offset_[0].mem_offset_, | |||||
op_desc->GetStreamId(), workspace_size, workspace_size); | |||||
memory_offset_[0].mem_offset_ += workspace_size; | memory_offset_[0].mem_offset_ += workspace_size; | ||||
index_offset.insert(std::make_pair(workspace_index, workspace_offset)); | index_offset.insert(std::make_pair(workspace_index, workspace_offset)); | ||||
@@ -1005,7 +1037,8 @@ ge::Status GraphMemoryAssigner::SetInputOffset() { | |||||
GELOGE(FAILED, "memory_offset_ is empty."); | GELOGE(FAILED, "memory_offset_ is empty."); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
GEEVENT("[IMAS]AfterAssignMemory : %s", compute_graph_->GetName().c_str()); | |||||
GEEVENT("[IMAS]AfterAssignMemory : %s memoffset[%zu]", compute_graph_->GetName().c_str(), | |||||
memory_offset_[0].mem_offset_); | |||||
for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { | for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { | ||||
if (UpdateOpInputOffset(node) != ge::SUCCESS) { | if (UpdateOpInputOffset(node) != ge::SUCCESS) { | ||||
GELOGE(ge::FAILED, "Update op input offset failed"); | GELOGE(ge::FAILED, "Update op input offset failed"); | ||||
@@ -1166,6 +1199,12 @@ ge::Status GraphMemoryAssigner::SetAtomicCleanAttr(const NodePtr &n, int64_t ato | |||||
GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(node_op_desc, ATTR_NAME_AUTOMIC_ADD_MEM_SIZE, mem_size_vector), | GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListInt(node_op_desc, ATTR_NAME_AUTOMIC_ADD_MEM_SIZE, mem_size_vector), | ||||
GELOGE(FAILED, "SetListInt failed."); | GELOGE(FAILED, "SetListInt failed."); | ||||
return FAILED); | return FAILED); | ||||
GELOGI( | |||||
"[IMAS]SetAtomicCleanAttr : Set %s name[%s] output[%d] offset to [%ld] streamid[%ld] size[%ld] " | |||||
"realsize[%ld].", | |||||
node->GetOwnerComputeGraph()->GetName().c_str(), node_op_desc->GetName().c_str(), 0, atomic_mem_start, | |||||
node->GetOpDesc()->GetStreamId(), atomic_mem_size, atomic_mem_size); | |||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -28,6 +28,7 @@ | |||||
#include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
#include "graph/ge_context.h" | |||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
#include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
#include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
@@ -39,7 +40,6 @@ | |||||
#include "graph/utils/op_desc_utils.h" | #include "graph/utils/op_desc_utils.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "graph/ge_context.h" | |||||
#include "init/gelib.h" | #include "init/gelib.h" | ||||
#include "memory/memory_assigner.h" | #include "memory/memory_assigner.h" | ||||
#include "omg/version.h" | #include "omg/version.h" | ||||
@@ -78,15 +78,16 @@ bool IsGeLocalOp(const ge::ConstOpDescPtr &op_desc) { | |||||
ge::GeTensorDesc output_desc = op_desc->GetOutputDesc(0); | ge::GeTensorDesc output_desc = op_desc->GetOutputDesc(0); | ||||
return !(output_desc.GetDataType() == ge::DT_STRING); | return !(output_desc.GetDataType() == ge::DT_STRING); | ||||
} | } | ||||
const set<string> ge_local_set = { | |||||
ge::STREAMMERGE, ge::MEMCPYASYNC, ge::STREAMACTIVE, ge::STREAMSWITCH, ge::VARIABLE, ge::NOOP, ge::CONSTANT, | |||||
ge::ENTER, ge::REFENTER, ge::LOOPCOND, ge::NEXTITERATION, ge::REFNEXTITERATION, ge::EXIT, ge::REFEXIT}; | |||||
const set<string> ge_local_set = {ge::STREAMMERGE, ge::MEMCPYASYNC, ge::STREAMACTIVE, ge::STREAMSWITCH, | |||||
ge::VARIABLE, ge::NOOP, ge::CONSTANT, ge::ENTER, | |||||
ge::REFENTER, ge::LOOPCOND, ge::NEXTITERATION, ge::REFNEXTITERATION, | |||||
ge::EXIT, ge::REFEXIT, ge::MEMCPYADDRASYNC}; | |||||
return (ge_local_set.find(type) != ge_local_set.end()); | return (ge_local_set.find(type) != ge_local_set.end()); | ||||
} | } | ||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
ModelBuilder::ModelBuilder(ge::ComputeGraphPtr compute_graph, const vector<SubGraphInfoPtr> &subgraphs, | |||||
ModelBuilder::ModelBuilder(ge::ComputeGraphPtr compute_graph, const Graph2SubGraphInfoList &subgraphs, | |||||
const map<string, int> &stream_max_parallel_num, bool hcom_parallel, int mode) | const map<string, int> &stream_max_parallel_num, bool hcom_parallel, int mode) | ||||
: mem_offset_(0), | : mem_offset_(0), | ||||
weight_offset_(kWeightsStartOffset), | weight_offset_(kWeightsStartOffset), | ||||
@@ -225,6 +226,25 @@ Status ModelBuilder::SetInputOutputDesc() { | |||||
if (!is_loop_graph_ && node_op_desc->GetType() == LOOPCOND) { | if (!is_loop_graph_ && node_op_desc->GetType() == LOOPCOND) { | ||||
is_loop_graph_ = true; | is_loop_graph_ = true; | ||||
} | } | ||||
// if user set input node format ND, the expected node for data and netoutput format is ND in | |||||
// final graph. | |||||
if ((domi::GetContext().format == domi::DOMI_TENSOR_ND) && | |||||
((node_op_desc->GetType() == DATA_TYPE) || (node_op_desc->GetType() == NETOUTPUT))) { | |||||
GELOGI("The node [%s] format should be set ND.", node_op_desc->GetName().c_str()); | |||||
auto inputDescsPtr = node_op_desc->GetAllInputsDescPtr(); | |||||
auto outputDescsPtr = node_op_desc->GetAllOutputsDescPtr(); | |||||
ge::Format format = ge::FORMAT_ND; | |||||
for (auto &inputDescPtr : inputDescsPtr) { | |||||
GE_CHECK_NOTNULL(inputDescPtr); | |||||
inputDescPtr->SetFormat(format); | |||||
inputDescPtr->SetOriginFormat(format); | |||||
} | |||||
for (auto &outputDescPtr : outputDescsPtr) { | |||||
GE_CHECK_NOTNULL(outputDescPtr); | |||||
outputDescPtr->SetFormat(format); | |||||
outputDescPtr->SetOriginFormat(format); | |||||
} | |||||
} | |||||
if (node_op_desc->GetType() == DATA_TYPE || node_op_desc->GetType() == AIPP_DATA_TYPE) { | if (node_op_desc->GetType() == DATA_TYPE || node_op_desc->GetType() == AIPP_DATA_TYPE) { | ||||
GELOGD("Data node: %s.", n->GetName().c_str()); | GELOGD("Data node: %s.", n->GetName().c_str()); | ||||
@@ -37,7 +37,7 @@ | |||||
namespace ge { | namespace ge { | ||||
class ModelBuilder { | class ModelBuilder { | ||||
public: | public: | ||||
ModelBuilder(ge::ComputeGraphPtr whole_graph, const std::vector<SubGraphInfoPtr> &subgraphs, | |||||
ModelBuilder(ge::ComputeGraphPtr whole_graph, const Graph2SubGraphInfoList &subgraphs, | |||||
const std::map<std::string, int> &stream_max_parallel_num, bool hcom_parallel, | const std::map<std::string, int> &stream_max_parallel_num, bool hcom_parallel, | ||||
int mode = static_cast<int>(domi::BuildMode::GEN_TASK_WITHOUT_FUSION)); | int mode = static_cast<int>(domi::BuildMode::GEN_TASK_WITHOUT_FUSION)); | ||||
@@ -85,7 +85,7 @@ class ModelBuilder { | |||||
ge::ComputeGraphPtr compute_graph_; | ge::ComputeGraphPtr compute_graph_; | ||||
const std::vector<SubGraphInfoPtr> &subgraphs_; | |||||
const Graph2SubGraphInfoList &subgraphs_; | |||||
int64_t stream_num_; | int64_t stream_num_; | ||||
@@ -164,6 +164,9 @@ Status RunContextUtil::CreateRunContext(Model &model, const ComputeGraphPtr &gra | |||||
return ret; | return ret; | ||||
} | } | ||||
GELOGI("CreateRunContext: data_mem_base_ = %p, weight_mem_base_ = %p, memory_size = %lu, weight_size = %lu", | |||||
data_mem_base_, weight_mem_base_, data_mem_size_, weight_mem_size_); | |||||
run_context_ = {rt_model_, nullptr, session_id, data_mem_size_, data_mem_base_, weight_mem_size_, | run_context_ = {rt_model_, nullptr, session_id, data_mem_size_, data_mem_base_, weight_mem_size_, | ||||
weight_mem_base_, buffer, stream_list_, event_list_, label_list_}; | weight_mem_base_, buffer, stream_list_, event_list_, label_list_}; | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -40,7 +40,7 @@ const uint32_t kMaxSwitchStreamNum = 1; | |||||
namespace ge { | namespace ge { | ||||
Status StreamAllocator::AssignLogicalStreams(const std::map<std::string, int> &max_parallel_num, bool hcom_parallel) { | Status StreamAllocator::AssignLogicalStreams(const std::map<std::string, int> &max_parallel_num, bool hcom_parallel) { | ||||
GELOGI("AssignLogicalStreams start."); | |||||
GELOGI("Assign logical streams start."); | |||||
GE_CHECK_NOTNULL(whole_graph_); | GE_CHECK_NOTNULL(whole_graph_); | ||||
GraphUtils::DumpGEGraph(whole_graph_, "BeforeAssignedLogicalStreams"); | GraphUtils::DumpGEGraph(whole_graph_, "BeforeAssignedLogicalStreams"); | ||||
GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "BeforeAssignedLogicalStreams"); | GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "BeforeAssignedLogicalStreams"); | ||||
@@ -52,7 +52,6 @@ Status StreamAllocator::AssignLogicalStreams(const std::map<std::string, int> &m | |||||
} | } | ||||
const map<string, SchedulerConf> &scheduler_confs = gelib->DNNEngineManagerObj().GetSchedulers(); | const map<string, SchedulerConf> &scheduler_confs = gelib->DNNEngineManagerObj().GetSchedulers(); | ||||
LogicalStreamAllocator logical_allocator(scheduler_confs, max_parallel_num, hcom_parallel); | LogicalStreamAllocator logical_allocator(scheduler_confs, max_parallel_num, hcom_parallel); | ||||
Status status = logical_allocator.Assign(whole_graph_, subgraphs_, stream_num_); | Status status = logical_allocator.Assign(whole_graph_, subgraphs_, stream_num_); | ||||
if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
@@ -62,7 +61,7 @@ Status StreamAllocator::AssignLogicalStreams(const std::map<std::string, int> &m | |||||
GraphUtils::DumpGEGraph(whole_graph_, "AfterAssignedLogicalStreams"); | GraphUtils::DumpGEGraph(whole_graph_, "AfterAssignedLogicalStreams"); | ||||
GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "AfterAssignedLogicalStreams"); | GraphUtils::DumpGEGraphToOnnx(*whole_graph_, "AfterAssignedLogicalStreams"); | ||||
GELOGI("AssignLogicalStreams success."); | |||||
GELOGI("Assign logical streams success."); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -136,7 +135,7 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu | |||||
GELOGI("None of nodes need to assign stream, stream num is 0, it will cause error, so change it to 1"); | GELOGI("None of nodes need to assign stream, stream num is 0, it will cause error, so change it to 1"); | ||||
stream_num_ = 1; | stream_num_ = 1; | ||||
} | } | ||||
GELOGI("stream_num_: %ld, event_num_: %u.", stream_num_, event_num_); | |||||
GELOGI("stream num: %ld, event num: %u.", stream_num_, event_num_); | |||||
GELOGI("RefreshRealStream successfully."); | GELOGI("RefreshRealStream successfully."); | ||||
stream_num = stream_num_; | stream_num = stream_num_; | ||||
@@ -148,7 +147,7 @@ Status StreamAllocator::RefreshRealStream(int64_t &stream_num, int64_t &event_nu | |||||
// Split the stream according to the maximum number of nodes in the stream. | // Split the stream according to the maximum number of nodes in the stream. | ||||
Status StreamAllocator::SplitStreams() { | Status StreamAllocator::SplitStreams() { | ||||
if (stream_num_ == 0) { | if (stream_num_ == 0) { | ||||
GELOGI("stream_num_ is 0"); | |||||
GELOGI("The number of streams is 0 and no need to split."); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -30,7 +30,7 @@ | |||||
namespace ge { | namespace ge { | ||||
class StreamAllocator { | class StreamAllocator { | ||||
public: | public: | ||||
StreamAllocator(ComputeGraphPtr whole_graph, const std::vector<SubGraphInfoPtr> &subgraphs) | |||||
StreamAllocator(ComputeGraphPtr whole_graph, const Graph2SubGraphInfoList &subgraphs) | |||||
: whole_graph_(std::move(whole_graph)), subgraphs_(subgraphs) {} | : whole_graph_(std::move(whole_graph)), subgraphs_(subgraphs) {} | ||||
StreamAllocator(const StreamAllocator &) = delete; | StreamAllocator(const StreamAllocator &) = delete; | ||||
StreamAllocator &operator=(const StreamAllocator &) = delete; | StreamAllocator &operator=(const StreamAllocator &) = delete; | ||||
@@ -75,7 +75,7 @@ class StreamAllocator { | |||||
bool IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr, const NodePtr &recv_node_ptr) const; | bool IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr, const NodePtr &recv_node_ptr) const; | ||||
ComputeGraphPtr whole_graph_; | ComputeGraphPtr whole_graph_; | ||||
const std::vector<SubGraphInfoPtr> &subgraphs_; | |||||
const Graph2SubGraphInfoList &subgraphs_; | |||||
int64_t stream_num_{0}; | int64_t stream_num_{0}; | ||||
uint32_t event_num_{0}; | uint32_t event_num_{0}; | ||||
@@ -29,19 +29,21 @@ static const int64_t kInvalidStream = -1; | |||||
namespace ge { | namespace ge { | ||||
StreamGraphOptimizer::~StreamGraphOptimizer() {} | StreamGraphOptimizer::~StreamGraphOptimizer() {} | ||||
void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, vector<SubGraphInfoPtr> &subgraph_infos) { | |||||
void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map) { | |||||
size_t node_size = comp_graph->GetDirectNodesSize(); | size_t node_size = comp_graph->GetDirectNodesSize(); | ||||
GELOGI("Refresh placeholder and end nodeId start from node num: %zu", node_size); | GELOGI("Refresh placeholder and end nodeId start from node num: %zu", node_size); | ||||
for (const auto &sub_graph_info : subgraph_infos) { | |||||
ComputeGraphPtr sub_graph = sub_graph_info->GetSubGraph(); | |||||
if (sub_graph == nullptr) { | |||||
continue; | |||||
} | |||||
for (ge::NodePtr &node : sub_graph->GetDirectNode()) { | |||||
GE_CHECK_NOTNULL_EXEC(node->GetOpDesc(), return ); | |||||
if ((node->GetType() == END) || (node->GetType() == PLACEHOLDER)) { | |||||
node->GetOpDesc()->SetId(static_cast<int64_t>(node_size)); | |||||
node_size++; | |||||
for (const auto &subgraph_pair : subgraph_map) { | |||||
for (const auto &subgraph_info : subgraph_pair.second) { | |||||
ComputeGraphPtr subgraph = subgraph_info->GetSubGraph(); | |||||
if (subgraph == nullptr) { | |||||
continue; | |||||
} | |||||
for (ge::NodePtr &node : subgraph->GetDirectNode()) { | |||||
GE_CHECK_NOTNULL_EXEC(node->GetOpDesc(), return ); | |||||
if ((node->GetType() == END) || (node->GetType() == PLACEHOLDER)) { | |||||
node->GetOpDesc()->SetId(static_cast<int64_t>(node_size)); | |||||
node_size++; | |||||
} | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -71,67 +73,71 @@ bool StreamGraphOptimizer::IsSameStreamId(const ComputeGraphPtr &comp_graph) { | |||||
} | } | ||||
Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &comp_graph, | Status StreamGraphOptimizer::OptimizeStreamedSubGraph(const ComputeGraphPtr &comp_graph, | ||||
vector<SubGraphInfoPtr> &subgraph_infos, | |||||
Graph2SubGraphInfoList &subgraph_map, | |||||
struct RunContext &run_context) { | struct RunContext &run_context) { | ||||
Status ret = SUCCESS; | |||||
GELOGI("Begin to Get optimize streamed subgraph."); | |||||
GELOGI("Optimize streamed subgraph start."); | |||||
RefreshNodeId(comp_graph, subgraph_infos); | |||||
RefreshNodeId(comp_graph, subgraph_map); | |||||
std::shared_ptr<GELib> instance = ge::GELib::GetInstance(); | std::shared_ptr<GELib> instance = ge::GELib::GetInstance(); | ||||
GE_CHECK_NOTNULL(instance); | GE_CHECK_NOTNULL(instance); | ||||
for (auto &sub_graph_info : subgraph_infos) { | |||||
ComputeGraphPtr sub_graph = sub_graph_info->GetSubGraph(); | |||||
if (sub_graph == nullptr) { | |||||
continue; | |||||
} | |||||
for (const auto &subgraph_pair : subgraph_map) { | |||||
for (const auto &subgraph_info : subgraph_pair.second) { | |||||
ComputeGraphPtr subgraph = subgraph_info->GetSubGraph(); | |||||
GE_CHECK_NOTNULL(subgraph); | |||||
std::string engine_name = sub_graph_info->GetEngineName(); | |||||
GELOGI("Optimize subgraph %s", subgraph->GetName().c_str()); | |||||
vector<GraphOptimizerPtr> graph_optimizers; | |||||
if (instance->DNNEngineManagerObj().IsEngineRegistered(engine_name)) { | |||||
instance->OpsKernelManagerObj().GetGraphOptimizerByEngine(engine_name, graph_optimizers); | |||||
GELOGI("Subgraph: %s start optimize streamed graph. engineName: %s, subgraph num: %zu, graph Optimizer num: %zu.", | |||||
sub_graph->GetName().c_str(), engine_name.c_str(), subgraph_infos.size(), graph_optimizers.size()); | |||||
std::string engine_name = subgraph_info->GetEngineName(); | |||||
auto nodes = sub_graph->GetDirectNode(); | |||||
if (nodes.empty()) { | |||||
continue; | |||||
} | |||||
if (!IsSameStreamId(sub_graph)) { | |||||
GELOGI("There are more than one stream in subgraph %s", sub_graph->GetName().c_str()); | |||||
continue; | |||||
} | |||||
OpDescPtr op_desc = nodes.at(0)->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
int64_t stream_id = op_desc->GetStreamId(); | |||||
if (static_cast<size_t>(stream_id) >= run_context.graphStreamList.size()) { | |||||
GELOGE(FAILED, "stream_id is bigger than run_context.graphStreamList.size()"); | |||||
return FAILED; | |||||
} | |||||
run_context.stream = run_context.graphStreamList[stream_id]; | |||||
GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu.", | |||||
sub_graph->GetName().c_str(), engine_name.c_str(), stream_id, | |||||
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream))); | |||||
for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { | |||||
GE_CHECK_NOTNULL(*iter); | |||||
ret = (*iter)->OptimizeStreamGraph(*sub_graph, run_context); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, | |||||
"[optimizeStreamedSubGraph]: optimize streamed subgraph failed, subgraph: %s, engine_name: %s, graph " | |||||
"Optimizer num: %zu, ret: %u", | |||||
sub_graph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size(), ret); | |||||
return ret; | |||||
vector<GraphOptimizerPtr> graph_optimizers; | |||||
if (instance->DNNEngineManagerObj().IsEngineRegistered(engine_name)) { | |||||
instance->OpsKernelManagerObj().GetGraphOptimizerByEngine(engine_name, graph_optimizers); | |||||
GELOGI("Subgraph: %s start optimize streamed graph. engineName: %s, graph Optimizer num: %zu.", | |||||
subgraph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size()); | |||||
auto nodes = subgraph->GetDirectNode(); | |||||
if (nodes.empty()) { | |||||
continue; | |||||
} | |||||
if (!IsSameStreamId(subgraph)) { | |||||
GELOGI("There are more than one stream in subgraph %s", subgraph->GetName().c_str()); | |||||
continue; | |||||
} | |||||
OpDescPtr op_desc = nodes.at(0)->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
int64_t stream_id = op_desc->GetStreamId(); | |||||
if (static_cast<size_t>(stream_id) >= run_context.graphStreamList.size()) { | |||||
GELOGE(FAILED, "stream_id %ld is bigger than run_context.graphStreamList.size() %zu", stream_id, | |||||
run_context.graphStreamList.size()); | |||||
return FAILED; | |||||
} | |||||
run_context.stream = run_context.graphStreamList[stream_id]; | |||||
GELOGD("Subgraph has same stream id, subgraph: %s, engine_name: %s, stream_id: %ld, rtstream: %lu.", | |||||
subgraph->GetName().c_str(), engine_name.c_str(), stream_id, | |||||
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(run_context.stream))); | |||||
for (auto iter = graph_optimizers.begin(); iter != graph_optimizers.end(); ++iter) { | |||||
GE_CHECK_NOTNULL(*iter); | |||||
Status ret = (*iter)->OptimizeStreamGraph(*subgraph, run_context); | |||||
if (ret != SUCCESS) { | |||||
GELOGE( | |||||
ret, | |||||
"[optimizeStreamedSubGraph]: optimize streamed subgraph failed, subgraph: %s, engine_name: %s, graph " | |||||
"Optimizer num: %zu, ret: %u", | |||||
subgraph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size(), ret); | |||||
return ret; | |||||
} | |||||
GELOGI( | |||||
"[optimizeStreamedSubGraph]: optimize streamed subgraph success, subgraph: %s, engine_name: %s, graph " | |||||
"Optimizer num: %zu!", | |||||
subgraph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size()); | |||||
} | } | ||||
GELOGI( | |||||
"[optimizeStreamedSubGraph]: optimize streamed subgraph success, subgraph: %s, engine_name: %s, graph " | |||||
"Optimizer num: %zu!", | |||||
sub_graph->GetName().c_str(), engine_name.c_str(), graph_optimizers.size()); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
return ret; | |||||
GELOGI("Optimize streamed subgraph success."); | |||||
return SUCCESS; | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -35,11 +35,11 @@ class StreamGraphOptimizer { | |||||
virtual ~StreamGraphOptimizer(); | virtual ~StreamGraphOptimizer(); | ||||
Status OptimizeStreamedSubGraph(const ComputeGraphPtr &comp_graph, std::vector<SubGraphInfoPtr> &subgraph_ptr_list, | |||||
Status OptimizeStreamedSubGraph(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map, | |||||
struct RunContext &run_context); | struct RunContext &run_context); | ||||
private: | private: | ||||
void RefreshNodeId(const ComputeGraphPtr &comp_graph, std::vector<SubGraphInfoPtr> &subgraph_ptr_list); | |||||
void RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map); | |||||
bool IsSameStreamId(const ComputeGraphPtr &comp_graph); | bool IsSameStreamId(const ComputeGraphPtr &comp_graph); | ||||
}; | }; | ||||
@@ -221,10 +221,8 @@ Status TaskGenerator::SaveL1fusionNodes(map<int64_t, std::vector<NodePtr>> &l1_f | |||||
if (call_check) { | if (call_check) { | ||||
auto input_group_id = *input_group_ids.begin(); | auto input_group_id = *input_group_ids.begin(); | ||||
if (group_id != input_group_id) { | if (group_id != input_group_id) { | ||||
GELOGE(INTERNAL_ERROR, | |||||
"L1Fusion: node[name:%s(%s) with group id:%ld and diff from it's input nodes's group id:%ld ", | |||||
GELOGW("L1Fusion: node[name:%s(%s) with group id:%ld and diff from it's input nodes's group id:%ld ", | |||||
name.c_str(), type.c_str(), group_id, input_group_id); | name.c_str(), type.c_str(), group_id, input_group_id); | ||||
return INTERNAL_ERROR; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -172,7 +172,7 @@ NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::st | |||||
GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); | ||||
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | ||||
NodePtr label_set = graph->AddNodeFront(op_desc); | |||||
NodePtr label_set = graph->AddNode(op_desc); | |||||
GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); | GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); | ||||
// Link control edge to graph tail. | // Link control edge to graph tail. | ||||
@@ -202,7 +202,7 @@ NodePtr LabelMaker::AddLabelGotoEnter(const ComputeGraphPtr &graph, const std::s | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTO); | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX); | |||||
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
SetStreamIdEnter(graph, op_desc); | SetStreamIdEnter(graph, op_desc); | ||||
@@ -238,7 +238,7 @@ NodePtr LabelMaker::AddLabelGotoLeave(const ComputeGraphPtr &graph, const std::s | |||||
const NodePtr &node = *it; | const NodePtr &node = *it; | ||||
GE_CHECK_NOTNULL_EXEC(node, return nullptr); | GE_CHECK_NOTNULL_EXEC(node, return nullptr); | ||||
OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTO); | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(name, LABELGOTOEX); | |||||
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
SetStreamIdLeave(graph, op_desc); | SetStreamIdLeave(graph, op_desc); | ||||
@@ -366,6 +366,7 @@ NodePtr LabelMaker::AddLabelSwitchIndex(const ComputeGraphPtr &graph, const std: | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(name, DATA); | OpDescPtr op_desc = MakeShared<OpDesc>(name, DATA); | ||||
GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | GE_CHECK_NOTNULL_EXEC(op_desc, return nullptr); | ||||
op_desc->SetStreamId(kInvalidStreamId); | |||||
GELOGI("Data: Create node %s.", op_desc->GetName().c_str()); | GELOGI("Data: Create node %s.", op_desc->GetName().c_str()); | ||||
if (op_desc->AddOutputDesc(desc) != GRAPH_SUCCESS) { | if (op_desc->AddOutputDesc(desc) != GRAPH_SUCCESS) { | ||||
@@ -20,11 +20,11 @@ | |||||
namespace { | namespace { | ||||
const uint32_t kCoreDim = 1; // for rtCpuKernelLaunch | const uint32_t kCoreDim = 1; // for rtCpuKernelLaunch | ||||
const char *const kCpuTaskModelEnqueue = "modelEnqueue"; | const char *const kCpuTaskModelEnqueue = "modelEnqueue"; | ||||
const char *const kCpuTaskPrepareInput = "modelPrepareInput"; | |||||
const char *const kCpuTaskWaitEndGraph = "modelWaitEndGraph"; | const char *const kCpuTaskWaitEndGraph = "modelWaitEndGraph"; | ||||
const char *const kCpuTaskPrepareOutput = "modelPrepareOutput"; | |||||
const char *const kCpuTaskPrepareOutput = "bufferPrepareOutput"; | |||||
const char *const kCpuTaskModelDequeue = "modelDequeue"; | const char *const kCpuTaskModelDequeue = "modelDequeue"; | ||||
const char *const kCpuTaskModelRepeat = "modelRepeat"; | const char *const kCpuTaskModelRepeat = "modelRepeat"; | ||||
const char *const kCpuTaskZeroCopy = "zeroCpy"; | |||||
} // namespace | } // namespace | ||||
namespace ge { | namespace ge { | ||||
@@ -93,19 +93,19 @@ Status CpuTaskModelDequeue::Distribute() { | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief definiteness queue schedule, bind output queue to task. | |||||
/// @param [in] addr: NetOutput Op input tensor address. | |||||
/// @param [in] size: NetOutput Op input tensor size. | |||||
/// @param [in] in_mbuf: input mbuf addr for input data. | |||||
/// @brief definiteness queue schedule, zero copy. | |||||
/// @param [in] mbuf_list: input/output mbuf addr list for input/output data. | |||||
/// @param [in] outside_addrs: model input/output memory addr | |||||
/// @return: 0 for success / others for failed | /// @return: 0 for success / others for failed | ||||
/// | /// | ||||
Status CpuTaskPrepareInput::Init(uintptr_t addr, uint32_t size, uintptr_t in_mbuf) { | |||||
Status CpuTaskZeroCopy::Init(std::vector<uintptr_t> &mbuf_list, | |||||
std::map<const void *, std::vector<void *>> &outside_addrs) { | |||||
if ((args_ != nullptr) || (args_size_ > 0)) { | if ((args_ != nullptr) || (args_size_ > 0)) { | ||||
GELOGE(FAILED, "Task already initialized, size: %u", args_size_); | GELOGE(FAILED, "Task already initialized, size: %u", args_size_); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
args_size_ = sizeof(PrepareInputInfo); | |||||
args_size_ = sizeof(AddrMapInfo); | |||||
rtError_t status = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); | rtError_t status = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); | ||||
if (status != RT_ERROR_NONE) { | if (status != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); | GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); | ||||
@@ -113,36 +113,99 @@ Status CpuTaskPrepareInput::Init(uintptr_t addr, uint32_t size, uintptr_t in_mbu | |||||
} | } | ||||
GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "args data.", args_size_) | GE_PRINT_DYNAMIC_MEMORY(rtMalloc, "args data.", args_size_) | ||||
PrepareInputInfo prepare; | |||||
prepare.in_mbuf = in_mbuf; | |||||
prepare.mbuf_offset = 0; | |||||
prepare.data_size = size; | |||||
prepare.data_addr = addr; | |||||
status = rtMemcpy(args_, args_size_, &prepare, args_size_, RT_MEMCPY_HOST_TO_DEVICE); | |||||
AddrMapInfo addr_map_info; | |||||
for (const auto &addrs : outside_addrs) { | |||||
addr_map_info.addr_num += addrs.second.size(); | |||||
} | |||||
GELOGI("addr_map_info.addr_num is %zu", addr_map_info.addr_num); | |||||
// init src_addrs/dst_addrs | |||||
size_t index = 0; | |||||
vector<uint64_t> src_addrs; | |||||
vector<uint64_t> dst_addrs; | |||||
for (const auto &addrs : outside_addrs) { | |||||
for (size_t i = 0; i < addrs.second.size(); ++i) { | |||||
src_addrs.push_back(mbuf_list.at(index)); | |||||
dst_addrs.push_back(reinterpret_cast<uint64_t>(addrs.second.at(i))); | |||||
} | |||||
index++; | |||||
} | |||||
// malloc mem for src_addrs/dst_addrs, and copy data of src_addrs/dst_addrs | |||||
status = rtMalloc(&src_addr_, src_addrs.size() * sizeof(uint64_t), RT_MEMORY_HBM); | |||||
if (status != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); | |||||
return RT_FAILED; | |||||
} | |||||
status = rtMemcpy(src_addr_, src_addrs.size() * sizeof(uint64_t), src_addrs.data(), | |||||
src_addrs.size() * sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); | |||||
if (status != RT_ERROR_NONE) { | if (status != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); | GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); | ||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
status = rtMalloc(&dst_addr_, dst_addrs.size() * sizeof(uint64_t), RT_MEMORY_HBM); | |||||
if (status != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt malloc failed, status: 0x%x", status); | |||||
return RT_FAILED; | |||||
} | |||||
status = rtMemcpy(dst_addr_, dst_addrs.size() * sizeof(uint64_t), dst_addrs.data(), | |||||
dst_addrs.size() * sizeof(uint64_t), RT_MEMCPY_HOST_TO_DEVICE); | |||||
if (status != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); | |||||
return RT_FAILED; | |||||
} | |||||
// src_addr_list is init to src_addr, which is the point to src_addrs | |||||
if (!src_addrs.empty() && !dst_addrs.empty()) { | |||||
addr_map_info.src_addr_list = reinterpret_cast<uint64_t>(src_addr_); | |||||
addr_map_info.dst_addr_list = reinterpret_cast<uint64_t>(dst_addr_); | |||||
GELOGI("src_addr_list is %lu, dst_addr_list is %lu", addr_map_info.src_addr_list, addr_map_info.dst_addr_list); | |||||
} | |||||
status = rtMemcpy(args_, args_size_, &addr_map_info, sizeof(AddrMapInfo), RT_MEMCPY_HOST_TO_DEVICE); | |||||
if (status != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt memcpy failed, status: 0x%x", status); | |||||
return RT_FAILED; | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status CpuTaskPrepareInput::Distribute() { | |||||
Status CpuTaskZeroCopy::Distribute() { | |||||
if ((args_ == nullptr) || (args_size_ == 0) || (stream_ == nullptr)) { | if ((args_ == nullptr) || (args_size_ == 0) || (stream_ == nullptr)) { | ||||
GELOGE(FAILED, "Task not initialized, distribute failed, size: %u", args_size_); | GELOGE(FAILED, "Task not initialized, distribute failed, size: %u", args_size_); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
rtError_t status = rtCpuKernelLaunch(nullptr, kCpuTaskPrepareInput, kCoreDim, args_, args_size_, nullptr, stream_); | |||||
rtError_t status = rtCpuKernelLaunch(nullptr, kCpuTaskZeroCopy, kCoreDim, args_, args_size_, nullptr, stream_); | |||||
if (status != RT_ERROR_NONE) { | if (status != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt CpuKernelLaunch PrepareInput failed, status: 0x%X", status); | |||||
GELOGE(RT_FAILED, "Call rt CpuKernelLaunch ZeroCopy failed, status: 0x%X", status); | |||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
GELOGI("Cpu kernel launch prepare input task success."); | |||||
GELOGI("Cpu kernel launch zero copy task success."); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
CpuTaskZeroCopy::~CpuTaskZeroCopy() { | |||||
if (src_addr_ == nullptr && dst_addr_ == nullptr) { | |||||
return; | |||||
} | |||||
if (src_addr_ != nullptr) { | |||||
rtError_t status = rtFree(src_addr_); | |||||
if (status != RT_ERROR_NONE) { | |||||
GELOGW("Call rt free failed, status: 0x%x", status); | |||||
} | |||||
} | |||||
if (dst_addr_ != nullptr) { | |||||
rtError_t status = rtFree(dst_addr_); | |||||
if (status != RT_ERROR_NONE) { | |||||
GELOGW("Call rt free failed, status: 0x%x", status); | |||||
} | |||||
} | |||||
src_addr_ = nullptr; | |||||
dst_addr_ = nullptr; | |||||
} | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief definiteness queue schedule, bind output queue to task. | /// @brief definiteness queue schedule, bind output queue to task. | ||||
@@ -47,6 +47,13 @@ struct PrepareOutputInfo { | |||||
uintptr_t out_mbuf; // output mbuf addr | uintptr_t out_mbuf; // output mbuf addr | ||||
}; | }; | ||||
// For AICPU task "modelZeroCopy" | |||||
struct AddrMapInfo { | |||||
uint32_t addr_num = 0; | |||||
uint64_t src_addr_list; | |||||
uint64_t dst_addr_list; | |||||
}; | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief CpuTask base, inherit from TaskInfo used for manage. | /// @brief CpuTask base, inherit from TaskInfo used for manage. | ||||
@@ -78,17 +85,21 @@ class CpuTaskModelDequeue : public CpuTaskInfo { | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief definiteness queue schedule, bind output queue to task. | |||||
/// @brief definiteness queue schedule, zero copy. | |||||
/// | /// | ||||
class CpuTaskPrepareInput : public CpuTaskInfo { | |||||
class CpuTaskZeroCopy : public CpuTaskInfo { | |||||
public: | public: | ||||
explicit CpuTaskPrepareInput(rtStream_t stream) : CpuTaskInfo(stream) {} | |||||
~CpuTaskPrepareInput() override {} | |||||
explicit CpuTaskZeroCopy(rtStream_t stream) : CpuTaskInfo(stream) {} | |||||
~CpuTaskZeroCopy() override; | |||||
Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override { return SUCCESS; } | Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override { return SUCCESS; } | ||||
Status Init(uintptr_t addr, uint32_t size, uintptr_t in_mbuf); | |||||
Status Init(std::vector<uintptr_t> &mbuf_list, std::map<const void *, std::vector<void *>> &outside_addrs); | |||||
Status Distribute() override; | Status Distribute() override; | ||||
private: | |||||
void *src_addr_ = nullptr; | |||||
void *dst_addr_ = nullptr; | |||||
}; | }; | ||||
/// | /// | ||||
@@ -340,13 +340,6 @@ class DavinciModel { | |||||
vector<InputOutputDescInfo> &output_desc, | vector<InputOutputDescInfo> &output_desc, | ||||
std::vector<uint32_t> &inputFormats, std::vector<uint32_t> &output_formats); | std::vector<uint32_t> &inputFormats, std::vector<uint32_t> &output_formats); | ||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief copy input data to model | |||||
/// @return Status | |||||
/// | |||||
Status CopyInputDataToModel(const std::vector<DataBuffer> &data, uint32_t data_op_index, bool device_data); | |||||
Status ReturnResult(uint32_t data_id, const bool rslt_flg, const bool seq_end_flg, OutputData *output_data); | Status ReturnResult(uint32_t data_id, const bool rslt_flg, const bool seq_end_flg, OutputData *output_data); | ||||
Status ReturnNoOutput(uint32_t data_id); | Status ReturnNoOutput(uint32_t data_id); | ||||
@@ -413,20 +406,6 @@ class DavinciModel { | |||||
/// | /// | ||||
uint32_t GetDeviceId() const { return device_id_; } | uint32_t GetDeviceId() const { return device_id_; } | ||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief Set Train Mode | |||||
/// @return void | |||||
/// | |||||
void SetTrainMode(bool mode) { is_train_mode_ = mode; } | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief Get Train Mode | |||||
/// @return bool true | |||||
/// | |||||
bool GetTrainMode() { return is_train_mode_; } | |||||
GeModelPtr GetGeModel() { return ge_model_; } | GeModelPtr GetGeModel() { return ge_model_; } | ||||
const RuntimeParam &GetRuntimeParam() { return runtime_param_; } | const RuntimeParam &GetRuntimeParam() { return runtime_param_; } | ||||
@@ -519,15 +498,14 @@ class DavinciModel { | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief Copy Data addr to model for direct use. | /// @brief Copy Data addr to model for direct use. | ||||
/// @param [in] const vector<void *> &addrs: model input memory addr list. | |||||
/// @param [in] const vector<uint32_t> &sizes: model input memory size list. | |||||
/// @param [in] const std::map<uint32_t, std::pair<int64_t, void *>> &data_info: model memory addr/size list. | |||||
/// @param [in] const std::vector<DataBuffer> &blobs: user input data list. | /// @param [in] const std::vector<DataBuffer> &blobs: user input data list. | ||||
/// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input | /// @param [in] bool is_dynamic_input: whether is dynamic input, true: is dynamic input; false: not is dynamic input | ||||
/// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy | /// @param [in] ZeroCopyMode zero_copy_mode: input zero copy or output zero copy | ||||
/// @param [in] string batch_label: batch label for multi-batch scenes | /// @param [in] string batch_label: batch label for multi-batch scenes | ||||
/// @return SUCCESS handle successfully / others handle failed | /// @return SUCCESS handle successfully / others handle failed | ||||
/// | /// | ||||
Status ZeroCopyBlobs(const std::vector<void *> &addr_list, const std::vector<int64_t> &size_list, | |||||
Status ZeroCopyBlobs(const std::map<uint32_t, std::pair<int64_t, void *>> &data_info, | |||||
const std::vector<DataBuffer> &blobs, bool is_dynamic_input, ZeroCopyMode zero_copy_mode, | const std::vector<DataBuffer> &blobs, bool is_dynamic_input, ZeroCopyMode zero_copy_mode, | ||||
string batch_label); | string batch_label); | ||||
@@ -610,11 +588,9 @@ class DavinciModel { | |||||
/// @brief Data Op Initialize. | /// @brief Data Op Initialize. | ||||
/// @param [in] NodePtr: Data Op. | /// @param [in] NodePtr: Data Op. | ||||
/// @param [in/out] data_op_index: NetOutput addr size info. | /// @param [in/out] data_op_index: NetOutput addr size info. | ||||
/// @param [in/out] input_data_info: Data index and addr info {index, {size, addr}}. | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status InitDataOp(const NodePtr &node, uint32_t &data_op_index, | |||||
std::map<uint32_t, std::pair<int64_t, void *>> &input_data_info); | |||||
Status InitDataOp(const NodePtr &node, uint32_t &data_op_index); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
@@ -632,20 +608,28 @@ class DavinciModel { | |||||
/// | /// | ||||
Status InitNetOutput(const OpDescPtr &op_desc); | Status InitNetOutput(const OpDescPtr &op_desc); | ||||
/// | |||||
/// @ingroup ge | |||||
/// @brief Make Input and Output addr for feature use. | |||||
/// @param [in] input_data_info: Data index and addr info {index, {size, addr}}. | |||||
/// @return Status | |||||
/// | |||||
Status CombineDataInfo(const std::map<uint32_t, std::pair<int64_t, void *>> &input_data_info); | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief Constant Op Init. | /// @brief Constant Op Init. | ||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status InitConstant(const ConstOpDescPtr &op_desc) const; | |||||
Status InitConstant(const OpDescPtr &op_desc); | |||||
Status InitVariable(const OpDescPtr &op_desc); | |||||
Status InitEndGraph(const OpDescPtr &op_desc); | |||||
/// @ingroup ge | |||||
/// @brief LabelSet Op Initialize. | |||||
/// @param [in] op_desc: LabelSet Op descriptor. | |||||
/// @return Status | |||||
Status InitLabelSet(const OpDescPtr &op_desc); | |||||
Status InitStreamSwitch(const OpDescPtr &op_desc); | |||||
Status InitStreamActive(const OpDescPtr &op_desc); | |||||
Status InitStreamSwitchN(const OpDescPtr &op_desc); | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
@@ -662,7 +646,7 @@ class DavinciModel { | |||||
/// @brief Init model stream for NN model. | /// @brief Init model stream for NN model. | ||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status InitModelStream(rtStream_t stream, bool async_mode); | |||||
Status InitModelStream(rtStream_t stream); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
@@ -678,12 +662,16 @@ class DavinciModel { | |||||
/// | /// | ||||
Status BindInputQueue(); | Status BindInputQueue(); | ||||
Status CpuTaskModelZeroCopy(std::vector<uintptr_t> &mbuf_list, | |||||
std::map<const void *, std::vector<void *>> &outside_addrs); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief ACL, Bind NetOutput Op addr to output queue. | /// @brief ACL, Bind NetOutput Op addr to output queue. | ||||
/// @return: 0 for success / others for fail | /// @return: 0 for success / others for fail | ||||
/// | /// | ||||
Status BindOutputQueue(); | Status BindOutputQueue(); | ||||
Status CpuModelPrepareOutput(uintptr_t addr, uint32_t size); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
@@ -692,13 +680,6 @@ class DavinciModel { | |||||
/// | /// | ||||
Status BindActiveStream(); | Status BindActiveStream(); | ||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief insert active_stream_indication_ | |||||
/// @return Status | |||||
/// | |||||
Status MarkActiveStream(const OpDescPtr &op_desc); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief definiteness queue schedule, bind input queue to task. | /// @brief definiteness queue schedule, bind input queue to task. | ||||
@@ -707,7 +688,7 @@ class DavinciModel { | |||||
/// @param [in] size: Data Op output tensor size. | /// @param [in] size: Data Op output tensor size. | ||||
/// @return: 0 for success / others for fail | /// @return: 0 for success / others for fail | ||||
/// | /// | ||||
Status CpuModelDequeue(uint32_t queue_id, uintptr_t addr, uint32_t size); | |||||
Status CpuModelDequeue(uint32_t queue_id); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
@@ -734,6 +715,8 @@ class DavinciModel { | |||||
/// | /// | ||||
Status CpuWaitEndGraph(); | Status CpuWaitEndGraph(); | ||||
Status BindEnqueue(); | |||||
Status CpuModelEnqueue(uint32_t queue_id, uintptr_t out_mbuf); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief definiteness queue schedule, repeat run model. | /// @brief definiteness queue schedule, repeat run model. | ||||
@@ -783,10 +766,8 @@ class DavinciModel { | |||||
vector<OpDescPtr> variable_op_list_; | vector<OpDescPtr> variable_op_list_; | ||||
vector<int64_t> output_size_list_; // Init by NetOutput Input Tensor | |||||
vector<void *> output_addr_list_; // Init by NetOutput Input Tensor | |||||
vector<int64_t> input_size_list_; // Init by Data Output Tensor | |||||
vector<void *> input_addr_list_; // Init by Data Output Tensor | |||||
std::map<uint32_t, std::pair<int64_t, void *>> input_data_info_; // Init by Data Output Tensor | |||||
std::map<uint32_t, std::pair<int64_t, void *>> output_data_info_; // Init by NetOutput Input Tensor | |||||
// output op: save cce op actual needed memory size | // output op: save cce op actual needed memory size | ||||
vector<int64_t> output_memory_size_list_; | vector<int64_t> output_memory_size_list_; | ||||
@@ -813,6 +794,7 @@ class DavinciModel { | |||||
vector<rtEvent_t> event_list_; | vector<rtEvent_t> event_list_; | ||||
vector<rtLabel_t> label_list_; | vector<rtLabel_t> label_list_; | ||||
set<uint32_t> label_id_indication_; | |||||
std::mutex outside_addrs_mutex_; | std::mutex outside_addrs_mutex_; | ||||
std::map<const void *, std::vector<void *>> input_outside_addrs_; | std::map<const void *, std::vector<void *>> input_outside_addrs_; | ||||
@@ -830,6 +812,8 @@ class DavinciModel { | |||||
bool is_inner_model_stream_; | bool is_inner_model_stream_; | ||||
bool is_async_mode_; // For NN execute, Async mode use rtMemcpyAsync on rt_model_stream_. | |||||
// ACL queue schedule, save queue ids for Init. | // ACL queue schedule, save queue ids for Init. | ||||
std::vector<TaskInfoPtr> cpu_task_list_; | std::vector<TaskInfoPtr> cpu_task_list_; | ||||
std::vector<uint32_t> input_queue_ids_; // input queue ids created by caller. | std::vector<uint32_t> input_queue_ids_; // input queue ids created by caller. | ||||
@@ -847,8 +831,6 @@ class DavinciModel { | |||||
uint32_t device_id_; | uint32_t device_id_; | ||||
bool is_train_mode_; | |||||
std::mutex flowctrl_op_index_internal_map_mutex_; | std::mutex flowctrl_op_index_internal_map_mutex_; | ||||
std::map<uint32_t, uint32_t> flowctrl_op_index_internal_map_; | std::map<uint32_t, uint32_t> flowctrl_op_index_internal_map_; | ||||
std::set<uint32_t> active_stream_indication_; | std::set<uint32_t> active_stream_indication_; | ||||
@@ -358,26 +358,17 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<Tensor | |||||
input_data.timestamp = 0; | input_data.timestamp = 0; | ||||
input_data.index = 0; | input_data.index = 0; | ||||
std::size_t index = 0; | |||||
for (const auto &op : model->GetDataList()) { | |||||
GE_CHECK_NOTNULL(op); | |||||
GE_CHECK_GE(inputs.size(), 1); | |||||
GE_CHECK_GE(inputs.size() - 1, index); | |||||
for (size_t i = 0; i < inputs.size(); ++i) { | |||||
DataBuffer data; | DataBuffer data; | ||||
data.data = inputs[index].data.data; | |||||
data.length = inputs[index].data.length; | |||||
data.data = inputs[i].data.data; | |||||
data.length = inputs[i].data.length; | |||||
input_data.blobs.push_back(data); | input_data.blobs.push_back(data); | ||||
index++; | |||||
} | } | ||||
CHECK_FALSE_EXEC(input_data.blobs.size() >= inputs.size(), | |||||
GELOGW("cur_inputs size = %zu, inputs size = %zu.", input_data.blobs.size(), inputs.size());); | |||||
OutputData output_data; | OutputData output_data; | ||||
output_data.model_id = model_id; | output_data.model_id = model_id; | ||||
output_data.index = 0; | output_data.index = 0; | ||||
for (size_t i = 0; i < outputs.size(); i++) { | |||||
for (size_t i = 0; i < outputs.size(); ++i) { | |||||
DataBuffer data; | DataBuffer data; | ||||
data.data = outputs[i].data.data; | data.data = outputs[i].data.data; | ||||
data.length = outputs[i].data.length; | data.length = outputs[i].data.length; | ||||
@@ -675,6 +666,15 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model | |||||
break; | break; | ||||
} | } | ||||
davinci_model->SetId(model_id); | davinci_model->SetId(model_id); | ||||
int32_t device_id = 0; | |||||
rtError_t rt_ret = rtGetDevice(&device_id); | |||||
if (rt_ret != RT_ERROR_NONE || device_id < 0) { | |||||
GELOGE(RT_FAILED, "Call rtGetDevice failed, ret = 0x%X, device_id = %d.", rt_ret, device_id); | |||||
return FAILED; | |||||
} | |||||
davinci_model->SetDeviceId(device_id); | |||||
ret = davinci_model->Init(dev_ptr, mem_size, weight_ptr, weight_size); | ret = davinci_model->Init(dev_ptr, mem_size, weight_ptr, weight_size); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, break, "DavinciInit failed."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(ret != SUCCESS, break, "DavinciInit failed."); | ||||
@@ -51,27 +51,6 @@ bool ModelUtils::IsOutput(ConstOpDescPtr op_desc) { | |||||
return false; | return false; | ||||
} | } | ||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief Check is the Input need trans code. | |||||
/// @return bool | |||||
/// | |||||
bool ModelUtils::IsInputTensorNeedTrans(ConstOpDescPtr op_desc, size_t tensor_index) { | |||||
GE_CHECK_NOTNULL_EXEC(op_desc, return false); | |||||
const auto &input_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(tensor_index)); | |||||
const auto &output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(tensor_index)); | |||||
GE_CHECK_NOTNULL_EXEC(input_desc, return false); | |||||
GE_CHECK_NOTNULL_EXEC(output_desc, return false); | |||||
if ((output_desc->GetFormat() == FORMAT_NC1HWC0) && (output_desc->GetDataType() == DT_INT8)) { | |||||
// AIPP input, add attribute in data op to tag aipp | |||||
return false; | |||||
} | |||||
return (input_desc->GetFormat() != output_desc->GetFormat()) || | |||||
(input_desc->GetDataType() != output_desc->GetDataType()); | |||||
} | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief Get input size. | /// @brief Get input size. | ||||
@@ -398,6 +377,8 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co | |||||
GE_CHK_STATUS(TensorUtils::GetDataOffset(tensor_desc, data_offset)); | GE_CHK_STATUS(TensorUtils::GetDataOffset(tensor_desc, data_offset)); | ||||
uint8_t *weight_addr = static_cast<uint8_t *>(weight_base + data_offset - logic_weight_base); | uint8_t *weight_addr = static_cast<uint8_t *>(weight_base + data_offset - logic_weight_base); | ||||
v_input_data_addr.push_back(weight_addr); | v_input_data_addr.push_back(weight_addr); | ||||
GELOGI("[IMAS]GetInputDataAddrs graph_%u type[C] name[%s] input[%zu] memaddr[%p]", model_param.graph_id, | |||||
op_desc->GetName().c_str(), i, weight_addr); | |||||
}); | }); | ||||
non_const_index++; | non_const_index++; | ||||
continue; | continue; | ||||
@@ -411,7 +392,10 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co | |||||
non_const_index++; | non_const_index++; | ||||
GE_IF_BOOL_EXEC(var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(input_offset), | GE_IF_BOOL_EXEC(var_size != 0 && ge::VarManager::Instance(session_id)->IsVarAddr(input_offset), | ||||
uint8_t *variable_addr = var_base + input_offset - logic_var_base; | uint8_t *variable_addr = var_base + input_offset - logic_var_base; | ||||
v_input_data_addr.push_back(variable_addr); continue;); | |||||
v_input_data_addr.push_back(variable_addr); | |||||
GELOGI("[IMAS]GetInputDataAddrs graph_%u type[V] name[%s] input[%lu] memaddr[%p]", | |||||
model_param.graph_id, op_desc->GetName().c_str(), i, variable_addr); | |||||
continue;); | |||||
bool input_tensor = false; | bool input_tensor = false; | ||||
GE_IF_BOOL_EXEC(TensorUtils::GetInputTensor(op_desc->GetOutputDesc(i), input_tensor) != GRAPH_SUCCESS, | GE_IF_BOOL_EXEC(TensorUtils::GetInputTensor(op_desc->GetOutputDesc(i), input_tensor) != GRAPH_SUCCESS, | ||||
@@ -421,12 +405,14 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co | |||||
uint8_t *mem_addr = nullptr; | uint8_t *mem_addr = nullptr; | ||||
// l1 fusion | // l1 fusion | ||||
if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { | if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { | ||||
mem_addr = reinterpret_cast<uint8_t *>(input_offset); | |||||
mem_addr = reinterpret_cast<uint8_t *>(reinterpret_cast<intptr_t>(input_offset)); | |||||
v_input_data_addr.push_back(mem_addr); | v_input_data_addr.push_back(mem_addr); | ||||
} else { | } else { | ||||
mem_addr = static_cast<uint8_t *>(mem_base + input_offset - logic_mem_base); | mem_addr = static_cast<uint8_t *>(mem_base + input_offset - logic_mem_base); | ||||
v_input_data_addr.push_back(mem_addr); | v_input_data_addr.push_back(mem_addr); | ||||
} | } | ||||
GELOGI("[IMAS]GetInputDataAddrs graph_%u type[F] name[%s] input[%zu] memaddr[%p]", model_param.graph_id, | |||||
op_desc->GetName().c_str(), i, mem_addr); | |||||
} | } | ||||
return v_input_data_addr; | return v_input_data_addr; | ||||
@@ -487,12 +473,14 @@ vector<void *> ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C | |||||
uint8_t *mem_addr = nullptr; | uint8_t *mem_addr = nullptr; | ||||
// l1 fusion | // l1 fusion | ||||
if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { | if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { | ||||
mem_addr = reinterpret_cast<uint8_t *>(v_output_offset[i]); | |||||
mem_addr = reinterpret_cast<uint8_t *>(reinterpret_cast<intptr_t>(v_output_offset[i])); | |||||
v_output_data_addr.push_back(mem_addr); | v_output_data_addr.push_back(mem_addr); | ||||
} else { | } else { | ||||
mem_addr = static_cast<uint8_t *>(mem_base + v_output_offset[i] - logic_mem_base); | mem_addr = static_cast<uint8_t *>(mem_base + v_output_offset[i] - logic_mem_base); | ||||
v_output_data_addr.push_back(mem_addr); | v_output_data_addr.push_back(mem_addr); | ||||
} | } | ||||
GELOGI("[IMAS]GetOutputDataAddrs graph_%u type[F] name[%s] output[%zu] memaddr[%p]", model_param.graph_id, | |||||
op_desc->GetName().c_str(), i, mem_addr); | |||||
} | } | ||||
return v_output_data_addr; | return v_output_data_addr; | ||||
} | } | ||||
@@ -530,7 +518,7 @@ vector<void *> ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param | |||||
if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { | if (has_mem_type_attr && v_memory_type[i] != RT_MEMORY_HBM) { | ||||
v_workspace_data_addr.push_back(reinterpret_cast<uint8_t *>(v_workspace_offset[i])); | v_workspace_data_addr.push_back(reinterpret_cast<uint8_t *>(v_workspace_offset[i])); | ||||
GELOGI("L1Fusion: op: %s, GetWorkspaceDataAddrs mem_addr[workspace index %zu]:%p", op_desc->GetName().c_str(), i, | GELOGI("L1Fusion: op: %s, GetWorkspaceDataAddrs mem_addr[workspace index %zu]:%p", op_desc->GetName().c_str(), i, | ||||
reinterpret_cast<uint8_t *>(v_workspace_offset[i])); | |||||
reinterpret_cast<uint8_t *>(reinterpret_cast<intptr_t>(v_workspace_offset[i]))); | |||||
} else { | } else { | ||||
int64_t workspace_offset = v_workspace_offset[i]; | int64_t workspace_offset = v_workspace_offset[i]; | ||||
int64_t workspace_bytes = v_workspace_bytes[i]; | int64_t workspace_bytes = v_workspace_bytes[i]; | ||||
@@ -558,6 +546,7 @@ Status ModelUtils::ConvertVirtualAddressToPhysical(uint8_t *virtual_address, uin | |||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
GELOGD("virtual_address=%p, physical_address=%p", virtual_address, physical_address); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -40,13 +40,6 @@ class ModelUtils { | |||||
/// | /// | ||||
static bool IsOutput(ConstOpDescPtr op_desc); | static bool IsOutput(ConstOpDescPtr op_desc); | ||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief Check is the Input need trans code. | |||||
/// @return bool | |||||
/// | |||||
static bool IsInputTensorNeedTrans(ConstOpDescPtr op_desc, size_t tensor_index); | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief Get input size. | /// @brief Get input size. | ||||
@@ -38,6 +38,7 @@ Status EndGraphTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
} | } | ||||
model_ = davinci_model->GetRtModelHandle(); | model_ = davinci_model->GetRtModelHandle(); | ||||
GELOGI("InitEndGraphTaskInfo Init Success, model:%p, stream:%p", model_, stream_); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -125,6 +125,7 @@ Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_m | |||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
GELOGD("hccl_stream addr is=%p", stream); | |||||
hccl_stream_list_.push_back(stream); | hccl_stream_list_.push_back(stream); | ||||
davinci_model->PushHcclStream(stream); | davinci_model->PushHcclStream(stream); | ||||
} | } | ||||
@@ -245,6 +246,8 @@ void HcclTaskInfo::GetPrivateDefByTaskDef(const domi::TaskDef &task) { | |||||
GELOGE(RT_FAILED, "Call rtMemcpy Fail, ret = 0x%X.", ret); | GELOGE(RT_FAILED, "Call rtMemcpy Fail, ret = 0x%X.", ret); | ||||
return; | return; | ||||
} | } | ||||
GELOGI("The first address of the custom info, privateDef=%p.", private_def_); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -41,6 +41,7 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
} | } | ||||
auto kernel_ex_def = task_def.kernel_ex(); | auto kernel_ex_def = task_def.kernel_ex(); | ||||
const RuntimeParam &rts_param = davinci_model->GetRuntimeParam(); | |||||
// 1. Copy context from kernelExDef.private to workspace | // 1. Copy context from kernelExDef.private to workspace | ||||
uint32_t op_index = kernel_ex_def.op_index(); | uint32_t op_index = kernel_ex_def.op_index(); | ||||
@@ -50,12 +51,12 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
if (CopyTaskInfo(kernel_ex_def, davinci_model->GetRuntimeParam(), op_desc) != SUCCESS) { | |||||
if (CopyTaskInfo(kernel_ex_def, rts_param, op_desc) != SUCCESS) { | |||||
GELOGE(FAILED, "copy task info to workspace failed."); | GELOGE(FAILED, "copy task info to workspace failed."); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
vector<void *> workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(davinci_model->GetRuntimeParam(), op_desc); | |||||
const vector<void *> workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); | |||||
if (workspace_data_addrs.empty()) { | if (workspace_data_addrs.empty()) { | ||||
GELOGE(FAILED, "workspace_data_addrs is empty."); | GELOGE(FAILED, "workspace_data_addrs is empty."); | ||||
return FAILED; | return FAILED; | ||||
@@ -79,16 +80,16 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
uint64_t step_id_addr = 0; | uint64_t step_id_addr = 0; | ||||
OpDescPtr step_id_node = davinci_model->GetVariableOp(NODE_NAME_GLOBAL_STEP); | OpDescPtr step_id_node = davinci_model->GetVariableOp(NODE_NAME_GLOBAL_STEP); | ||||
if (step_id_node != nullptr) { | if (step_id_node != nullptr) { | ||||
vector<void *> v_step_id_addr = ModelUtils::GetOutputDataAddrs(davinci_model->GetRuntimeParam(), step_id_node); | |||||
vector<void *> v_step_id_addr = ModelUtils::GetOutputDataAddrs(rts_param, step_id_node); | |||||
if (!v_step_id_addr.empty()) { | if (!v_step_id_addr.empty()) { | ||||
step_id_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(v_step_id_addr[0])); | step_id_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(v_step_id_addr[0])); | ||||
} | } | ||||
} | } | ||||
// 3. Set workspaceaddr, inputOutputDataAddr | // 3. Set workspaceaddr, inputOutputDataAddr | ||||
uint64_t workspace_base_addr = reinterpret_cast<uint64_t>(workspace_data_addrs[0]); | |||||
vector<void *> input_addrs = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); | |||||
vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); | |||||
uint64_t workspace_base_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(workspace_data_addrs[0])); | |||||
const vector<void *> input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); | |||||
const vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); | |||||
vector<void *> io_addrs; | vector<void *> io_addrs; | ||||
io_addrs.insert(io_addrs.end(), input_addrs.begin(), input_addrs.end()); | io_addrs.insert(io_addrs.end(), input_addrs.begin(), input_addrs.end()); | ||||
io_addrs.insert(io_addrs.end(), output_addrs.begin(), output_addrs.end()); | io_addrs.insert(io_addrs.end(), output_addrs.begin(), output_addrs.end()); | ||||
@@ -132,7 +133,13 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
rt_ret = rtMemcpy(kernel_buf_, sizeof(STR_FWK_OP_KERNEL), static_cast<void *>(&fwk_op_kernel), | rt_ret = rtMemcpy(kernel_buf_, sizeof(STR_FWK_OP_KERNEL), static_cast<void *>(&fwk_op_kernel), | ||||
sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_DEVICE); | sizeof(STR_FWK_OP_KERNEL), RT_MEMCPY_HOST_TO_DEVICE); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) | ||||
davinci_model->SetZeroCopyAddr(op_desc, io_addrs, input_output_addr_); | |||||
vector<void *> virtual_io_addrs; // use virtual address for zero copy key. | |||||
const vector<void *> virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); | |||||
const vector<void *> virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); | |||||
virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); | |||||
virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); | |||||
davinci_model->SetZeroCopyAddr(op_desc, virtual_io_addrs, input_output_addr_); | |||||
kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); | kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); | ||||
davinci_model_ = davinci_model; | davinci_model_ = davinci_model; | ||||
@@ -25,6 +25,7 @@ class KernelExTaskInfo : public TaskInfo { | |||||
public: | public: | ||||
KernelExTaskInfo() | KernelExTaskInfo() | ||||
: task_id_(0), | : task_id_(0), | ||||
stream_id_(0), | |||||
dump_flag_(RT_KERNEL_DEFAULT), | dump_flag_(RT_KERNEL_DEFAULT), | ||||
kernel_buf_size_(0), | kernel_buf_size_(0), | ||||
davinci_model_(nullptr), | davinci_model_(nullptr), | ||||
@@ -221,13 +221,13 @@ Status KernelTaskInfo::SuperKernelLaunch() { | |||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
// Call the fuse API | // Call the fuse API | ||||
skt::SuperKernel *superKernel; | |||||
skt::SuperKernel *superKernel = nullptr; | |||||
if (factory->FuseKernels(skt_kernel_list, skt_arg_list, skt_info_.last_block_dim, superKernel) != SUCCESS) { | if (factory->FuseKernels(skt_kernel_list, skt_arg_list, skt_info_.last_block_dim, superKernel) != SUCCESS) { | ||||
GELOGE(RT_FAILED, "SuperKernelLaunch: fuse call failed"); | GELOGE(RT_FAILED, "SuperKernelLaunch: fuse call failed"); | ||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
// Launch a super kernel | // Launch a super kernel | ||||
if (superKernel->Launch(skt_info_.last_stream, true) != SUCCESS) { | |||||
if (superKernel->Launch(skt_info_.last_stream, RT_KERNEL_DUMPFLAG) != SUCCESS) { | |||||
GELOGE(RT_FAILED, "SuperKernelLaunch: launch failed"); | GELOGE(RT_FAILED, "SuperKernelLaunch: launch failed"); | ||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
@@ -341,6 +341,7 @@ Status KernelTaskInfo::Distribute() { | |||||
rtError_t rt_ret = RT_ERROR_NONE; | rtError_t rt_ret = RT_ERROR_NONE; | ||||
char *skt_enable_env = getenv("SKT_ENABLE"); | char *skt_enable_env = getenv("SKT_ENABLE"); | ||||
int64_t env_flag = (skt_enable_env != nullptr) ? strtol(skt_enable_env, nullptr, 10) : 0; | int64_t env_flag = (skt_enable_env != nullptr) ? strtol(skt_enable_env, nullptr, 10) : 0; | ||||
bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); | |||||
if (kernel_type_ == cce::ccKernelType::AI_CPU) { | if (kernel_type_ == cce::ccKernelType::AI_CPU) { | ||||
// blockDim is reserved parameter, set to 1 | // blockDim is reserved parameter, set to 1 | ||||
rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(so_name_.c_str()), | rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(so_name_.c_str()), | ||||
@@ -348,11 +349,10 @@ Status KernelTaskInfo::Distribute() { | |||||
nullptr, stream_, dump_flag_); | nullptr, stream_, dump_flag_); | ||||
} else { | } else { | ||||
/* default: not skt launch */ | /* default: not skt launch */ | ||||
bool call_skt = ((env_flag != 0) || is_l1_fusion_enable_); | |||||
GELOGI( | GELOGI( | ||||
"KernelTaskInfo Distribute Start, sktenable:%ld taskid:%u sktid:%u last_sktid:%u stubfunc_name:%s " | |||||
"KernelTaskInfo Distribute Start, sktenable:%d taskid:%u sktid:%u last_sktid:%u stubfunc_name:%s " | |||||
"stubfunc:%p blockdim:%u stream:%p", | "stubfunc:%p blockdim:%u stream:%p", | ||||
env_flag, task_id_, skt_id_, skt_info_.last_task_id, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); | |||||
call_skt, task_id_, skt_id_, skt_info_.last_task_id, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); | |||||
// l1 fusion enable and env flag open (kCloseSkt for skt debug) | // l1 fusion enable and env flag open (kCloseSkt for skt debug) | ||||
if (call_skt && (env_flag != kCloseSkt)) { | if (call_skt && (env_flag != kCloseSkt)) { | ||||
GE_RETURN_IF_ERROR(SuperKernelDistribute()); | GE_RETURN_IF_ERROR(SuperKernelDistribute()); | ||||
@@ -371,7 +371,7 @@ Status KernelTaskInfo::Distribute() { | |||||
GELOGI( | GELOGI( | ||||
"KernelTaskInfo Distribute Success. sktenable:%d taskid:%d sktid:%d stubfunc_name:%s stubfunc:%p " | "KernelTaskInfo Distribute Success. sktenable:%d taskid:%d sktid:%d stubfunc_name:%s stubfunc:%p " | ||||
"blockdim:%d stream:%p", | "blockdim:%d stream:%p", | ||||
env_flag, task_id_, skt_id_, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); | |||||
call_skt, task_id_, skt_id_, stub_func_name_.c_str(), stub_func_, block_dim_, stream_); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -423,12 +423,12 @@ Status KernelTaskInfo::InitTVMTask(DavinciModel *davinci_model, uint16_t offset, | |||||
stub_func_ = const_cast<char *>(bin_file_key); | stub_func_ = const_cast<char *>(bin_file_key); | ||||
} | } | ||||
const vector<void *> input_data_addrs = ModelUtils::GetInputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); | |||||
const vector<void *> output_data_addrs = ModelUtils::GetOutputDataAddrs(davinci_model->GetRuntimeParam(), op_desc); | |||||
const vector<void *> workspace_data_addrs = | |||||
ModelUtils::GetWorkspaceDataAddrs(davinci_model->GetRuntimeParam(), op_desc); | |||||
vector<void *> tensor_device_addrs; | |||||
const RuntimeParam &rts_param = davinci_model->GetRuntimeParam(); | |||||
const vector<void *> input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); | |||||
const vector<void *> output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); | |||||
const vector<void *> workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); | |||||
vector<void *> tensor_device_addrs; | |||||
tensor_device_addrs.insert(tensor_device_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); | tensor_device_addrs.insert(tensor_device_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); | ||||
tensor_device_addrs.insert(tensor_device_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); | tensor_device_addrs.insert(tensor_device_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); | ||||
tensor_device_addrs.insert(tensor_device_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); | tensor_device_addrs.insert(tensor_device_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); | ||||
@@ -468,7 +468,13 @@ Status KernelTaskInfo::InitTVMTask(DavinciModel *davinci_model, uint16_t offset, | |||||
reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(args_) + offset + sizeof(void *) * input_data_addrs.size()); | reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(args_) + offset + sizeof(void *) * input_data_addrs.size()); | ||||
} | } | ||||
davinci_model_->SetZeroCopyAddr(op_desc, tensor_device_addrs, static_cast<char *>(args_) + offset); | |||||
vector<void *> virtual_io_addrs; // use virtual address for zero copy key. | |||||
const vector<void *> virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); | |||||
const vector<void *> virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); | |||||
virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); | |||||
virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); | |||||
davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, static_cast<char *>(args_) + offset); | |||||
// update origin l2 data | // update origin l2 data | ||||
string sm_desc = kernel_def.sm_desc(); | string sm_desc = kernel_def.sm_desc(); | ||||
char *sm_contrl = nullptr; | char *sm_contrl = nullptr; | ||||
@@ -516,6 +522,7 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::map<uint32_t, std::shared_ | |||||
} | } | ||||
auto op_desc = iter->second; | auto op_desc = iter->second; | ||||
const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); | |||||
const domi::KernelContext &context = kernel_def.context(); | const domi::KernelContext &context = kernel_def.context(); | ||||
const uint32_t kCustomAicpuArgsLen = 5; | const uint32_t kCustomAicpuArgsLen = 5; | ||||
@@ -534,11 +541,8 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::map<uint32_t, std::shared_ | |||||
ctx_.argsOffset[i] = (reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data())))[i]; | ctx_.argsOffset[i] = (reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data())))[i]; | ||||
} | } | ||||
const std::vector<void *> input_data_addrs = | |||||
ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); | |||||
const std::vector<void *> output_data_addrs = | |||||
ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); | |||||
const std::vector<void *> input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); | |||||
const std::vector<void *> output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); | |||||
Status ret = StoreInputOutputTensor(input_data_addrs, output_data_addrs, ModelUtils::GetInputDescs(op_desc), | Status ret = StoreInputOutputTensor(input_data_addrs, output_data_addrs, ModelUtils::GetInputDescs(op_desc), | ||||
ModelUtils::GetOutputDescs(op_desc)); | ModelUtils::GetOutputDescs(op_desc)); | ||||
@@ -583,15 +587,15 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::map<uint32_t, std::shared_ | |||||
} | } | ||||
} | } | ||||
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[0])) = | *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[0])) = | ||||
reinterpret_cast<uint64_t>(custom_info_.input_descs); // arg 0 | |||||
reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_descs)); // arg 0 | |||||
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[1])) = | *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[1])) = | ||||
reinterpret_cast<uint64_t>(custom_info_.input_addrs); // arg 1 | |||||
reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.input_addrs)); // arg 1 | |||||
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[2])) = | *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[2])) = | ||||
reinterpret_cast<uint64_t>(custom_info_.output_descs); // arg 2 | |||||
reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_descs)); // arg 2 | |||||
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[3])) = | *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[3])) = | ||||
reinterpret_cast<uint64_t>(custom_info_.output_addrs); // arg 3 | |||||
reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.output_addrs)); // arg 3 | |||||
*(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[4])) = | *(reinterpret_cast<uint64_t *>(args + ctx_.argsOffset[4])) = | ||||
reinterpret_cast<uint64_t>(custom_info_.attr_handle); // arg 4 | |||||
reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(custom_info_.attr_handle)); // arg 4 | |||||
rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); | rt_ret = rtMalloc(&args_, args_size_, RT_MEMORY_HBM); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
@@ -606,8 +610,10 @@ Status KernelTaskInfo::InitAICPUCustomTask(const std::map<uint32_t, std::shared_ | |||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
davinci_model_->SetZeroCopyAddr(op_desc, input_data_addrs, custom_info_.input_addrs); | |||||
davinci_model_->SetZeroCopyAddr(op_desc, output_data_addrs, custom_info_.output_addrs); | |||||
const vector<void *> virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); | |||||
const vector<void *> virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); | |||||
davinci_model_->SetZeroCopyAddr(op_desc, virtual_in_addrs, custom_info_.input_addrs); | |||||
davinci_model_->SetZeroCopyAddr(op_desc, virtual_out_addrs, custom_info_.output_addrs); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -714,8 +720,10 @@ Status KernelTaskInfo::InitAicpuTask(const std::map<uint32_t, OpDescPtr> &op_lis | |||||
} | } | ||||
OpDescPtr op_desc = iter->second; | OpDescPtr op_desc = iter->second; | ||||
vector<void *> input_addrs = ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); | |||||
vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); | |||||
const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); | |||||
vector<void *> input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); | |||||
vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); | |||||
vector<void *> io_addrs; | vector<void *> io_addrs; | ||||
io_addrs.insert(io_addrs.end(), input_addrs.begin(), input_addrs.end()); | io_addrs.insert(io_addrs.end(), input_addrs.begin(), input_addrs.end()); | ||||
io_addrs.insert(io_addrs.end(), output_addrs.begin(), output_addrs.end()); | io_addrs.insert(io_addrs.end(), output_addrs.begin(), output_addrs.end()); | ||||
@@ -752,7 +760,13 @@ Status KernelTaskInfo::InitAicpuTask(const std::map<uint32_t, OpDescPtr> &op_lis | |||||
sizeof(void *) * input_addrs.size()); | sizeof(void *) * input_addrs.size()); | ||||
} | } | ||||
davinci_model_->SetZeroCopyAddr(op_desc, io_addrs, static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead)); | |||||
vector<void *> virtual_io_addrs; // use virtual address for zero copy key. | |||||
const vector<void *> virtual_in_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc, false); | |||||
const vector<void *> virtual_out_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc, false); | |||||
virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_in_addrs.begin(), virtual_in_addrs.end()); | |||||
virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); | |||||
davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, | |||||
static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead)); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -977,7 +991,7 @@ Status KernelTaskInfo::SetFlowtable(std::string &flowtable, const domi::KernelDe | |||||
*(reinterpret_cast<uint64_t *>( | *(reinterpret_cast<uint64_t *>( | ||||
args + (reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data())))[0])) = | args + (reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data())))[0])) = | ||||
reinterpret_cast<uint64_t>(flowtable_); | |||||
reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(flowtable_)); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -0,0 +1,149 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/load/new_model_manager/task_info/memcpy_addr_async_task_info.h" | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "graph/load/new_model_manager/davinci_model.h" | |||||
namespace ge { | |||||
Status MemcpyAddrAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||||
GELOGI("MemcpyAddrAsyncTaskInfo Init Start."); | |||||
if (davinci_model == nullptr) { | |||||
GELOGE(PARAM_INVALID, "davinci_model is null!"); | |||||
return PARAM_INVALID; | |||||
} | |||||
Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | |||||
auto memcpy_async_def = task_def.memcpy_async(); | |||||
uint64_t logic_dst = memcpy_async_def.dst(); | |||||
uint64_t logic_src = memcpy_async_def.src(); | |||||
dst_max_ = memcpy_async_def.dst_max(); | |||||
uint64_t update_base_addr = 0; | |||||
ret = GetUpdateBaseAddr(davinci_model, logic_src, update_base_addr); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | |||||
src_ = reinterpret_cast<uint8_t *>(update_base_addr + logic_src); | |||||
if (src_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "src_ is null!"); | |||||
return PARAM_INVALID; | |||||
} | |||||
uint64_t mem_base = reinterpret_cast<uint64_t>(davinci_model->MemBase()); | |||||
uint64_t logic_mem_base = davinci_model->GetRtBaseAddr(); | |||||
dst_ = reinterpret_cast<uint8_t *>(mem_base + (logic_dst - logic_mem_base)); | |||||
if (dst_ == nullptr) { | |||||
GELOGE(PARAM_INVALID, "dst_ is null!"); | |||||
return PARAM_INVALID; | |||||
} | |||||
count_ = memcpy_async_def.count(); | |||||
kind_ = memcpy_async_def.kind(); | |||||
// malloc args memory | |||||
size_t args_size = sizeof(void *); | |||||
rtError_t rt_ret = rtMalloc(&args_, args_size * 2, RT_MEMORY_HBM); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return RT_FAILED; | |||||
} | |||||
// copy orign src | |||||
GELOGI("src_args:%p, destMax:%zu, src_:%p, count=%zu, kind=%u", args_, args_size, src_, args_size, | |||||
RT_MEMCPY_HOST_TO_DEVICE); | |||||
rt_ret = rtMemcpy(args_, args_size, &src_, args_size, RT_MEMCPY_HOST_TO_DEVICE); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api for src failed, ret: 0x%X", rt_ret); | |||||
return RT_FAILED; | |||||
} | |||||
// copy orign dst | |||||
GELOGI("dst_args:%p, destMax:%zu, dst_:%p, count=%zu, kind=%u", | |||||
reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(args_) + args_size), args_size, dst_, args_size, | |||||
RT_MEMCPY_HOST_TO_DEVICE); | |||||
rt_ret = rtMemcpy(reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(args_) + args_size), args_size, &dst_, | |||||
args_size, RT_MEMCPY_HOST_TO_DEVICE); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api for dst failed, ret: 0x%X", rt_ret); | |||||
return RT_FAILED; | |||||
} | |||||
GELOGI("InitMemcpyAddrAsyncTaskInfo, logic_src:%p, logic_dst:%p, src:%p, dst:%p, src_args:%p, dst_args:%p", | |||||
reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(logic_src)), | |||||
reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(logic_dst)), src_, dst_, args_, | |||||
reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(args_) + args_size)); | |||||
return SUCCESS; | |||||
} | |||||
Status MemcpyAddrAsyncTaskInfo::Distribute() { | |||||
GELOGI("MemcpyAddrAsyncTaskInfo Distribute Start."); | |||||
GELOGI("Distribute MemcpyAddrAsync, dst_max:%lu, count:%lu, kind:%u.", dst_max_, count_, kind_); | |||||
rtError_t rt_ret = rtMemcpyAsync(reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(args_) + sizeof(void *)), | |||||
dst_max_, args_, count_, static_cast<rtMemcpyKind_t>(kind_), stream_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
return RT_FAILED; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status MemcpyAddrAsyncTaskInfo::GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, | |||||
uint64_t &base_addr) { | |||||
GE_CHECK_NOTNULL(davinci_model); | |||||
uint64_t data_base_addr = | |||||
reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(davinci_model->MemBase())) - davinci_model->GetRtBaseAddr(); | |||||
uint64_t weight_base_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(davinci_model->WeightsMemBase())) - | |||||
davinci_model->GetRtWeightAddr(); | |||||
uint64_t var_base_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(davinci_model->VarMemBase())) - | |||||
davinci_model->GetRtVarAddr(); | |||||
uint64_t data_base_addr_start = davinci_model->GetRtBaseAddr(); | |||||
uint64_t data_base_addr_end = davinci_model->GetRtBaseAddr() + davinci_model->TotalMemSize(); | |||||
uint64_t wight_base_addr_start = davinci_model->GetRtWeightAddr(); | |||||
uint64_t wight_base_addr_end = davinci_model->GetRtWeightAddr() + davinci_model->TotalWeightsMemSize(); | |||||
uint64_t varible_base_addr_start = davinci_model->GetRtVarAddr(); | |||||
uint64_t varible_base_addr_end = davinci_model->GetRtVarAddr() + davinci_model->TotalVarMemSize(); | |||||
if ((data_base_addr_start <= update_addr) && (update_addr <= data_base_addr_end)) { | |||||
base_addr = data_base_addr; | |||||
GELOGI("The update_addr is data address."); | |||||
} else if ((wight_base_addr_start <= update_addr) && (update_addr <= wight_base_addr_end)) { | |||||
base_addr = weight_base_addr; | |||||
GELOGI("The update_addr is weight address."); | |||||
} else if ((varible_base_addr_start <= update_addr) && (update_addr <= varible_base_addr_end)) { | |||||
base_addr = var_base_addr; | |||||
GELOGI("The update_addr is variable address."); | |||||
} else if (update_addr != 0) { | |||||
base_addr = 0; | |||||
GELOGE(PARAM_INVALID, "The update_addr is abnormal."); | |||||
return PARAM_INVALID; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
REGISTER_TASK_INFO(RT_MODEL_TASK_MEMCPY_ADDR_ASYNC, MemcpyAddrAsyncTaskInfo); | |||||
} // namespace ge |
@@ -0,0 +1,55 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ | |||||
#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ | |||||
#include "graph/load/new_model_manager/task_info/task_info.h" | |||||
namespace ge { | |||||
class MemcpyAddrAsyncTaskInfo : public TaskInfo { | |||||
public: | |||||
MemcpyAddrAsyncTaskInfo() : dst_(nullptr), dst_max_(0), src_(nullptr), args_(nullptr), count_(0), kind_(0) {} | |||||
~MemcpyAddrAsyncTaskInfo() override { | |||||
src_ = nullptr; | |||||
dst_ = nullptr; | |||||
if (args_ != nullptr) { | |||||
rtError_t ret = rtFree(args_); | |||||
if (ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", ret); | |||||
} | |||||
} | |||||
args_ = nullptr; | |||||
} | |||||
Status Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||||
Status Distribute() override; | |||||
private: | |||||
Status GetUpdateBaseAddr(DavinciModel *davinci_model, uint64_t update_addr, uint64_t &base_addr); | |||||
void *dst_; | |||||
uint64_t dst_max_; | |||||
void *src_; | |||||
void *args_; | |||||
uint64_t count_; | |||||
uint32_t kind_; | |||||
}; | |||||
} // namespace ge | |||||
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_MEMCPY_ADDR_ASYNC_TASK_INFO_H_ |
@@ -51,6 +51,9 @@ Status MemcpyAsyncTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *da | |||||
count_ = memcpy_async_def.count(); | count_ = memcpy_async_def.count(); | ||||
kind_ = memcpy_async_def.kind(); | kind_ = memcpy_async_def.kind(); | ||||
GELOGI("MemcpyAsyncTaskInfo Init Success, logic_src:%p, logic_dst:%p, src:%p, dst:%p", | |||||
reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(logic_src)), | |||||
reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(logic_dst)), src_, dst_); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -63,6 +63,8 @@ Status StreamActiveTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *d | |||||
active_stream_ = davinci_model->GetStreamList()[active_stream_index_list[internal_index]]; | active_stream_ = davinci_model->GetStreamList()[active_stream_index_list[internal_index]]; | ||||
active_stream_id_ = stream_active_def.active_stream_id(); | active_stream_id_ = stream_active_def.active_stream_id(); | ||||
GELOGI("InitStreamActiveTaskInfo Init Success, index:%u, activeStream:%p, activeStreamID:%u.", internal_index, | |||||
active_stream_, active_stream_id_); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -74,6 +76,8 @@ Status StreamActiveTaskInfo::Distribute() { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
GELOGI("StreamActiveTaskInfo Distribute Success. activeStreamID:%p.", active_stream_); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -95,6 +95,10 @@ Status StreamSwitchTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *d | |||||
} | } | ||||
data_type_ = static_cast<rtSwitchDataType_t>(data_type); | data_type_ = static_cast<rtSwitchDataType_t>(data_type); | ||||
} | } | ||||
GELOGI("InitStreamSwitchTaskInfo Init Success, cond:%d, trueStream:%p, trueStreamID:%u, datatype:%d.", cond_, | |||||
true_stream_, true_stream_id_, data_type_); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -105,6 +109,8 @@ Status StreamSwitchTaskInfo::Distribute() { | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
GELOGI("StreamSwitchTaskInfo Distribute Success. cond:%d, stream:%p, datatype:%d.", cond_, true_stream_, data_type_); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -19,17 +19,17 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace skt { | namespace skt { | ||||
Status SuperKernel::Launch(rtStream_t stream, bool dump_flag) { | |||||
Status SuperKernel::Launch(rtStream_t stream, uint32_t dump_flag) { | |||||
const void *func_stub_ = this->GetFuncStub(); | const void *func_stub_ = this->GetFuncStub(); | ||||
const void *args[] = {this->GetNavTablePtr(), (const void *)this->GetNavTableSize()}; | |||||
const void *args[] = {this->GetNavTablePtr(), | |||||
reinterpret_cast<const void *>(reinterpret_cast<uintptr_t>(this->GetNavTableSize()))}; | |||||
void *device_args_addr = nullptr; | |||||
rtError_t rt_ret = rtMalloc((void **)&(device_args_addr), sizeof(args), RT_MEMORY_HBM); | |||||
rtError_t rt_ret = rtMalloc((void **)&(device_args_addr_), sizeof(args), RT_MEMORY_HBM); | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) | ||||
rt_ret = rtMemcpy((void *)device_args_addr, sizeof(args), (void *)args, sizeof(args), RT_MEMCPY_HOST_TO_DEVICE); | |||||
rt_ret = rtMemcpy((void *)device_args_addr_, sizeof(args), (void *)args, sizeof(args), RT_MEMCPY_HOST_TO_DEVICE); | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) | ||||
rt_ret = rtKernelLaunchWithFlag((void *const)func_stub_, block_dim_, device_args_addr, sizeof(args), NULL, stream, | |||||
rt_ret = rtKernelLaunchWithFlag((void *const)func_stub_, block_dim_, device_args_addr_, sizeof(args), NULL, stream, | |||||
dump_flag); | dump_flag); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelLaunchWithFlag failied. error: 0x%X", rt_ret); | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelLaunchWithFlag failied. error: 0x%X", rt_ret); | ||||
return FAILED;) | return FAILED;) | ||||
@@ -25,6 +25,7 @@ namespace ge { | |||||
namespace skt { | namespace skt { | ||||
class SuperKernel { | class SuperKernel { | ||||
private: | private: | ||||
void *device_args_addr_ = nullptr; | |||||
const void *func_stub_; | const void *func_stub_; | ||||
void *dev_nav_table_; | void *dev_nav_table_; | ||||
uint64_t nav_table_size_; | uint64_t nav_table_size_; | ||||
@@ -33,8 +34,18 @@ class SuperKernel { | |||||
public: | public: | ||||
SuperKernel(const void *stub, void *ptr, uint64_t sz, uint32_t dim) | SuperKernel(const void *stub, void *ptr, uint64_t sz, uint32_t dim) | ||||
: func_stub_(stub), dev_nav_table_(ptr), nav_table_size_(sz), block_dim_(dim) {} | : func_stub_(stub), dev_nav_table_(ptr), nav_table_size_(sz), block_dim_(dim) {} | ||||
~SuperKernel() {} | |||||
Status Launch(rtStream_t stream, bool dump_flag); | |||||
~SuperKernel() { | |||||
// free memory when all releasing | |||||
if (device_args_addr_ != nullptr) { | |||||
GE_CHK_RT(rtFree(device_args_addr_)); | |||||
GELOGI("SKT: super_kernel args addr free."); | |||||
} | |||||
if (dev_nav_table_ != nullptr) { | |||||
GE_CHK_RT(rtFree(dev_nav_table_)); | |||||
GELOGI("SKT: super_kernel args addr free."); | |||||
} | |||||
} | |||||
Status Launch(rtStream_t stream, uint32_t dump_flag); | |||||
const void *GetFuncStub() const { return func_stub_; } | const void *GetFuncStub() const { return func_stub_; } | ||||
const void *GetNavTablePtr() const { return dev_nav_table_; } | const void *GetNavTablePtr() const { return dev_nav_table_; } | ||||
uint64_t GetNavTableSize() const { return nav_table_size_; } | uint64_t GetNavTableSize() const { return nav_table_size_; } | ||||
@@ -30,26 +30,26 @@ Status SuperKernelFactory::Init() { | |||||
rt_ret = rtGetFunctionByName(this->sk_stub_name_.c_str(), &this->func_stub_); | rt_ret = rtGetFunctionByName(this->sk_stub_name_.c_str(), &this->func_stub_); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, | ||||
"rtGetFunctionByName " | "rtGetFunctionByName " | ||||
"failied. stub_func: %s", | |||||
"failed. stub_func: %s", | |||||
this->sk_stub_name_.c_str()); | this->sk_stub_name_.c_str()); | ||||
return FAILED;) | return FAILED;) | ||||
rt_ret = rtGetAddrByFun(this->func_stub_, &this->func_ptr_); | rt_ret = rtGetAddrByFun(this->func_stub_, &this->func_ptr_); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failied. error: 0x%X", rt_ret); | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); | |||||
return FAILED;) | return FAILED;) | ||||
if (this->use_physical_address_ != nullptr) { | if (this->use_physical_address_ != nullptr) { | ||||
void *skt_func = nullptr; | void *skt_func = nullptr; | ||||
rt_ret = rtKernelConfigTransArg(this->func_ptr_, sizeof(uint64_t), 0, &skt_func); | rt_ret = rtKernelConfigTransArg(this->func_ptr_, sizeof(uint64_t), 0, &skt_func); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); | |||||
return FAILED;) | return FAILED;) | ||||
GELOGD( | GELOGD( | ||||
"SKT: fuseKernels super_kernel_template subFunc %p, device func " | "SKT: fuseKernels super_kernel_template subFunc %p, device func " | ||||
"address %p, device physic PC %p", | "address %p, device physic PC %p", | ||||
(uint64_t)this->func_stub_, (uint64_t)this->func_ptr_, (uint64_t)skt_func); | |||||
this->func_stub_, this->func_ptr_, skt_func); | |||||
} else { | } else { | ||||
GELOGD( | GELOGD( | ||||
"SKT: fuseKernels super_kernel_template subFunc %p, device func " | "SKT: fuseKernels super_kernel_template subFunc %p, device func " | ||||
"address %p", | "address %p", | ||||
(uint64_t)this->func_stub_, (uint64_t)this->func_ptr_); | |||||
this->func_stub_, this->func_ptr_); | |||||
} | } | ||||
} | } | ||||
is_init_ = true; | is_init_ = true; | ||||
@@ -94,63 +94,66 @@ Status SuperKernelFactory::FuseKernels(const std::vector<void *> &stub_func_list | |||||
uint64_t nav_table_size = 2 * stub_func_list.size() * sizeof(int64_t); | uint64_t nav_table_size = 2 * stub_func_list.size() * sizeof(int64_t); | ||||
rtError_t rt_ret; | rtError_t rt_ret; | ||||
void *hbm_nav_table_addr = nullptr; | |||||
if (this->use_physical_address_ != nullptr) { | if (this->use_physical_address_ != nullptr) { | ||||
for (unsigned i = 0; i < stub_func_list.size(); i++) { | for (unsigned i = 0; i < stub_func_list.size(); i++) { | ||||
void *sub_device_func = nullptr; | void *sub_device_func = nullptr; | ||||
rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); | rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failied. error: 0x%X", rt_ret); | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); | |||||
return FAILED;) | return FAILED;) | ||||
void *sub_device_func_pys = nullptr; | void *sub_device_func_pys = nullptr; | ||||
void *args_addr_pys = nullptr; | void *args_addr_pys = nullptr; | ||||
rt_ret = rtKernelConfigTransArg(sub_device_func, sizeof(uint64_t), 0, &sub_device_func_pys); | rt_ret = rtKernelConfigTransArg(sub_device_func, sizeof(uint64_t), 0, &sub_device_func_pys); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); | |||||
return FAILED;) | return FAILED;) | ||||
rt_ret = rtKernelConfigTransArg(args_addr_list[i], sizeof(uint64_t), 0, &args_addr_pys); | rt_ret = rtKernelConfigTransArg(args_addr_list[i], sizeof(uint64_t), 0, &args_addr_pys); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); | |||||
return FAILED;) | return FAILED;) | ||||
GELOGD( | GELOGD( | ||||
"SKT: fuseKernels subFunc %p, device func address %p, device " | "SKT: fuseKernels subFunc %p, device func address %p, device " | ||||
"physic func address %p", | "physic func address %p", | ||||
stub_func_list[i], (uint64_t)sub_device_func, (uint64_t)sub_device_func_pys); | |||||
nav_table[i * 2] = (uint64_t)sub_device_func_pys / 4; | |||||
GELOGD("SKT: CALL offet %p", nav_table[i * 2]); | |||||
nav_table[i * 2 + 1] = (uint64_t)args_addr_pys; | |||||
stub_func_list[i], sub_device_func, sub_device_func_pys); | |||||
// store two uint64_t address | |||||
// address divided by 4 because of 32bits encoding, call offset will *4 when calculating | |||||
nav_table[i * 2] = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(sub_device_func_pys)) / 4; | |||||
GELOGD("SKT: CALL offset %p", nav_table[i * 2]); | |||||
nav_table[i * 2 + 1] = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(args_addr_pys)); | |||||
GELOGD("SKT: fuseKernels args base address %p", nav_table[i * 2 + 1]); | GELOGD("SKT: fuseKernels args base address %p", nav_table[i * 2 + 1]); | ||||
} | } | ||||
void *hbm_nav_table_addr = nullptr; | |||||
void *hbm_nav_table_addr_pys = nullptr; | void *hbm_nav_table_addr_pys = nullptr; | ||||
rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); | rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) | |||||
rt_ret = | rt_ret = | ||||
rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); | rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); return FAILED;) | |||||
rt_ret = rtKernelConfigTransArg(hbm_nav_table_addr, sizeof(uint64_t), 0, &hbm_nav_table_addr_pys); | rt_ret = rtKernelConfigTransArg(hbm_nav_table_addr, sizeof(uint64_t), 0, &hbm_nav_table_addr_pys); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failied. error: 0x%X", rt_ret); | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtKernelConfigTransArg failed. error: 0x%X", rt_ret); | |||||
return FAILED;) | return FAILED;) | ||||
GELOGD("SKT: hbm_nav_table_addr %p, hbm_nav_table_addr_pys %p", (uint64_t)hbm_nav_table_addr, | |||||
(uint64_t)hbm_nav_table_addr_pys); | |||||
GELOGD("SKT: hbm_nav_table_addr %p, hbm_nav_table_addr_pys %p", hbm_nav_table_addr, hbm_nav_table_addr_pys); | |||||
// Create the necessary metadata for the super kernel | // Create the necessary metadata for the super kernel | ||||
h = new SuperKernel(this->func_stub_, hbm_nav_table_addr_pys, nav_table_size, block_dim); | h = new SuperKernel(this->func_stub_, hbm_nav_table_addr_pys, nav_table_size, block_dim); | ||||
} else { | } else { | ||||
for (unsigned i = 0; i < stub_func_list.size(); i++) { | for (unsigned i = 0; i < stub_func_list.size(); i++) { | ||||
void *sub_device_func = nullptr; | void *sub_device_func = nullptr; | ||||
rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); | rt_ret = rtGetAddrByFun(stub_func_list[i], &sub_device_func); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failied. error: 0x%X", rt_ret); | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtGetAddrByFun failed. error: 0x%X", rt_ret); | |||||
return FAILED;) | return FAILED;) | ||||
GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], (uint64_t)sub_device_func); | |||||
nav_table[i * 2] = (uint64_t)sub_device_func / 4; | |||||
GELOGD("SKT: fuseKernels subFunc %p, device func address %p", stub_func_list[i], sub_device_func); | |||||
// store two uint64_t address | |||||
// address divided by 4 because of 32bits encoding, call offset will *4 when calculating | |||||
nav_table[i * 2] = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(sub_device_func)) / 4; | |||||
GELOGD("SKT: CALL offet %p", nav_table[i * 2]); | GELOGD("SKT: CALL offet %p", nav_table[i * 2]); | ||||
nav_table[i * 2 + 1] = (uint64_t)args_addr_list[i]; | |||||
nav_table[i * 2 + 1] = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(args_addr_list[i])); | |||||
GELOGD("SKT: fuseKernels args base address %p", nav_table[i * 2 + 1]); | GELOGD("SKT: fuseKernels args base address %p", nav_table[i * 2 + 1]); | ||||
} | } | ||||
void *hbm_nav_table_addr = nullptr; | |||||
rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); | rt_ret = rtMalloc((void **)&hbm_nav_table_addr, nav_table_size, RT_MEMORY_HBM); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failied. error: 0x%X", rt_ret); return FAILED;) | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc failed. error: 0x%X", rt_ret); return FAILED;) | |||||
rt_ret = | rt_ret = | ||||
rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); | rtMemcpy((void *)hbm_nav_table_addr, nav_table_size, (void *)nav_table, nav_table_size, RT_MEMCPY_HOST_TO_DEVICE); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failied. error: 0x%X", rt_ret); return FAILED;) | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy failed. error: 0x%X", rt_ret); return FAILED;) | |||||
// Create the necessary metadata for the super kernel | // Create the necessary metadata for the super kernel | ||||
h = new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim); | h = new SuperKernel(this->func_stub_, hbm_nav_table_addr, nav_table_size, block_dim); | ||||
} | } | ||||
@@ -31,12 +31,12 @@ class SuperKernelFactory { | |||||
const char *use_physical_address_ = getenv("GE_USE_PHYSICAL_ADDRESS"); | const char *use_physical_address_ = getenv("GE_USE_PHYSICAL_ADDRESS"); | ||||
bool is_init_ = false; | bool is_init_ = false; | ||||
SuperKernelFactory(){}; | SuperKernelFactory(){}; | ||||
~SuperKernelFactory(){}; | |||||
public: | public: | ||||
SuperKernelFactory(SuperKernelFactory const &) = delete; | SuperKernelFactory(SuperKernelFactory const &) = delete; | ||||
void operator=(SuperKernelFactory const &) = delete; | void operator=(SuperKernelFactory const &) = delete; | ||||
static SuperKernelFactory &GetInstance(); | static SuperKernelFactory &GetInstance(); | ||||
SuperKernelFactory(const std::string &sk_stub_name_, const std::string &bin_file); | |||||
Status Init(); | Status Init(); | ||||
Status Uninitialize(); | Status Uninitialize(); | ||||
Status FuseKernels(const std::vector<void *> &stub_func_list, const std::vector<void *> &args_addr_list, | Status FuseKernels(const std::vector<void *> &stub_func_list, const std::vector<void *> &args_addr_list, | ||||
@@ -33,6 +33,7 @@ | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "framework/common/ge_types.h" | #include "framework/common/ge_types.h" | ||||
#include "graph/manager/util/rt_context_util.h" | |||||
#include "graph/common/transop_util.h" | #include "graph/common/transop_util.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
@@ -117,6 +118,7 @@ Status GraphManager::Initialize(const std::map<string, string> &options) { | |||||
} | } | ||||
graph_map_.clear(); | graph_map_.clear(); | ||||
cache_helper_map_.clear(); | |||||
init_flag_ = true; | init_flag_ = true; | ||||
thread_run_flag_ = true; | thread_run_flag_ = true; | ||||
@@ -180,6 +182,7 @@ Status GraphManager::Finalize() { | |||||
} | } | ||||
} | } | ||||
graph_map_.clear(); | graph_map_.clear(); | ||||
cache_helper_map_.clear(); | |||||
// graph context | // graph context | ||||
if (graph_context_ != nullptr) { | if (graph_context_ != nullptr) { | ||||
@@ -426,6 +429,13 @@ Status GraphManager::PreRun(const GraphNodePtr &graph_node, const std::vector<Ge | |||||
sub_graph_list[0]->SetSubGraph(merged_compute_graph); | sub_graph_list[0]->SetSubGraph(merged_compute_graph); | ||||
// set subgraphlist to graphnode | // set subgraphlist to graphnode | ||||
graph_node->SetSubGraph(sub_graph_list); | graph_node->SetSubGraph(sub_graph_list); | ||||
// when set incre build, save om model and var manager | |||||
auto save_ret = SaveCacheAfterBuild(graph_node->GetGraphId(), merged_compute_graph, ge_model); | |||||
if (save_ret != SUCCESS) { | |||||
GELOGW("Fail to save cache."); | |||||
} | |||||
// release rts generate context | |||||
RtContextUtil::GetInstance().DestroyrtContexts(); | |||||
GE_TIMESTAMP_END(PreRun, "GraphManager::PreRun"); | GE_TIMESTAMP_END(PreRun, "GraphManager::PreRun"); | ||||
GEEVENT("[GEPERFTRACE] GE PreRun End"); | GEEVENT("[GEPERFTRACE] GE PreRun End"); | ||||
return ret; | return ret; | ||||
@@ -444,10 +454,14 @@ Status GraphManager::StartForRunGraph(const GraphNodePtr &graph_node, const std: | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
GeModelPtr ge_model = nullptr; | GeModelPtr ge_model = nullptr; | ||||
ret = PreRun(graph_node, inputs, ge_models, ge_model, session_id); | |||||
// check need incre build. | |||||
ret = IncreBuild(graph_node, ge_model); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "PreRun Failed."); | |||||
return ret; | |||||
ret = PreRun(graph_node, inputs, ge_models, ge_model, session_id); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "PreRun Failed."); | |||||
return ret; | |||||
} | |||||
} | } | ||||
ret = LoadGraph(ge_model, graph_node); | ret = LoadGraph(ge_model, graph_node); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
@@ -492,6 +506,90 @@ Status GraphManager::LoadGraph(const GeModelPtr &ge_model, const GraphNodePtr &g | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphManager::LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, | |||||
GeModelPtr &ge_model) { | |||||
auto graph_id = graph_node->GetGraphId(); | |||||
auto ret = cache_helper->LoadOmModelFromCache(ge_model); | |||||
if (ret != SUCCESS) { | |||||
GELOGW("Fail to load om model from cache."); | |||||
if (cache_helper->ClearCache(graph_id) != SUCCESS) { | |||||
GELOGW("Fail to clear cache of graph %u.", graph_id); | |||||
} | |||||
return FAILED; | |||||
} | |||||
ret = cache_helper->RecoverVarManagerFromCache(); | |||||
if (ret != SUCCESS) { | |||||
GELOGW("Fail to recover VarManager from cache."); | |||||
if (cache_helper->ClearCache(graph_id) != SUCCESS) { | |||||
GELOGW("Fail to clear cache of graph %u.", graph_id); | |||||
} | |||||
return FAILED; | |||||
} | |||||
ComputeGraphPtr compute_graph_in_model = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | |||||
if (compute_graph_in_model == nullptr) { | |||||
GELOGW("Error occurred when get compute graph from om, abandon."); | |||||
return FAILED; | |||||
} else { | |||||
graph_node->SetComputeGraph(compute_graph_in_model); | |||||
graph_node->SetGeModel(ge_model); | |||||
GELOGI("Load model and graph form cache om file."); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status GraphManager::SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper) { | |||||
auto ret = cache_helper->SaveCacheInfoToCache(); | |||||
if (ret != SUCCESS) { | |||||
GELOGW("Fail to save cache info of graph[%d] to cache.", graph_id); | |||||
return FAILED; | |||||
} | |||||
ret = cache_helper->SaveVarManagerToCache(true); | |||||
if (ret != SUCCESS) { | |||||
GELOGW("Fail to save var manager to cache."); | |||||
cache_helper->ClearCache(graph_id); | |||||
return FAILED; | |||||
} | |||||
GELOGI("Cache files have been saved."); | |||||
return SUCCESS; | |||||
} | |||||
Status GraphManager::SaveCacheAfterBuild(uint32_t graph_id, ge::ComputeGraphPtr graph, GeModelPtr &ge_model) { | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
if ((instance_ptr == nullptr) || !instance_ptr->InitFlag()) { | |||||
GELOGW("GELib not initialized."); | |||||
return FAILED; | |||||
} | |||||
if (instance_ptr->IsIncreBuild()) { | |||||
auto iter = cache_helper_map_.find(graph_id); | |||||
if (iter == cache_helper_map_.end()) { | |||||
GELOGW("Can not find ModelCacheHelper of graph[%u]", graph_id); | |||||
return FAILED; | |||||
} else { | |||||
ModelCacheHelperPtr cache_helper = iter->second; | |||||
auto ret = cache_helper->RefreshComputeGraph(graph); | |||||
if (ret != SUCCESS) { | |||||
cache_helper->ClearCache(graph_id); | |||||
GELOGW("Fail to refresh cache helper's compute graph"); | |||||
return FAILED; | |||||
} | |||||
ret = cache_helper->SaveVarManagerToCache(false); | |||||
if (ret != SUCCESS) { | |||||
cache_helper->ClearCache(graph_id); | |||||
GELOGW("Fail to save VarManager to cache"); | |||||
return FAILED; | |||||
} | |||||
ret = cache_helper->SaveOmModelToCache(ge_model); | |||||
if (ret != SUCCESS) { | |||||
cache_helper->ClearCache(graph_id); | |||||
GELOGW("Fail to save om model to cache"); | |||||
return FAILED; | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, | Status GraphManager::InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, | ||||
const std::vector<GeTensor> &inputs, std::vector<GeTensor> &outputs) { | const std::vector<GeTensor> &inputs, std::vector<GeTensor> &outputs) { | ||||
Status ret = graph_executor_.SetCondition(&sync_run_mutex_, &condition_, graph_run_listener_); | Status ret = graph_executor_.SetCondition(&sync_run_mutex_, &condition_, graph_run_listener_); | ||||
@@ -551,6 +649,9 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vector<GeTenso | |||||
GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "[RunGraph] compute_graph_tmp is NULL, graph id = %u.", graph_id); | GELOGE(GE_GRAPH_GRAPH_NODE_NULL, "[RunGraph] compute_graph_tmp is NULL, graph id = %u.", graph_id); | ||||
return GE_GRAPH_GRAPH_NODE_NULL;)) | return GE_GRAPH_GRAPH_NODE_NULL;)) | ||||
// when set incre build, add cache helper map | |||||
AddModelCacheHelperToMap(graph_id, session_id, compute_graph_tmp); | |||||
std::vector<GeModelPtr> ge_models; | std::vector<GeModelPtr> ge_models; | ||||
if (options_.local_fmk_op_flag) { | if (options_.local_fmk_op_flag) { | ||||
@@ -583,7 +684,7 @@ Status GraphManager::RunGraph(const GraphId &graph_id, const std::vector<GeTenso | |||||
if (!all_sub_graph.empty()) { | if (!all_sub_graph.empty()) { | ||||
auto checkPointGraph = all_sub_graph[0]->GetSubGraph(); | auto checkPointGraph = all_sub_graph[0]->GetSubGraph(); | ||||
if (IsCheckpointGraph(checkPointGraph)) { | if (IsCheckpointGraph(checkPointGraph)) { | ||||
ret = CheckpointHandle(graph_id, outputs); | |||||
ret = CheckpointHandle(graph_id, checkPointGraph, outputs); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "[RunGraph] CheckpointHandle failed!"); | GELOGE(ret, "[RunGraph] CheckpointHandle failed!"); | ||||
} | } | ||||
@@ -667,6 +768,15 @@ Status GraphManager::SaveParams(ge::GeModel &model, const std::string &type, con | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void GraphManager::RemoveModelCacheHelper(const GraphId &graph_id) { | |||||
auto iter = cache_helper_map_.find(graph_id); | |||||
if (iter != cache_helper_map_.end()) { | |||||
cache_helper_map_.erase(iter); | |||||
} else { | |||||
GELOGW("[GraphManager] cache helper does not exist, graph_id = %u", graph_id); | |||||
} | |||||
} | |||||
Status GraphManager::RemoveGraph(const GraphId &graph_id) { | Status GraphManager::RemoveGraph(const GraphId &graph_id) { | ||||
auto it = graph_map_.find(graph_id); | auto it = graph_map_.find(graph_id); | ||||
if (it == graph_map_.end()) { | if (it == graph_map_.end()) { | ||||
@@ -716,6 +826,9 @@ Status GraphManager::RemoveGraph(const GraphId &graph_id) { | |||||
} | } | ||||
var_acc_ctrl_.RemoveGraph(graph_id); | var_acc_ctrl_.RemoveGraph(graph_id); | ||||
graph_map_.erase(it); | graph_map_.erase(it); | ||||
RemoveModelCacheHelper(graph_id); | |||||
auto ge_model = graph_node->GetGeModel(); | auto ge_model = graph_node->GetGeModel(); | ||||
if (ge_model != nullptr) { | if (ge_model != nullptr) { | ||||
GELOGI("Unload model %u.", ge_model->GetModelId()); | GELOGI("Unload model %u.", ge_model->GetModelId()); | ||||
@@ -1106,21 +1219,15 @@ Status GraphManager::SummaryHandle(const GraphId &graph_id, std::vector<GeTensor | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphManager::CheckpointHandle(const GraphId &graph_id, const std::vector<GeTensor> &outputs) { | |||||
Status GraphManager::CheckpointHandle(const GraphId &graph_id, const ComputeGraphPtr &compute_graph, | |||||
const std::vector<GeTensor> &outputs) { | |||||
GELOGI("[GraphManager] CheckpointHandle, outputsSize=%zu.", outputs.size()); | GELOGI("[GraphManager] CheckpointHandle, outputsSize=%zu.", outputs.size()); | ||||
std::vector<InputOutputDescInfo> outputs_desc = graph_executor_.GetOutputsDesc(); | std::vector<InputOutputDescInfo> outputs_desc = graph_executor_.GetOutputsDesc(); | ||||
GELOGI("[GraphManager] CheckpointHandle, outputsDescSize=%zu.", outputs_desc.size()); | GELOGI("[GraphManager] CheckpointHandle, outputsDescSize=%zu.", outputs_desc.size()); | ||||
// find graph | |||||
GraphNodePtr graph_node = nullptr; | |||||
Status ret = GetGraphNode(graph_id, graph_node); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[CheckpointHandle] graph not exist, graph_id = %u.", graph_id); | |||||
return ret; | |||||
} | |||||
ComputeGraphPtr compute_graph_ptr = GraphUtils::GetComputeGraph(*(graph_node->GetGraph())); | |||||
std::map<string, Tensor> save_results; | std::map<string, Tensor> save_results; | ||||
NodePtr netoutput = nullptr; | NodePtr netoutput = nullptr; | ||||
for (const auto &node : compute_graph_ptr->GetDirectNode()) { | |||||
for (const auto &node : compute_graph->GetDirectNode()) { | |||||
if (node->GetType() == kNetOutput) { | if (node->GetType() == kNetOutput) { | ||||
netoutput = node; | netoutput = node; | ||||
break; | break; | ||||
@@ -1248,6 +1355,8 @@ bool GraphManager::CheckTransOpForCheckpointGraph(NodePtr &node) { | |||||
return true; | return true; | ||||
} | } | ||||
static inline bool CheckConstanOpForCheckpointGraph(NodePtr &node) { return node->GetOutDataNodes().empty(); } | |||||
bool GraphManager::IsCheckpointGraph(ComputeGraphPtr &compute_graph) { | bool GraphManager::IsCheckpointGraph(ComputeGraphPtr &compute_graph) { | ||||
if (compute_graph == nullptr) { | if (compute_graph == nullptr) { | ||||
GELOGE(GE_GRAPH_PARAM_NULLPTR, "[IsCheckpointGraph] computeGraph is nullptr."); | GELOGE(GE_GRAPH_PARAM_NULLPTR, "[IsCheckpointGraph] computeGraph is nullptr."); | ||||
@@ -1268,6 +1377,10 @@ bool GraphManager::IsCheckpointGraph(ComputeGraphPtr &compute_graph) { | |||||
if (!CheckTransOpForCheckpointGraph(node)) { | if (!CheckTransOpForCheckpointGraph(node)) { | ||||
return false; | return false; | ||||
} | } | ||||
} else if (op->GetType() == CONSTANTOP) { | |||||
if (!CheckConstanOpForCheckpointGraph(node)) { | |||||
return false; | |||||
} | |||||
} else if (op->GetType() != kSend && op->GetType() != kRecv) { | } else if (op->GetType() != kSend && op->GetType() != kRecv) { | ||||
GELOGI("this node is not allow in checkpoint sub graph, node_type: %s, node_name: %s.", op->GetType().c_str(), | GELOGI("this node is not allow in checkpoint sub graph, node_type: %s, node_name: %s.", op->GetType().c_str(), | ||||
op->GetName().c_str()); | op->GetName().c_str()); | ||||
@@ -1439,8 +1552,6 @@ Status GraphManager::OptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_gra | |||||
names_to_passes.emplace_back("ReshapeRemovePass", &trans_op_nearby_allreduce_fusion_pass); | names_to_passes.emplace_back("ReshapeRemovePass", &trans_op_nearby_allreduce_fusion_pass); | ||||
ReshapeRemovePass reshape_remove_pass; | ReshapeRemovePass reshape_remove_pass; | ||||
names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); | names_to_passes.emplace_back("ReshapeRemovePass", &reshape_remove_pass); | ||||
ReplaceWithEmptyConstPass replace_with_empty_const_pass; | |||||
names_to_passes.emplace_back("ReplaceWithEmptyConstPass", &replace_with_empty_const_pass); | |||||
ConstantFoldingPass constant_folding_pass; | ConstantFoldingPass constant_folding_pass; | ||||
names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | names_to_passes.emplace_back("ConstantFoldingPass", &constant_folding_pass); | ||||
DimensionAdjustPass dimension_adjust_pass; | DimensionAdjustPass dimension_adjust_pass; | ||||
@@ -1632,6 +1743,51 @@ Status GraphManager::RunGraphAsync(const GraphId &graph_id, const std::vector<ge | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void GraphManager::AddModelCacheHelperToMap(const GraphId &graph_id, uint64_t session_id, | |||||
ComputeGraphPtr &compute_graph) { | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
if (instance_ptr != nullptr && instance_ptr->IsIncreBuild()) { | |||||
auto iter = cache_helper_map_.find(graph_id); | |||||
if (iter == cache_helper_map_.end()) { | |||||
ModelCacheHelperPtr cache_helper = MakeShared<ge::ModelCacheHelper>(session_id, graph_id, compute_graph); | |||||
if (cache_helper != nullptr) { | |||||
cache_helper_map_.emplace(std::make_pair(graph_id, cache_helper)); | |||||
} else { | |||||
GELOGW("Cache helper make shared failed, graph_id = %u.", graph_id); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
Status GraphManager::IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model) { | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
if (instance_ptr == nullptr || !instance_ptr->IsIncreBuild()) { | |||||
return FAILED; | |||||
} | |||||
const uint32_t graph_id = graph_node->GetGraphId(); | |||||
auto iter = cache_helper_map_.find(graph_id); | |||||
if (iter == cache_helper_map_.end()) { | |||||
GELOGW("Can not find ModelCacheHelper of graph[%u]", graph_id); | |||||
return FAILED; | |||||
} | |||||
ModelCacheHelperPtr cache_helper = iter->second; | |||||
if (cache_helper->IsModelCacheHit()) { | |||||
GEEVENT("Model cache hit."); | |||||
Status ret = LoadFromCache(graph_node, cache_helper, ge_model); | |||||
if (ret == SUCCESS) { | |||||
return SUCCESS; | |||||
} else { | |||||
GELOGW("Error occurred when load from cache, abandon."); | |||||
} | |||||
} else { | |||||
GEEVENT("Model cache miss."); | |||||
} | |||||
if (SaveCacheBeforeBuild(graph_node->GetGraphId(), cache_helper) != SUCCESS) { | |||||
GELOGW("Error occurred when save cache."); | |||||
} | |||||
return FAILED; | |||||
} | |||||
void GraphManager::PreRunThread(GraphManager *graph_manager) { | void GraphManager::PreRunThread(GraphManager *graph_manager) { | ||||
if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { | if (prctl(PR_SET_NAME, ("GE_PreRun")) != 0) { | ||||
GELOGW("Set thread name failed."); | GELOGW("Set thread name failed."); | ||||
@@ -1685,6 +1841,8 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
return; | return; | ||||
} | } | ||||
} | } | ||||
// when set incre build, save cache helper. | |||||
graph_manager->AddModelCacheHelperToMap(args.graph_id, args.session_id, compute_graph_tmp); | |||||
std::vector<GeModelPtr> ge_models; | std::vector<GeModelPtr> ge_models; | ||||
@@ -1707,12 +1865,15 @@ void GraphManager::PreRunThread(GraphManager *graph_manager) { | |||||
return; | return; | ||||
} | } | ||||
ret = graph_manager->PreRun(graph_node, ge_inputs, ge_models, ge_model, args.session_id); | |||||
if (ret != SUCCESS) { | |||||
graph_node->SetRunFlag(false); | |||||
ReturnError(graph_manager, args.callback, ret, "PreRun failed, thread exit."); | |||||
graph_node->Unlock(); | |||||
return; | |||||
// check need incre build. | |||||
if (graph_manager->IncreBuild(graph_node, ge_model) != SUCCESS) { | |||||
ret = graph_manager->PreRun(graph_node, ge_inputs, ge_models, ge_model, args.session_id); | |||||
if (ret != SUCCESS) { | |||||
graph_node->SetRunFlag(false); | |||||
ReturnError(graph_manager, args.callback, ret, "PreRun Failed, thread exit.."); | |||||
graph_node->Unlock(); | |||||
return; | |||||
} | |||||
} | } | ||||
graph_node->SetBuildFlag(true); | graph_node->SetBuildFlag(true); | ||||
graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); | graph_manager->var_acc_ctrl_.SetGraphBuildEnd(graph_node->GetGraphId()); | ||||
@@ -27,6 +27,7 @@ | |||||
#include "common/blocking_queue.h" | #include "common/blocking_queue.h" | ||||
#include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
#include "common/helper/model_cache_helper.h" | |||||
#include "external/graph/types.h" | #include "external/graph/types.h" | ||||
#include "ge/ge_api_types.h" | #include "ge/ge_api_types.h" | ||||
#include "graph/build/graph_builder.h" | #include "graph/build/graph_builder.h" | ||||
@@ -211,7 +212,8 @@ class GraphManager { | |||||
Status SummaryHandle(const GraphId &graph_id, std::vector<GeTensor> &outputs); | Status SummaryHandle(const GraphId &graph_id, std::vector<GeTensor> &outputs); | ||||
Status CheckpointHandle(const GraphId &graph_id, const std::vector<GeTensor> &outputs); | |||||
Status CheckpointHandle(const GraphId &graph_id, const ComputeGraphPtr &compute_graph, | |||||
const std::vector<GeTensor> &outputs); | |||||
// call the callback function of ME to push summary result data to ME | // call the callback function of ME to push summary result data to ME | ||||
Status PushSummaryData2ME(const GraphId &graph_id, const std::map<std::string, ge::Tensor> &summary_data); | Status PushSummaryData2ME(const GraphId &graph_id, const std::map<std::string, ge::Tensor> &summary_data); | ||||
@@ -260,6 +262,13 @@ class GraphManager { | |||||
bool IsGraphNeedBuild(const GraphNodePtr &graph_node); | bool IsGraphNeedBuild(const GraphNodePtr &graph_node); | ||||
Status LoadFromCache(const GraphNodePtr &graph_node, const ModelCacheHelperPtr &cache_helper, GeModelPtr &ge_model); | |||||
Status SaveCacheBeforeBuild(uint32_t graph_id, const ModelCacheHelperPtr &cache_helper); | |||||
Status SaveCacheAfterBuild(uint32_t graph_id, ComputeGraphPtr graph, GeModelPtr &ge_model); | |||||
void AddModelCacheHelperToMap(const GraphId &graph_id, uint64_t session_id, ComputeGraphPtr &compute_graph); | |||||
Status IncreBuild(const GraphNodePtr &graph_node, GeModelPtr &ge_model); | |||||
void RemoveModelCacheHelper(const GraphId &graph_id); | |||||
static void PreRunThread(GraphManager *graph_manager); | static void PreRunThread(GraphManager *graph_manager); | ||||
static void RunThread(GraphManager *graph_manager); | static void RunThread(GraphManager *graph_manager); | ||||
static void StopQueue(GraphManager *graph_manager); | static void StopQueue(GraphManager *graph_manager); | ||||
@@ -274,6 +283,8 @@ class GraphManager { | |||||
std::map<GraphId, GraphNodePtr> graph_map_; | std::map<GraphId, GraphNodePtr> graph_map_; | ||||
std::map<GraphId, ModelCacheHelperPtr> cache_helper_map_; | |||||
// for run graph synchronous return | // for run graph synchronous return | ||||
std::mutex sync_run_mutex_; | std::mutex sync_run_mutex_; | ||||
std::condition_variable condition_; | std::condition_variable condition_; | ||||
@@ -64,6 +64,10 @@ ge::Status VarResource::GetVarAddr(const std::string &var_name, const ge::GeTens | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void VarResource::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) { | |||||
var_addr_mgr_map = var_addr_mgr_map_; | |||||
} | |||||
void VarResource::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, | void VarResource::SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, | ||||
rtMemType_t memory_type) { | rtMemType_t memory_type) { | ||||
std::string var_key = VarKey(var_name, tensor_desc); | std::string var_key = VarKey(var_name, tensor_desc); | ||||
@@ -170,6 +174,14 @@ void VarResource::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &b | |||||
var_broad_cast_info_[graph_id][broad_cast_info.var_name] = broad_cast_info; | var_broad_cast_info_[graph_id][broad_cast_info.var_name] = broad_cast_info; | ||||
} | } | ||||
ge::Status VarResource::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) { | |||||
if (var_broad_cast_info_.count(graph_id) == 0 || var_broad_cast_info_[graph_id].count(var_name) == 0) { | |||||
return FAILED; | |||||
} | |||||
broad_cast_info = var_broad_cast_info_[graph_id][var_name]; | |||||
return SUCCESS; | |||||
} | |||||
ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | ge::Status VarResource::SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | ||||
const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr) { | const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr) { | ||||
if (var_op_desc == nullptr) { | if (var_op_desc == nullptr) { | ||||
@@ -282,11 +294,17 @@ Status MemResource::AssignVarMem(const std::string &var_name, uint64_t size, uin | |||||
// align 512 BYTE | // align 512 BYTE | ||||
var_mem_size_ = var_mem_size_ + kSessionMemAlignSize; | var_mem_size_ = var_mem_size_ + kSessionMemAlignSize; | ||||
GELOGI( | |||||
"[IMAS]AssignVarMem Set session_%lu name[%s] output[%d]" | |||||
"offset to [%zu] size[%lu] realsize[%lu].", | |||||
session_id, var_name.c_str(), 0, mem_offset, (var_mem_size_ - mem_offset), real_size); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
int64_t MemResource::GetVarMemSize() const { return var_mem_size_; } | int64_t MemResource::GetVarMemSize() const { return var_mem_size_; } | ||||
void MemResource::UpdateVarMemSize(int64_t mem_size) { var_mem_size_ = mem_size; }; | |||||
VarManager::VarManager(uint64_t session_id) | VarManager::VarManager(uint64_t session_id) | ||||
: version_(SessionVersion::OTHER_VERSION), | : version_(SessionVersion::OTHER_VERSION), | ||||
session_id_(session_id), | session_id_(session_id), | ||||
@@ -363,6 +381,21 @@ ge::Status VarManager::SetVarAddr(const std::string &var_name, const ge::GeTenso | |||||
return ge::SUCCESS; | return ge::SUCCESS; | ||||
} | } | ||||
ge::Status VarManager::SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address, | |||||
rtMemType_t memory_type) { | |||||
GELOGI("VarManager::SaveVarAddr var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(), | |||||
ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(), | |||||
ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str()); | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
if (var_resource_ == nullptr) { | |||||
GELOGW("VarManager has not been init."); | |||||
return ge::INTERNAL_ERROR; | |||||
} | |||||
var_resource_->SaveVarAddr(var_name, tensor_desc, address, memory_type); | |||||
return ge::SUCCESS; | |||||
} | |||||
ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, | ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, | ||||
rtMemType_t &memory_type) { | rtMemType_t &memory_type) { | ||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
@@ -388,6 +421,10 @@ ge::Status VarManager::GetVarAddr(const std::string &var_name, const ge::GeTenso | |||||
return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type); | return GetVarAddr(var_name, tensor_desc, dev_ptr, memory_type); | ||||
} | } | ||||
void VarManager::GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map) { | |||||
var_resource_->GetAllVarAddrMgr(var_addr_mgr_map); | |||||
} | |||||
int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { | int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { | ||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
MemResource *mem_resource = nullptr; | MemResource *mem_resource = nullptr; | ||||
@@ -405,14 +442,36 @@ int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { | |||||
return mem_resource->GetVarMemSize(); | return mem_resource->GetVarMemSize(); | ||||
} | } | ||||
Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) { | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
MemResource *mem_resource = nullptr; | |||||
auto iter = mem_resource_map_.find(memory_type); | |||||
if (iter == mem_resource_map_.end()) { | |||||
mem_resource = new (std::nothrow) MemResource(); | |||||
if (mem_resource == nullptr) { | |||||
GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type); | |||||
return ge::INTERNAL_ERROR; | |||||
} else { | |||||
mem_resource_map_[memory_type] = mem_resource; | |||||
} | |||||
} else { | |||||
mem_resource = iter->second; | |||||
} | |||||
if (mem_resource == nullptr) { | |||||
GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid."); | |||||
return FAILED; | |||||
} | |||||
mem_resource->UpdateVarMemSize(mem_size); | |||||
return SUCCESS; | |||||
} | |||||
ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, | ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, | ||||
rtMemType_t memory_type) { | rtMemType_t memory_type) { | ||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
GELOGI( | |||||
"VarManager::AssignVarMem var_name = %s, data_type = %s, data_format = " | |||||
"%s.", | |||||
var_name.c_str(), ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(), | |||||
ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str()); | |||||
GELOGI("VarManager::AssignVarMem var_name = %s, data_type = %s, data_format = %s.", var_name.c_str(), | |||||
ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(), | |||||
ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str()); | |||||
int64_t tensor_desc_size = 0; | int64_t tensor_desc_size = 0; | ||||
size_t mem_offset = 0; | size_t mem_offset = 0; | ||||
@@ -475,14 +534,13 @@ ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTen | |||||
if (cur_tensor_desc.GetFormat() != tensor_desc.GetFormat() || | if (cur_tensor_desc.GetFormat() != tensor_desc.GetFormat() || | ||||
cur_tensor_desc.GetDataType() != tensor_desc.GetDataType() || | cur_tensor_desc.GetDataType() != tensor_desc.GetDataType() || | ||||
cur_tensor_desc.GetShape().GetDims() != tensor_desc.GetShape().GetDims()) { | cur_tensor_desc.GetShape().GetDims() != tensor_desc.GetShape().GetDims()) { | ||||
GELOGI( | |||||
"var %s assigned new memory (format, data type, shape) (%s, %s, " | |||||
"%zu) from (%s, %s, %zu)", | |||||
var_name.c_str(), ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(), | |||||
ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(), tensor_desc.GetShape().GetDims().size(), | |||||
ge::TypeUtils::DataTypeToSerialString(cur_tensor_desc.GetDataType()).c_str(), | |||||
ge::TypeUtils::FormatToSerialString(cur_tensor_desc.GetFormat()).c_str(), | |||||
cur_tensor_desc.GetShape().GetDims().size()); | |||||
GELOGI("var %s assigned new memory (format, data type, shape) (%s, %s, %zu) from (%s, %s, %zu)", var_name.c_str(), | |||||
ge::TypeUtils::DataTypeToSerialString(tensor_desc.GetDataType()).c_str(), | |||||
ge::TypeUtils::FormatToSerialString(tensor_desc.GetFormat()).c_str(), | |||||
tensor_desc.GetShape().GetDims().size(), | |||||
ge::TypeUtils::DataTypeToSerialString(cur_tensor_desc.GetDataType()).c_str(), | |||||
ge::TypeUtils::FormatToSerialString(cur_tensor_desc.GetFormat()).c_str(), | |||||
cur_tensor_desc.GetShape().GetDims().size()); | |||||
var_resource_->SetVarAddr(var_name, tensor_desc, | var_resource_->SetVarAddr(var_name, tensor_desc, | ||||
reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(mem_offset)), memory_type); | reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(mem_offset)), memory_type); | ||||
} | } | ||||
@@ -550,6 +608,16 @@ ge::Status VarManager::SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastIn | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
ge::Status VarManager::GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info) { | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
if (var_resource_ == nullptr) { | |||||
GELOGW("VarManager has not been init."); | |||||
return ge::INTERNAL_ERROR; | |||||
} | |||||
return var_resource_->GetBroadCastInfo(graph_id, var_name, broad_cast_info); | |||||
} | |||||
ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) { | ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc) { | ||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str()); | GELOGD("VarManager::RenewCurVarDesc var_name = %s.", var_name.c_str()); | ||||
@@ -672,6 +740,7 @@ Status VarManager::SetMemoryMallocSize(const map<string, string> &options) { | |||||
GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse graph memory manager malloc max size failed."); | GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse graph memory manager malloc max size failed."); | ||||
return ge::GE_GRAPH_OPTIONS_INVALID; | return ge::GE_GRAPH_OPTIONS_INVALID; | ||||
} | } | ||||
GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_); | |||||
} | } | ||||
it = options.find(VARIABLE_MEMORY_MAX_SIZE); | it = options.find(VARIABLE_MEMORY_MAX_SIZE); | ||||
@@ -101,6 +101,8 @@ class VarResource { | |||||
ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, | ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, | ||||
rtMemType_t &memory_type); | rtMemType_t &memory_type); | ||||
void GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map); | |||||
void SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, | void SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, | ||||
rtMemType_t rtMemType_t); | rtMemType_t rtMemType_t); | ||||
@@ -113,6 +115,8 @@ class VarResource { | |||||
void SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | void SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | ||||
ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | |||||
ge::Status SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | ge::Status SyncVarData2BroadCast(uint32_t graph_id, const std::string &var_name, | ||||
const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr); | const ge::ConstOpDescPtr &var_op_desc, uint8_t *base_ptr); | ||||
@@ -175,6 +179,8 @@ class MemResource { | |||||
int64_t GetVarMemSize() const; | int64_t GetVarMemSize() const; | ||||
void UpdateVarMemSize(int64_t mem_size); | |||||
private: | private: | ||||
uint64_t total_size_; | uint64_t total_size_; | ||||
uint64_t var_mem_size_; | uint64_t var_mem_size_; | ||||
@@ -196,9 +202,14 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
ge::Status SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, | ge::Status SetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *dev_ptr, | ||||
rtMemType_t memory_type); | rtMemType_t memory_type); | ||||
ge::Status SaveVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t *address, | |||||
rtMemType_t memory_type); | |||||
ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, | ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr, | ||||
rtMemType_t &memory_type); | rtMemType_t &memory_type); | ||||
void GetAllVarAddrMgr(std::unordered_map<std::string, VarAddrMgr> &var_addr_mgr_map); | |||||
ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); | ge::Status GetVarAddr(const std::string &var_name, const ge::GeTensorDesc &tensor_desc, uint8_t **dev_ptr); | ||||
ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, ge::ConstOpDescPtr var_op_desc, | ge::Status SyncVarData(uint32_t graph_id, const std::string &var_name, ge::ConstOpDescPtr var_op_desc, | ||||
@@ -206,6 +217,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | ge::Status SaveBroadCastInfo(uint32_t graph_id, const VarBroadCastInfo &broad_cast_info); | ||||
ge::Status GetBroadCastInfo(uint32_t graph_id, const string &var_name, VarBroadCastInfo &broad_cast_info); | |||||
ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, ge::ConstOpDescPtr var_op_desc, | ge::Status SyncBroadCastData2Var(uint32_t graph_id, const std::string &var_name, ge::ConstOpDescPtr var_op_desc, | ||||
uint8_t *base_ptr); | uint8_t *base_ptr); | ||||
@@ -251,6 +264,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
int64_t GetVarMemSize(rtMemType_t memory_type); | int64_t GetVarMemSize(rtMemType_t memory_type); | ||||
Status UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size); | |||||
bool IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc); | bool IsVarExist(const std::string &var_name, const ge::GeTensorDesc &tensor_desc); | ||||
bool IsVarExist(const std::string &var_name); | bool IsVarExist(const std::string &var_name); | ||||
@@ -238,6 +238,14 @@ Status ge::GraphPartitioner::MergeSubGraph(ge::ComputeGraphPtr &output_merged_co | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
GE_TIMESTAMP_END(MergeGraphTopologicalSorting, "GraphPartitioner::MergeGraphTopologicalSorting"); | GE_TIMESTAMP_END(MergeGraphTopologicalSorting, "GraphPartitioner::MergeGraphTopologicalSorting"); | ||||
// flush all nodes' engine of merged graph | |||||
GE_TIMESTAMP_START(MergeGraphEnginePlacerRun); | |||||
graph_info_.engine_placer_.SetComputeGraph(output_merged_compute_graph); | |||||
if (graph_info_.engine_placer_.Run() != SUCCESS) { | |||||
GELOGE(GE_GRAPH_INIT_FAILED, "[GraphPartitioner]: engine_placer run failed"); | |||||
return FAILED; | |||||
} | |||||
GE_TIMESTAMP_END(MergeGraphEnginePlacerRun, "GraphPartitioner::MergeGraphEnginePlacerRun"); | |||||
GELOGI("Graph merge ends."); | GELOGI("Graph merge ends."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -200,7 +200,18 @@ bool AtomicAddrCleanPass::IsAtomicOp(const NodePtr &node) { | |||||
vector<OpInfo> op_info_vec = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); | vector<OpInfo> op_info_vec = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); | ||||
for (const auto &op_info : op_info_vec) { | for (const auto &op_info : op_info_vec) { | ||||
if (op_info.isAtomic) { | if (op_info.isAtomic) { | ||||
GELOGI("Recognized atomic op %s from HCCL engine.", op_desc->GetName().c_str()); | |||||
GELOGI("Recognized atomic op %s from DNN_HCCL engine.", op_desc->GetName().c_str()); | |||||
// check peer input is DATA | |||||
for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | |||||
if (in_data_anchor->GetPeerOutAnchor() != nullptr && | |||||
in_data_anchor->GetPeerOutAnchor()->GetOwnerNode() != nullptr) { | |||||
auto peer_in_node = in_data_anchor->GetPeerOutAnchor()->GetOwnerNode(); | |||||
if (peer_in_node->GetType() == DATA) { | |||||
GELOGI("Recognized atomic op %s from DNN_HCCL engine and input is DATA.", op_desc->GetName().c_str()); | |||||
return false; | |||||
} | |||||
} | |||||
} | |||||
hcom_node_vec_.push_back(node); | hcom_node_vec_.push_back(node); | ||||
return true; | return true; | ||||
} | } | ||||
@@ -49,9 +49,11 @@ Status CastKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstG | |||||
GELOGE(PARAM_INVALID, "Input const_weight_ptr is nullptr."); | GELOGE(PARAM_INVALID, "Input const_weight_ptr is nullptr."); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
const uint8_t *src_data = const_weight_ptr->GetData().data(); | const uint8_t *src_data = const_weight_ptr->GetData().data(); | ||||
if (op_desc_ptr == nullptr || src_data == nullptr) { | |||||
GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr or src_data is nullptr."); | |||||
// src_data == nullptr is supported | |||||
if (op_desc_ptr == nullptr) { | |||||
GELOGE(PARAM_INVALID, "Parameter's invalid, Input opDescPtr is nullptr."); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
GeTensorDesc op_desc = op_desc_ptr->GetOutputDesc(0); | GeTensorDesc op_desc = op_desc_ptr->GetOutputDesc(0); | ||||
@@ -73,7 +75,7 @@ Status CastKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstG | |||||
TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(), | TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(), | ||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
GE_CHECK_SIZE(const_weight_ptr->GetData().GetSize()); | |||||
// const_weight_ptr->GetData().GetSize() == 0 is supported | |||||
auto src_data_size = src_shape.GetShapeSize(); | auto src_data_size = src_shape.GetShapeSize(); | ||||
if (src_data_size == 0 && | if (src_data_size == 0 && | ||||
static_cast<int>(const_weight_ptr->GetData().GetSize()) == GetSizeByDataType(src_data_type)) { | static_cast<int>(const_weight_ptr->GetData().GetSize()) == GetSizeByDataType(src_data_type)) { | ||||
@@ -113,7 +115,6 @@ Status CastKernel::Compute(const OpDescPtr op_desc_ptr, const std::vector<ConstG | |||||
} | } | ||||
if (output_ptr->SetData(trans_result.data.get(), trans_result.length) != SUCCESS) { | if (output_ptr->SetData(trans_result.data.get(), trans_result.length) != SUCCESS) { | ||||
GELOGW("Compute: SetData failed"); | GELOGW("Compute: SetData failed"); | ||||
return FAILED; | |||||
} | } | ||||
v_output.push_back(output_ptr); | v_output.push_back(output_ptr); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -113,12 +113,26 @@ bool KernelUtils::CheckSizeForTransOp(const ge::ConstGeTensorPtr &const_weight_p | |||||
GELOGI("Const real value Size:%zu, op_desc Shape Size:%ld, data_type:%s.", data_size, cal_size, | GELOGI("Const real value Size:%zu, op_desc Shape Size:%ld, data_type:%s.", data_size, cal_size, | ||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
if ((shape_size != 0) || (length != 0 && (data_size / static_cast<size_t>(length) != 1))) { | |||||
if (!(data_size == static_cast<size_t>(cal_size) && data_size != 0)) { | |||||
if (shape_size != 0) { | |||||
// Standard tensor | |||||
if (data_size != static_cast<size_t>(cal_size) || data_size == 0) { | |||||
GELOGW("Const input data size is not equal with tensor desc shape"); | |||||
return false; | |||||
} | |||||
} else if (data_shape.GetDimNum() != 0) { | |||||
// Empty tensor, has zero in shape vector | |||||
if (data_size != 0) { | |||||
GELOGW("Const input data size is not equal with tensor desc shape"); | |||||
return false; | |||||
} | |||||
} else { | |||||
// Scalar tensor, has only one element in tensor | |||||
if (length != 0 && (data_size / static_cast<size_t>(length) != 1)) { | |||||
GELOGW("Const input data size is not equal with tensor desc shape"); | GELOGW("Const input data size is not equal with tensor desc shape"); | ||||
return false; | return false; | ||||
} | } | ||||
} | } | ||||
return true; | return true; | ||||
} | } | ||||
@@ -29,6 +29,7 @@ namespace ge { | |||||
class KernelUtils { | class KernelUtils { | ||||
public: | public: | ||||
KernelUtils() = delete; | KernelUtils() = delete; | ||||
~KernelUtils() = delete; | |||||
static Status CheckDimensionNodeInfo(const NodePtr &node_ptr); | static Status CheckDimensionNodeInfo(const NodePtr &node_ptr); | ||||
static bool CheckFormatSupported(const NodePtr &node_ptr); | static bool CheckFormatSupported(const NodePtr &node_ptr); | ||||
static bool CheckSizeForTransOp(const ConstGeTensorPtr &const_weight_ptr, const OpDescPtr &op_desc_ptr); | static bool CheckSizeForTransOp(const ConstGeTensorPtr &const_weight_ptr, const OpDescPtr &op_desc_ptr); | ||||
@@ -41,7 +42,7 @@ class KernelUtils { | |||||
* @param [out] output the tensor for save sequence of numbers | * @param [out] output the tensor for save sequence of numbers | ||||
* @author | * @author | ||||
*/ | */ | ||||
template<typename T> | |||||
template <typename T> | |||||
static Status GenData(const int64_t data_num, const T value, const GeTensorPtr &output) { | static Status GenData(const int64_t data_num, const T value, const GeTensorPtr &output) { | ||||
if (data_num > 0) { | if (data_num > 0) { | ||||
if (!CheckInt64MulOverflow(data_num, static_cast<int64_t>(sizeof(T)))) { | if (!CheckInt64MulOverflow(data_num, static_cast<int64_t>(sizeof(T)))) { | ||||
@@ -69,12 +70,12 @@ class KernelUtils { | |||||
} | } | ||||
/** | /** | ||||
* Calculate dimension | |||||
* @param [in] dims save the tensor of the dimension | |||||
* @param [in] vec_dim results of each dimension | |||||
* @param [out] data_num total size of data | |||||
* @author | |||||
*/ | |||||
* Calculate dimension | |||||
* @param [in] dims save the tensor of the dimension | |||||
* @param [in] vec_dim results of each dimension | |||||
* @param [out] data_num total size of data | |||||
* @author | |||||
*/ | |||||
template <typename T> | template <typename T> | ||||
static Status CalcDims(const ConstGeTensorPtr dims, std::vector<int64_t> &vec_dim, int64_t &data_num) { | static Status CalcDims(const ConstGeTensorPtr dims, std::vector<int64_t> &vec_dim, int64_t &data_num) { | ||||
data_num = 1; | data_num = 1; | ||||
@@ -67,8 +67,8 @@ Status PackKernel::ValidateKernelParams(const ge::OpDescPtr &op_desc_ptr, | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
if (!(AttrUtils::GetInt(op_desc_ptr, PACK_ATTR_NAME_NUM, n_))) { | if (!(AttrUtils::GetInt(op_desc_ptr, PACK_ATTR_NAME_NUM, n_))) { | ||||
GELOGE(PARAM_INVALID, "Attr %s is not exist.", PACK_ATTR_NAME_NUM.c_str()); | |||||
return PARAM_INVALID; | |||||
n_ = 0; | |||||
GELOGD("Attr %s is not set, default value %ld is used.", PACK_ATTR_NAME_NUM.c_str(), n_); | |||||
} | } | ||||
if (!(AttrUtils::GetInt(op_desc_ptr, ATTR_NAME_AXIS, axis_))) { | if (!(AttrUtils::GetInt(op_desc_ptr, ATTR_NAME_AXIS, axis_))) { | ||||
GELOGE(PARAM_INVALID, "Attr %s is not exist.", ATTR_NAME_AXIS.c_str()); | GELOGE(PARAM_INVALID, "Attr %s is not exist.", ATTR_NAME_AXIS.c_str()); | ||||
@@ -105,11 +105,7 @@ Status PackKernel::ValidateInputs(const ge::OpDescPtr &op_desc_ptr, const std::v | |||||
GELOGW("Input %ld of pack kernel %s is null.", i, op_desc_ptr->GetName().c_str()); | GELOGW("Input %ld of pack kernel %s is null.", i, op_desc_ptr->GetName().c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
// check if tensor contains data | |||||
if (input[i]->GetData().size() == 0) { | |||||
GELOGW("Inputs %ld do not have value.", i); | |||||
return NOT_CHANGED; | |||||
} | |||||
if (i == 0) { | if (i == 0) { | ||||
// get first input shape | // get first input shape | ||||
shape = input[0]->GetTensorDesc().GetShape(); | shape = input[0]->GetTensorDesc().GetShape(); | ||||
@@ -127,8 +123,8 @@ Status PackKernel::ValidateInputs(const ge::OpDescPtr &op_desc_ptr, const std::v | |||||
auto dst_shape = tensor_desc.GetShape(); | auto dst_shape = tensor_desc.GetShape(); | ||||
int64_t num = 1; | int64_t num = 1; | ||||
for (auto dim : dst_shape.GetDims()) { | for (auto dim : dst_shape.GetDims()) { | ||||
if (dim < 1) { | |||||
GELOGW("Invalid zero dim in the shape %s", formats::ShapeToString(shape).c_str()); | |||||
if (dim < 0) { | |||||
GELOGW("Invalid dim ld% in the shape %s", dim, formats::ShapeToString(shape).c_str()); | |||||
return NOT_CHANGED; | return NOT_CHANGED; | ||||
} | } | ||||
num *= dim; | num *= dim; | ||||
@@ -141,6 +137,12 @@ Status PackKernel::ValidateInputs(const ge::OpDescPtr &op_desc_ptr, const std::v | |||||
GELOGW("Shape of input %ld is not equal wiht input 0.", i); | GELOGW("Shape of input %ld is not equal wiht input 0.", i); | ||||
return NOT_CHANGED; | return NOT_CHANGED; | ||||
} | } | ||||
// check tensor data size is zero ot not | |||||
if (input[i]->GetData().size() == 0 && num != 0) { | |||||
GELOGW("Inputs %ld do not have value.", i); | |||||
return NOT_CHANGED; | |||||
} | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -167,6 +169,13 @@ void PackKernel::ExpandDims(const int64_t axis, const std::vector<ge::ConstGeTen | |||||
Status PackKernel::CopyOutputData(const GeShape &final_shape, const std::vector<ge::ConstGeTensorPtr> &input, | Status PackKernel::CopyOutputData(const GeShape &final_shape, const std::vector<ge::ConstGeTensorPtr> &input, | ||||
ge::GeTensorPtr &output_ptr) { | ge::GeTensorPtr &output_ptr) { | ||||
output_ptr->MutableTensorDesc().SetShape(final_shape); | |||||
output_ptr->MutableTensorDesc().SetDataType(DataType(data_type_)); | |||||
if (final_shape.GetShapeSize() == 0 && final_shape.GetDims().size() != 0) { | |||||
// means has zero in shape list, output tnesor data is []. | |||||
return SUCCESS; | |||||
} | |||||
int64_t times = 1; | int64_t times = 1; | ||||
int64_t unit = 1; | int64_t unit = 1; | ||||
// calculate data unit | // calculate data unit | ||||
@@ -210,8 +219,6 @@ Status PackKernel::CopyOutputData(const GeShape &final_shape, const std::vector< | |||||
if (output_ptr->SetData(buf.get(), static_cast<size_t>(output_size * data_size)) != GRAPH_SUCCESS) { | if (output_ptr->SetData(buf.get(), static_cast<size_t>(output_size * data_size)) != GRAPH_SUCCESS) { | ||||
GELOGW("CopyOutputData: SetData failed"); | GELOGW("CopyOutputData: SetData failed"); | ||||
} | } | ||||
output_ptr->MutableTensorDesc().SetShape(final_shape); | |||||
output_ptr->MutableTensorDesc().SetDataType(DataType(data_type_)); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -63,10 +63,7 @@ Status ReduceProdKernel::ReduceProdCheck(const ge::OpDescPtr &op_desc_ptr, | |||||
GELOGE(PARAM_INVALID, "Axis must be at most rank 1, node node: %s", op_desc_ptr->GetName().c_str()); | GELOGE(PARAM_INVALID, "Axis must be at most rank 1, node node: %s", op_desc_ptr->GetName().c_str()); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
if (data_tensor->GetData().size() == 0 || axis_tensor->GetData().size() == 0) { | |||||
GELOGE(PARAM_INVALID, "ReduceProdKernel data size of inputs is 0, node node: %s", op_desc_ptr->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
DataType data_type = data_tensor->GetTensorDesc().GetDataType(); | DataType data_type = data_tensor->GetTensorDesc().GetDataType(); | ||||
if (kReduceProdSupportedType.find(data_type) == kReduceProdSupportedType.end()) { | if (kReduceProdSupportedType.find(data_type) == kReduceProdSupportedType.end()) { | ||||
GELOGE(PARAM_INVALID, "ReduceProdKernel data type %s not support, node name: %s", | GELOGE(PARAM_INVALID, "ReduceProdKernel data type %s not support, node name: %s", | ||||
@@ -151,7 +148,6 @@ Status ReduceProdKernel::DataCal(const std::vector<ge::ConstGeTensorPtr> &input, | |||||
static_cast<size_t>(head_dim_ * end_dim_ * sizeof(int32_t))) != GRAPH_SUCCESS, | static_cast<size_t>(head_dim_ * end_dim_ * sizeof(int32_t))) != GRAPH_SUCCESS, | ||||
GELOGW("set data failed"); | GELOGW("set data failed"); | ||||
return INTERNAL_ERROR); | return INTERNAL_ERROR); | ||||
output_ptr->MutableTensorDesc().SetDataType(data_dtype); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -260,19 +256,32 @@ Status ReduceProdKernel::Compute(const ge::OpDescPtr op_desc_ptr, const std::vec | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return NOT_CHANGED; | return NOT_CHANGED; | ||||
} | } | ||||
} else if (input.at(kReduceProdAxisIndex)->GetData().size() == 0) { | |||||
// axis tensor value is [], means no process for input | |||||
output_ptr->MutableTensorDesc().SetShape(input.at(kReduceProdDataIndex)->GetTensorDesc().GetShape()); | |||||
output_ptr->MutableTensorDesc().SetDataType(input.at(kReduceProdDataIndex)->GetTensorDesc().GetDataType()); | |||||
if (output_ptr->SetData(input.at(kReduceProdDataIndex)->GetData()) != GRAPH_SUCCESS) { | |||||
GELOGW("Compute: SetData failed"); | |||||
} | |||||
} else { | } else { | ||||
// calculate axis to reduce | // calculate axis to reduce | ||||
ret = AxisCal(input); | ret = AxisCal(input); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return NOT_CHANGED; | return NOT_CHANGED; | ||||
} | } | ||||
// calculate data and data type | |||||
ret = DataCal(input, output_ptr); | |||||
if (ret != SUCCESS) { | |||||
return NOT_CHANGED; | |||||
} | |||||
// calculate shape | |||||
// calculate and set shape | |||||
ShapeCal(op_desc_ptr, input, output_ptr); | ShapeCal(op_desc_ptr, input, output_ptr); | ||||
// set data type | |||||
output_ptr->MutableTensorDesc().SetDataType(input.at(kReduceProdDataIndex)->GetTensorDesc().GetDataType()); | |||||
// data size == 0 means input tensor has zero in shape, and tensor value is []. | |||||
if (input.at(kReduceProdDataIndex)->GetData().size() != 0) { | |||||
// calculate data and data type | |||||
ret = DataCal(input, output_ptr); | |||||
if (ret != SUCCESS) { | |||||
return NOT_CHANGED; | |||||
} | |||||
} | |||||
} | } | ||||
// print output tensor information, and will be deleted | // print output tensor information, and will be deleted | ||||
@@ -48,8 +48,9 @@ Status TransdataKernel::ValidateInput(const OpDescPtr &op_desc_ptr, const std::v | |||||
GELOGE(PARAM_INVALID, "Input const_weight_ptr is nullptr."); | GELOGE(PARAM_INVALID, "Input const_weight_ptr is nullptr."); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
const uint8_t *src_data = const_weight_ptr->GetData().data(); | |||||
if (op_desc_ptr == nullptr || src_data == nullptr) { | |||||
// src_data == nullptr is supported | |||||
if (op_desc_ptr == nullptr) { | |||||
GELOGE(PARAM_INVALID, "Input opDescPtr is nullptr."); | GELOGE(PARAM_INVALID, "Input opDescPtr is nullptr."); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
@@ -26,6 +26,7 @@ namespace ge { | |||||
class PassUtils { | class PassUtils { | ||||
public: | public: | ||||
PassUtils() = delete; | PassUtils() = delete; | ||||
~PassUtils() = delete; | |||||
static NodePtr GetInDataNode(const ConstNodePtr &node, int index); | static NodePtr GetInDataNode(const ConstNodePtr &node, int index); | ||||
@@ -137,7 +137,7 @@ Status SwitchOpPass::ReplaceSwitchNode(ComputeGraphPtr &graph, NodePtr &switch_n | |||||
NodePtr out_node = peer_in_anchor->GetOwnerNode(); | NodePtr out_node = peer_in_anchor->GetOwnerNode(); | ||||
GE_CHK_STATUS_RET(GetOriginalType(out_node, type), "Get node type fail."); | GE_CHK_STATUS_RET(GetOriginalType(out_node, type), "Get node type fail."); | ||||
if ((type == MERGE) || (type == REFMERGE)) { | if ((type == MERGE) || (type == REFMERGE)) { | ||||
NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, peer_data_anchor); | |||||
NodePtr memcpy_node = CreateMemcpyAsyncNode(graph, peer_data_anchor, false); | |||||
GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return FAILED, "Create memcpy_async node fail."); | GE_CHK_BOOL_EXEC(memcpy_node != nullptr, return FAILED, "Create memcpy_async node fail."); | ||||
GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, memcpy_node->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::AddEdge(peer_data_anchor, memcpy_node->GetInDataAnchor(0)), | ||||
"MemcpyAsync node add edge fail."); | "MemcpyAsync node add edge fail."); | ||||
@@ -234,16 +234,18 @@ Status SwitchOpPass::ReplaceMergeNode(ComputeGraphPtr &graph, NodePtr &merge_nod | |||||
need_label_nodes_.emplace_back(stream_merge); | need_label_nodes_.emplace_back(stream_merge); | ||||
} | } | ||||
bool multi_batch_flag = false; | |||||
if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { | if (merge_op_desc->HasAttr(ATTR_INSERT_BY_MBATCH)) { | ||||
if (!ge::AttrUtils::SetBool(op_desc, ATTR_INSERT_BY_MBATCH, true)) { | if (!ge::AttrUtils::SetBool(op_desc, ATTR_INSERT_BY_MBATCH, true)) { | ||||
GELOGE(FAILED, "Set attr ATTR_INSERT_BY_MBATCH fail, StreamMerge:%s.", node_name.c_str()); | GELOGE(FAILED, "Set attr ATTR_INSERT_BY_MBATCH fail, StreamMerge:%s.", node_name.c_str()); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
multi_batch_flag = true; | |||||
} | } | ||||
(void)bypass_nodes_.insert(merge_node); | (void)bypass_nodes_.insert(merge_node); | ||||
GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, stream_merge), "StreamMerge add memcpy node fail."); | |||||
GE_CHK_STATUS_RET(AddMemcpyAsyncNodes(graph, stream_merge, multi_batch_flag), "StreamMerge add memcpy node fail."); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -302,17 +304,20 @@ NodePtr SwitchOpPass::CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodeP | |||||
/// @brief Add MemcpyAsync Node | /// @brief Add MemcpyAsync Node | ||||
/// @param [in] graph | /// @param [in] graph | ||||
/// @param [in] in_node | /// @param [in] in_node | ||||
/// @param [in] multi_batch_flag | |||||
/// @return ge::NodePtr | /// @return ge::NodePtr | ||||
/// | /// | ||||
NodePtr SwitchOpPass::CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor) { | |||||
NodePtr SwitchOpPass::CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, | |||||
bool multi_batch_flag) { | |||||
GE_CHK_BOOL_EXEC(out_data_anchor != nullptr, return nullptr, "Param of input node is null."); | GE_CHK_BOOL_EXEC(out_data_anchor != nullptr, return nullptr, "Param of input node is null."); | ||||
OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); | OpDescPtr pre_op_desc = out_data_anchor->GetOwnerNode()->GetOpDesc(); | ||||
GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); | GE_CHK_BOOL_EXEC(pre_op_desc != nullptr, return nullptr, "OpDesc of pre node is invalid."); | ||||
std::string node_name = pre_op_desc->GetName() + "_" + MEMCPYASYNC; | |||||
std::string memcpy_type = multi_batch_flag ? MEMCPYADDRASYNC : MEMCPYASYNC; | |||||
std::string node_name = pre_op_desc->GetName() + "_" + memcpy_type; | |||||
node_name = CheckDuplicateName(node_name); | node_name = CheckDuplicateName(node_name); | ||||
GELOGI("Create MemcpyAsync op:%s.", node_name.c_str()); | GELOGI("Create MemcpyAsync op:%s.", node_name.c_str()); | ||||
OpDescPtr op_desc = MakeShared<OpDesc>(node_name, MEMCPYASYNC); | |||||
OpDescPtr op_desc = MakeShared<OpDesc>(node_name, memcpy_type); | |||||
if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
GELOGE(FAILED, "Create op_desc fail, MemcpyAsync:%s.", node_name.c_str()); | GELOGE(FAILED, "Create op_desc fail, MemcpyAsync:%s.", node_name.c_str()); | ||||
return nullptr; | return nullptr; | ||||
@@ -432,9 +437,10 @@ NodePtr SwitchOpPass::CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node) { | |||||
/// @brief Add MemcpyAsync Op as StreamMerge in_node | /// @brief Add MemcpyAsync Op as StreamMerge in_node | ||||
/// @param [in] graph | /// @param [in] graph | ||||
/// @param [in] node | /// @param [in] node | ||||
/// @param [in] multi_batch_flag | |||||
/// @return Status | /// @return Status | ||||
/// | /// | ||||
Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node) { | |||||
Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node, bool multi_batch_flag) { | |||||
GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); | GE_CHK_BOOL_EXEC(node != nullptr, return FAILED, "Param of pre node is null."); | ||||
for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | for (InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
@@ -447,7 +453,7 @@ Status SwitchOpPass::AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &node) | |||||
continue); | continue); | ||||
GE_IF_BOOL_EXEC(type != MEMCPYASYNC, { | GE_IF_BOOL_EXEC(type != MEMCPYASYNC, { | ||||
in_node = CreateMemcpyAsyncNode(graph, peer_out_anchor); | |||||
in_node = CreateMemcpyAsyncNode(graph, peer_out_anchor, multi_batch_flag); | |||||
GE_CHK_BOOL_EXEC(in_node != nullptr, return FAILED, "Create MemcpyAsync node fail."); | GE_CHK_BOOL_EXEC(in_node != nullptr, return FAILED, "Create MemcpyAsync node fail."); | ||||
GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "MemcpyAsync node remove edge fail."); | GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_out_anchor, in_data_anchor), "MemcpyAsync node remove edge fail."); | ||||
GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, in_node->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::AddEdge(peer_out_anchor, in_node->GetInDataAnchor(0)), | ||||
@@ -103,13 +103,13 @@ class SwitchOpPass : public GraphPass { | |||||
NodePtr CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix, | NodePtr CreateStreamSwitchNode(ComputeGraphPtr &graph, const NodePtr &switch_node, const std::string &suffix, | ||||
OutDataAnchorPtr &peer_cond_anchor); | OutDataAnchorPtr &peer_cond_anchor); | ||||
NodePtr CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor); | |||||
NodePtr CreateMemcpyAsyncNode(ComputeGraphPtr &graph, const OutDataAnchorPtr &out_data_anchor, bool multi_batch_flag); | |||||
Status CombineSwitchNode(ComputeGraphPtr &graph); | Status CombineSwitchNode(ComputeGraphPtr &graph); | ||||
NodePtr CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node); | NodePtr CreateActiveNode(ComputeGraphPtr &graph, NodePtr &node); | ||||
Status AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &stream_merge_node); | |||||
Status AddMemcpyAsyncNodes(ComputeGraphPtr &graph, NodePtr &stream_merge_node, bool multi_batch_flag); | |||||
Status BypassSwitchNode(NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, OutDataAnchorPtr &peer_cond_anchor); | Status BypassSwitchNode(NodePtr &switch_node, OutDataAnchorPtr &peer_data_anchor, OutDataAnchorPtr &peer_cond_anchor); | ||||
@@ -22,11 +22,14 @@ | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "external/graph/graph.h" | #include "external/graph/graph.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/common/omg_util.h" | |||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
namespace ge { | namespace ge { | ||||
std::map<std::string, std::map<int, int>> VariablePrepareOpPass::ref_node_without_prototype_map_{ | |||||
{REFSWITCH, {{0, 0}, {0, 1}}}}; | |||||
Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { | Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { | ||||
GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
for (const auto &node : graph->GetDirectNode()) { | for (const auto &node : graph->GetDirectNode()) { | ||||
@@ -43,9 +46,7 @@ Status VariablePrepareOpPass::Run(ComputeGraphPtr graph) { | |||||
for (auto &node : graph->GetDirectNode()) { | for (auto &node : graph->GetDirectNode()) { | ||||
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | ||||
bool is_variable = node->GetOpDesc()->GetType() == VARIABLE; | |||||
bool is_deal = has_dealed_variable_.find(node->GetName()) == has_dealed_variable_.end(); | |||||
if (is_variable && is_deal) { | |||||
if (node->GetOpDesc()->GetType() == VARIABLE) { | |||||
Status ret = DealVariableNode(node); | Status ret = DealVariableNode(node); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "variable add back edge failed"); | GELOGE(ret, "variable add back edge failed"); | ||||
@@ -149,7 +150,7 @@ NodePtr VariablePrepareOpPass::GetFinalWritableNode(ge::NodePtr &writable_node, | |||||
} | } | ||||
} | } | ||||
if (!found_writeable_node) { | if (!found_writeable_node) { | ||||
GELOGI("final writable node is %s", current_node->GetName().c_str()); | |||||
GELOGD("final writable node is %s", current_node->GetName().c_str()); | |||||
return current_node; | return current_node; | ||||
} | } | ||||
} | } | ||||
@@ -159,53 +160,54 @@ Status VariablePrepareOpPass::AddVariableRef(ge::NodePtr &final_writable_node, g | |||||
GE_CHECK_NOTNULL(final_writable_node); | GE_CHECK_NOTNULL(final_writable_node); | ||||
GE_CHECK_NOTNULL(var_node); | GE_CHECK_NOTNULL(var_node); | ||||
NodePtr var_ref_node = CreatVariableRef(final_writable_node, var_node); | |||||
GE_CHECK_NOTNULL(var_ref_node); | |||||
// add control anchor between var_ref_node and final peer node | |||||
// var_ref_node need to execute before other nodes | |||||
if (final_writable_node->GetType() == FRAMEWORKOP) { | |||||
GELOGD("No need to add variable_ref for frameworkop"); | |||||
return SUCCESS; | |||||
} | |||||
std::stringstream variable_ref_name; | |||||
variable_ref_name << "_TO_" << final_writable_node->GetName() << "_REF_" << index; | |||||
ge::NodePtr find_node = var_node->GetOwnerComputeGraph()->FindNode(var_node->GetName() + variable_ref_name.str()); | |||||
if (find_node != nullptr) { | |||||
GELOGD("The corresponding variable_ref [%s] has been added to this connection.", find_node->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
NodePtr variable_ref_node = CreatVariableRef(var_node->GetName() + variable_ref_name.str(), var_node); | |||||
GELOGI("Add variable_ref between [%s] and [%s]", var_node->GetName().c_str(), variable_ref_node->GetName().c_str()); | |||||
GE_CHECK_NOTNULL(variable_ref_node); | |||||
// add control anchor between variable_ref and final peer node | |||||
// variable_ref_node need to execute before other nodes | |||||
auto final_writable_outAnchors = final_writable_node->GetAllOutAnchors(); | auto final_writable_outAnchors = final_writable_node->GetAllOutAnchors(); | ||||
for (auto &final_writable_outAnchor : final_writable_outAnchors) { | for (auto &final_writable_outAnchor : final_writable_outAnchors) { | ||||
GE_CHECK_NOTNULL(final_writable_outAnchor); | GE_CHECK_NOTNULL(final_writable_outAnchor); | ||||
for (auto &final_writable_peerAnchor : final_writable_outAnchor->GetPeerAnchors()) { | for (auto &final_writable_peerAnchor : final_writable_outAnchor->GetPeerAnchors()) { | ||||
GE_CHECK_NOTNULL(final_writable_peerAnchor); | GE_CHECK_NOTNULL(final_writable_peerAnchor); | ||||
NodePtr peer_node = final_writable_peerAnchor->GetOwnerNode(); | NodePtr peer_node = final_writable_peerAnchor->GetOwnerNode(); | ||||
graphStatus ret = ge::GraphUtils::AddEdge(var_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()); | |||||
graphStatus ret = | |||||
ge::GraphUtils::AddEdge(variable_ref_node->GetOutControlAnchor(), peer_node->GetInControlAnchor()); | |||||
if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
GELOGE(FAILED, "add control anchor between var_ref_node and final_writable peer_node failed"); | |||||
GELOGE(FAILED, "add control anchor between variable_ref and final_writable peer node failed"); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
} | } | ||||
} | } | ||||
// add edge final node:index ---> var_ref_node:0 | |||||
graphStatus ret = | graphStatus ret = | ||||
ge::GraphUtils::AddEdge(final_writable_node->GetOutDataAnchor(index), var_ref_node->GetInDataAnchor(0)); | |||||
ge::GraphUtils::AddEdge(final_writable_node->GetOutDataAnchor(index), variable_ref_node->GetInDataAnchor(0)); | |||||
if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
GELOGE(FAILED, "add data anchor between var_ref_node and final_writable peer_node failed"); | |||||
GELOGE(FAILED, "add data anchor between variable_ref and final_writable peer node failed"); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
ge::NodePtr VariablePrepareOpPass::CreatVariableRef(ge::NodePtr &final_writable_node, ge::NodePtr &var_node) { | |||||
if ((final_writable_node == nullptr) || (var_node == nullptr) || (var_node->GetOwnerComputeGraph() == nullptr)) { | |||||
GELOGE(FAILED, "parameter ptr is null."); | |||||
return nullptr; | |||||
} | |||||
GELOGD("Create VarRef Op: final_writable_node: [%s] var_node: [%s]>>>>", final_writable_node->GetName().c_str(), | |||||
var_node->GetName().c_str()); | |||||
static uint32_t var_ref_count = 0; | |||||
std::stringstream var_ref_name; | |||||
var_ref_name << "_to_" << final_writable_node->GetName() << "_REF_" << var_ref_count++; | |||||
ge::NodePtr VariablePrepareOpPass::CreatVariableRef(const std::string &variable_ref_name, ge::NodePtr &var_node) { | |||||
OpDescPtr var_op_desc = var_node->GetOpDesc(); | OpDescPtr var_op_desc = var_node->GetOpDesc(); | ||||
if (var_op_desc == nullptr) { | if (var_op_desc == nullptr) { | ||||
GELOGE(FAILED, "get var opdesc is nullptr"); | GELOGE(FAILED, "get var opdesc is nullptr"); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
OpDescPtr var_ref_op_desc = | |||||
MakeShared<OpDesc>(var_node->GetName() + var_ref_name.str().c_str(), var_op_desc->GetType()); | |||||
OpDescPtr var_ref_op_desc = MakeShared<OpDesc>(variable_ref_name.c_str(), var_op_desc->GetType()); | |||||
if (var_ref_op_desc == nullptr) { | if (var_ref_op_desc == nullptr) { | ||||
GELOGE(FAILED, "var_ref opdesc is nullptr"); | GELOGE(FAILED, "var_ref opdesc is nullptr"); | ||||
return nullptr; | return nullptr; | ||||
@@ -217,15 +219,15 @@ ge::NodePtr VariablePrepareOpPass::CreatVariableRef(ge::NodePtr &final_writable_ | |||||
GE_IF_BOOL_EXEC(var_ref_op_desc->AddInputDesc(var_op_desc->GetOutputDesc(0)) != SUCCESS, | GE_IF_BOOL_EXEC(var_ref_op_desc->AddInputDesc(var_op_desc->GetOutputDesc(0)) != SUCCESS, | ||||
GELOGW("add input desc edge failed"); | GELOGW("add input desc edge failed"); | ||||
return nullptr); | return nullptr); | ||||
NodePtr var_ref_node = var_node->GetOwnerComputeGraph()->AddNode(var_ref_op_desc); | |||||
GE_IF_BOOL_EXEC(var_ref_node == nullptr, GELOGW("var_ref_node is null"); return nullptr); | |||||
has_dealed_variable_.insert(var_node->GetName()); | |||||
NodePtr variable_ref_node = var_node->GetOwnerComputeGraph()->AddNode(var_ref_op_desc); | |||||
GE_IF_BOOL_EXEC(variable_ref_node == nullptr, GELOGW("variable_ref_node is null"); return nullptr); | |||||
bool is_set_str = ge::AttrUtils::SetStr(var_ref_op_desc, REF_VAR_SRC_VAR_NAME, var_op_desc->GetName()); | bool is_set_str = ge::AttrUtils::SetStr(var_ref_op_desc, REF_VAR_SRC_VAR_NAME, var_op_desc->GetName()); | ||||
if (is_set_str) { | if (is_set_str) { | ||||
GELOGD("Set node [%s] REF_VAR_SRC_VAR_NAME [%s]", var_ref_node->GetName().c_str(), var_op_desc->GetName().c_str()); | |||||
GELOGD("Set node [%s] REF_VAR_SRC_VAR_NAME [%s]", variable_ref_node->GetName().c_str(), | |||||
var_op_desc->GetName().c_str()); | |||||
} | } | ||||
return var_ref_node; | |||||
return variable_ref_node; | |||||
} | } | ||||
int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int input_index) { | int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int input_index) { | ||||
@@ -240,16 +242,13 @@ int VariablePrepareOpPass::GetWritableNodeOutIndex(const NodePtr &node, int inpu | |||||
} | } | ||||
} | } | ||||
auto node_iter = ref_input_output_map_.find(node_type); | |||||
if (node_iter == ref_input_output_map_.end()) { | |||||
return -1; | |||||
} | |||||
auto index_iter = node_iter->second.find(input_index); | |||||
if (index_iter == node_iter->second.end()) { | |||||
return -1; | |||||
if (node_type == FRAMEWORKOP) { | |||||
std::string original_type; | |||||
GE_IF_BOOL_EXEC(GetOriginalType(node, original_type) != SUCCESS, GELOGW("Get node original type fail")); | |||||
GELOGI("find frameworkop: [%s], original type is %s", node->GetName().c_str(), original_type.c_str()); | |||||
return FindRefOutIndex(original_type, input_index, ref_node_without_prototype_map_); | |||||
} | } | ||||
return index_iter->second; | |||||
return FindRefOutIndex(node_type, input_index, ref_input_output_map_); | |||||
} | } | ||||
void VariablePrepareOpPass::GenerateRefTypeAndInputOutputMap(const NodePtr &node) { | void VariablePrepareOpPass::GenerateRefTypeAndInputOutputMap(const NodePtr &node) { | ||||
@@ -301,4 +300,18 @@ Status VariablePrepareOpPass::UpdateAssignOpDesc(const ge::NodePtr &node) { | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
int VariablePrepareOpPass::FindRefOutIndex(const std::string &node_type, int input_index, | |||||
const std::map<std::string, std::map<int, int>> &ref_map) { | |||||
auto node_iter = ref_map.find(node_type); | |||||
if (node_iter == ref_map.end()) { | |||||
return -1; | |||||
} | |||||
auto index_iter = node_iter->second.find(input_index); | |||||
if (index_iter == node_iter->second.end()) { | |||||
return -1; | |||||
} | |||||
return index_iter->second; | |||||
} | |||||
} // namespace ge | } // namespace ge |