@@ -42,7 +42,7 @@ class BlockingQueue { | |||||
return false; | return false; | ||||
} | } | ||||
item = queue_.front(); | |||||
item = std::move(queue_.front()); | |||||
queue_.pop_front(); | queue_.pop_front(); | ||||
full_cond_.notify_one(); | full_cond_.notify_one(); | ||||
@@ -71,6 +71,27 @@ class BlockingQueue { | |||||
return true; | return true; | ||||
} | } | ||||
bool Push(T &&item, bool is_wait = true) { | |||||
std::unique_lock<std::mutex> lock(mutex_); | |||||
while (queue_.size() >= max_size_ && !is_stoped_) { | |||||
if (!is_wait) { | |||||
return false; | |||||
} | |||||
full_cond_.wait(lock); | |||||
} | |||||
if (is_stoped_) { | |||||
return false; | |||||
} | |||||
queue_.emplace_back(std::move(item)); | |||||
empty_cond_.notify_one(); | |||||
return true; | |||||
} | |||||
void Stop() { | void Stop() { | ||||
{ | { | ||||
std::unique_lock<std::mutex> lock(mutex_); | std::unique_lock<std::mutex> lock(mutex_); | ||||
@@ -26,6 +26,7 @@ using std::string; | |||||
namespace ge { | namespace ge { | ||||
// when need to eliminate GETaskKernelHcclInfo, so not need DAVINCI_TRAIN/DAVINCI_CLOUD | // when need to eliminate GETaskKernelHcclInfo, so not need DAVINCI_TRAIN/DAVINCI_CLOUD | ||||
struct GETaskKernelHcclInfo { | struct GETaskKernelHcclInfo { | ||||
string input_name; | |||||
string hccl_type; | string hccl_type; | ||||
void *inputDataAddr; | void *inputDataAddr; | ||||
void *outputDataAddr; | void *outputDataAddr; | ||||
@@ -35,6 +36,7 @@ struct GETaskKernelHcclInfo { | |||||
int32_t opType; | int32_t opType; | ||||
int64_t rootId; | int64_t rootId; | ||||
uint64_t workSpaceMemSize; | uint64_t workSpaceMemSize; | ||||
std::vector<int64_t> dims; | |||||
std::vector<rtStream_t> hcclStreamList; | std::vector<rtStream_t> hcclStreamList; | ||||
}; | }; | ||||
@@ -48,7 +50,7 @@ struct GETaskInfo { | |||||
uint32_t privateDefLen; | uint32_t privateDefLen; | ||||
void *opsKernelStorePtr; | void *opsKernelStorePtr; | ||||
GETaskKernelHcclInfo kernelHcclInfo; | |||||
std::vector<GETaskKernelHcclInfo> kernelHcclInfo; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ | #endif // INC_COMMON_OPSKERNEL_GE_TASK_INFO_H_ |
@@ -73,7 +73,7 @@ class OpsKernelInfoStore { | |||||
// only call fe engine interface to compile single op | // only call fe engine interface to compile single op | ||||
virtual Status CompileOp(vector<ge::NodePtr> &node_vec) { return SUCCESS; } | virtual Status CompileOp(vector<ge::NodePtr> &node_vec) { return SUCCESS; } | ||||
virtual Status CompileOpRun(vector<ge::NodePtr> &node_vec) { return SUCCESS; } | |||||
// load task for op | // load task for op | ||||
virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; } | virtual Status LoadTask(GETaskInfo &task) { return SUCCESS; } | ||||
@@ -33,6 +33,7 @@ const char *const OPTION_EXEC_SESSION_ID = "ge.exec.sessionId"; | |||||
const char *const OPTION_EXEC_DEVICE_ID = "ge.exec.deviceId"; | const char *const OPTION_EXEC_DEVICE_ID = "ge.exec.deviceId"; | ||||
const char *const OPTION_EXEC_JOB_ID = "ge.exec.jobId"; | const char *const OPTION_EXEC_JOB_ID = "ge.exec.jobId"; | ||||
const char *const OPTION_EXEC_IS_USEHCOM = "ge.exec.isUseHcom"; | const char *const OPTION_EXEC_IS_USEHCOM = "ge.exec.isUseHcom"; | ||||
const char *const OPTION_EXEC_IS_USEHVD = "ge.exec.isUseHvd"; | |||||
const char *const OPTION_EXEC_RANK_ID = "ge.exec.rankId"; | const char *const OPTION_EXEC_RANK_ID = "ge.exec.rankId"; | ||||
const char *const OPTION_EXEC_POD_NAME = "ge.exec.podName"; | const char *const OPTION_EXEC_POD_NAME = "ge.exec.podName"; | ||||
const char *const OPTION_EXEC_DEPLOY_MODE = "ge.exec.deployMode"; | const char *const OPTION_EXEC_DEPLOY_MODE = "ge.exec.deployMode"; | ||||
@@ -52,6 +53,7 @@ const char *const OPTION_EXEC_PROFILING_OPTIONS = "ge.exec.profilingOptions"; | |||||
const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | const char *const OPTION_EXEC_HCCL_FLAG = "ge.exec.hcclFlag"; | ||||
const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | const char *const OPTION_EXEC_ATOMIC_FLAG = "ge.exec.enable_atomic"; | ||||
const char *const OPTION_EXEC_DISABLE_REUSED_MEMORY = "ge.exec.disableReuseMemory"; | const char *const OPTION_EXEC_DISABLE_REUSED_MEMORY = "ge.exec.disableReuseMemory"; | ||||
const char *const OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION = "ge.exec.isTailingOptimization"; | |||||
// Option key: memory init | // Option key: memory init | ||||
const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; | const char *const GRAPH_MEMORY_MAX_SIZE = "ge.graphMemoryMaxSize"; | ||||
@@ -153,7 +155,7 @@ const std::string STREAM_MAX_PARALLEL_NUM = "ge.streamMaxParallelNum"; | |||||
const std::string OUTPUT_DATATYPE = "ge.outputDatatype"; | const std::string OUTPUT_DATATYPE = "ge.outputDatatype"; | ||||
// congigure opSelectImplmode to setting op select implmode | // congigure opSelectImplmode to setting op select implmode | ||||
const std::string kOpSelectImplmode = "ge.opSelectImplmode"; | |||||
const std::string OP_SELECT_IMPL_MODE = "ge.opSelectImplmode"; | |||||
// configure whether to enable hcom parallel by session constructor options param, | // configure whether to enable hcom parallel by session constructor options param, | ||||
// its value should be "0" or "1", default value is "0" | // its value should be "0" or "1", default value is "0" | ||||
@@ -214,6 +216,9 @@ const char *const ENABLE_PRINT_OP_PASS = "ge.enablePrintOpPass"; | |||||
// Its value should be "true" or "false", default value is "false" | // Its value should be "true" or "false", default value is "false" | ||||
const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; | const char *const ENABLE_SINGLE_STREAM = "ge.enableSingleStream"; | ||||
// Configure input fp16 nodes | |||||
const std::string INPUT_FP16_NODES = "ge.INPUT_NODES_SET_FP16"; | |||||
// Graph run mode | // Graph run mode | ||||
enum GraphRunMode { PREDICTION = 0, TRAIN }; | enum GraphRunMode { PREDICTION = 0, TRAIN }; | ||||
@@ -263,14 +268,37 @@ static const char *const AUTO_TUNE_MODE = ge::AUTO_TUNE_MODE.c_str(); | |||||
static const char *const CORE_TYPE = ge::CORE_TYPE.c_str(); | static const char *const CORE_TYPE = ge::CORE_TYPE.c_str(); | ||||
static const char *const SOC_VERSION = ge::SOC_VERSION.c_str(); | static const char *const SOC_VERSION = ge::SOC_VERSION.c_str(); | ||||
static const char *const ENABLE_SINGLE_STREAM = ge::ENABLE_SINGLE_STREAM; | static const char *const ENABLE_SINGLE_STREAM = ge::ENABLE_SINGLE_STREAM; | ||||
static const char *const AICORE_NUM = ge::AICORE_NUM.c_str(); | |||||
static const char *const FUSION_SWITCH_FILE = ge::FUSION_SWITCH_FILE.c_str(); | |||||
static const char *const ENABLE_SMALL_CHANNEL = ge::ENABLE_SMALL_CHANNEL.c_str(); | |||||
static const char *const QUANT_OPTIMIZE = ge::QUANT_OPTIMIZE.c_str(); | |||||
static const char *const OP_SELECT_IMPL_MODE = ge::OP_SELECT_IMPL_MODE.c_str(); | |||||
static const char *const OUTPUT_TYPE = ge::OUTPUT_DATATYPE.c_str(); | |||||
static const char *const BUFFER_OPTIMIZE = ge::BUFFER_OPTIMIZE.c_str(); | |||||
static const char *const ENABLE_COMPRESS_WEIGHT = ge::ENABLE_COMPRESS_WEIGHT.c_str(); | |||||
static const char *const COMPRESS_WEIGHT_CONF = "compress_weight_conf"; | |||||
static const char *const OUT_NODES = ge::OUTPUT_NODE_NAME.c_str(); | |||||
static const char *const INPUT_FP16_NODES = ge::INPUT_FP16_NODES.c_str(); | |||||
static const char *const LOG_LEVEL = "log"; | |||||
// for interface: aclgrphBuildModel | // for interface: aclgrphBuildModel | ||||
const std::set<std::string> ir_builder_suppported_options = {INPUT_FORMAT, INPUT_SHAPE, DYNAMIC_BATCH_SIZE, | |||||
DYNAMIC_IMAGE_SIZE, INSERT_OP_FILE}; | |||||
const std::set<std::string> ir_builder_suppported_options = { | |||||
INPUT_FORMAT, INPUT_SHAPE, DYNAMIC_BATCH_SIZE, DYNAMIC_IMAGE_SIZE, | |||||
INSERT_OP_FILE, OUTPUT_TYPE, BUFFER_OPTIMIZE, ENABLE_COMPRESS_WEIGHT, | |||||
COMPRESS_WEIGHT_CONF, OUT_NODES, INPUT_FP16_NODES, LOG_LEVEL}; | |||||
// for interface: aclgrphBuildInitialize | // for interface: aclgrphBuildInitialize | ||||
const std::set<std::string> global_options = { | |||||
HEAD_STREAM, CORE_TYPE, SOC_VERSION, PRECISION_MODE, EXEC_DISABLE_REUSED_MEMORY, | |||||
AUTO_TUNE_MODE, ENABLE_SINGLE_STREAM}; | |||||
const std::set<std::string> global_options = {HEAD_STREAM, | |||||
CORE_TYPE, | |||||
SOC_VERSION, | |||||
PRECISION_MODE, | |||||
EXEC_DISABLE_REUSED_MEMORY, | |||||
AUTO_TUNE_MODE, | |||||
ENABLE_SINGLE_STREAM, | |||||
AICORE_NUM, | |||||
FUSION_SWITCH_FILE, | |||||
ENABLE_SMALL_CHANNEL, | |||||
QUANT_OPTIMIZE, | |||||
OP_SELECT_IMPL_MODE}; | |||||
} // namespace ir_option | } // namespace ir_option | ||||
} // namespace ge | } // namespace ge | ||||
@@ -48,12 +48,9 @@ class NamedAttrs; | |||||
class Graph; | class Graph; | ||||
class AttrValue; | class AttrValue; | ||||
using SubgraphBuilder = std::function<Graph(const std::string &name)>; | |||||
using SubgraphBuilder = std::function<Graph()>; | |||||
using OperatorImplPtr = std::shared_ptr<OperatorImpl>; | using OperatorImplPtr = std::shared_ptr<OperatorImpl>; | ||||
class Graph; | |||||
using GraphBuilderCallback = std::function<Graph()>; | |||||
class OpIO; | class OpIO; | ||||
using OutHandler = std::shared_ptr<OpIO>; | using OutHandler = std::shared_ptr<OpIO>; | ||||
using InHandler = std::shared_ptr<OpIO>; | using InHandler = std::shared_ptr<OpIO>; | ||||
@@ -139,12 +136,6 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
void SetInferenceContext(const InferenceContextPtr &inference_context); | void SetInferenceContext(const InferenceContextPtr &inference_context); | ||||
InferenceContextPtr GetInferenceContext() const; | InferenceContextPtr GetInferenceContext() const; | ||||
void SetGraphBuilder(const GraphBuilderCallback &builder); | |||||
graphStatus GetGraphBuilder(GraphBuilderCallback &builder) const; | |||||
void AddSubgraphName(const string &name); | |||||
string GetSubgraphName(int index) const; | |||||
graphStatus VerifyAllAttr(bool disable_common_verifier = false); | graphStatus VerifyAllAttr(bool disable_common_verifier = false); | ||||
size_t GetInputsSize() const; | size_t GetInputsSize() const; | ||||
@@ -265,9 +256,9 @@ class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY Operator { | |||||
Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name); | Operator &SetInput(const string &dst_name, uint32_t dst_index, const Operator &src_oprt, const string &name); | ||||
void SubgraphRegister(const std::string &name, bool dynamic); | |||||
void SubgraphCountRegister(const std::string &name, uint32_t count); | |||||
void SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder); | |||||
void SubgraphRegister(const string &ir_name, bool dynamic); | |||||
void SubgraphCountRegister(const string &ir_name, uint32_t count); | |||||
void SetSubgraphBuilder(const string &ir_name, uint32_t index, const SubgraphBuilder &builder); | |||||
private: | private: | ||||
Operator &SetInput(const string &dst_name, const OutHandler &out_handler); | Operator &SetInput(const string &dst_name, const OutHandler &out_handler); | ||||
@@ -186,56 +186,54 @@ class OpReg { | |||||
Operator::OutputRegister(#x); \ | Operator::OutputRegister(#x); \ | ||||
(void)OpReg() | (void)OpReg() | ||||
#define DYNAMIC_INPUT(x, t) \ | |||||
N(); \ | |||||
__dy_input_##x(); \ | |||||
} \ | |||||
\ | |||||
public: \ | |||||
_THIS_TYPE &create_dynamic_input_##x(unsigned int num, bool isPushBack = true) { \ | |||||
Operator::DynamicInputRegister(#x, num, isPushBack); \ | |||||
return *this; \ | |||||
} \ | |||||
_THIS_TYPE &create_dynamic_input_byindex_##x(unsigned int num, size_t index) { \ | |||||
Operator::DynamicInputRegisterByIndex(#x, num, index); \ | |||||
return *this; \ | |||||
} \ | |||||
TensorDesc get_dynamic_input_desc_##x(unsigned int index) const { return Operator::GetDynamicInputDesc(#x, index); } \ | |||||
graphStatus update_dynamic_input_desc_##x(unsigned int index, const TensorDesc &tensorDesc) { \ | |||||
return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ | |||||
} \ | |||||
_THIS_TYPE &set_dynamic_input_##x(unsigned int dstIndex, Operator &v) { \ | |||||
Operator::SetInput(#x, dstIndex, v); \ | |||||
return *this; \ | |||||
} \ | |||||
_THIS_TYPE &set_dynamic_input_##x(unsigned int dstIndex, Operator &v, const string &srcName) { \ | |||||
Operator::SetInput(#x, dstIndex, v, srcName); \ | |||||
return *this; \ | |||||
} \ | |||||
\ | |||||
private: \ | |||||
void __dy_input_##x() { \ | |||||
#define DYNAMIC_INPUT(x, t) \ | |||||
N(); \ | |||||
__dy_input_##x(); \ | |||||
} \ | |||||
\ | |||||
public: \ | |||||
_THIS_TYPE &create_dynamic_input_##x(uint32_t num, bool isPushBack = true) { \ | |||||
Operator::DynamicInputRegister(#x, num, isPushBack); \ | |||||
return *this; \ | |||||
} \ | |||||
_THIS_TYPE &create_dynamic_input_byindex_##x(uint32_t num, size_t index) { \ | |||||
Operator::DynamicInputRegisterByIndex(#x, num, index); \ | |||||
return *this; \ | |||||
} \ | |||||
TensorDesc get_dynamic_input_desc_##x(uint32_t index) const { return Operator::GetDynamicInputDesc(#x, index); } \ | |||||
graphStatus update_dynamic_input_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ | |||||
return Operator::UpdateDynamicInputDesc(#x, index, tensorDesc); \ | |||||
} \ | |||||
_THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v) { \ | |||||
Operator::SetInput(#x, dstIndex, v); \ | |||||
return *this; \ | |||||
} \ | |||||
_THIS_TYPE &set_dynamic_input_##x(uint32_t dstIndex, Operator &v, const string &srcName) { \ | |||||
Operator::SetInput(#x, dstIndex, v, srcName); \ | |||||
return *this; \ | |||||
} \ | |||||
\ | |||||
private: \ | |||||
void __dy_input_##x() { \ | |||||
(void)OpReg() | (void)OpReg() | ||||
#define DYNAMIC_OUTPUT(x, t) \ | |||||
N(); \ | |||||
__dy_output_##x(); \ | |||||
} \ | |||||
\ | |||||
public: \ | |||||
_THIS_TYPE &create_dynamic_output_##x(unsigned int num, bool isPushBack = true) { \ | |||||
Operator::DynamicOutputRegister(#x, num, isPushBack); \ | |||||
return *this; \ | |||||
} \ | |||||
TensorDesc get_dynamic_output_desc_##x(unsigned int index) const { \ | |||||
return Operator::GetDynamicOutputDesc(#x, index); \ | |||||
} \ | |||||
graphStatus update_dynamic_output_desc_##x(unsigned int index, const TensorDesc &tensorDesc) { \ | |||||
return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \ | |||||
} \ | |||||
\ | |||||
private: \ | |||||
void __dy_output_##x() { \ | |||||
#define DYNAMIC_OUTPUT(x, t) \ | |||||
N(); \ | |||||
__dy_output_##x(); \ | |||||
} \ | |||||
\ | |||||
public: \ | |||||
_THIS_TYPE &create_dynamic_output_##x(uint32_t num, bool isPushBack = true) { \ | |||||
Operator::DynamicOutputRegister(#x, num, isPushBack); \ | |||||
return *this; \ | |||||
} \ | |||||
TensorDesc get_dynamic_output_desc_##x(uint32_t index) const { return Operator::GetDynamicOutputDesc(#x, index); } \ | |||||
graphStatus update_dynamic_output_desc_##x(uint32_t index, const TensorDesc &tensorDesc) { \ | |||||
return Operator::UpdateDynamicOutputDesc(#x, index, tensorDesc); \ | |||||
} \ | |||||
\ | |||||
private: \ | |||||
void __dy_output_##x() { \ | |||||
(void)OpReg() | (void)OpReg() | ||||
#define GRAPH(x) \ | #define GRAPH(x) \ | ||||
@@ -258,29 +256,29 @@ class OpReg { | |||||
Operator::SubgraphCountRegister(#x, 1); \ | Operator::SubgraphCountRegister(#x, 1); \ | ||||
(void)OpReg() | (void)OpReg() | ||||
#define DYNAMIC_GRAPH(x) \ | |||||
N(); \ | |||||
__graph_##x(); \ | |||||
} \ | |||||
\ | |||||
public: \ | |||||
static const string name_graph_##x() { return #x; } \ | |||||
_THIS_TYPE &create_dynamic_subgraph_##x(unsigned int num) { \ | |||||
Operator::SubgraphCountRegister(#x, num); \ | |||||
return *this; \ | |||||
} \ | |||||
SubgraphBuilder get_dynamic_subgraph_builder_##x(unsigned int index) const { \ | |||||
return Operator::GetDynamicSubgraphBuilder(#x, index); \ | |||||
} \ | |||||
Graph get_dynamic_subgraph_##x(unsigned int index) const { return Operator::GetDynamicSubgraph(#x, index); } \ | |||||
_THIS_TYPE &set_dynamic_subgraph_builder_##x(unsigned int index, const SubgraphBuilder &v) { \ | |||||
Operator::SetSubgraphBuilder(#x, index, v); \ | |||||
return *this; \ | |||||
} \ | |||||
\ | |||||
private: \ | |||||
void __graph_##x() { \ | |||||
Operator::SubgraphRegister(#x, true); \ | |||||
#define DYNAMIC_GRAPH(x) \ | |||||
N(); \ | |||||
__graph_##x(); \ | |||||
} \ | |||||
\ | |||||
public: \ | |||||
static const string name_graph_##x() { return #x; } \ | |||||
_THIS_TYPE &create_dynamic_subgraph_##x(uint32_t num) { \ | |||||
Operator::SubgraphCountRegister(#x, num); \ | |||||
return *this; \ | |||||
} \ | |||||
SubgraphBuilder get_dynamic_subgraph_builder_##x(uint32_t index) const { \ | |||||
return Operator::GetDynamicSubgraphBuilder(#x, index); \ | |||||
} \ | |||||
Graph get_dynamic_subgraph_##x(uint32_t index) const { return Operator::GetDynamicSubgraph(#x, index); } \ | |||||
_THIS_TYPE &set_dynamic_subgraph_builder_##x(uint32_t index, const SubgraphBuilder &v) { \ | |||||
Operator::SetSubgraphBuilder(#x, index, v); \ | |||||
return *this; \ | |||||
} \ | |||||
\ | |||||
private: \ | |||||
void __graph_##x() { \ | |||||
Operator::SubgraphRegister(#x, true); \ | |||||
(void)OpReg() | (void)OpReg() | ||||
#define PASTE(g_register, y) g_register##y | #define PASTE(g_register, y) g_register##y | ||||
@@ -24,7 +24,7 @@ | |||||
namespace ge { | namespace ge { | ||||
static const int64_t UNKNOWN_DIM = -1; | static const int64_t UNKNOWN_DIM = -1; | ||||
static const int64_t UNKNOWN_DIM_NUM = -2; | static const int64_t UNKNOWN_DIM_NUM = -2; | ||||
static const std::vector<int64_t> UNKNOWN_SHAPE = {0}; | |||||
static const std::vector<int64_t> UNKNOWN_SHAPE = {-1}; | |||||
static const std::vector<int64_t> UNKNOWN_RANK = {-2}; | static const std::vector<int64_t> UNKNOWN_RANK = {-2}; | ||||
#ifdef HOST_VISIBILITY | #ifdef HOST_VISIBILITY | ||||
@@ -40,6 +40,14 @@ enum FrameworkType { | |||||
FMK_TYPE_RESERVED, | FMK_TYPE_RESERVED, | ||||
}; | }; | ||||
enum OpEngineType { | |||||
ENGINE_SYS = 0, // default engine | |||||
ENGINE_AICORE = 1, | |||||
ENGINE_VECTOR = 2, | |||||
ENGINE_AICUBE = 3, // not support | |||||
ENGINE_AIVECTOR = 4 // not support | |||||
}; | |||||
const char *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM"; | const char *const GE_ENGINE_ATTR_MEM_TYPE_HBM = "HBM"; | ||||
// Data cache, including data address and length | // Data cache, including data address and length | ||||
@@ -141,6 +149,7 @@ struct Options { | |||||
int32_t device_id; | int32_t device_id; | ||||
std::string job_id; | std::string job_id; | ||||
bool isUseHcom; | bool isUseHcom; | ||||
bool isUseHvd; | |||||
bool deployMode; | bool deployMode; | ||||
bool isAICPUMode; | bool isAICPUMode; | ||||
bool enable_atomic; | bool enable_atomic; | ||||
@@ -442,6 +442,7 @@ REGISTER_OPTYPE_DECLARE(STREAMMERGE, "StreamMerge"); | |||||
REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | REGISTER_OPTYPE_DECLARE(ENDGRAPH, "EndGraph"); | ||||
REGISTER_OPTYPE_DECLARE(SEND, "Send"); | REGISTER_OPTYPE_DECLARE(SEND, "Send"); | ||||
REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | REGISTER_OPTYPE_DECLARE(RECV, "Recv"); | ||||
REGISTER_OPTYPE_DECLARE(ENDOFSEQUENCE, "EndOfSequence"); | |||||
REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | REGISTER_OPTYPE_DECLARE(LABELSET, "LabelSet"); | ||||
REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | REGISTER_OPTYPE_DECLARE(LABELGOTO, "LabelGoto"); | ||||
@@ -508,6 +509,12 @@ REGISTER_OPTYPE_DECLARE(DEPTHWISEWEIGHT6D24D, "depthwise_weight_6d_2_4d"); | |||||
REGISTER_OPTYPE_DECLARE(SQRTGRAD, "SqrtGrad"); | REGISTER_OPTYPE_DECLARE(SQRTGRAD, "SqrtGrad"); | ||||
REGISTER_OPTYPE_DECLARE(SIGMOIDGRAD, "SigmoidGrad"); | REGISTER_OPTYPE_DECLARE(SIGMOIDGRAD, "SigmoidGrad"); | ||||
// Horovod operator | |||||
REGISTER_OPTYPE_DECLARE(HVDCALLBACKALLREDUCE, "HorovodAllreduce"); | |||||
REGISTER_OPTYPE_DECLARE(HVDCALLBACKALLGATHER, "HorovodAllgather"); | |||||
REGISTER_OPTYPE_DECLARE(HVDCALLBACKBROADCAST, "HorovodBroadcast"); | |||||
REGISTER_OPTYPE_DECLARE(HVDWAIT, "HorovodWait"); | |||||
enum InputMode { INPUT = 0, CONST }; | enum InputMode { INPUT = 0, CONST }; | ||||
// Definition of the processing status enum of the process module | // Definition of the processing status enum of the process module | ||||
@@ -1,61 +0,0 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_FRAMEWORK_DLOG_LOG_H_ | |||||
#define INC_FRAMEWORK_DLOG_LOG_H_ | |||||
#include <string> | |||||
#if !defined(__ANDROID__) && !defined(ANDROID) | |||||
#include "toolchain/slog.h" | |||||
#else | |||||
#include <android/log.h> | |||||
#endif | |||||
#ifdef _MSC_VER | |||||
#define FUNC_NAME __FUNCTION__ | |||||
#else | |||||
#define FUNC_NAME __PRETTY_FUNCTION__ | |||||
#endif | |||||
#if !defined(__ANDROID__) && !defined(ANDROID) | |||||
#define DAV_LOGI(MOD_NAME, fmt, ...) dlog_info(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define DAV_LOGW(MOD_NAME, fmt, ...) dlog_warn(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define DAV_LOGE(MOD_NAME, fmt, ...) dlog_error(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define DAV_LOGD(MOD_NAME, fmt, ...) dlog_debug(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#define DAV_EVENT(MOD_NAME, fmt, ...) dlog_event(static_cast<int>(GE), "%s:" #fmt, __FUNCTION__, ##__VA_ARGS__) | |||||
#else | |||||
#define DAV_LOGI(MOD_NAME, fmt, ...) \ | |||||
__android_log_print(ANDROID_LOG_INFO, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
#define DAV_LOGW(MOD_NAME, fmt, ...) \ | |||||
__android_log_print(ANDROID_LOG_WARN, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
#define DAV_LOGE(MOD_NAME, fmt, ...) \ | |||||
__android_log_print(ANDROID_LOG_ERROR, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
#define DAV_LOGD(MOD_NAME, fmt, ...) \ | |||||
__android_log_print(ANDROID_LOG_DEBUG, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
#define DAV_EVENT(MOD_NAME, fmt, ...) \ | |||||
__android_log_print(ANDROID_LOG_DEBUG, MOD_NAME, "%s %s(%d)::" #fmt, __FILE__, __FUNCTION__, __LINE__, ##__VA_ARGS__) | |||||
#endif | |||||
#define DLOG_DECLARE(level) \ | |||||
void Log_##level(const char *mod_name, const char *func, const char *file, int line, const char *format, ...) | |||||
namespace domi { | |||||
DLOG_DECLARE(INFO); | |||||
DLOG_DECLARE(WARNING); | |||||
DLOG_DECLARE(ERROR); | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_DLOG_LOG_H_ |
@@ -0,0 +1,113 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_ | |||||
#define INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_ | |||||
#include <memory> | |||||
#include <vector> | |||||
#include "ge_runtime/op_info.h" | |||||
#include "ge_runtime/task_info.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
class DavinciModel { | |||||
public: | |||||
DavinciModel(const std::vector<std::shared_ptr<TaskInfo>> &task_info_list, | |||||
const std::vector<std::shared_ptr<OpInfo>> &data_info_list, | |||||
const std::vector<std::shared_ptr<OpInfo>> &output_info_list, | |||||
const std::vector<std::shared_ptr<OpInfo>> &constant_info_list, | |||||
const std::vector<model_runner::OpInfoPtr> &variable_info_list, | |||||
const std::vector<uint32_t> &wait_active_stream_list, | |||||
const std::vector<uint32_t> &force_copy_stream_list, uint64_t mem_size = 0, uint64_t weight_size = 0, | |||||
uint64_t var_size = 0, uintptr_t logic_mem_base = 0, uintptr_t logic_weight_base = 0, | |||||
uintptr_t logic_var_base = 0, uint32_t stream_num = 0, uint32_t batch_num = 0, uint32_t event_num = 0, | |||||
int32_t priority = 0) | |||||
: task_info_list_(task_info_list), | |||||
data_info_list_(data_info_list), | |||||
output_info_list_(output_info_list), | |||||
constant_info_list_(constant_info_list), | |||||
variable_info_list_(variable_info_list), | |||||
wait_active_stream_list_(wait_active_stream_list), | |||||
force_copy_stream_list_(force_copy_stream_list), | |||||
mem_size_(mem_size), | |||||
weight_size_(weight_size), | |||||
var_size_(var_size), | |||||
logic_mem_base_(logic_mem_base), | |||||
logic_weight_base_(logic_weight_base), | |||||
logic_var_base_(logic_var_base), | |||||
stream_num_(stream_num), | |||||
batch_num_(batch_num), | |||||
event_num_(event_num), | |||||
priority_(priority) {} | |||||
~DavinciModel() {} | |||||
uint64_t GetMemSize() const { return mem_size_; } | |||||
uint64_t GetWeightSize() const { return weight_size_; } | |||||
uint64_t GetVarSize() const { return var_size_; } | |||||
uintptr_t GetLogicMemBase() const { return logic_mem_base_; } | |||||
uintptr_t GetLogicWeightBase() const { return logic_weight_base_; } | |||||
uintptr_t GetLogicVarBase() const { return logic_var_base_; } | |||||
uint32_t GetStreamNum() const { return stream_num_; } | |||||
uint32_t GetBatchNum() const { return batch_num_; } | |||||
uint32_t GetEventNum() const { return event_num_; } | |||||
const std::vector<uint32_t> &GetWaitActiveStreams() const { return wait_active_stream_list_; } | |||||
const std::vector<uint32_t> &GetForceCopyStreams() const { return force_copy_stream_list_; } | |||||
int32_t GetPriority() const { return priority_; } | |||||
const std::vector<std::shared_ptr<TaskInfo>> &GetTaskInfoList() const { return task_info_list_; } | |||||
const std::vector<std::shared_ptr<OpInfo>> &GetDataInfoList() const { return data_info_list_; } | |||||
const std::vector<std::shared_ptr<OpInfo>> &GetOutputInfoList() const { return output_info_list_; } | |||||
const std::vector<std::shared_ptr<OpInfo>> &GetConstantInfoList() const { return output_info_list_; } | |||||
const std::vector<model_runner::OpInfoPtr> &GetVariableInfoList() const { return variable_info_list_; } | |||||
private: | |||||
std::vector<std::shared_ptr<TaskInfo>> task_info_list_; | |||||
std::vector<std::shared_ptr<OpInfo>> data_info_list_; | |||||
std::vector<std::shared_ptr<OpInfo>> output_info_list_; | |||||
std::vector<std::shared_ptr<OpInfo>> constant_info_list_; | |||||
std::vector<model_runner::OpInfoPtr> variable_info_list_; | |||||
std::vector<uint32_t> wait_active_stream_list_; | |||||
std::vector<uint32_t> force_copy_stream_list_; | |||||
uint64_t mem_size_; | |||||
uint64_t weight_size_; | |||||
uint64_t var_size_; | |||||
uintptr_t logic_mem_base_; | |||||
uintptr_t logic_weight_base_; | |||||
uintptr_t logic_var_base_; | |||||
uint32_t stream_num_; | |||||
uint32_t batch_num_; | |||||
uint32_t event_num_; | |||||
int32_t priority_; | |||||
// Disable to copy constructor and assignment operator | |||||
DavinciModel &operator=(const DavinciModel &) = delete; | |||||
DavinciModel(const DavinciModel &) = delete; | |||||
}; | |||||
} // namespace model_runner | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_GE_RUNTIME_DAVINCI_MODEL_H_ |
@@ -0,0 +1,58 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_ | |||||
#define INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_ | |||||
#include <memory> | |||||
#include <unordered_map> | |||||
#include <vector> | |||||
#include "common/ge_inner_error_codes.h" | |||||
#include "common/ge_types.h" | |||||
#include "ge_runtime/davinci_model.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
class RuntimeModel; | |||||
class ModelRunner { | |||||
public: | |||||
static ModelRunner &Instance(); | |||||
bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, | |||||
std::shared_ptr<DavinciModel> davinci_model, std::shared_ptr<ModelListener> listener); | |||||
const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const; | |||||
bool UnloadModel(uint32_t model_id); | |||||
bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); | |||||
bool GetInputOutputDescInfo(uint32_t model_id, bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, | |||||
std::vector<InputOutputDescInfo> *output_desc, std::vector<uint32_t> *input_format, | |||||
std::vector<uint32_t> *output_format); | |||||
private: | |||||
ModelRunner() = default; | |||||
~ModelRunner() = default; | |||||
std::unordered_map<uint32_t, std::shared_ptr<RuntimeModel>> runtime_models_; | |||||
}; | |||||
} // namespace model_runner | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_GE_RUNTIME_MODEL_RUNNER_H_ |
@@ -0,0 +1,72 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_ | |||||
#define INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_ | |||||
#include <memory> | |||||
#include <string> | |||||
#include <vector> | |||||
namespace ge { | |||||
namespace model_runner { | |||||
struct TensorInfo { | |||||
int64_t GetShapeSize() const { | |||||
int64_t res = 1; | |||||
if (dims.empty()) { | |||||
return 0; | |||||
} | |||||
for (auto dim : dims) { | |||||
res *= dim; | |||||
} | |||||
return res; | |||||
} | |||||
int64_t GetDim(uint32_t index) { | |||||
if (index >= dims.size()) { | |||||
return 0; | |||||
} | |||||
return dims[index]; | |||||
} | |||||
std::vector<int64_t> dims; | |||||
uint32_t datatype; | |||||
uint32_t format; | |||||
uint32_t real_dim_cnt; | |||||
uint32_t size; | |||||
bool is_output; | |||||
}; | |||||
struct OpInfo { | |||||
uint32_t index; | |||||
std::string name; | |||||
std::string type; | |||||
bool var_is_broadcast; | |||||
std::vector<uintptr_t> input_addrs; | |||||
std::vector<uintptr_t> output_addrs; | |||||
std::vector<TensorInfo> input_tensors; | |||||
std::vector<TensorInfo> output_tensors; | |||||
std::vector<TensorInfo> weight_tensors; | |||||
std::vector<std::string> src_name; | |||||
std::vector<int64_t> src_index; | |||||
std::string weight_data; | |||||
}; | |||||
using TensorInfoPtr = std::shared_ptr<TensorInfo>; | |||||
using OpInfoPtr = std::shared_ptr<OpInfo>; | |||||
} // namespace model_runner | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_GE_RUNTIME_OP_INFO_H_ |
@@ -0,0 +1,394 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_ | |||||
#define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_ | |||||
#include <stdint.h> | |||||
#include <functional> | |||||
#include <memory> | |||||
#include <string> | |||||
#include <vector> | |||||
#include "cce/taskdown_api.h" | |||||
namespace ge { | |||||
namespace model_runner { | |||||
enum TaskInfoType { | |||||
CCE = 0, | |||||
TBE, | |||||
AICPU, | |||||
LABEL_SET, | |||||
LABEL_SWITCH, | |||||
LABEL_GOTO, | |||||
EVENT_RECORD, | |||||
EVENT_WAIT, | |||||
FUSION_START, | |||||
FUSION_END, | |||||
HCCL, | |||||
PROFILER_TRACE, | |||||
MEMCPY_ASYNC, | |||||
STREAM_SWITCH, | |||||
STREAM_ACTIVE, | |||||
// Insert new task type here | |||||
REVSERVED = 23 | |||||
}; | |||||
class TaskInfo { | |||||
public: | |||||
virtual ~TaskInfo() {} | |||||
uint32_t stream_id() const { return stream_id_; } | |||||
TaskInfoType type() const { return type_; } | |||||
protected: | |||||
TaskInfo(uint32_t stream_id, TaskInfoType type) : stream_id_(stream_id), type_(type) {} | |||||
private: | |||||
uint32_t stream_id_; | |||||
TaskInfoType type_; | |||||
}; | |||||
class CceTaskInfo : public TaskInfo { | |||||
public: | |||||
CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim, | |||||
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, | |||||
const std::vector<uint8_t> &flow_table, const std::vector<uint8_t> &args_offset, bool is_flowtable) | |||||
: TaskInfo(stream_id, TaskInfoType::CCE), | |||||
ctx_(ctx), | |||||
stub_func_(stub_func), | |||||
block_dim_(block_dim), | |||||
args_(args), | |||||
args_size_(args_size), | |||||
sm_desc_(sm_desc), | |||||
flow_table_(flow_table), | |||||
args_offset_(args_offset), | |||||
is_flowtable_(is_flowtable) {} | |||||
~CceTaskInfo() override {} | |||||
cce::ccOpContext cc_context() const { return ctx_; } | |||||
std::string stub_func() const { return stub_func_; } | |||||
uint32_t block_dim() const { return block_dim_; } | |||||
const std::vector<uint8_t> &args() const { return args_; } | |||||
uint32_t args_size() const { return args_size_; } | |||||
const std::vector<uint8_t> &sm_desc() const { return sm_desc_; } | |||||
const std::vector<uint8_t> &flow_table() const { return flow_table_; } | |||||
const std::vector<uint8_t> &args_offset() const { return args_offset_; } | |||||
bool is_flowtable() const { return is_flowtable_; } | |||||
private: | |||||
cce::ccOpContext ctx_; | |||||
std::string stub_func_; | |||||
uint32_t block_dim_; | |||||
std::vector<uint8_t> args_; | |||||
uint32_t args_size_; | |||||
std::vector<uint8_t> sm_desc_; | |||||
std::vector<uint8_t> flow_table_; | |||||
std::vector<uint8_t> args_offset_; | |||||
bool is_flowtable_; | |||||
}; | |||||
class TbeTaskInfo : public TaskInfo { | |||||
public: | |||||
TbeTaskInfo(uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, const std::vector<uint8_t> &args, | |||||
uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, uint32_t binary_size, | |||||
const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs, | |||||
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs) | |||||
: TaskInfo(stream_id, TaskInfoType::TBE), | |||||
stub_func_(stub_func), | |||||
block_dim_(block_dim), | |||||
args_(args), | |||||
args_size_(args_size), | |||||
sm_desc_(sm_desc), | |||||
binary_(binary), | |||||
binary_size_(binary_size), | |||||
meta_data_(meta_data), | |||||
input_data_addrs_(input_data_addrs), | |||||
output_data_addrs_(output_data_addrs), | |||||
workspace_addrs_(workspace_addrs) {} | |||||
~TbeTaskInfo() override {} | |||||
const std::string &stub_func() const { return stub_func_; } | |||||
uint32_t block_dim() const { return block_dim_; } | |||||
const std::vector<uint8_t> &args() const { return args_; } | |||||
uint32_t args_size() const { return args_size_; } | |||||
const std::vector<uint8_t> &sm_desc() const { return sm_desc_; } | |||||
void *binary() const { return binary_; } | |||||
uint32_t binary_size() const { return binary_size_; } | |||||
const std::vector<uint8_t> &meta_data() const { return meta_data_; } | |||||
const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } | |||||
const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } | |||||
const std::vector<void *> &workspace_addrs() const { return workspace_addrs_; } | |||||
void SetBinary(void *binary, uint32_t binary_size) { | |||||
binary_ = binary; | |||||
binary_size_ = binary_size; | |||||
} | |||||
private: | |||||
std::string stub_func_; | |||||
uint32_t block_dim_; | |||||
std::vector<uint8_t> args_; | |||||
uint32_t args_size_; | |||||
std::vector<uint8_t> sm_desc_; | |||||
void *binary_; | |||||
uint32_t binary_size_; | |||||
std::vector<uint8_t> meta_data_; | |||||
std::vector<void *> input_data_addrs_; | |||||
std::vector<void *> output_data_addrs_; | |||||
std::vector<void *> workspace_addrs_; | |||||
}; | |||||
class AicpuTaskInfo : public TaskInfo { | |||||
public: | |||||
AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def, | |||||
const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs) | |||||
: TaskInfo(stream_id, TaskInfoType::AICPU), | |||||
so_name_(so_name), | |||||
kernel_name_(kernel_name), | |||||
node_def_(node_def), | |||||
input_data_addrs_(input_data_addrs), | |||||
output_data_addrs_(output_data_addrs) {} | |||||
~AicpuTaskInfo() override {} | |||||
const std::string &so_name() const { return so_name_; } | |||||
const std::string &kernel_name() const { return kernel_name_; } | |||||
const std::string &node_def() const { return node_def_; } | |||||
const std::vector<void *> &input_data_addrs() const { return input_data_addrs_; } | |||||
const std::vector<void *> &output_data_addrs() const { return output_data_addrs_; } | |||||
private: | |||||
std::string so_name_; | |||||
std::string kernel_name_; | |||||
std::string node_def_; | |||||
std::vector<void *> input_data_addrs_; | |||||
std::vector<void *> output_data_addrs_; | |||||
}; | |||||
class LabelTaskInfo : public TaskInfo { | |||||
public: | |||||
uint32_t label_id() const { return label_id_; } | |||||
protected: | |||||
LabelTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t label_id) | |||||
: TaskInfo(stream_id, type), label_id_(label_id) {} | |||||
virtual ~LabelTaskInfo() override {} | |||||
uint32_t label_id_; | |||||
}; | |||||
class LabelSetTaskInfo : public LabelTaskInfo { | |||||
public: | |||||
LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SET, label_id) {} | |||||
~LabelSetTaskInfo() override {} | |||||
}; | |||||
class LabelSwitchTaskInfo : public LabelTaskInfo { | |||||
public: | |||||
LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_SWITCH, label_id) {} | |||||
~LabelSwitchTaskInfo() override {} | |||||
}; | |||||
class LabelGotoTaskInfo : public LabelTaskInfo { | |||||
public: | |||||
LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id) | |||||
: LabelTaskInfo(stream_id, TaskInfoType::LABEL_GOTO, label_id) {} | |||||
~LabelGotoTaskInfo() override {} | |||||
}; | |||||
class EventTaskInfo : public TaskInfo { | |||||
public: | |||||
uint32_t event_id() const { return event_id_; } | |||||
protected: | |||||
EventTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t event_id) | |||||
: TaskInfo(stream_id, type), event_id_(event_id) {} | |||||
virtual ~EventTaskInfo() override {} | |||||
uint32_t event_id_; | |||||
}; | |||||
class EventRecordTaskInfo : public EventTaskInfo { | |||||
public: | |||||
EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id) | |||||
: EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {} | |||||
~EventRecordTaskInfo() override {} | |||||
}; | |||||
class EventWaitTaskInfo : public EventTaskInfo { | |||||
public: | |||||
EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id) | |||||
: EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {} | |||||
~EventWaitTaskInfo() override {} | |||||
}; | |||||
class FusionStartTaskInfo : public TaskInfo { | |||||
public: | |||||
explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {} | |||||
~FusionStartTaskInfo() override {} | |||||
}; | |||||
class FusionEndTaskInfo : public TaskInfo { | |||||
public: | |||||
explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {} | |||||
~FusionEndTaskInfo() override {} | |||||
}; | |||||
class HcclTaskInfo : public TaskInfo { | |||||
public: | |||||
HcclTaskInfo(uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr, | |||||
void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, | |||||
const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, | |||||
int64_t op_type, int64_t data_type, std::function<bool(void *, void *)> hcom_bind_model, | |||||
std::function<bool(void *)> hcom_unbind_model, | |||||
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task) | |||||
: TaskInfo(stream_id, TaskInfoType::HCCL), | |||||
hccl_type_(hccl_type), | |||||
input_data_addr_(input_data_addr), | |||||
output_data_addr_(output_data_addr), | |||||
workspace_addr_(workspace_addr), | |||||
workspace_size_(workspace_size), | |||||
hccl_stream_num_(hccl_stream_num), | |||||
private_def_(private_def), | |||||
ops_kernel_store_(ops_kernel_store), | |||||
count_(count), | |||||
root_id_(root_id), | |||||
op_type_(op_type), | |||||
data_type_(data_type), | |||||
hcom_bind_model_(hcom_bind_model), | |||||
hcom_unbind_model_(hcom_unbind_model), | |||||
hcom_distribute_task_(hcom_distribute_task) {} | |||||
~HcclTaskInfo() override {} | |||||
const std::string &hccl_type() const { return hccl_type_; } | |||||
void *input_data_addr() const { return input_data_addr_; } | |||||
void *output_data_addr() const { return output_data_addr_; } | |||||
void *workspace_addr() const { return workspace_addr_; } | |||||
int64_t workspace_size() const { return workspace_size_; } | |||||
int64_t hccl_stream_num() const { return hccl_stream_num_; } | |||||
const std::vector<uint8_t> &private_def() const { return private_def_; } | |||||
void *ops_kernel_store() const { return ops_kernel_store_; } | |||||
int32_t count() const { return count_; } | |||||
int64_t root_id() const { return root_id_; } | |||||
int64_t op_type() const { return op_type_; } | |||||
int64_t data_type() const { return data_type_; } | |||||
std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; } | |||||
std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; } | |||||
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const { | |||||
return hcom_distribute_task_; | |||||
} | |||||
private: | |||||
std::string hccl_type_; | |||||
void *input_data_addr_; | |||||
void *output_data_addr_; | |||||
void *workspace_addr_; | |||||
int64_t workspace_size_; | |||||
int64_t hccl_stream_num_; | |||||
std::vector<uint8_t> private_def_; | |||||
void *ops_kernel_store_; | |||||
int32_t count_; | |||||
int64_t root_id_; | |||||
int64_t op_type_; | |||||
int64_t data_type_; | |||||
std::function<bool(void *, void *)> hcom_bind_model_; | |||||
std::function<bool(void *)> hcom_unbind_model_; | |||||
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_; | |||||
}; | |||||
class ProfilerTraceTaskInfo : public TaskInfo { | |||||
public: | |||||
ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat) | |||||
: TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {} | |||||
~ProfilerTraceTaskInfo() override {} | |||||
uint64_t log_id() const { return log_id_; } | |||||
bool notify() const { return notify_; } | |||||
uint32_t flat() const { return flat_; } | |||||
private: | |||||
uint64_t log_id_; | |||||
bool notify_; | |||||
uint32_t flat_; | |||||
}; | |||||
class MemcpyAsyncTaskInfo : public TaskInfo { | |||||
public: | |||||
MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind) | |||||
: TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC), | |||||
dst_(dst), | |||||
dst_max_(dst_max), | |||||
src_(src), | |||||
count_(count), | |||||
kind_(kind) {} | |||||
~MemcpyAsyncTaskInfo() override {} | |||||
void *dst() const { return dst_; } | |||||
uint64_t dst_max() const { return dst_max_; } | |||||
void *src() const { return src_; } | |||||
uint64_t count() const { return count_; } | |||||
uint32_t kind() const { return kind_; } | |||||
private: | |||||
void *dst_; | |||||
uint64_t dst_max_; | |||||
void *src_; | |||||
uint64_t count_; | |||||
int32_t kind_; | |||||
}; | |||||
class StreamSwitchTaskInfo : public TaskInfo { | |||||
public: | |||||
StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond, | |||||
int64_t data_type) | |||||
: TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH), | |||||
true_stream_id_(true_stream_id), | |||||
input_addr_(input_addr), | |||||
value_addr_(value_addr), | |||||
cond_(cond), | |||||
data_type_(data_type) {} | |||||
~StreamSwitchTaskInfo() override {} | |||||
int64_t true_stream_id() const { return true_stream_id_; } | |||||
void *input_addr() const { return input_addr_; } | |||||
void *value_addr() const { return value_addr_; } | |||||
int64_t cond() const { return cond_; } | |||||
int64_t data_type() const { return data_type_; } | |||||
private: | |||||
int64_t true_stream_id_; | |||||
void *input_addr_; | |||||
void *value_addr_; | |||||
int64_t cond_; | |||||
int64_t data_type_; | |||||
}; | |||||
class StreamActiveTaskInfo : public TaskInfo { | |||||
public: | |||||
StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id) | |||||
: TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {} | |||||
~StreamActiveTaskInfo() override {} | |||||
uint32_t active_stream_id() const { return active_stream_id_; } | |||||
private: | |||||
uint32_t active_stream_id_; | |||||
}; | |||||
} // namespace model_runner | |||||
} // namespace ge | |||||
#endif // INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_ |
@@ -23,6 +23,7 @@ | |||||
#include <vector> | #include <vector> | ||||
#include "ge/ge_ir_build.h" | #include "ge/ge_ir_build.h" | ||||
#include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
#include "common/ge_types.h" | |||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
@@ -30,9 +31,13 @@ | |||||
namespace ge { | namespace ge { | ||||
class GeGenerator { | class GeGenerator { | ||||
public: | public: | ||||
static GeGenerator &GetInstance() { | |||||
static GeGenerator Instance; | |||||
return Instance; | |||||
} | |||||
GeGenerator() = default; | GeGenerator() = default; | ||||
~GeGenerator() = default; | |||||
~GeGenerator() { (void)Finalize(); } | |||||
GeGenerator(const GeGenerator &) = delete; | GeGenerator(const GeGenerator &) = delete; | ||||
@@ -60,10 +65,25 @@ class GeGenerator { | |||||
/// | /// | ||||
Status BuildSingleOpModel(OpDescPtr &op_desc, const std::vector<GeTensor> &inputs, | Status BuildSingleOpModel(OpDescPtr &op_desc, const std::vector<GeTensor> &inputs, | ||||
const std::vector<GeTensor> &outputs, const std::string &model_file_name); | const std::vector<GeTensor> &outputs, const std::string &model_file_name); | ||||
/// | |||||
/// @ingroup ge | |||||
/// @brief: Build single Op into model buff. | |||||
/// @param [in] op_desc: the OP description. | |||||
/// @param [in] inputs: input tensors. | |||||
/// @param [in] outputs: output tensors. | |||||
/// @param [in] engine_type: specific engine. | |||||
/// @param [out] model_buff: model buff of single op. | |||||
/// @return SUCCESS or FAILED | |||||
Status BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs, | |||||
OpEngineType engine_type, ModelBufferData &model_buff); | |||||
private: | private: | ||||
Status GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | Status GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | ||||
ge::ModelBufferData &model, bool is_offline = true); | ge::ModelBufferData &model, bool is_offline = true); | ||||
Status BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs, | |||||
const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, | |||||
bool is_offline = true); | |||||
class Impl; | class Impl; | ||||
std::shared_ptr<Impl> impl_; | std::shared_ptr<Impl> impl_; | ||||
@@ -0,0 +1,113 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_FRAMEWORK_OMG_OMG_H_ | |||||
#define INC_FRAMEWORK_OMG_OMG_H_ | |||||
#include <google/protobuf/message.h> | |||||
#include <string> | |||||
#include <unordered_map> | |||||
#include <vector> | |||||
#include "framework/common/types.h" | |||||
#include "framework/omg/omg_inner_types.h" | |||||
#include "proto/ge_ir.pb.h" | |||||
#include "proto/om.pb.h" | |||||
#include "graph/compute_graph.h" | |||||
#include "graph/graph.h" | |||||
#include "graph/model.h" | |||||
#include "runtime/kernel.h" | |||||
using domi::Status; | |||||
using std::pair; | |||||
using std::string; | |||||
using std::unordered_map; | |||||
using std::vector; | |||||
namespace ge { | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief init omg context | |||||
* @return void | |||||
*/ | |||||
Status InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, | |||||
bool is_dynamic_input); | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief generate graph based on the input model file and weight file | |||||
* @param [out] graph graph | |||||
* @param [in] model_file path of model file | |||||
* @param [in] weights_file path of weight file | |||||
* @param [in] type type of the input model | |||||
* @param [in] op_conf op mapping configuration | |||||
* @param [in] target type of platform. If a tiny model is generated, set target to tiny | |||||
* @param [in] run_mode run model | |||||
* @param [in] enable_l2dynamic enable l2dynamic | |||||
* @param [in] is_dynamic_input dynamic input, true of false | |||||
* @param [in] atc_params multiply atc params | |||||
* @return Status result code | |||||
*/ | |||||
Status ParseGraph(ge::Graph &graph, const std::map<string, string> &atc_params, const char *model_file, | |||||
const char *weights_file, domi::FrameworkType type, const char *op_conf = nullptr, | |||||
const char *target = nullptr, RunMode run_mode = GEN_OM_MODEL, bool is_dynamic_input = false); | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief generates a simplified JSON file based on the key value of the offline model file in protobuf format | |||||
* @param [in] model_file path of offline model file | |||||
* @param [out] json_file path of json file | |||||
* @param [key] encrypted key | |||||
* @return Status result code | |||||
*/ | |||||
Status ConvertOmModelToJson(const char *model_file, const char *json_file); | |||||
Status ConvertPbtxtToJson(const char *model_file, const char *json_file); | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief convert the model file in protobuf format into a JSON file. | |||||
* @param [in] framework type of model | |||||
* @param [in] om model_file path of offline model file | |||||
* @param [out] json_file path of json file | |||||
* @param [key] encrypted key | |||||
* @return Status result code | |||||
*/ | |||||
Status ConvertFwkModelToJson(domi::FrameworkType framework, const char *model_file, const char *json_file); | |||||
void GetGroupName(ge::proto::ModelDef &model); | |||||
void FindParserSo(const string &path, vector<string> &fileList, string &caffe_parser_path); | |||||
Status CheckCustomAiCpuOpLib(); | |||||
Status DumpInfershapeJson(const ge::Graph &graph, const char *json_file); | |||||
Status SetOutputNodeInfo(ge::Graph &graph, const std::string &output_type, const std::string &output_format); | |||||
Status GetOutputLeaf(ge::NodePtr node, std::vector<std::pair<ge::NodePtr, int32_t>> &output_nodes_info, | |||||
std::vector<std::string> &output_nodes_name); | |||||
} // namespace ge | |||||
namespace domi { | |||||
/** | |||||
* @ingroup domi_omg | |||||
* @brief get omg context | |||||
* @return reference of OmgContext | |||||
*/ | |||||
ge::OmgContext &GetContext(); | |||||
} // namespace domi | |||||
#endif // INC_FRAMEWORK_OMG_OMG_H_ |
@@ -44,11 +44,10 @@ namespace ge { | |||||
* @brief run model | * @brief run model | ||||
*/ | */ | ||||
enum RunMode { | enum RunMode { | ||||
GEN_OM_MODEL = 0, // generate offline model file | |||||
MODEL_TO_JSON = 1, // convert to JSON file | |||||
MODEL_TO_JSON_WITH_SHAPE = 2, // convert to json file with shape | |||||
ONLY_PRE_CHECK = 3, // only for pre-check | |||||
PBTXT_TO_JSON = 5 // pbtxt to json | |||||
GEN_OM_MODEL = 0, // generate offline model file | |||||
MODEL_TO_JSON = 1, // convert to JSON file | |||||
ONLY_PRE_CHECK = 3, // only for pre-check | |||||
PBTXT_TO_JSON = 5 // pbtxt to json | |||||
}; | }; | ||||
/// | /// | ||||
@@ -93,6 +92,8 @@ struct OmgContext { | |||||
std::map<std::string, std::vector<int32_t>> out_nodes_map; | std::map<std::string, std::vector<int32_t>> out_nodes_map; | ||||
// user-designate out nodes (this is used for determing the orders) | // user-designate out nodes (this is used for determing the orders) | ||||
std::vector<std::pair<std::string, int32_t>> user_out_nodes; | std::vector<std::pair<std::string, int32_t>> user_out_nodes; | ||||
// net out nodes (where user_out_nodes or leaf nodes) | |||||
std::vector<std::string> net_out_nodes; | |||||
// path for the aicpu custom operator so_file | // path for the aicpu custom operator so_file | ||||
std::vector<std::string> aicpu_op_run_paths; | std::vector<std::string> aicpu_op_run_paths; | ||||
// ddk version | // ddk version | ||||
@@ -235,6 +235,8 @@ class ComputeGraph : public std::enable_shared_from_this<ComputeGraph>, public A | |||||
std::vector<NodePtr> &stack); | std::vector<NodePtr> &stack); | ||||
graphStatus BFSTopologicalSorting(std::vector<NodePtr> &node_vec, std::map<NodePtr, uint32_t> &map_in_edge_num, | graphStatus BFSTopologicalSorting(std::vector<NodePtr> &node_vec, std::map<NodePtr, uint32_t> &map_in_edge_num, | ||||
std::deque<NodePtr> &stack); | std::deque<NodePtr> &stack); | ||||
graphStatus BFSTopologicalSortingWithGroup(std::vector<NodePtr> &node_vec, | |||||
std::map<NodePtr, uint32_t> &map_in_edge_num, std::deque<NodePtr> &stack); | |||||
graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | graphStatus CollectBreadthOutNode(const NodePtr &node, std::map<NodePtr, uint32_t> &map_in_edge_num, | ||||
std::map<string, NodePtr> &breadth_node_map); | std::map<string, NodePtr> &breadth_node_map); | ||||
graphStatus TopologicalSortingGraph(); | graphStatus TopologicalSortingGraph(); | ||||
@@ -94,6 +94,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FORMAT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_FORMAT; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STORAGE_SHAPE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILTER_FORMAT; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_FILTER_FORMAT; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_K; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LRN_K; | ||||
@@ -133,6 +137,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string NEW_AIPP_CONV_OP; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SESSION_GRAPH_ID; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_PARENT_GRAPH_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE; | ||||
@@ -692,6 +697,8 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MOD | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_OUT_NODES_NAME; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_WEIGHT_SIZE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR; | ||||
@@ -920,6 +927,7 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_STREAM_SWITCH_COND; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_ACTIVE_STREAM_LIST; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCHN_PRED_VALUE; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG; | ||||
@@ -999,14 +1007,17 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SWITCH_FOR_L2_FUSION; | ||||
// functional ops attr | // functional ops attr | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_THEN_BRANCH; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_IF_ELSE_BRANCH; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_COND; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_WHILE_BODY; | ||||
// used for label switch | // used for label switch | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_INDEX; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_LABEL_SWITCH_LIST; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_SUBGRAPH_END_NODE; | |||||
// Varible | |||||
// Variable | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_SRC_VAR_NAME; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string VAR_ATTR_SRC_VAR_NAME; | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string REF_VAR_PRE_PEER_OUT_INDEX; | ||||
@@ -1032,6 +1043,13 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAM | |||||
// Dynamic stitch | // Dynamic stitch | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string DYNAMIC_STITCH_ATTR_NAME_NUM; | ||||
// Used for support Horovod | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_INTER_EVENT_IDENTIFY; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE; | |||||
// for gradient group | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_GROUP; | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY extern const std::string ATTR_NAME_HCCL_FUSED_FLAG; | |||||
} // namespace ge | } // namespace ge | ||||
#endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ | #endif // INC_GRAPH_DEBUG_GE_ATTR_DEFINE_H_ |
@@ -264,6 +264,8 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name); | graphStatus SetSubgraphInstanceName(uint32_t index, const std::string &name); | ||||
void RemoveSubgraphInstanceName(const std::string &name); | void RemoveSubgraphInstanceName(const std::string &name); | ||||
graphStatus GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const; | |||||
protected: | protected: | ||||
ProtoAttrMapHelper MutableAttrMap() override; | ProtoAttrMapHelper MutableAttrMap() override; | ||||
ConstProtoAttrMapHelper GetAttrMap() const override; | ConstProtoAttrMapHelper GetAttrMap() const override; | ||||
@@ -288,7 +290,7 @@ class OpDesc : public std::enable_shared_from_this<OpDesc>, public AttrHolder { | |||||
// subgraph ir names to type, for a `if` operator: | // subgraph ir names to type, for a `if` operator: | ||||
// then_branch: static | // then_branch: static | ||||
// else_branch: dynamic | |||||
// else_branch: static | |||||
// or for a `case` op: | // or for a `case` op: | ||||
// branches: dynamic | // branches: dynamic | ||||
std::map<std::string, SubgraphType> subgraph_ir_names_to_type_; | std::map<std::string, SubgraphType> subgraph_ir_names_to_type_; | ||||
@@ -0,0 +1,46 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ | |||||
#define INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ | |||||
#include <map> | |||||
#include <memory> | |||||
#include <mutex> | |||||
#include <vector> | |||||
#include "external/graph/ge_error_codes.h" | |||||
#include "external/graph/tensor.h" | |||||
namespace ge { | |||||
class GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY RuntimeInferenceContext { | |||||
public: | |||||
static graphStatus GetContext(const std::string &context_id, RuntimeInferenceContext **ctx); | |||||
static graphStatus CreateContext(const std::string &context_id); | |||||
static void DestroyContext(const std::string &context_id); | |||||
graphStatus SetTensor(int64_t node_id, int output_id, Tensor &&tensor); | |||||
graphStatus GetTensor(int64_t node_id, int output_id, Tensor &tensor); | |||||
private: | |||||
std::map<int64_t, std::vector<Tensor>> tensors_; | |||||
std::mutex mu_; | |||||
static std::map<std::string, std::unique_ptr<RuntimeInferenceContext>> contexts_; | |||||
static std::mutex ctx_mu_; | |||||
}; | |||||
} // namespace ge | |||||
#endif // INC_GRAPH_RUNTIME_INFERENCE_CONTEXT_H_ |
@@ -29,6 +29,18 @@ | |||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
#include "graph/model.h" | #include "graph/model.h" | ||||
#define GE_DUMP(compute_graph, name) \ | |||||
do { \ | |||||
GraphUtils::DumpGEGraph(compute_graph, name); \ | |||||
GraphUtils::DumpGEGraphToOnnx(*compute_graph, name); \ | |||||
for (const auto &sub_graph_func : compute_graph->GetAllSubgraphs()) { \ | |||||
static int8_t i = 0; \ | |||||
auto sub_graph_func_name = std::string(name) + std::string("_sub_graph_") + std::to_string(i++); \ | |||||
GraphUtils::DumpGEGraph(sub_graph_func, sub_graph_func_name); \ | |||||
GraphUtils::DumpGEGraphToOnnx(*sub_graph_func, sub_graph_func_name); \ | |||||
} \ | |||||
} while (0) | |||||
#define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ | #define REFER_ATTR_VALUE(VT_ENUM, DataType, attr, ret) \ | ||||
do { \ | do { \ | ||||
DataType ret; \ | DataType ret; \ | ||||
@@ -155,6 +167,8 @@ class GraphUtils { | |||||
static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, | static graphStatus InsertNodeBetweenDataAnchors(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, | ||||
const NodePtr &new_node); | const NodePtr &new_node); | ||||
static graphStatus RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node); | |||||
static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node); | static graphStatus RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node); | ||||
static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, | static graphStatus InsertTransNode(ComputeGraphPtr compute_graph, const InDataAnchorPtr &in_data_anchor, | ||||
@@ -299,6 +313,24 @@ class GraphUtils { | |||||
std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | std::map<std::string, std::vector<NodeIndexIO>> &symbol_to_anchors, | ||||
std::map<std::string, std::string> &anchor_to_symbol); | std::map<std::string, std::string> &anchor_to_symbol); | ||||
/// | |||||
/// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs | |||||
/// of the graph have UNKNOWN_SHAPE operators or not. | |||||
/// Note: This function will only look 'down' from the graph, not 'up'. For example, the following | |||||
/// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE | |||||
/// ROOT graph: A -----> B -----> C | |||||
/// K subgraph U | |||||
/// | | |||||
/// V | |||||
/// SUB graph: D --> E --> F | |||||
/// K K K | |||||
/// @param [in] graph | |||||
/// @return bool | |||||
/// | |||||
static bool IsUnknownShapeGraph(const ComputeGraphPtr &graph); | |||||
static NodePtr FindNodeFromAllNodes(ComputeGraphPtr &graph, const std::string &name); | |||||
private: | private: | ||||
/// | /// | ||||
/// Get reference-mapping for in_data_anchors of node | /// Get reference-mapping for in_data_anchors of node | ||||
@@ -438,6 +470,11 @@ class ComputeGraphBuilder { | |||||
/// | /// | ||||
NodePtr GetNode(const std::string &name); | NodePtr GetNode(const std::string &name); | ||||
/// @brief Get all nodes | |||||
/// @return std::vector<NodePtr> | |||||
/// | |||||
std::vector<NodePtr> GetAllNodes(); | |||||
protected: | protected: | ||||
/// | /// | ||||
/// @brief Build nodes | /// @brief Build nodes | ||||
@@ -535,6 +572,13 @@ class CompleteGraphBuilder : public ComputeGraphBuilder { | |||||
/// | /// | ||||
CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind); | CompleteGraphBuilder &AddOutput(const std::string &owner_node_name, uint32_t anchor_ind); | ||||
/// | |||||
/// @brief Add target for graph | |||||
/// @param [in] target_name | |||||
/// @return CompleteGraphBuilder | |||||
/// | |||||
CompleteGraphBuilder &AddTarget(const std::string &target_name); | |||||
/// | /// | ||||
/// @brief Set parent-node of graph | /// @brief Set parent-node of graph | ||||
/// @param [in] parent_node | /// @param [in] parent_node | ||||
@@ -590,10 +634,19 @@ class CompleteGraphBuilder : public ComputeGraphBuilder { | |||||
/// | /// | ||||
void AddRetValNodes(graphStatus &error_code, std::string &error_msg); | void AddRetValNodes(graphStatus &error_code, std::string &error_msg); | ||||
/// | |||||
/// @brief Build target-nodes for graph | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return void | |||||
/// | |||||
void BuildGraphTargets(graphStatus &error_code, std::string &error_msg); | |||||
std::string name_; | std::string name_; | ||||
NodePtr parent_node_; | NodePtr parent_node_; | ||||
std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_; | std::map<uint32_t, std::pair<std::vector<std::string>, std::vector<uint32_t>>> graph_inputs_; | ||||
std::vector<std::pair<std::string, uint32_t>> graph_outputs_; | std::vector<std::pair<std::string, uint32_t>> graph_outputs_; | ||||
std::vector<std::string> graph_targets_; | |||||
// index_of_graph_input -> in_anchor_index_of_parent_node | // index_of_graph_input -> in_anchor_index_of_parent_node | ||||
std::map<uint32_t, uint32_t> input_mapping_; | std::map<uint32_t, uint32_t> input_mapping_; | ||||
@@ -17,10 +17,23 @@ | |||||
#ifndef INC_GRAPH_UTILS_NODE_UTILS_H_ | #ifndef INC_GRAPH_UTILS_NODE_UTILS_H_ | ||||
#define INC_GRAPH_UTILS_NODE_UTILS_H_ | #define INC_GRAPH_UTILS_NODE_UTILS_H_ | ||||
#include <set> | |||||
#include <map> | #include <map> | ||||
#include <vector> | #include <vector> | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
namespace ge { | namespace ge { | ||||
// Op types of Const like Opps. | |||||
extern const std::set<std::string> kConstOpTypes; | |||||
// Op types of If like Opps. | |||||
extern const std::set<std::string> kIfOpTypes; | |||||
// Op types of While like Opps. | |||||
extern const std::set<std::string> kWhileOpTypes; | |||||
// Op types of Case like Opps. | |||||
extern const std::set<std::string> kCaseOpTypes; | |||||
// Op types of For like Opps. | |||||
extern const std::set<std::string> kForOpTypes; | |||||
class NodeUtils { | class NodeUtils { | ||||
public: | public: | ||||
static graphStatus AddSendEventId(const NodePtr &node, const uint32_t &event_id); | static graphStatus AddSendEventId(const NodePtr &node, const uint32_t &event_id); | ||||
@@ -94,6 +107,13 @@ class NodeUtils { | |||||
/// | /// | ||||
static bool GetConstOpType(const NodePtr &in_node, std::string &op_type); | static bool GetConstOpType(const NodePtr &in_node, std::string &op_type); | ||||
/// | |||||
/// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. | |||||
/// @param [in] node | |||||
/// @return return GRAPH_SUCCESS if remove successfully, other for failed. | |||||
/// | |||||
static graphStatus RemoveSubgraphsOnNode(const NodePtr &node); | |||||
private: | private: | ||||
static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | static std::map<NodePtr, std::vector<uint32_t>> map_send_info_; | ||||
static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_; | static std::map<NodePtr, std::vector<uint32_t>> map_recv_info_; | ||||
@@ -24,6 +24,7 @@ | |||||
#include "graph/ge_error_codes.h" | #include "graph/ge_error_codes.h" | ||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/usr_types.h" | #include "graph/usr_types.h" | ||||
#include "register/register_types.h" | |||||
namespace ge { | namespace ge { | ||||
class TypeUtils { | class TypeUtils { | ||||
@@ -37,6 +38,7 @@ class TypeUtils { | |||||
static std::string FormatToSerialString(Format format); | static std::string FormatToSerialString(Format format); | ||||
static Format SerialStringToFormat(const std::string &str); | static Format SerialStringToFormat(const std::string &str); | ||||
static Format DataFormatToFormat(const std::string &str); | static Format DataFormatToFormat(const std::string &str); | ||||
static Format DomiFormatToFormat(domi::domiTensorFormat_t domi_format); | |||||
static graphStatus Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def); | static graphStatus Usr2DefQuantizeFactorParams(const UsrQuantizeFactorParams &usr, QuantizeFactorParams &def); | ||||
static graphStatus Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr); | static graphStatus Def2UsrQuantizeFactorParams(const QuantizeFactorParams &def, UsrQuantizeFactorParams &usr); | ||||
@@ -36,6 +36,75 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
const size_t OUTPUT_PARAM_SIZE = 2; | const size_t OUTPUT_PARAM_SIZE = 2; | ||||
bool IsUseBFS() { | |||||
string run_mode; | |||||
const int base = 10; | |||||
if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) { | |||||
if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) >= TRAIN) { | |||||
return true; | |||||
} | |||||
} else { | |||||
GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); | |||||
} | |||||
return false; | |||||
} | |||||
bool IsTailingOptimization() { | |||||
string is_tailing_optimization_option; | |||||
auto ret = GetContext().GetOption(ge::OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION, is_tailing_optimization_option); | |||||
if (ret == GRAPH_SUCCESS) { | |||||
GELOGI("Option ge.exec.isTailingOptimization is %s", is_tailing_optimization_option.c_str()); | |||||
// "1" means it's True from frontend option | |||||
return is_tailing_optimization_option == "1"; | |||||
} | |||||
GELOGW("OPTION_EXEC_ENABLE_TAILING_OPTIMIZATION not set, use BFSTopologicalSorting by default."); | |||||
return false; | |||||
} | |||||
bool IsFusedNode(const NodePtr &node) { | |||||
bool is_fused_node = false; | |||||
AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_FLAG, is_fused_node); | |||||
return is_fused_node; | |||||
} | |||||
string GetGroupId(const NodePtr &node) { | |||||
string group_id; | |||||
AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, group_id); | |||||
return group_id; | |||||
} | |||||
bool IsGroupEnd(const NodePtr &node) { | |||||
if (GetGroupId(node).empty()) { | |||||
return false; | |||||
} | |||||
if (node->GetOutDataNodesSize() == 0) { | |||||
return true; | |||||
} | |||||
for (const auto &out_data_node : node->GetOutDataNodes()) { | |||||
if (IsFusedNode(out_data_node)) { | |||||
return true; | |||||
} | |||||
} | |||||
return false; | |||||
} | |||||
void SplitNodeToStack(const std::map<string, NodePtr> &breadth_node_map, string current_group_id, | |||||
std::vector<NodePtr> &stack_input, std::deque<NodePtr> &group_stack, std::deque<NodePtr> &stack) { | |||||
for (const auto &name_node : breadth_node_map) { | |||||
// group first | |||||
string group_id; | |||||
if (AttrUtils::GetStr(name_node.second->GetOpDesc(), ATTR_NAME_HCCL_FUSED_GROUP, group_id)) { | |||||
GELOGI("current node %s, group id: %s , current group id %s", name_node.second->GetName().c_str(), | |||||
group_id.c_str(), current_group_id.c_str()); | |||||
if (!current_group_id.empty() && group_id != current_group_id) { | |||||
GELOGI("node go to input_stack back: %s", name_node.second->GetName().c_str()); | |||||
(void)stack_input.insert(stack_input.begin(), name_node.second); | |||||
} else { | |||||
current_group_id = group_id; | |||||
GELOGI("node go to group_stack: %s", name_node.second->GetName().c_str()); | |||||
(void)group_stack.push_front(name_node.second); | |||||
} | |||||
continue; | |||||
} | |||||
GELOGI("node go to stack: %s ", name_node.second->GetName().c_str()); | |||||
(void)stack.push_front(name_node.second); | |||||
} | |||||
} | |||||
} // namespace | } // namespace | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name) | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY ComputeGraph::ComputeGraph(const std::string &name) | ||||
@@ -546,24 +615,21 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY void ComputeGraph::SetParentNode( | |||||
/// | /// | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | ||||
ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping) { | ComputeGraph::UpdateInputMapping(const std::map<uint32_t, uint32_t> &input_mapping) { | ||||
size_t update_num = 0; | |||||
for (auto &input : nodes_) { | for (auto &input : nodes_) { | ||||
if (update_num >= input_mapping.size()) { | |||||
break; | |||||
} | |||||
uint32_t cur_index = 0; | |||||
if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { | |||||
continue; | |||||
} | |||||
auto iter = input_mapping.find(cur_index); | |||||
if (iter == input_mapping.end()) { | |||||
continue; | |||||
} | |||||
if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||||
GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||||
return GRAPH_FAILED; | |||||
if (input->GetType() == DATA) { | |||||
uint32_t cur_index = 0; | |||||
if (!ge::AttrUtils::GetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, cur_index)) { | |||||
continue; | |||||
} | |||||
auto iter = input_mapping.find(cur_index); | |||||
if (iter == input_mapping.end()) { | |||||
continue; | |||||
} | |||||
if (!ge::AttrUtils::SetInt(input->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX, iter->second)) { | |||||
GE_LOGE("UpdateInputMapping failed: set attr ATTR_NAME_PARENT_NODE_INDEX failed."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | } | ||||
update_num++; | |||||
} | } | ||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
@@ -719,10 +785,10 @@ graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec, | |||||
node = stack_input.back(); | node = stack_input.back(); | ||||
stack_input.pop_back(); | stack_input.pop_back(); | ||||
} | } | ||||
node_vec.push_back(node); | node_vec.push_back(node); | ||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); | GELOGD("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); | ||||
CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map); | CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map); | ||||
for (const auto &name_node : breadth_node_map) { | for (const auto &name_node : breadth_node_map) { | ||||
@@ -730,7 +796,65 @@ graphStatus ComputeGraph::BFSTopologicalSorting(std::vector<NodePtr> &node_vec, | |||||
} | } | ||||
breadth_node_map.clear(); | breadth_node_map.clear(); | ||||
} | } | ||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus ComputeGraph::BFSTopologicalSortingWithGroup(std::vector<NodePtr> &node_vec, | |||||
std::map<NodePtr, uint32_t> &map_in_edge_num, | |||||
std::deque<NodePtr> &stack) { | |||||
GELOGI("Runing_Bfs_Sort_With_Group"); | |||||
std::string current_group_id; | |||||
std::vector<NodePtr> stack_input; | |||||
std::deque<NodePtr> group_stack; | |||||
std::deque<NodePtr> fused_node_stack; | |||||
std::map<string, NodePtr> breadth_node_map; | |||||
// Record the number of non data nodes but no input nodes | |||||
GE_CHK_BOOL_EXEC(SortNodes(stack_input, map_in_edge_num) == GRAPH_SUCCESS, return GRAPH_FAILED, "sort nodes failed"); | |||||
// Only data nodes here | |||||
while (!stack_input.empty() || !stack.empty() || !group_stack.empty()) { | |||||
NodePtr node = nullptr; | |||||
if (!group_stack.empty()) { | |||||
// Traversal node in group has priority | |||||
node = group_stack.back(); | |||||
group_stack.pop_back(); | |||||
} else if (!stack.empty()) { | |||||
node = stack.back(); | |||||
stack.pop_back(); | |||||
} else { | |||||
node = stack_input.back(); | |||||
stack_input.pop_back(); | |||||
} | |||||
if (IsFusedNode(node) && current_group_id.empty()) { | |||||
current_group_id = node->GetName(); | |||||
} | |||||
if (GetGroupId(node).empty() || GetGroupId(node) == current_group_id) { | |||||
node_vec.push_back(node); | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
GELOGI("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); | |||||
} else { | |||||
if (current_group_id.empty()) { | |||||
current_group_id = GetGroupId(node); | |||||
node_vec.push_back(node); | |||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
GELOGI("node_vec.push_back %s", node->GetOpDesc()->GetName().c_str()); | |||||
} else { | |||||
GELOGI("current group id is %s ,node go to input_stack back: %s", current_group_id.c_str(), | |||||
node->GetName().c_str()); | |||||
(void)stack_input.insert(stack_input.begin(), node); | |||||
continue; | |||||
} | |||||
} | |||||
CollectBreadthOutNode(node, map_in_edge_num, breadth_node_map); | |||||
SplitNodeToStack(breadth_node_map, current_group_id, stack_input, group_stack, stack); | |||||
breadth_node_map.clear(); | |||||
// check the end of group | |||||
if (IsGroupEnd(node)) { | |||||
GELOGI("Current node %s is end of group %s.", node->GetName().c_str(), current_group_id.c_str()); | |||||
current_group_id = ""; | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
@@ -751,15 +875,14 @@ graphStatus ComputeGraph::CollectBreadthOutNode(const NodePtr &node, std::map<No | |||||
} | } | ||||
} | } | ||||
} | } | ||||
GE_IF_BOOL_EXEC( | |||||
node->GetOutControlAnchor() != nullptr, for (AnchorPtr peer_in_anchor | |||||
: node->GetOutControlAnchor()->GetPeerAnchors()) { | |||||
if (node->GetOutControlAnchor() != nullptr) { | |||||
for (AnchorPtr peer_in_anchor : node->GetOutControlAnchor()->GetPeerAnchors()) { | |||||
auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); | auto iter = map_in_edge_num.find(peer_in_anchor->GetOwnerNode()); | ||||
if (iter != map_in_edge_num.end() && 0 == --iter->second) { | if (iter != map_in_edge_num.end() && 0 == --iter->second) { | ||||
(void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); | (void)breadth_node_map.emplace(peer_in_anchor->GetOwnerNode()->GetName(), peer_in_anchor->GetOwnerNode()); | ||||
} | } | ||||
}) | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
@@ -796,21 +919,18 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus ComputeGraph::Topolog | |||||
graphStatus ComputeGraph::TopologicalSortingGraph() { | graphStatus ComputeGraph::TopologicalSortingGraph() { | ||||
std::vector<NodePtr> node_vec; | std::vector<NodePtr> node_vec; | ||||
std::map<NodePtr, uint32_t> map_in_edge_num; | std::map<NodePtr, uint32_t> map_in_edge_num; | ||||
bool use_BFS = false; | |||||
string run_mode; | |||||
const int base = 10; | |||||
if (ge::GetContext().GetOption(ge::OPTION_GRAPH_RUN_MODE, run_mode) == GRAPH_SUCCESS && !run_mode.empty()) { | |||||
if (GraphRunMode(std::strtol(run_mode.c_str(), nullptr, base)) >= TRAIN) { | |||||
use_BFS = true; | |||||
} | |||||
} else { | |||||
GELOGW("OPTION_GRAPH_RUN_MODE not set, use BFSTopologicalSorting by default."); | |||||
} | |||||
bool use_BFS = IsUseBFS(); | |||||
bool is_tailing_optimization = IsTailingOptimization(); | |||||
if (use_BFS) { | if (use_BFS) { | ||||
std::deque<NodePtr> stack; | std::deque<NodePtr> stack; | ||||
if (BFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) { | |||||
return GRAPH_FAILED; | |||||
if (is_tailing_optimization) { | |||||
if (BFSTopologicalSortingWithGroup(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
} else { | |||||
if (BFSTopologicalSorting(node_vec, map_in_edge_num, stack) != GRAPH_SUCCESS) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | } | ||||
} else { | } else { | ||||
std::vector<NodePtr> stack; | std::vector<NodePtr> stack; | ||||
@@ -48,6 +48,12 @@ GE_REGISTER_OPTYPE(VARIABLEV2, "VariableV2"); | |||||
GE_REGISTER_OPTYPE(INPUT_TYPE, "Input"); | GE_REGISTER_OPTYPE(INPUT_TYPE, "Input"); | ||||
// Horovod operator | |||||
GE_REGISTER_OPTYPE(HVDCALLBACKALLREDUCE, "hvdCallbackAllreduce"); | |||||
GE_REGISTER_OPTYPE(HVDCALLBACKALLGATHER, "hvdCallbackAllgather"); | |||||
GE_REGISTER_OPTYPE(HVDCALLBACKBROADCAST, "hvdCallbackBroadcast"); | |||||
GE_REGISTER_OPTYPE(HVDWAIT, "hvdWait"); | |||||
GE_REGISTER_OPTYPE(NODE_NAME_NET_OUTPUT, "Node_Output"); | GE_REGISTER_OPTYPE(NODE_NAME_NET_OUTPUT, "Node_Output"); | ||||
GE_REGISTER_OPTYPE(RECV, "Recv"); | GE_REGISTER_OPTYPE(RECV, "Recv"); | ||||
@@ -76,6 +76,10 @@ const std::string ATTR_NAME_ALGO = "algo"; | |||||
const std::string ATTR_NAME_FORMAT = "format"; | const std::string ATTR_NAME_FORMAT = "format"; | ||||
const std::string ATTR_NAME_STORAGE_FORMAT = "storage_format"; | |||||
const std::string ATTR_NAME_STORAGE_SHAPE = "storage_shape"; | |||||
const std::string ATTR_NAME_FILTER_FORMAT = "filter_format"; | const std::string ATTR_NAME_FILTER_FORMAT = "filter_format"; | ||||
const std::string ATTR_NAME_LRN_K = "lrn_k"; | const std::string ATTR_NAME_LRN_K = "lrn_k"; | ||||
@@ -115,6 +119,7 @@ const std::string ATTR_NAME_AIPP = "aipp"; | |||||
const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; | const std::string NEW_AIPP_CONV_OP = "new_conv_op_for_aipp"; | ||||
const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | const std::string ATTR_NAME_SESSION_GRAPH_ID = "_session_graph_id"; | ||||
const std::string ATTR_NAME_PARENT_GRAPH_NAME = "_parent_graph_name"; | |||||
const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; | const std::string ATTR_NAME_MULTISHAPE_BATCHLIST = "multi_shape_batchlist"; | ||||
const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; | const std::string ATTR_NAME_MULTISHAPE_BATCHLIST_SIZE = "multi_shape_batchlist_size"; | ||||
@@ -697,6 +702,8 @@ const std::string ATTR_MODEL_MEMORY_SIZE = "memory_size"; | |||||
const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size"; | const std::string ATTR_MODEL_ZERO_COPY_MEMORY_SIZE = "zero_copy_memory_size"; | ||||
const std::string ATTR_MODEL_OUT_NODES_NAME = "attr_model_out_nodes_name"; | |||||
const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; | const std::string ATTR_MODEL_WEIGHT_SIZE = "weight_size"; | ||||
const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; | const std::string ATTR_MODEL_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; | ||||
@@ -895,6 +902,7 @@ const std::string ATTR_NAME_ACTIVE_STREAM_LIST = "active_stream_list"; | |||||
const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | const std::string ATTR_NAME_SWITCHN_PRED_VALUE = "switch_pred_value"; | ||||
const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; | const std::string ATTR_NAME_ITERATORS_PER_LOOP = "iterations_per_loop"; | ||||
const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; | const std::string ATTR_NAME_FLOW_CTRL_NODE_FLAG = "is_flow_ctrl_node"; | ||||
const std::string ATTR_NAME_SUBGRAPH_FIRST_ACTIVE = "subgraph_first_active"; | |||||
const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | const std::string ATTR_NAME_SWITCH_BRANCH_NODE_LABEL = "_switch_branch_node_label"; | ||||
const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | const std::string ATTR_NAME_SWITCH_TRUE_BRANCH_FLAG = "_switch_true_branch_flag"; | ||||
@@ -973,12 +981,15 @@ const std::string ATTR_NAME_DATA_DUMP_ORIGIN_FORMAT = "_datadump_origin_format"; | |||||
const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | const std::string ATTR_NAME_DATA_DUMP_ORIGIN_DATA_TYPE = "_datadump_origin_data_type"; | ||||
// functional ops attr | // functional ops attr | ||||
const std::string ATTR_NAME_IF_THEN_BRANCH = "then_branch"; | |||||
const std::string ATTR_NAME_IF_ELSE_BRANCH = "else_branch"; | |||||
const std::string ATTR_NAME_WHILE_COND = "cond"; | const std::string ATTR_NAME_WHILE_COND = "cond"; | ||||
const std::string ATTR_NAME_WHILE_BODY = "body"; | const std::string ATTR_NAME_WHILE_BODY = "body"; | ||||
// used for label switch | // used for label switch | ||||
const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | const std::string ATTR_NAME_LABEL_SWITCH_INDEX = "_label_switch_index"; | ||||
const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; | const std::string ATTR_NAME_LABEL_SWITCH_LIST = "_label_switch_list"; | ||||
const std::string ATTR_NAME_SUBGRAPH_END_NODE = "_subgraph_end_node"; | |||||
const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | const std::string ATTR_NAME_INPUT_DATATYPE = "input_datatype"; | ||||
const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | const std::string ATTR_NAME_OUTPUT_DATATYPE = "output_datatype"; | ||||
@@ -990,4 +1001,11 @@ const std::string ATTR_NAME_VALID_INPUT_SHAPE_LIST_LIST = "_valid_input_shape_li | |||||
const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_list_list"; | const std::string ATTR_NAME_VALID_OUTPUT_SHAPE_LIST_LIST = "_valid_output_shape_list_list"; | ||||
const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; | const std::string ATTR_NAME_SLICE_INPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; | ||||
const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; | const std::string ATTR_NAME_SLICE_OUTPUT_OFFSET_LIST_LIST = "_input_offset_list_list"; | ||||
// used for Horovod | |||||
const std::string ATTR_INTER_EVENT_IDENTIFY = "event_id"; | |||||
const std::string ATTR_HOROVOD_ATTR_REDUCE_TYPE = "reduce_op"; | |||||
// used for allreduce tailing optimization | |||||
const std::string ATTR_NAME_HCCL_FUSED_GROUP = "_hccl_fused_group"; | |||||
const std::string ATTR_NAME_HCCL_FUSED_FLAG = "_hccl_fused_node"; | |||||
} // namespace ge | } // namespace ge |
@@ -878,7 +878,7 @@ graphStatus OpDesc::CommonVerify() const { | |||||
// Checking shape of all inputs | // Checking shape of all inputs | ||||
vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims(); | vector<int64_t> ishape = GetInputDescPtr(iname)->GetShape().GetDims(); | ||||
for (int64_t dim : ishape) { | for (int64_t dim : ishape) { | ||||
GE_CHK_BOOL_RET_STATUS(dim >= -1, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", | |||||
GE_CHK_BOOL_RET_STATUS(dim >= -2, GRAPH_FAILED, "operator input %s shape contains negative or zero dimension.", | |||||
iname.c_str()); | iname.c_str()); | ||||
} | } | ||||
} | } | ||||
@@ -1310,4 +1310,25 @@ OpDesc::GetSubgraphTypeByIrName(const std::string &name) const { | |||||
} | } | ||||
return iter->second; | return iter->second; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
OpDesc::GetSubgraphNameByInstanceName(const std::string &instance_name, std::string &subgraph_name) const { | |||||
for (size_t idx = 0; idx < subgraph_instance_names_.size(); ++idx) { | |||||
if (subgraph_instance_names_[idx] != instance_name) { // find subgraph index. | |||||
continue; | |||||
} | |||||
for (auto name_to_index : subgraph_names_to_index_) { | |||||
if (name_to_index.second != idx) { // find subgraph name. | |||||
continue; | |||||
} | |||||
subgraph_name = name_to_index.first; | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -30,9 +30,11 @@ | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
#include "graph/ge_attr_value.h" | #include "graph/ge_attr_value.h" | ||||
#include "graph/ge_context.h" | |||||
#include "graph/ge_tensor.h" | #include "graph/ge_tensor.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
#include "graph/runtime_inference_context.h" | |||||
#include "graph/usr_types.h" | #include "graph/usr_types.h" | ||||
#include "utils/graph_utils.h" | #include "utils/graph_utils.h" | ||||
#include "utils/op_desc_utils.h" | #include "utils/op_desc_utils.h" | ||||
@@ -349,48 +351,54 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||||
InferenceContextPtr GetInferenceContext() const { return inference_context_; } | InferenceContextPtr GetInferenceContext() const { return inference_context_; } | ||||
void SubgraphRegister(const std::string &name, bool dynamic) { | |||||
op_desc_->RegisterSubgraphIrName(name, dynamic ? kDynamic : kStatic); | |||||
void SubgraphRegister(const std::string &ir_name, bool dynamic) { | |||||
op_desc_->RegisterSubgraphIrName(ir_name, dynamic ? kDynamic : kStatic); | |||||
} | } | ||||
void SubgraphCountRegister(const std::string &name, uint32_t count) { | |||||
if (op_desc_->GetSubgraphTypeByIrName(name) == kStatic) { | |||||
op_desc_->AddSubgraphName(name); | |||||
void SubgraphCountRegister(const std::string &ir_name, uint32_t count) { | |||||
if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kStatic) { | |||||
op_desc_->AddSubgraphName(ir_name); | |||||
subgraph_names_to_builders_[ir_name] = nullptr; | |||||
} else { | } else { | ||||
for (uint32_t i = 0; i < count; ++i) { | for (uint32_t i = 0; i < count; ++i) { | ||||
op_desc_->AddSubgraphName(name + std::to_string(i)); | |||||
string key_name = ir_name + std::to_string(i); | |||||
op_desc_->AddSubgraphName(key_name); | |||||
subgraph_names_to_builders_[key_name] = nullptr; | |||||
} | } | ||||
} | } | ||||
subgraph_names_to_builders_[name].resize(count, nullptr); | |||||
} | } | ||||
void SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder) { | |||||
auto iter = subgraph_names_to_builders_.find(name); | |||||
if (iter == subgraph_names_to_builders_.end()) { | |||||
GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u, invalid name", name.c_str(), index); | |||||
return; | |||||
void SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { | |||||
string key_name = ir_name; | |||||
if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { | |||||
key_name += std::to_string(index); | |||||
} | } | ||||
if (iter->second.size() <= index) { | |||||
GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u, excceds the max size %zu", | |||||
name.c_str(), index, iter->second.size()); | |||||
auto it = subgraph_names_to_builders_.find(key_name); | |||||
if (it == subgraph_names_to_builders_.end()) { | |||||
GELOGE(PARAM_INVALID, "Failed to set subgraph builder for name %s index %u.", ir_name.c_str(), index); | |||||
return; | return; | ||||
} | } | ||||
iter->second[index] = builder; | |||||
it->second = builder; | |||||
} | } | ||||
SubgraphBuilder GetSubgraphBuilder(const std::string &name, uint32_t index) const { | |||||
SubgraphBuilder GetSubgraphBuilder(const std::string &ir_name, uint32_t index) const { | |||||
string key_name = ir_name; | |||||
if (op_desc_->GetSubgraphTypeByIrName(ir_name) == kDynamic) { | |||||
key_name += std::to_string(index); | |||||
} | |||||
return GetSubgraphBuilder(key_name); | |||||
} | |||||
SubgraphBuilder GetSubgraphBuilder(const std::string &name) const { | |||||
auto iter = subgraph_names_to_builders_.find(name); | auto iter = subgraph_names_to_builders_.find(name); | ||||
if (iter == subgraph_names_to_builders_.end()) { | if (iter == subgraph_names_to_builders_.end()) { | ||||
GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s index %u, invalid name", name.c_str(), index); | |||||
return nullptr; | |||||
} | |||||
if (iter->second.size() <= index) { | |||||
GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s index %u, excceds the max size %zu", | |||||
name.c_str(), index, iter->second.size()); | |||||
GELOGE(PARAM_INVALID, "Failed to get subgraph builder for name %s", name.c_str()); | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
return iter->second[index]; | |||||
return iter->second; | |||||
} | } | ||||
std::vector<std::string> GetSubgraphNames() const { | std::vector<std::string> GetSubgraphNames() const { | ||||
@@ -408,12 +416,11 @@ class OperatorImpl : public std::enable_shared_from_this<OperatorImpl> { | |||||
private: | private: | ||||
ge::ConstNodePtr node_{nullptr}; | ge::ConstNodePtr node_{nullptr}; | ||||
ge::InferenceContextPtr inference_context_; | ge::InferenceContextPtr inference_context_; | ||||
GraphBuilderCallback graph_builder_callback_; | |||||
std::map<string, std::vector<OpIO>> output_links_{}; | std::map<string, std::vector<OpIO>> output_links_{}; | ||||
std::map<string, OpIO> input_link_{}; | std::map<string, OpIO> input_link_{}; | ||||
std::vector<std::weak_ptr<OperatorImpl>> control_input_link_{}; | std::vector<std::weak_ptr<OperatorImpl>> control_input_link_{}; | ||||
std::vector<std::weak_ptr<OperatorImpl>> control_output_link_{}; | std::vector<std::weak_ptr<OperatorImpl>> control_output_link_{}; | ||||
std::map<std::string, std::vector<SubgraphBuilder>> subgraph_names_to_builders_; | |||||
std::map<std::string, SubgraphBuilder> subgraph_names_to_builders_; | |||||
}; | }; | ||||
// Used to manage OperatorImpl instances created by ge api. | // Used to manage OperatorImpl instances created by ge api. | ||||
@@ -582,6 +589,17 @@ graphStatus Operator::GetInputConstData(const string &dst_name, Tensor &data) co | |||||
return const_op.GetAttr(op::Const::name_attr_value(), data); | return const_op.GetAttr(op::Const::name_attr_value(), data); | ||||
} | } | ||||
} | } | ||||
// Try get from runtime inference context | |||||
auto session_id = std::to_string(GetContext().SessionId()); | |||||
RuntimeInferenceContext *runtime_infer_ctx = nullptr; | |||||
if (RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx) == GRAPH_SUCCESS) { | |||||
GELOGD("To get constant from runtime inference context. session_id = %s", session_id.c_str()); | |||||
auto ret = runtime_infer_ctx->GetTensor(peer_node_ptr->GetOpDesc()->GetId(), out_data_anchor->GetIdx(), data); | |||||
if (ret == GRAPH_SUCCESS) { | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} | |||||
} else { | } else { | ||||
// For outer graph | // For outer graph | ||||
return GetInputConstDataOut(dst_name, data); | return GetInputConstDataOut(dst_name, data); | ||||
@@ -1204,25 +1222,27 @@ void Operator::SubgraphCountRegister(const std::string &name, uint32_t count) { | |||||
operator_impl_->SubgraphCountRegister(name, count); | operator_impl_->SubgraphCountRegister(name, count); | ||||
} | } | ||||
void Operator::SetSubgraphBuilder(const std::string &name, uint32_t index, const SubgraphBuilder &builder) { | |||||
void Operator::SetSubgraphBuilder(const std::string &ir_name, uint32_t index, const SubgraphBuilder &builder) { | |||||
if (operator_impl_ == nullptr) { | if (operator_impl_ == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", name.c_str()); | |||||
GELOGE(GRAPH_FAILED, "operator impl is nullptr, name %s.", ir_name.c_str()); | |||||
return; | return; | ||||
} | } | ||||
operator_impl_->SetSubgraphBuilder(name, index, builder); | |||||
operator_impl_->SetSubgraphBuilder(ir_name, index, builder); | |||||
} | } | ||||
std::vector<std::string> Operator::GetSubgraphNames() const { return operator_impl_->GetSubgraphNames(); } | std::vector<std::string> Operator::GetSubgraphNames() const { return operator_impl_->GetSubgraphNames(); } | ||||
SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &name, uint32_t index) const { | |||||
SubgraphBuilder Operator::GetDynamicSubgraphBuilder(const string &ir_name, uint32_t index) const { | |||||
if (operator_impl_ == nullptr) { | if (operator_impl_ == nullptr) { | ||||
GELOGE(GRAPH_FAILED, "operator impl is nullptr."); | GELOGE(GRAPH_FAILED, "operator impl is nullptr."); | ||||
return nullptr; | return nullptr; | ||||
} | } | ||||
return operator_impl_->GetSubgraphBuilder(name, index); | |||||
return operator_impl_->GetSubgraphBuilder(ir_name, index); | |||||
} | } | ||||
SubgraphBuilder Operator::GetSubgraphBuilder(const string &name) const { return GetDynamicSubgraphBuilder(name, 0); } | |||||
SubgraphBuilder Operator::GetSubgraphBuilder(const string &ir_name) const { | |||||
return GetDynamicSubgraphBuilder(ir_name, 0); | |||||
} | |||||
Graph Operator::GetSubgraph(const string &name) const { | Graph Operator::GetSubgraph(const string &name) const { | ||||
if (operator_impl_ == nullptr) { | if (operator_impl_ == nullptr) { | ||||
@@ -1307,8 +1327,8 @@ class GraphBuilderImpl { | |||||
} | } | ||||
} | } | ||||
GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr, | GE_CHK_BOOL_EXEC(!vec_inputs.empty(), return nullptr, | ||||
"User Input do not include operator such as \ | |||||
Data, Variable operator or operator that has output but no input."); | |||||
"User Input do not include operator such as " | |||||
"Data, Variable operator or operator that has output but no input."); | |||||
auto ret = WalkAllOperators(vec_inputs); | auto ret = WalkAllOperators(vec_inputs); | ||||
GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed."); | GE_CHK_BOOL_EXEC(ret == GRAPH_SUCCESS, return nullptr, "WalkAllOperators failed."); | ||||
@@ -1361,8 +1381,67 @@ class GraphBuilderImpl { | |||||
vec_op_back_forward.push_back(in_link.lock()); | vec_op_back_forward.push_back(in_link.lock()); | ||||
} | } | ||||
que.push(vec_op_back_forward); | que.push(vec_op_back_forward); | ||||
if (WalkAllSubgraphs(node_ptr, op_impl) != GRAPH_SUCCESS) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | } | ||||
} | } | ||||
return MoveSubgraphToRoot(graph_); | |||||
} | |||||
graphStatus WalkAllSubgraphs(const NodePtr &node, const OperatorImplPtr &op_impl) { | |||||
const string name = node->GetName(); | |||||
for (auto &name_idx : op_impl->op_desc_->GetSubgraphNameIndexes()) { | |||||
const SubgraphBuilder &builder = op_impl->GetSubgraphBuilder(name_idx.first); | |||||
GE_CHK_BOOL_EXEC(builder != nullptr, return GRAPH_FAILED, "Node: %s, Get builder failed.", name.c_str()); | |||||
Graph graph = builder(); // Build subgraph from user define builder. | |||||
const ComputeGraphPtr &subgraph = GraphUtils::GetComputeGraph(graph); | |||||
GE_CHK_BOOL_EXEC(subgraph != nullptr, return GRAPH_FAILED, "Node: %s, Build graph failed.", name.c_str()); | |||||
subgraph->SetParentNode(node); | |||||
subgraph->SetParentGraph(graph_); | |||||
if (graph_->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (op_impl->op_desc_->SetSubgraphInstanceName(name_idx.second, subgraph->GetName()) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Failed to set subgraph %s index %u", subgraph->GetName().c_str(), name_idx.second); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus MoveSubgraphToRoot(const ComputeGraphPtr &graph) { | |||||
const ComputeGraphPtr &root_graph = GraphUtils::FindRootGraph(graph); | |||||
if (root_graph == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "Graph: %s, Find root graph failed.", graph->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
if (root_graph == graph) { | |||||
auto subgraphs = graph->GetAllSubgraphs(); | |||||
for (auto &subgraph : subgraphs) { | |||||
if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
} else { | |||||
auto subgraphs = graph->GetAllSubgraphs(); | |||||
for (auto &subgraph : subgraphs) { | |||||
if (root_graph->AddSubgraph(subgraph->GetName(), subgraph) != GRAPH_SUCCESS) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
graph->RemoveSubgraph(subgraph->GetName()); | |||||
if (MoveSubgraphToRoot(subgraph) != GRAPH_SUCCESS) { | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
@@ -1423,11 +1502,22 @@ class GraphBuilderImpl { | |||||
}; | }; | ||||
inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) { | inline bool HasSameNameNode(const ComputeGraphPtr &compute_graph) { | ||||
for (const auto &graph : compute_graph->GetAllSubgraphs()) { | |||||
std::set<string> node_names; | |||||
for (auto const &node : graph->GetDirectNode()) { | |||||
node_names.insert(node->GetName()); | |||||
} | |||||
if (node_names.size() != graph->GetDirectNodesSize()) { | |||||
return true; | |||||
} | |||||
} | |||||
std::set<string> node_names; | std::set<string> node_names; | ||||
for (auto const &node : compute_graph->GetAllNodes()) { | |||||
for (auto const &node : compute_graph->GetDirectNode()) { | |||||
node_names.insert(node->GetName()); | node_names.insert(node->GetName()); | ||||
} | } | ||||
return node_names.size() != compute_graph->GetAllNodes().size(); | |||||
return node_names.size() != compute_graph->GetDirectNodesSize(); | |||||
} | } | ||||
ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector<ge::Operator> &inputs) { | ComputeGraphPtr GraphUtils::CreateGraphFromOperator(const string &name, const vector<ge::Operator> &inputs) { | ||||
@@ -136,17 +136,11 @@ graphStatus RefRelations::Impl::BuildRefRelationsForBranch( | |||||
out_ref_i_all_refs.emplace_back(cell_root); | out_ref_i_all_refs.emplace_back(cell_root); | ||||
for (const auto &ele : ref_o_net_nodes) { | for (const auto &ele : ref_o_net_nodes) { | ||||
RefCell cell_netoutput_in; | RefCell cell_netoutput_in; | ||||
RefCell cell_netoutput_out; | |||||
cell_netoutput_in.node_name = (ele.first)->GetName(); | cell_netoutput_in.node_name = (ele.first)->GetName(); | ||||
cell_netoutput_in.node = ele.first; | cell_netoutput_in.node = ele.first; | ||||
cell_netoutput_in.in_out = NODE_IN; | cell_netoutput_in.in_out = NODE_IN; | ||||
cell_netoutput_in.in_out_idx = ele.second; | cell_netoutput_in.in_out_idx = ele.second; | ||||
cell_netoutput_out.node_name = (ele.first)->GetName(); | |||||
cell_netoutput_out.node = ele.first; | |||||
cell_netoutput_out.in_out = NODE_OUT; | |||||
cell_netoutput_out.in_out_idx = ele.second; | |||||
out_ref_i_all_refs.emplace_back(cell_netoutput_in); | out_ref_i_all_refs.emplace_back(cell_netoutput_in); | ||||
out_ref_i_all_refs.emplace_back(cell_netoutput_out); | |||||
} | } | ||||
node_refs.emplace_back(out_ref_i_all_refs); | node_refs.emplace_back(out_ref_i_all_refs); | ||||
ref_o++; | ref_o++; | ||||
@@ -155,6 +149,7 @@ graphStatus RefRelations::Impl::BuildRefRelationsForBranch( | |||||
} | } | ||||
graphStatus RefRelations::Impl::BuildLookUpTables() { | graphStatus RefRelations::Impl::BuildLookUpTables() { | ||||
GELOGD("start to build look up table!"); | |||||
for (size_t i = 0; i < values_.size(); i++) { | for (size_t i = 0; i < values_.size(); i++) { | ||||
vector<vector<RefCell>> &val = values_[i]; | vector<vector<RefCell>> &val = values_[i]; | ||||
for (const auto &ele : val) { | for (const auto &ele : val) { | ||||
@@ -216,12 +211,7 @@ graphStatus RefRelations::Impl::BuildRefRelationsForWhile( | |||||
cell_netoutput_in.node = ele.first; | cell_netoutput_in.node = ele.first; | ||||
cell_netoutput_in.in_out = NODE_IN; | cell_netoutput_in.in_out = NODE_IN; | ||||
cell_netoutput_in.in_out_idx = ele.second; | cell_netoutput_in.in_out_idx = ele.second; | ||||
cell_netoutput_out.node_name = (ele.first)->GetName(); | |||||
cell_netoutput_out.node = ele.first; | |||||
cell_netoutput_out.in_out = NODE_OUT; | |||||
cell_netoutput_out.in_out_idx = ele.second; | |||||
ref_i_all_refs.emplace_back(cell_netoutput_in); | ref_i_all_refs.emplace_back(cell_netoutput_in); | ||||
ref_i_all_refs.emplace_back(cell_netoutput_out); | |||||
} | } | ||||
node_refs.emplace_back(ref_i_all_refs); | node_refs.emplace_back(ref_i_all_refs); | ||||
ref_i++; | ref_i++; | ||||
@@ -237,13 +227,10 @@ graphStatus RefRelations::Impl::BuildRelationsWithFuncNodeType( | |||||
auto node_type = root_node->GetType(); | auto node_type = root_node->GetType(); | ||||
auto status = GRAPH_SUCCESS; | auto status = GRAPH_SUCCESS; | ||||
if (node_type == kIf || node_type == kCase) { | |||||
if (node_type != kWhile) { | |||||
status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); | status = BuildRefRelationsForBranch(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); | ||||
} else if (node_type == kWhile) { | |||||
status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); | |||||
} else { | } else { | ||||
GELOGE(GRAPH_PARAM_INVALID, "Node type [%s] is not supported for build ref relations!", node_type.c_str()); | |||||
status = GRAPH_PARAM_INVALID; | |||||
status = BuildRefRelationsForWhile(root_node, classed_data_nodes, classed_netoutput_nodes, node_refs); | |||||
} | } | ||||
return status; | return status; | ||||
} | } | ||||
@@ -291,6 +278,7 @@ graphStatus RefRelations::Impl::GetRootGraph(ge::ComputeGraph &graph, ge::Comput | |||||
graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector<NodePtr> &data_nodes, | graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector<NodePtr> &data_nodes, | ||||
vector<vector<NodePtr>> &classed_data_nodes) { | vector<vector<NodePtr>> &classed_data_nodes) { | ||||
GELOGD("start to process subgraph data nodes!"); | |||||
int max_ref_idx = 0; | int max_ref_idx = 0; | ||||
for (const auto &e : data_nodes) { | for (const auto &e : data_nodes) { | ||||
int i; | int i; | ||||
@@ -315,6 +303,7 @@ graphStatus RefRelations::Impl::ProcessSubgraphDataNodes(vector<NodePtr> &data_n | |||||
graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( | graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( | ||||
const vector<NodePtr> &netoutput_nodes, vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes) { | const vector<NodePtr> &netoutput_nodes, vector<vector<std::pair<NodePtr, size_t>>> &classed_netoutput_nodes) { | ||||
GELOGD("[RefRelations]Start to process subgraph netoutput!"); | |||||
for (const auto &sub_netoutput_node : netoutput_nodes) { | for (const auto &sub_netoutput_node : netoutput_nodes) { | ||||
auto op_desc = sub_netoutput_node->GetOpDesc(); | auto op_desc = sub_netoutput_node->GetOpDesc(); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
@@ -340,6 +329,7 @@ graphStatus RefRelations::Impl::ProcessSubgraphNetoutput( | |||||
} | } | ||||
graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { | graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { | ||||
GELOGD("Start to build ref relations!"); | |||||
/* First Step: Get root graph */ | /* First Step: Get root graph */ | ||||
ge::ComputeGraph &root_graph = graph; | ge::ComputeGraph &root_graph = graph; | ||||
auto status = GetRootGraph(graph, root_graph); | auto status = GetRootGraph(graph, root_graph); | ||||
@@ -349,12 +339,12 @@ graphStatus RefRelations::Impl::BuildRefRelations(ge::ComputeGraph &graph) { | |||||
for (const auto &node : graph.GetAllNodes()) { | for (const auto &node : graph.GetAllNodes()) { | ||||
auto node_type = node->GetType(); | auto node_type = node->GetType(); | ||||
if (function_op.find(node_type) == function_op.end()) { | |||||
continue; | |||||
} | |||||
std::vector<NodePtr> ref_nodes; | std::vector<NodePtr> ref_nodes; | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | auto sub_graph_names = op_desc->GetSubgraphInstanceNames(); | ||||
if (sub_graph_names.empty()) { | |||||
continue; | |||||
} | |||||
vector<NodePtr> data_nodes; | vector<NodePtr> data_nodes; | ||||
vector<NodePtr> netoutput_nodes; | vector<NodePtr> netoutput_nodes; | ||||
// Get data and netoutput of sub_graph | // Get data and netoutput of sub_graph | ||||
@@ -0,0 +1,95 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/runtime_inference_context.h" | |||||
#include <cstdint> | |||||
#include "framework/common/debug/ge_log.h" | |||||
namespace ge { | |||||
std::map<std::string, std::unique_ptr<RuntimeInferenceContext>> RuntimeInferenceContext::contexts_; | |||||
std::mutex RuntimeInferenceContext::ctx_mu_; | |||||
graphStatus RuntimeInferenceContext::CreateContext(const std::string &context_id) { | |||||
GELOGI("To create context. session id = %s", context_id.c_str()); | |||||
auto ctx = std::unique_ptr<RuntimeInferenceContext>(new (std::nothrow) RuntimeInferenceContext()); | |||||
if (ctx == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "Failed to create instance of RuntimeInferenceContext. context_id = %s", context_id.c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
auto emplace_ret = contexts_.emplace(context_id, std::move(ctx)); | |||||
if (!emplace_ret.second) { | |||||
GELOGE(GRAPH_FAILED, "Old context not destroyed"); | |||||
return GRAPH_FAILED; | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
void RuntimeInferenceContext::DestroyContext(const std::string &context_id) { | |||||
GELOGI("To destroy context. session id = %s", context_id.c_str()); | |||||
std::lock_guard<std::mutex> lk(ctx_mu_); | |||||
contexts_.erase(context_id); | |||||
} | |||||
graphStatus RuntimeInferenceContext::GetContext(const std::string &context_id, RuntimeInferenceContext **ctx) { | |||||
std::lock_guard<std::mutex> lk(ctx_mu_); | |||||
auto it = contexts_.find(context_id); | |||||
if (it != contexts_.end()) { | |||||
*ctx = it->second.get(); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
GELOGD("Runtime inference context not created. session id = %s", context_id.c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
graphStatus RuntimeInferenceContext::SetTensor(int64_t node_id, int output_id, Tensor &&tensor) { | |||||
std::lock_guard<std::mutex> lk(mu_); | |||||
auto &output_tensors = tensors_[node_id]; | |||||
if (static_cast<uint32_t>(output_id) >= output_tensors.size()) { | |||||
output_tensors.resize(output_id + 1); | |||||
} | |||||
GELOGD("Set tensor for node_id = %ld, output_id = %d", node_id, output_id); | |||||
output_tensors[output_id] = std::move(tensor); | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
graphStatus RuntimeInferenceContext::GetTensor(int64_t node_id, int output_id, Tensor &tensor) { | |||||
if (output_id < 0) { | |||||
GELOGE(GRAPH_PARAM_INVALID, "Invalid output index: %d", output_id); | |||||
return GRAPH_PARAM_INVALID; | |||||
} | |||||
std::lock_guard<std::mutex> lk(mu_); | |||||
auto iter = tensors_.find(node_id); | |||||
if (iter == tensors_.end()) { | |||||
GELOGE(INTERNAL_ERROR, "Node not register. Id = %ld", node_id); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
auto &output_tensors = iter->second; | |||||
if (static_cast<uint32_t>(output_id) >= output_tensors.size()) { | |||||
GELOGE(GRAPH_FAILED, "Node output is not registered. node_id = %ld, output index = %d", node_id, output_id); | |||||
return GRAPH_FAILED; | |||||
} | |||||
GELOGD("Get tensor for node_id = %ld, output_id = %d", node_id, output_id); | |||||
tensor = output_tensors[output_id]; | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} // namespace ge |
@@ -273,6 +273,9 @@ void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const | |||||
auto data_type = TypeUtils::DataTypeToSerialString(input_desc->GetDataType()); | auto data_type = TypeUtils::DataTypeToSerialString(input_desc->GetDataType()); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_dtype:" + std::to_string(i), | AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "input_desc_dtype:" + std::to_string(i), | ||||
&data_type); | &data_type); | ||||
auto data_type_origin = TypeUtils::DataTypeToSerialString(input_desc->GetOriginDataType()); | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, | |||||
"input_desc_origin_dtype:" + std::to_string(i), &data_type_origin); | |||||
auto dims = input_desc->GetShape().GetDims(); | auto dims = input_desc->GetShape().GetDims(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_desc_shape:" + std::to_string(i), | AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "input_desc_shape:" + std::to_string(i), | ||||
&dims); | &dims); | ||||
@@ -346,6 +349,9 @@ void OnnxUtils::AddAttrProtoForOpInAndOutDesc(onnx::NodeProto *node_proto, const | |||||
auto data_type = TypeUtils::DataTypeToSerialString(output_desc->GetDataType()); | auto data_type = TypeUtils::DataTypeToSerialString(output_desc->GetDataType()); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_dtype:" + std::to_string(i), | AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, "output_desc_dtype:" + std::to_string(i), | ||||
&data_type); | &data_type); | ||||
auto origin_data_type = TypeUtils::DataTypeToSerialString(output_desc->GetOriginDataType()); | |||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_STRING, | |||||
"output_desc_origin_dtype:" + std::to_string(i), &origin_data_type); | |||||
auto dims = output_desc->GetShape().GetDims(); | auto dims = output_desc->GetShape().GetDims(); | ||||
AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_desc_shape:" + std::to_string(i), | AddAttrProto(node_proto, onnx::AttributeProto_AttributeType_INTS, "output_desc_shape:" + std::to_string(i), | ||||
&dims); | &dims); | ||||
@@ -61,6 +61,7 @@ const char *const kDumpGraphLevel = "DUMP_GRAPH_LEVEL"; | |||||
const char *const kDumpStrBuild = "Build"; | const char *const kDumpStrBuild = "Build"; | ||||
const char *const kDumpStrPartition = "partition"; | const char *const kDumpStrPartition = "partition"; | ||||
const char *const kDumpStrOptimizeSubgraph = "OptimizeSubGraph"; | const char *const kDumpStrOptimizeSubgraph = "OptimizeSubGraph"; | ||||
const char *const kDumpStrSubgraphFunc = "sub_graph"; | |||||
const char *const kDumpStrAicpu = "Aicpu"; | const char *const kDumpStrAicpu = "Aicpu"; | ||||
}; // namespace | }; // namespace | ||||
@@ -202,6 +203,58 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus GraphUtils::InsertNod | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | |||||
GraphUtils::RemoveSubgraphRecursively(const ComputeGraphPtr &compute_graph, const NodePtr &remove_node) { | |||||
GE_CHECK_NOTNULL(compute_graph); | |||||
if (remove_node == nullptr) { | |||||
GELOGE(GRAPH_FAILED, "The node ptr should not be null."); | |||||
return GRAPH_FAILED; | |||||
} | |||||
// Check if this node is belong to this compute graph, maybe a little slow | |||||
const auto &all_nodes_in_graph = compute_graph->GetDirectNode(); | |||||
if (std::find(all_nodes_in_graph.begin(), all_nodes_in_graph.end(), remove_node) == all_nodes_in_graph.end()) { | |||||
GELOGE(GRAPH_FAILED, "Can not find node %s in graph %s.", remove_node->GetName().c_str(), | |||||
compute_graph->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
// Find all subgraph of this node | |||||
const auto &root_graph = GraphUtils::FindRootGraph(compute_graph); | |||||
std::vector<ComputeGraphPtr> subgraphs; | |||||
std::vector<NodePtr> all_nodes; | |||||
std::deque<NodePtr> candidates; | |||||
NodePtr remove_node_new = remove_node; | |||||
candidates.emplace_back(remove_node_new); | |||||
while (!candidates.empty()) { | |||||
const NodePtr node = candidates.front(); | |||||
all_nodes.emplace_back(node); | |||||
candidates.pop_front(); | |||||
OpDescPtr op_desc = node->GetOpDesc(); | |||||
if (op_desc == nullptr) { | |||||
continue; | |||||
} | |||||
const auto &subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||||
for (auto name_iter = subgraph_names.rbegin(); name_iter != subgraph_names.rend(); ++name_iter) { | |||||
auto subgraph = root_graph->GetSubgraph(*name_iter); | |||||
if (subgraph != nullptr) { | |||||
subgraphs.emplace_back(subgraph); | |||||
candidates.insert(candidates.begin(), subgraph->nodes_.begin(), subgraph->nodes_.end()); | |||||
} | |||||
} | |||||
} | |||||
// Remove all subgraph | |||||
for (const auto &remove_graph : subgraphs) { | |||||
if (root_graph->RemoveSubGraph(remove_graph) != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Remove subgraph failed, sub graph name is %s, compute graph is %s.", | |||||
remove_node->GetName().c_str(), compute_graph->GetName().c_str()); | |||||
return GRAPH_FAILED; | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus | ||||
GraphUtils::RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node) { | GraphUtils::RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const NodePtr &node) { | ||||
GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
@@ -217,12 +270,10 @@ GraphUtils::RemoveNodeWithoutRelink(const ComputeGraphPtr &compute_graph, const | |||||
(void)compute_graph->RemoveOutputNode(node); | (void)compute_graph->RemoveOutputNode(node); | ||||
// If the node has sub-graphs, delete them | // If the node has sub-graphs, delete them | ||||
auto sub_graph_names = node->GetOpDesc()->GetSubgraphInstanceNames(); | |||||
if (!sub_graph_names.empty()) { | |||||
auto root_graph = FindRootGraph(compute_graph); | |||||
for (const auto &name : sub_graph_names) { | |||||
root_graph->RemoveSubgraph(name); | |||||
} | |||||
auto ret = RemoveSubgraphRecursively(compute_graph, node); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGE(GRAPH_FAILED, "Remove subgraph recursively failed."); | |||||
return GRAPH_FAILED; | |||||
} | } | ||||
auto iter = find(compute_graph->nodes_.begin(), compute_graph->nodes_.end(), node); | auto iter = find(compute_graph->nodes_.begin(), compute_graph->nodes_.end(), node); | ||||
@@ -484,9 +535,10 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY bool GraphUtils::MatchDumpStr(con | |||||
return false; | return false; | ||||
} | } | ||||
if (dump_graph_level == kDumpLevel2 && ((suffix.find(kDumpStrPartition) != std::string::npos) || | |||||
(suffix.find(kDumpStrOptimizeSubgraph) != std::string::npos) || | |||||
(suffix.find(kDumpStrAicpu) != std::string::npos))) { | |||||
if (dump_graph_level == kDumpLevel2 && | |||||
((suffix.find(kDumpStrPartition) != std::string::npos) || | |||||
(suffix.find(kDumpStrOptimizeSubgraph) != std::string::npos) || | |||||
(suffix.find(kDumpStrAicpu) != std::string::npos) || (suffix.find(kDumpStrSubgraphFunc) != std::string::npos))) { | |||||
return true; | return true; | ||||
} | } | ||||
@@ -1026,9 +1078,9 @@ graphStatus ReplaceControlAnchors(const NodePtr &new_node, const NodePtr &old_no | |||||
GE_CHECK_NOTNULL(old_out_control_anchor); | GE_CHECK_NOTNULL(old_out_control_anchor); | ||||
auto peer_in_anchors = old_out_control_anchor->GetPeerAnchors(); | auto peer_in_anchors = old_out_control_anchor->GetPeerAnchors(); | ||||
auto new_out_control_anchor = new_node->GetOutControlAnchor(); | auto new_out_control_anchor = new_node->GetOutControlAnchor(); | ||||
GE_CHECK_NOTNULL(new_out_control_anchor); | |||||
auto exists_in_anchors = new_out_control_anchor->GetPeerAnchors(); | auto exists_in_anchors = new_out_control_anchor->GetPeerAnchors(); | ||||
auto exists_in_anchors_set = std::set<AnchorPtr>(exists_in_anchors.begin(), exists_in_anchors.end()); | auto exists_in_anchors_set = std::set<AnchorPtr>(exists_in_anchors.begin(), exists_in_anchors.end()); | ||||
GE_CHECK_NOTNULL(new_out_control_anchor); | |||||
for (const auto &peer_in_anchor : peer_in_anchors) { | for (const auto &peer_in_anchor : peer_in_anchors) { | ||||
if (peer_in_anchor != nullptr) { | if (peer_in_anchor != nullptr) { | ||||
if (exists_in_anchors_set.count(peer_in_anchor) > 0) { | if (exists_in_anchors_set.count(peer_in_anchor) > 0) { | ||||
@@ -1304,6 +1356,26 @@ graphStatus GraphUtils::GetRefMapping(const ComputeGraphPtr &graph, | |||||
return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
} | } | ||||
GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY NodePtr GraphUtils::FindNodeFromAllNodes(ComputeGraphPtr &graph, | |||||
const std::string &name) { | |||||
auto root_graph = FindRootGraph(graph); | |||||
if (root_graph == nullptr) { | |||||
GE_LOGE("Failed find node %s, null root graph", name.c_str()); | |||||
return nullptr; | |||||
} | |||||
for (const auto &node : root_graph->GetAllNodes()) { | |||||
if (node == nullptr) { | |||||
continue; | |||||
} | |||||
if (node->GetName() == name) { | |||||
return node; | |||||
} | |||||
} | |||||
return nullptr; | |||||
} | |||||
/// | /// | ||||
/// Get reference-mapping for in_data_anchors of node | /// Get reference-mapping for in_data_anchors of node | ||||
/// @param [in] node | /// @param [in] node | ||||
@@ -1668,7 +1740,7 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t | |||||
for (const auto &input_name : op_desc->GetAllInputNames()) { | for (const auto &input_name : op_desc->GetAllInputNames()) { | ||||
if (!input_name.empty() && (output_name == input_name)) { | if (!input_name.empty() && (output_name == input_name)) { | ||||
reuse_in_index = op_desc->GetInputIndexByName(input_name); | reuse_in_index = op_desc->GetInputIndexByName(input_name); | ||||
GELOGI("Reference name[%s] output[%s][%u] ref to input[%s][%d].", op_desc->GetName().c_str(), | |||||
GELOGI("Reference name[%s] output[%s][%d] ref to input[%s][%d].", op_desc->GetName().c_str(), | |||||
output_name.c_str(), output_index, input_name.c_str(), reuse_in_index); | output_name.c_str(), output_index, input_name.c_str(), reuse_in_index); | ||||
return true; | return true; | ||||
} | } | ||||
@@ -1693,6 +1765,43 @@ bool GraphUtils::IsRefFromInput(const OutDataAnchorPtr &out_data_anchor, int32_t | |||||
return false; | return false; | ||||
} | } | ||||
/// | |||||
/// Determine if the graph is a UNKNOWN_SHAPE graph based on whether the graph and all subgraphs | |||||
/// of the graph have UNKNOWN_SHAPE operators or not. | |||||
/// Note: This function will only look 'down' from the graph, not 'up'. For example, the following | |||||
/// scenario (K for known shape, U for unknown shape), ROOT graph is UNKNOWN_SHAPE while SUB graph is KNOWN_SHAPE | |||||
/// ROOT graph: A -----> B -----> C | |||||
/// K subgraph U | |||||
/// | | |||||
/// V | |||||
/// SUB graph: D --> E --> F | |||||
/// K K K | |||||
/// @param [in] graph | |||||
/// @return bool | |||||
/// | |||||
bool GraphUtils::IsUnknownShapeGraph(const ComputeGraphPtr &graph) { | |||||
if (graph == nullptr) { | |||||
GELOGW("Input graph is nullptr."); | |||||
return false; | |||||
} | |||||
for (const auto &node : graph->GetDirectNode()) { | |||||
bool is_unknown = false; | |||||
auto ret = NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown); | |||||
if (ret != GRAPH_SUCCESS) { | |||||
GELOGW("Get node unknown status failed, node name:%s, type:%s.", node->GetName().c_str(), | |||||
node->GetType().c_str()); | |||||
continue; | |||||
} | |||||
if (is_unknown) { | |||||
GELOGD("Node %s, type %s is unknown shape in graph %s.", node->GetName().c_str(), node->GetType().c_str(), | |||||
graph->GetName().c_str()); | |||||
return true; | |||||
} | |||||
} | |||||
GELOGD("Graph %s does not have unknown shape node.", graph->GetName().c_str()); | |||||
return false; | |||||
} | |||||
/// | /// | ||||
/// @brief Add node to graph | /// @brief Add node to graph | ||||
/// @param [in] op_desc | /// @param [in] op_desc | ||||
@@ -1868,6 +1977,17 @@ NodePtr ComputeGraphBuilder::GetNode(const std::string &name) { | |||||
return iter->second; | return iter->second; | ||||
} | } | ||||
/// @brief Get all nodes | |||||
/// @return std::vector<NodePtr> | |||||
/// | |||||
std::vector<NodePtr> ComputeGraphBuilder::GetAllNodes() { | |||||
std::vector<NodePtr> nodes; | |||||
for (const auto &iter : node_names_) { | |||||
nodes.emplace_back(iter.second); | |||||
} | |||||
return nodes; | |||||
} | |||||
/// | /// | ||||
/// @brief Add node to graph | /// @brief Add node to graph | ||||
/// @param [in] op_desc | /// @param [in] op_desc | ||||
@@ -1937,6 +2057,16 @@ CompleteGraphBuilder &CompleteGraphBuilder::AddOutput(const std::string &owner_n | |||||
return *this; | return *this; | ||||
} | } | ||||
/// | |||||
/// @brief Add target for graph | |||||
/// @param [in] target_name | |||||
/// @return CompleteGraphBuilder | |||||
/// | |||||
CompleteGraphBuilder &CompleteGraphBuilder::AddTarget(const std::string &target_name) { | |||||
graph_targets_.emplace_back(target_name); | |||||
return *this; | |||||
} | |||||
/// | /// | ||||
/// @brief Set parent-node of graph | /// @brief Set parent-node of graph | ||||
/// @param [in] parent_node | /// @param [in] parent_node | ||||
@@ -2013,6 +2143,11 @@ ComputeGraphPtr CompleteGraphBuilder::Build(graphStatus &error_code, std::string | |||||
return nullptr; | return nullptr; | ||||
} | } | ||||
BuildGraphTargets(error_code, error_msg); | |||||
if (error_code != GRAPH_SUCCESS) { | |||||
return nullptr; | |||||
} | |||||
// ATTR_NAME_SESSION_GRAPH_ID | // ATTR_NAME_SESSION_GRAPH_ID | ||||
std::string graph_id; | std::string graph_id; | ||||
if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { | if (!AttrUtils::GetStr(parent_node_->GetOwnerComputeGraph(), ATTR_NAME_SESSION_GRAPH_ID, graph_id)) { | ||||
@@ -2210,6 +2345,27 @@ void CompleteGraphBuilder::AddRetValNodes(graphStatus &error_code, std::string & | |||||
GELOGD("AddRetValNodes succ."); | GELOGD("AddRetValNodes succ."); | ||||
} | } | ||||
/// | |||||
/// @brief Build target-nodes for graph | |||||
/// @param [out] error_code | |||||
/// @param [out] error_msg | |||||
/// @return void | |||||
/// | |||||
void CompleteGraphBuilder::BuildGraphTargets(graphStatus &error_code, std::string &error_msg) { | |||||
std::vector<NodePtr> target_nodes; | |||||
for (const std::string &target_name : graph_targets_) { | |||||
auto target_iter = node_names_.find(target_name); | |||||
if ((target_iter == node_names_.end()) || (target_iter->second == nullptr)) { | |||||
error_code = GRAPH_FAILED; | |||||
error_msg = "BuildGraphTargets failed: target_node " + target_name + " not exist in graph."; | |||||
return; | |||||
} | |||||
target_nodes.emplace_back(target_iter->second); | |||||
} | |||||
owner_graph_->SetGraphTargetNodesInfo(target_nodes); | |||||
return; | |||||
} | |||||
/// | /// | ||||
/// @brief Add node to graph | /// @brief Add node to graph | ||||
/// @param [in] op_desc | /// @param [in] op_desc | ||||
@@ -29,6 +29,13 @@ namespace ge { | |||||
std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{}; | std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_send_info_{}; | ||||
std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{}; | std::map<NodePtr, std::vector<uint32_t>> NodeUtils::map_recv_info_{}; | ||||
const std::set<std::string> kConstOpTypes = {"Const", "Constant"}; | |||||
const std::set<std::string> kIfOpTypes = {"If", "_If", "StatelessIf"}; | |||||
const std::set<std::string> kWhileOpTypes = {"While", "_While", "StatelessWhile"}; | |||||
const std::set<std::string> kCaseOpTypes = {"Case"}; | |||||
const std::set<std::string> kForOpTypes = {"For"}; | |||||
bool OpShapeIsUnknown(const OpDescPtr &desc) { | bool OpShapeIsUnknown(const OpDescPtr &desc) { | ||||
for (const auto &ptr : desc->GetAllInputsDescPtr()) { | for (const auto &ptr : desc->GetAllInputsDescPtr()) { | ||||
auto ge_shape = ptr->GetShape(); | auto ge_shape = ptr->GetShape(); | ||||
@@ -315,6 +322,9 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY graphStatus NodeUtils::UpdatePeer | |||||
peer_input_desc->SetOriginShape(output_tensor.GetOriginShape()); | peer_input_desc->SetOriginShape(output_tensor.GetOriginShape()); | ||||
peer_input_desc->SetDataType(output_tensor.GetDataType()); | peer_input_desc->SetDataType(output_tensor.GetDataType()); | ||||
peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType()); | peer_input_desc->SetOriginDataType(output_tensor.GetOriginDataType()); | ||||
std::vector<std::pair<int64_t, int64_t>> shape_range; | |||||
(void)output_tensor.GetShapeRange(shape_range); | |||||
peer_input_desc->SetShapeRange(shape_range); | |||||
ge::TensorUtils::SetRealDimCnt(*peer_input_desc, | ge::TensorUtils::SetRealDimCnt(*peer_input_desc, | ||||
static_cast<uint32_t>(output_tensor.GetShape().GetDims().size())); | static_cast<uint32_t>(output_tensor.GetShape().GetDims().size())); | ||||
GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", | GELOGI("Peer input opdesc name is %s, shape size is %zu, datatype is %d, original datatype is %d", | ||||
@@ -477,6 +487,14 @@ bool NodeUtils::IsSubgraphInput(const NodePtr &node) { | |||||
return false; | return false; | ||||
} | } | ||||
auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); | |||||
if (parent_op_desc == nullptr) { | |||||
return false; | |||||
} | |||||
if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { | |||||
return false; | |||||
} | |||||
return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); | return node->GetOpDesc()->HasAttr(ATTR_NAME_PARENT_NODE_INDEX); | ||||
} | } | ||||
@@ -491,6 +509,14 @@ bool NodeUtils::IsSubgraphOutput(const NodePtr &node) { | |||||
return false; | return false; | ||||
} | } | ||||
auto parent_op_desc = node->GetOwnerComputeGraph()->GetParentNode()->GetOpDesc(); | |||||
if (parent_op_desc == nullptr) { | |||||
return false; | |||||
} | |||||
if (AttrUtils::HasAttr(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE)) { | |||||
return false; | |||||
} | |||||
for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) { | for (GeTensorDesc &tensor : node->GetOpDesc()->GetAllInputsDesc()) { | ||||
if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) { | if (AttrUtils::HasAttr(tensor, ATTR_NAME_PARENT_NODE_INDEX)) { | ||||
return true; | return true; | ||||
@@ -557,4 +583,58 @@ bool NodeUtils::GetConstOpType(const NodePtr &in_node, std::string &op_type) { | |||||
return false; | return false; | ||||
} | } | ||||
/// | |||||
/// @brief Remove node-related subgraphs, including subgraphs of nodes in the subgraph. | |||||
/// @param [in] node | |||||
/// @return return GRAPH_SUCCESS if remove successfully, other for failed. | |||||
/// | |||||
Status NodeUtils::RemoveSubgraphsOnNode(const NodePtr &node) { | |||||
GE_CHECK_NOTNULL(node); | |||||
auto op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
auto subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||||
if (subgraph_names.empty()) { | |||||
return GRAPH_SUCCESS; | |||||
} else { | |||||
auto owner_graph = node->GetOwnerComputeGraph(); | |||||
GE_CHECK_NOTNULL(owner_graph); | |||||
auto root_graph = GraphUtils::FindRootGraph(owner_graph); | |||||
GE_CHECK_NOTNULL(root_graph); | |||||
std::unordered_set<std::string> subgraph_to_remove; | |||||
for (auto &subgraph_name : subgraph_names) { | |||||
std::deque<std::string> queue; | |||||
queue.push_back(subgraph_name); | |||||
subgraph_to_remove.insert(subgraph_name); | |||||
op_desc->RemoveSubgraphInstanceName(subgraph_name); | |||||
while (!queue.empty()) { | |||||
auto graph_name = queue.front(); | |||||
queue.pop_front(); | |||||
auto subgraph = root_graph->GetSubgraph(graph_name); | |||||
GE_CHECK_NOTNULL(subgraph); | |||||
for (const auto &sub_node : subgraph->GetDirectNode()) { | |||||
auto sub_op_desc = sub_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(sub_op_desc); | |||||
auto sub_names = sub_op_desc->GetSubgraphInstanceNames(); | |||||
// Subgraph and all nodes in it will be removed later, | |||||
// no need to remove 'SubgraphInstanceName' in op desc here. | |||||
for (auto &name : sub_names) { | |||||
if (subgraph_to_remove.insert(name).second) { | |||||
queue.push_back(name); | |||||
} | |||||
} | |||||
} | |||||
} | |||||
} | |||||
// Remove subgraph from root_graph | |||||
for (const auto &name : subgraph_to_remove) { | |||||
GELOGI("Remove subgraph:%s.", name.c_str()); | |||||
root_graph->RemoveSubgraph(name); | |||||
} | |||||
} | |||||
return GRAPH_SUCCESS; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -199,6 +199,23 @@ GE_FUNC_DEV_VISIBILITY GE_FUNC_HOST_VISIBILITY vector<ge::NodePtr> OpDescUtils:: | |||||
auto in_node = out_anchor->GetOwnerNode(); | auto in_node = out_anchor->GetOwnerNode(); | ||||
if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { | if ((in_node->GetType() == CONSTANT) || (in_node->GetType() == CONSTANTOP)) { | ||||
ret.push_back(in_node); | ret.push_back(in_node); | ||||
} else if (in_node->GetType() == DATA) { | |||||
const ComputeGraphPtr &graph = node.GetOwnerComputeGraph(); | |||||
GE_CHK_BOOL_EXEC(graph != nullptr, continue, "Owner graph is null"); | |||||
const NodePtr &parent_node = graph->GetParentNode(); | |||||
if (parent_node == nullptr) { | |||||
continue; // Root graph. | |||||
} | |||||
if (kWhileOpTypes.count(parent_node->GetType()) > 0) { | |||||
continue; // Subgraph of While cond or body. | |||||
} | |||||
NodePtr input_node = NodeUtils::GetParentInput(in_node); | |||||
if ((input_node != nullptr) && ((input_node->GetType() == CONSTANT) || (input_node->GetType() == CONSTANTOP))) { | |||||
ret.push_back(input_node); | |||||
} | |||||
} | } | ||||
} | } | ||||
return ret; | return ret; | ||||
@@ -17,6 +17,8 @@ | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "debug/ge_util.h" | #include "debug/ge_util.h" | ||||
using domi::domiTensorFormat_t; | |||||
namespace ge { | namespace ge { | ||||
static const std::map<Format, std::string> kFormatToStringMap = { | static const std::map<Format, std::string> kFormatToStringMap = { | ||||
{FORMAT_NCHW, "NCHW"}, | {FORMAT_NCHW, "NCHW"}, | ||||
@@ -60,6 +62,25 @@ static const std::map<Format, std::string> kFormatToStringMap = { | |||||
{FORMAT_RESERVED, "FORMAT_RESERVED"}, | {FORMAT_RESERVED, "FORMAT_RESERVED"}, | ||||
{FORMAT_ALL, "ALL"}}; | {FORMAT_ALL, "ALL"}}; | ||||
static const std::map<domiTensorFormat_t, Format> kDomiFormatToGeFormat = { | |||||
{domi::DOMI_TENSOR_NCHW, FORMAT_NCHW}, | |||||
{domi::DOMI_TENSOR_NHWC, FORMAT_NHWC}, | |||||
{domi::DOMI_TENSOR_ND, FORMAT_ND}, | |||||
{domi::DOMI_TENSOR_NC1HWC0, FORMAT_NC1HWC0}, | |||||
{domi::DOMI_TENSOR_FRACTAL_Z, FORMAT_FRACTAL_Z}, | |||||
{domi::DOMI_TENSOR_NC1C0HWPAD, FORMAT_NC1C0HWPAD}, | |||||
{domi::DOMI_TENSOR_NHWC1C0, FORMAT_NHWC1C0}, | |||||
{domi::DOMI_TENSOR_FSR_NCHW, FORMAT_FSR_NCHW}, | |||||
{domi::DOMI_TENSOR_FRACTAL_DECONV, FORMAT_FRACTAL_DECONV}, | |||||
{domi::DOMI_TENSOR_BN_WEIGHT, FORMAT_BN_WEIGHT}, | |||||
{domi::DOMI_TENSOR_CHWN, FORMAT_CHWN}, | |||||
{domi::DOMI_TENSOR_FILTER_HWCK, FORMAT_FILTER_HWCK}, | |||||
{domi::DOMI_TENSOR_NDHWC, FORMAT_NDHWC}, | |||||
{domi::DOMI_TENSOR_NCDHW, FORMAT_NCDHW}, | |||||
{domi::DOMI_TENSOR_DHWCN, FORMAT_DHWCN}, | |||||
{domi::DOMI_TENSOR_DHWNC, FORMAT_DHWNC}, | |||||
{domi::DOMI_TENSOR_RESERVED, FORMAT_RESERVED}}; | |||||
static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | static const std::unordered_set<std::string> kInternalFormat = {"NC1HWC0", | ||||
"FRACTAL_Z", | "FRACTAL_Z", | ||||
"NC1C0HWPAD", | "NC1C0HWPAD", | ||||
@@ -282,6 +303,15 @@ Format TypeUtils::DataFormatToFormat(const std::string &str) { | |||||
} | } | ||||
} | } | ||||
Format TypeUtils::DomiFormatToFormat(domi::domiTensorFormat_t domi_format) { | |||||
auto it = kDomiFormatToGeFormat.find(domi_format); | |||||
if (it != kDomiFormatToGeFormat.end()) { | |||||
return it->second; | |||||
} | |||||
GELOGE(GRAPH_FAILED, "do not find domi Format %d from map", domi_format); | |||||
return FORMAT_RESERVED; | |||||
} | |||||
static inline void CopyDataFromBuffer(vector<uint8_t> &data, const Buffer &buffer) { | static inline void CopyDataFromBuffer(vector<uint8_t> &data, const Buffer &buffer) { | ||||
data.clear(); | data.clear(); | ||||
if (buffer.GetData() != nullptr && buffer.GetSize() != 0) { | if (buffer.GetData() != nullptr && buffer.GetSize() != 0) { | ||||
@@ -64,6 +64,7 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"common/helper/model_cache_helper.cc" | "common/helper/model_cache_helper.cc" | ||||
"common/profiling/profiling_manager.cc" | "common/profiling/profiling_manager.cc" | ||||
"engine_manager/dnnengine_manager.cc" | "engine_manager/dnnengine_manager.cc" | ||||
"executor/ge_executor.cc" | |||||
"ge_local_engine/engine/host_cpu_engine.cc" | "ge_local_engine/engine/host_cpu_engine.cc" | ||||
"generator/ge_generator.cc" | "generator/ge_generator.cc" | ||||
"generator/generator_api.cc" | "generator/generator_api.cc" | ||||
@@ -107,47 +108,61 @@ file(GLOB TRAIN_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/partition/engine_place.cc" | "graph/partition/engine_place.cc" | ||||
"graph/partition/graph_partition.cc" | "graph/partition/graph_partition.cc" | ||||
"graph/passes/*.cc" | "graph/passes/*.cc" | ||||
"graph/passes/folding_kernel/add_kernel.cc" | |||||
"graph/passes/folding_kernel/broadcast_args_kernel.cc" | |||||
"graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc" | |||||
"graph/passes/folding_kernel/cast_kernel.cc" | |||||
"graph/passes/folding_kernel/concat_offset_kernel.cc" | |||||
"graph/passes/folding_kernel/concat_v2_kernel.cc" | |||||
"graph/passes/folding_kernel/dynamic_stitch_kernel.cc" | |||||
"graph/passes/folding_kernel/empty_kernel.cc" | |||||
"graph/passes/folding_kernel/expanddims_kernel.cc" | |||||
"graph/passes/folding_kernel/fill_kernel.cc" | |||||
"graph/passes/folding_kernel/floordiv_kernel.cc" | |||||
"graph/passes/folding_kernel/floormod_kernel.cc" | |||||
"graph/passes/folding_kernel/gather_v2_kernel.cc" | |||||
"graph/passes/folding_kernel/greater_kernel.cc" | |||||
"graph/passes/folding_kernel/kernel_utils.cc" | |||||
"graph/passes/folding_kernel/maximum_kernel.cc" | |||||
"graph/passes/folding_kernel/mul_kernel.cc" | |||||
"graph/passes/folding_kernel/pack_kernel.cc" | |||||
"graph/passes/folding_kernel/permute_kernel.cc" | |||||
"graph/passes/folding_kernel/range_kernel.cc" | |||||
"graph/passes/folding_kernel/rank_kernel.cc" | |||||
"graph/passes/folding_kernel/reduce_prod_kernel.cc" | |||||
"graph/passes/folding_kernel/reshape_kernel.cc" | |||||
"graph/passes/folding_kernel/rsqrt_kernel.cc" | |||||
"graph/passes/folding_kernel/shape_kernel.cc" | |||||
"graph/passes/folding_kernel/shape_n_kernel.cc" | |||||
"graph/passes/folding_kernel/size_kernel.cc" | |||||
"graph/passes/folding_kernel/slice_d_kernel.cc" | |||||
"graph/passes/folding_kernel/slice_kernel.cc" | |||||
"graph/passes/folding_kernel/squeeze_kernel.cc" | |||||
"graph/passes/folding_kernel/ssd_prior_box_kernel.cc" | |||||
"graph/passes/folding_kernel/strided_slice_kernel.cc" | |||||
"graph/passes/folding_kernel/sub_kernel.cc" | |||||
"graph/passes/folding_kernel/transdata_kernel.cc" | |||||
"graph/passes/folding_kernel/unpack_kernel.cc" | |||||
"host_kernels/add_kernel.cc" | |||||
"host_kernels/broadcast_args_kernel.cc" | |||||
"host_kernels/broadcast_gradient_args_kernel.cc" | |||||
"host_kernels/cast_kernel.cc" | |||||
"host_kernels/concat_offset_kernel.cc" | |||||
"host_kernels/concat_v2_kernel.cc" | |||||
"host_kernels/dynamic_stitch_kernel.cc" | |||||
"host_kernels/empty_kernel.cc" | |||||
"host_kernels/expanddims_kernel.cc" | |||||
"host_kernels/fill_kernel.cc" | |||||
"host_kernels/floordiv_kernel.cc" | |||||
"host_kernels/floormod_kernel.cc" | |||||
"host_kernels/gather_v2_kernel.cc" | |||||
"host_kernels/greater_kernel.cc" | |||||
"host_kernels/kernel_utils.cc" | |||||
"host_kernels/maximum_kernel.cc" | |||||
"host_kernels/mul_kernel.cc" | |||||
"host_kernels/pack_kernel.cc" | |||||
"host_kernels/permute_kernel.cc" | |||||
"host_kernels/range_kernel.cc" | |||||
"host_kernels/rank_kernel.cc" | |||||
"host_kernels/reduce_prod_kernel.cc" | |||||
"host_kernels/reshape_kernel.cc" | |||||
"host_kernels/rsqrt_kernel.cc" | |||||
"host_kernels/shape_kernel.cc" | |||||
"host_kernels/shape_n_kernel.cc" | |||||
"host_kernels/size_kernel.cc" | |||||
"host_kernels/slice_d_kernel.cc" | |||||
"host_kernels/slice_kernel.cc" | |||||
"host_kernels/squeeze_kernel.cc" | |||||
"host_kernels/ssd_prior_box_kernel.cc" | |||||
"host_kernels/strided_slice_kernel.cc" | |||||
"host_kernels/sub_kernel.cc" | |||||
"host_kernels/transdata_kernel.cc" | |||||
"host_kernels/transpose_kernel.cc" | |||||
"host_kernels/unpack_kernel.cc" | |||||
"graph/preprocess/graph_preprocess.cc" | "graph/preprocess/graph_preprocess.cc" | ||||
"graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
"graph/preprocess/insert_op/util_insert_aipp_op.cc" | "graph/preprocess/insert_op/util_insert_aipp_op.cc" | ||||
"graph/preprocess/multi_batch_copy_graph.cc" | "graph/preprocess/multi_batch_copy_graph.cc" | ||||
"hybrid/common/npu_memory_allocator.cc" | |||||
"hybrid/common/tensor_value.cc" | |||||
"hybrid/executor/*.cc" | |||||
"hybrid/executor/worker/*.cc" | |||||
"hybrid/hybrid_davinci_model.cc" | |||||
"hybrid/model/*.cc" | |||||
"hybrid/node_executor/aicore/*.cc" | |||||
"hybrid/node_executor/aicpu/aicpu_node_executor.cc" | |||||
"hybrid/node_executor/compiledsubgraph/known_node_executor.cc" | |||||
"hybrid/node_executor/hostcpu/ge_local_node_executor.cc" | |||||
"hybrid/node_executor/node_executor.cc" | |||||
"hybrid/node_executor/task_context.cc" | |||||
"init/gelib.cc" | "init/gelib.cc" | ||||
"model/ge_model.cc" | "model/ge_model.cc" | ||||
"model/ge_root_model.cc" | |||||
"omm/csa_interact.cc" | "omm/csa_interact.cc" | ||||
"opskernel_manager/ops_kernel_manager.cc" | "opskernel_manager/ops_kernel_manager.cc" | ||||
"session/inner_session.cc" | "session/inner_session.cc" | ||||
@@ -231,42 +246,43 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"graph/partition/engine_place.cc" | "graph/partition/engine_place.cc" | ||||
"graph/partition/graph_partition.cc" | "graph/partition/graph_partition.cc" | ||||
"graph/passes/*.cc" | "graph/passes/*.cc" | ||||
"graph/passes/folding_kernel/add_kernel.cc" | |||||
"graph/passes/folding_kernel/broadcast_args_kernel.cc" | |||||
"graph/passes/folding_kernel/broadcast_gradient_args_kernel.cc" | |||||
"graph/passes/folding_kernel/cast_kernel.cc" | |||||
"graph/passes/folding_kernel/concat_offset_kernel.cc" | |||||
"graph/passes/folding_kernel/concat_v2_kernel.cc" | |||||
"graph/passes/folding_kernel/dynamic_stitch_kernel.cc" | |||||
"graph/passes/folding_kernel/empty_kernel.cc" | |||||
"graph/passes/folding_kernel/expanddims_kernel.cc" | |||||
"graph/passes/folding_kernel/fill_kernel.cc" | |||||
"graph/passes/folding_kernel/floordiv_kernel.cc" | |||||
"graph/passes/folding_kernel/floormod_kernel.cc" | |||||
"graph/passes/folding_kernel/gather_v2_kernel.cc" | |||||
"graph/passes/folding_kernel/greater_kernel.cc" | |||||
"graph/passes/folding_kernel/kernel_utils.cc" | |||||
"graph/passes/folding_kernel/maximum_kernel.cc" | |||||
"graph/passes/folding_kernel/mul_kernel.cc" | |||||
"graph/passes/folding_kernel/pack_kernel.cc" | |||||
"graph/passes/folding_kernel/permute_kernel.cc" | |||||
"graph/passes/folding_kernel/range_kernel.cc" | |||||
"graph/passes/folding_kernel/rank_kernel.cc" | |||||
"graph/passes/folding_kernel/reduce_prod_kernel.cc" | |||||
"graph/passes/folding_kernel/reshape_kernel.cc" | |||||
"graph/passes/folding_kernel/rsqrt_kernel.cc" | |||||
"graph/passes/folding_kernel/shape_kernel.cc" | |||||
"graph/passes/folding_kernel/shape_n_kernel.cc" | |||||
"graph/passes/folding_kernel/size_kernel.cc" | |||||
"graph/passes/folding_kernel/slice_d_kernel.cc" | |||||
"graph/passes/folding_kernel/slice_kernel.cc" | |||||
"graph/passes/folding_kernel/squeeze_kernel.cc" | |||||
"graph/passes/folding_kernel/ssd_prior_box_kernel.cc" | |||||
"graph/passes/folding_kernel/strided_slice_kernel.cc" | |||||
"graph/passes/folding_kernel/sub_kernel.cc" | |||||
"graph/passes/folding_kernel/transdata_kernel.cc" | |||||
"graph/passes/folding_kernel/transpose_kernel.cc" | |||||
"graph/passes/folding_kernel/unpack_kernel.cc" | |||||
"host_kernels/add_kernel.cc" | |||||
"host_kernels/broadcast_args_kernel.cc" | |||||
"host_kernels/broadcast_gradient_args_kernel.cc" | |||||
"host_kernels/cast_kernel.cc" | |||||
"host_kernels/concat_offset_kernel.cc" | |||||
"host_kernels/concat_v2_kernel.cc" | |||||
"host_kernels/dynamic_stitch_kernel.cc" | |||||
"host_kernels/empty_kernel.cc" | |||||
"host_kernels/expanddims_kernel.cc" | |||||
"host_kernels/fill_kernel.cc" | |||||
"host_kernels/floordiv_kernel.cc" | |||||
"host_kernels/floormod_kernel.cc" | |||||
"host_kernels/gather_v2_kernel.cc" | |||||
"host_kernels/greater_kernel.cc" | |||||
"host_kernels/kernel_utils.cc" | |||||
"host_kernels/maximum_kernel.cc" | |||||
"host_kernels/mul_kernel.cc" | |||||
"host_kernels/pack_kernel.cc" | |||||
"host_kernels/permute_kernel.cc" | |||||
"host_kernels/range_kernel.cc" | |||||
"host_kernels/rank_kernel.cc" | |||||
"host_kernels/reduce_prod_kernel.cc" | |||||
"host_kernels/reshape_kernel.cc" | |||||
"host_kernels/rsqrt_kernel.cc" | |||||
"host_kernels/shape_kernel.cc" | |||||
"host_kernels/shape_n_kernel.cc" | |||||
"host_kernels/size_kernel.cc" | |||||
"host_kernels/slice_d_kernel.cc" | |||||
"host_kernels/slice_kernel.cc" | |||||
"host_kernels/squeeze_kernel.cc" | |||||
"host_kernels/ssd_prior_box_kernel.cc" | |||||
"host_kernels/strided_slice_kernel.cc" | |||||
"host_kernels/sub_kernel.cc" | |||||
"host_kernels/transdata_kernel.cc" | |||||
"host_kernels/transpose_kernel.cc" | |||||
"host_kernels/unpack_kernel.cc" | |||||
"hybrid/hybrid_davinci_model_stub.cc" | |||||
"graph/preprocess/graph_preprocess.cc" | "graph/preprocess/graph_preprocess.cc" | ||||
"graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
"graph/preprocess/insert_op/util_insert_aipp_op.cc" | "graph/preprocess/insert_op/util_insert_aipp_op.cc" | ||||
@@ -275,6 +291,7 @@ file(GLOB INFER_SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"ir_build/atc_ir_common.cc" | "ir_build/atc_ir_common.cc" | ||||
"ir_build/ge_ir_build.cc" | "ir_build/ge_ir_build.cc" | ||||
"model/ge_model.cc" | "model/ge_model.cc" | ||||
"model/ge_root_model.cc" | |||||
"omm/csa_interact.cc" | "omm/csa_interact.cc" | ||||
"opskernel_manager/ops_kernel_manager.cc" | "opskernel_manager/ops_kernel_manager.cc" | ||||
"session/inner_session.cc" | "session/inner_session.cc" | ||||
@@ -49,7 +49,7 @@ void GetOpsProtoPath(std::string &opsproto_path) { | |||||
const char *path_env = std::getenv("ASCEND_OPP_PATH"); | const char *path_env = std::getenv("ASCEND_OPP_PATH"); | ||||
if (path_env != nullptr) { | if (path_env != nullptr) { | ||||
std::string path = path_env; | std::string path = path_env; | ||||
opsproto_path = (path + "/op_proto/built-in/" + ":") + (path + "/op_proto/custom/"); | |||||
opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/"); | |||||
GELOGI("Get opsproto so path from env: %s", path.c_str()); | GELOGI("Get opsproto so path from env: %s", path.c_str()); | ||||
return; | return; | ||||
} | } | ||||
@@ -57,7 +57,7 @@ void GetOpsProtoPath(std::string &opsproto_path) { | |||||
GELOGI("path_base is %s", path_base.c_str()); | GELOGI("path_base is %s", path_base.c_str()); | ||||
path_base = path_base.substr(0, path_base.rfind('/')); | path_base = path_base.substr(0, path_base.rfind('/')); | ||||
path_base = path_base.substr(0, path_base.rfind('/') + 1); | path_base = path_base.substr(0, path_base.rfind('/') + 1); | ||||
opsproto_path = (path_base + "ops/op_proto/built-in/" + ":") + (path_base + "ops/op_proto/custom/"); | |||||
opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); | |||||
} | } | ||||
Status CheckDumpAndReuseMemory(const std::map<string, string> &options) { | Status CheckDumpAndReuseMemory(const std::map<string, string> &options) { | ||||
@@ -103,20 +103,6 @@ Status CheckOptionsValid(const std::map<string, string> &options) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void SaveDdkVersion(const std::map<string, string> &options) { | |||||
auto ddk_option = options.find(DDK_VERSION_FLAG); | |||||
if (ddk_option != options.end()) { | |||||
auto ddk_version = ddk_option->second; | |||||
if (!ddk_version.empty()) { | |||||
GELOGI("Input ddk version : %s.", ddk_version.c_str()); | |||||
domi::GetContext().ddk_version = ddk_version; | |||||
} | |||||
} else { | |||||
GELOGW("No ddkVersion!"); | |||||
return; | |||||
} | |||||
} | |||||
// Initialize GE, prepare for execution, call GELib::Initialize | // Initialize GE, prepare for execution, call GELib::Initialize | ||||
Status GEInitialize(const std::map<string, string> &options) { | Status GEInitialize(const std::map<string, string> &options) { | ||||
GELOGT(TRACE_INIT, "GEInitialize start"); | GELOGT(TRACE_INIT, "GEInitialize start"); | ||||
@@ -146,9 +132,6 @@ Status GEInitialize(const std::map<string, string> &options) { | |||||
} | } | ||||
GE_TIMESTAMP_END(CheckOptionsValid, "GEInitialize::CheckOptionsValid"); | GE_TIMESTAMP_END(CheckOptionsValid, "GEInitialize::CheckOptionsValid"); | ||||
GE_TIMESTAMP_START(InitPreparation); | |||||
SaveDdkVersion(options); | |||||
GE_TIMESTAMP_END(InitPreparation, "GEInitialize::InitPreparation"); | |||||
// call Initialize | // call Initialize | ||||
GELOGT(TRACE_RUNNING, "Initializing environment"); | GELOGT(TRACE_RUNNING, "Initializing environment"); | ||||
GE_TIMESTAMP_START(GELibInitialize); | GE_TIMESTAMP_START(GELibInitialize); | ||||
@@ -22,6 +22,7 @@ | |||||
#include <string> | #include <string> | ||||
#include "securec.h" | #include "securec.h" | ||||
#include "framework/common/fmk_types.h" | #include "framework/common/fmk_types.h" | ||||
#include "framework/common/debug/ge_log.h" | |||||
using std::set; | using std::set; | ||||
using std::string; | using std::string; | ||||
@@ -146,7 +147,10 @@ string Pb2Json::TypeBytes2String(string &field_name, string &type_bytes) { | |||||
uint8_t *value = 0; | uint8_t *value = 0; | ||||
value = reinterpret_cast<uint8_t *>(&temp_value); | value = reinterpret_cast<uint8_t *>(&temp_value); | ||||
char str[kSignificantDigits]; | char str[kSignificantDigits]; | ||||
sprintf_s(str, kSignificantDigits, "%d", *value); | |||||
if (sprintf_s(str, kSignificantDigits, "%d", *value) == -1) { | |||||
GELOGW("Convert bytes to string fail, filed name:%s", field_name.c_str()); | |||||
continue; | |||||
} | |||||
result += str; | result += str; | ||||
} | } | ||||
return result; | return result; | ||||
@@ -21,6 +21,7 @@ | |||||
#include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/debug/log.h" | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
@@ -199,6 +200,23 @@ Status TransposeWithShapeCheck(const uint8_t *data, const std::vector<int64_t> & | |||||
return Transpose(data, src_shape, src_data_type, perm_arg, result); | return Transpose(data, src_shape, src_data_type, perm_arg, result); | ||||
} | } | ||||
Status GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> &perm) { | |||||
auto dst_iter = perm_args.find(src_format); | |||||
if (dst_iter == perm_args.end()) { | |||||
GELOGE(UNSUPPORTED, "Failed to trans shape, do not support transpose from format %s to %s", | |||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | |||||
return UNSUPPORTED; | |||||
} | |||||
auto iter = dst_iter->second.find(dst_format); | |||||
if (iter == dst_iter->second.end()) { | |||||
GELOGE(UNSUPPORTED, "Failed to trans shape, do not support transpose from format %s to %s", | |||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | |||||
return UNSUPPORTED; | |||||
} | |||||
perm = iter->second; | |||||
return SUCCESS; | |||||
} | |||||
Status FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult &result) { | Status FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult &result) { | ||||
std::vector<int64_t> expected_shape; | std::vector<int64_t> expected_shape; | ||||
auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expected_shape); | auto ret = TransShape(args.src_format, args.src_shape, args.src_data_type, args.dst_format, expected_shape); | ||||
@@ -218,23 +236,12 @@ Status FormatTransferTranspose::TransFormat(const TransArgs &args, TransResult & | |||||
Status FormatTransferTranspose::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | Status FormatTransferTranspose::TransShape(Format src_format, const std::vector<int64_t> &src_shape, DataType data_type, | ||||
Format dst_format, std::vector<int64_t> &dst_shape) { | Format dst_format, std::vector<int64_t> &dst_shape) { | ||||
auto dst_iter = perm_args.find(src_format); | |||||
if (dst_iter == perm_args.end()) { | |||||
GELOGE(UNSUPPORTED, "Failed to trans shape, do not support transpose from format %s to %s", | |||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | |||||
return UNSUPPORTED; | |||||
} | |||||
auto iter = dst_iter->second.find(dst_format); | |||||
if (iter == dst_iter->second.end()) { | |||||
GELOGE(UNSUPPORTED, "Failed to trans shape, do not support transpose from format %s to %s", | |||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str()); | |||||
return UNSUPPORTED; | |||||
} | |||||
if (!IsShapeArgValid(src_shape, iter->second)) { | |||||
std::vector<int64_t> perm_arg; | |||||
GE_CHK_STATUS_RET_NOLOG(GetPermByForamt(src_format, dst_format, perm_arg)); | |||||
if (!IsShapeArgValid(src_shape, perm_arg)) { | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
dst_shape = TransShapeByPerm(src_shape, iter->second); | |||||
dst_shape = TransShapeByPerm(src_shape, perm_arg); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -31,6 +31,8 @@ Status TransposeWithShapeCheck(const uint8_t *src, const std::vector<int64_t> &s | |||||
const std::vector<int64_t> &dst_shape, DataType src_data_type, | const std::vector<int64_t> &dst_shape, DataType src_data_type, | ||||
const std::vector<int64_t> &perm_arg, TransResult &result); | const std::vector<int64_t> &perm_arg, TransResult &result); | ||||
Status GetPermByForamt(Format src_format, Format dst_format, std::vector<int64_t> &perm); | |||||
class FormatTransferTranspose : public FormatTransfer { | class FormatTransferTranspose : public FormatTransfer { | ||||
public: | public: | ||||
Status TransFormat(const TransArgs &args, TransResult &result) override; | Status TransFormat(const TransArgs &args, TransResult &result) override; | ||||
@@ -180,8 +180,7 @@ ModelHelper::SaveOriginalGraphToOmModel(const ge::Graph &graph, const std::strin | |||||
GELOGE(FAILED, "SaveModel fail for compute_graph null"); | GELOGE(FAILED, "SaveModel fail for compute_graph null"); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
ge::GraphUtils::DumpGEGraph(compute_graph, "OriginalGraph"); | |||||
ge::GraphUtils::DumpGEGraphToOnnx(*compute_graph, "OriginalGraph"); | |||||
GE_DUMP(compute_graph, "OriginalGraph"); | |||||
// Model | // Model | ||||
ModelPtr model_ptr = ge::MakeShared<ge::Model>(); | ModelPtr model_ptr = ge::MakeShared<ge::Model>(); | ||||
GE_CHECK_NOTNULL_EXEC(model_ptr, return MEMALLOC_FAILED); | GE_CHECK_NOTNULL_EXEC(model_ptr, return MEMALLOC_FAILED); | ||||
@@ -74,14 +74,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
// profiling startup first time | // profiling startup first time | ||||
GELOGI("Begin to init profiling, device num %zu", device_id_.size()); | |||||
for (size_t i = 0; i < device_id_.size(); ++i) { | for (size_t i = 0; i < device_id_.size(); ++i) { | ||||
ret = StartProfiling(0, device_id_[i]); | ret = StartProfiling(0, device_id_[i]); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Profiling start failed."); | |||||
GELOGE(ret, "Profiling start failed on device %d.", device_id_[i]); | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
GELOGI("Profiling init succ."); | |||||
GELOGI("Profiling init succ on device %d.", device_id_[i]); | |||||
} | } | ||||
} else { | |||||
GELOGI("The profiling is off, skip the initialization"); | |||||
} | } | ||||
#endif | #endif | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -164,7 +167,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::In | |||||
} | } | ||||
is_profiling_ = true; | is_profiling_ = true; | ||||
} catch (Json::parse_error &e) { | |||||
} catch (...) { | |||||
GELOGE(FAILED, "Json conf is not invalid !"); | GELOGE(FAILED, "Json conf is not invalid !"); | ||||
return ge::PARAM_INVALID; | return ge::PARAM_INVALID; | ||||
} | } | ||||
@@ -274,7 +277,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ge::Status ProfilingManager::St | |||||
ss << start_cfg; | ss << start_cfg; | ||||
send_profiling_config_ = ss.str(); | send_profiling_config_ = ss.str(); | ||||
GELOGI("Profiling config %s\n", send_profiling_config_.c_str()); | GELOGI("Profiling config %s\n", send_profiling_config_.c_str()); | ||||
} catch (Json::parse_error &e) { | |||||
} catch (...) { | |||||
GELOGE(FAILED, "Op trace json conf is not invalid !"); | GELOGE(FAILED, "Op trace json conf is not invalid !"); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
@@ -389,6 +389,7 @@ REGISTER_OPTYPE_DEFINE(STREAMMERGE, "StreamMerge"); | |||||
REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); | REGISTER_OPTYPE_DEFINE(ENDGRAPH, "EndGraph"); | ||||
REGISTER_OPTYPE_DEFINE(SEND, "Send"); | REGISTER_OPTYPE_DEFINE(SEND, "Send"); | ||||
REGISTER_OPTYPE_DEFINE(RECV, "Recv"); | REGISTER_OPTYPE_DEFINE(RECV, "Recv"); | ||||
REGISTER_OPTYPE_DEFINE(ENDOFSEQUENCE, "EndOfSequence"); | |||||
REGISTER_OPTYPE_DEFINE(LABELSET, "LabelSet"); | REGISTER_OPTYPE_DEFINE(LABELSET, "LabelSet"); | ||||
REGISTER_OPTYPE_DEFINE(LABELGOTO, "LabelGoto"); | REGISTER_OPTYPE_DEFINE(LABELGOTO, "LabelGoto"); | ||||
@@ -456,6 +457,12 @@ REGISTER_OPTYPE_DEFINE(SIGMOIDGRAD, "SigmoidGrad"); | |||||
REGISTER_OPTYPE_DEFINE(TRANSSHAPE, "TransShape"); | REGISTER_OPTYPE_DEFINE(TRANSSHAPE, "TransShape"); | ||||
// Horovod operator | |||||
REGISTER_OPTYPE_DEFINE(HVDCALLBACKALLREDUCE, "HorovodAllreduce"); | |||||
REGISTER_OPTYPE_DEFINE(HVDCALLBACKALLGATHER, "HorovodAllgather"); | |||||
REGISTER_OPTYPE_DEFINE(HVDCALLBACKBROADCAST, "HorovodBroadcast"); | |||||
REGISTER_OPTYPE_DEFINE(HVDWAIT, "HorovodWait"); | |||||
const std::string MODEL_ATTR_TASKS = "tasks"; | const std::string MODEL_ATTR_TASKS = "tasks"; | ||||
const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; | const std::string MODEL_ATTR_TASK_GEN_BASE_ADDR = "task_gen_base_addr"; | ||||
const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR = "task_gen_weight_addr"; | const std::string MODEL_ATTR_TASK_GEN_WEIGHT_ADDR = "task_gen_weight_addr"; | ||||
@@ -67,8 +67,9 @@ static bool ReadProtoFromCodedInputStream(CodedInputStream &coded_stream, Messag | |||||
} | } | ||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(const char *file, Message *proto) { | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr), return false, | |||||
"incorrect parameter. nullptr == file || nullptr == proto"); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file == nullptr || proto == nullptr), | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19001"); | |||||
return false, "Input parameter file or proto is nullptr!"); | |||||
std::string real_path = RealPath(file); | std::string real_path = RealPath(file); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "pb file path '%s' not valid", file); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return false, "pb file path '%s' not valid", file); | ||||
@@ -77,7 +78,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co | |||||
std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); | std::ifstream fs(real_path, std::ifstream::in | std::ifstream::binary); | ||||
if (!fs.is_open()) { | if (!fs.is_open()) { | ||||
GELOGE(ge::FAILED, "Open %s failed.", file); | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19004", {"realpath"}, {file}); | |||||
GELOGE(ge::FAILED, "Open real path[%s] failed.", file); | |||||
return false; | return false; | ||||
} | } | ||||
@@ -89,7 +91,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromBinaryFile(co | |||||
fs.close(); | fs.close(); | ||||
if (!ret) { | if (!ret) { | ||||
GELOGE(ge::FAILED, "Parse %s failed.", file); | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19005", {"filepath"}, {file}); | |||||
GELOGE(ge::FAILED, "Parse file[%s] failed.", file); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -113,17 +116,17 @@ long GetFileLength(const std::string &input_file) { | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), return -1, "input_file path '%s' not valid", input_file.c_str()); | ||||
unsigned long long file_length = 0; | unsigned long long file_length = 0; | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mmGetFileSize(input_file.c_str(), &file_length) != EN_OK, | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10037"); | |||||
return -1, "open file failed."); | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10037", {"filepath"}, {input_file}); | |||||
return -1, "Open file[%s] failed", input_file.c_str()); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), ErrorManager::GetInstance().ATCReportErrMessage("E10038"); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG((file_length == 0), ErrorManager::GetInstance().ATCReportErrMessage("E10038"); | ||||
return -1, "file length is 0, not valid."); | |||||
return -1, "File[%s] length is 0, not valid.", input_file.c_str()); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
file_length > kMaxFileSizeLimit, | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10039", {"filesize", "maxlen"}, | |||||
{std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); | |||||
return -1, "file size %lld is out of limit: %d.", file_length, kMaxFileSizeLimit); | |||||
file_length > kMaxFileSizeLimit, ErrorManager::GetInstance().ATCReportErrMessage( | |||||
"E10039", {"filepath", "filesize", "maxlen"}, | |||||
{input_file, std::to_string(file_length), std::to_string(kMaxFileSizeLimit)}); | |||||
return -1, "File[%s] size %lld is out of limit: %d.", input_file.c_str(), file_length, kMaxFileSizeLimit); | |||||
return static_cast<long>(file_length); | return static_cast<long>(file_length); | ||||
} | } | ||||
@@ -202,7 +205,9 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||||
GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); | GE_CHK_BOOL_EXEC(!directory_path.empty(), return -1, "directory path is empty."); | ||||
auto dir_path_len = directory_path.length(); | auto dir_path_len = directory_path.length(); | ||||
if (dir_path_len >= PATH_MAX) { | if (dir_path_len >= PATH_MAX) { | ||||
GELOGW("Directory path is too long."); | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, | |||||
{directory_path, std::to_string(PATH_MAX)}); | |||||
GELOGW("Path[%s] len is too long, it must smaller than %d", directory_path.c_str(), PATH_MAX); | |||||
return -1; | return -1; | ||||
} | } | ||||
char tmp_dir_path[PATH_MAX] = {0}; | char tmp_dir_path[PATH_MAX] = {0}; | ||||
@@ -213,8 +218,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||||
int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 | int32_t ret = mmMkdir(tmp_dir_path, S_IRUSR | S_IWUSR | S_IXUSR); // 700 | ||||
if (ret != 0) { | if (ret != 0) { | ||||
if (errno != EEXIST) { | if (errno != EEXIST) { | ||||
GELOGW("Cannot create directory %s. Make sure that the directory exists and writable.", | |||||
directory_path.c_str()); | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); | |||||
GELOGW("Cannot create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); | |||||
return ret; | return ret; | ||||
} | } | ||||
} | } | ||||
@@ -224,7 +229,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY int CreateDirectory(const std:: | |||||
int32_t ret = mmMkdir(const_cast<char *>(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 | int32_t ret = mmMkdir(const_cast<char *>(directory_path.c_str()), S_IRUSR | S_IWUSR | S_IXUSR); // 700 | ||||
if (ret != 0) { | if (ret != 0) { | ||||
if (errno != EEXIST) { | if (errno != EEXIST) { | ||||
GELOGW("Cannot create directory %s. Make sure that the directory exists and writable.", directory_path.c_str()); | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19006", {"path"}, {directory_path}); | |||||
GELOGW("Cannot create directory %s. Make sure the directory exists and writable.", directory_path.c_str()); | |||||
return ret; | return ret; | ||||
} | } | ||||
} | } | ||||
@@ -253,16 +259,17 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool ReadProtoFromText(const ch | |||||
std::string real_path = RealPath(file); | std::string real_path = RealPath(file); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(real_path.empty(), | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10036", {"realpath"}, {file}); | |||||
return false, "proto file real path '%s' not valid", file); | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10036", {"filepath"}, {file}); | |||||
return false, "Get path[%s]'s real path failed", file); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetFileLength(real_path) == -1, return false, "file size not valid."); | ||||
std::ifstream fs(real_path.c_str(), std::ifstream::in); | std::ifstream fs(real_path.c_str(), std::ifstream::in); | ||||
if (!fs.is_open()) { | if (!fs.is_open()) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10040", {"protofile"}, {file}); | |||||
GELOGE(ge::FAILED, "Fail to open proto file '%s'.", file); | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10040", {"realpth", "protofile"}, {real_path, file}); | |||||
GELOGE(ge::FAILED, "Fail to open proto file real path is '%s' when orginal file path is '%s'.", real_path.c_str(), | |||||
file); | |||||
return false; | return false; | ||||
} | } | ||||
@@ -328,18 +335,21 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInt64MulOverflow(int6 | |||||
FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY std::string RealPath(const char *path) { | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(path == nullptr, return "", "path pointer is NULL."); | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(strlen(path) >= PATH_MAX, return "", "path is invalid"); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
strlen(path) >= PATH_MAX, | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19002", {"filepath", "size"}, {path, std::to_string(PATH_MAX)}); | |||||
return "", "Path[%s] len is too long, it must smaller than %d", path, PATH_MAX); | |||||
// PATH_MAX is the system's own macro, indicating the maximum file path length supported | // PATH_MAX is the system's own macro, indicating the maximum file path length supported | ||||
std::shared_ptr<char> resolved_path(new (std::nothrow) char[PATH_MAX](), std::default_delete<char[]>()); | std::shared_ptr<char> resolved_path(new (std::nothrow) char[PATH_MAX](), std::default_delete<char[]>()); | ||||
if (resolved_path == nullptr) { | |||||
GELOGW("new an PATH_MAX string object failed."); | |||||
return ""; | |||||
} | |||||
std::string res; | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | |||||
resolved_path == nullptr, | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E19003", {"filepath", "size"}, {path, std::to_string(PATH_MAX)}); | |||||
return "", "Path[%s] new string object len[%d] failed.", path, PATH_MAX); | |||||
// Nullptr is returned when the path does not exist or there is no permission | // Nullptr is returned when the path does not exist or there is no permission | ||||
// Return absolute path when path is accessible | // Return absolute path when path is accessible | ||||
std::string res; | |||||
if (realpath(path, resolved_path.get()) != nullptr) { | if (realpath(path, resolved_path.get()) != nullptr) { | ||||
res = resolved_path.get(); | res = resolved_path.get(); | ||||
} | } | ||||
@@ -360,7 +370,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const | |||||
// Unable to get absolute path (does not exist or does not have permission to access) | // Unable to get absolute path (does not exist or does not have permission to access) | ||||
if (real_path.empty()) { | if (real_path.empty()) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)}); | ErrorManager::GetInstance().ATCReportErrMessage("E10002", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)}); | ||||
GELOGW("Can not get real path for %s, %s", file_path.c_str(), strerror(errno)); | |||||
GELOGW("Path[%s]'s realpath is empty, errmsg[%s]", file_path.c_str(), strerror(errno)); | |||||
return false; | return false; | ||||
} | } | ||||
@@ -381,7 +391,7 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckInputPathValid(const | |||||
// The absolute path points to a file that is not readable | // The absolute path points to a file that is not readable | ||||
if (access(real_path.c_str(), R_OK) != 0) { | if (access(real_path.c_str(), R_OK) != 0) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)}); | ErrorManager::GetInstance().ATCReportErrMessage("E10003", {"path", "errmsg"}, {file_path.c_str(), strerror(errno)}); | ||||
GELOGW("Read path[%s] failed, %s", file_path.c_str(), strerror(errno)); | |||||
GELOGW("Read path[%s] failed, errmsg[%s]", file_path.c_str(), strerror(errno)); | |||||
return false; | return false; | ||||
} | } | ||||
@@ -416,9 +426,10 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY bool CheckOutputPathValid(const | |||||
// File is not readable or writable | // File is not readable or writable | ||||
if (access(real_path.c_str(), W_OK | F_OK) != 0) { | if (access(real_path.c_str(), W_OK | F_OK) != 0) { | ||||
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"path", "errmsg"}, | |||||
{real_path.c_str(), strerror(errno)}); | |||||
GELOGW("Write file failed, path[%s], %s", real_path.c_str(), strerror(errno)); | |||||
ErrorManager::GetInstance().ATCReportErrMessage("E10004", {"realpath", "path", "errmsg"}, | |||||
{real_path, file_path, strerror(errno)}); | |||||
GELOGW("Write file[%s] failed, input path is %s, errmsg[%s]", real_path.c_str(), file_path.c_str(), | |||||
strerror(errno)); | |||||
return false; | return false; | ||||
} | } | ||||
} else { | } else { | ||||
@@ -59,12 +59,15 @@ file(GLOB SRC_LIST RELATIVE ${CMAKE_CURRENT_LIST_DIR} | |||||
"../graph/load/new_model_manager/tbe_handle_store.cc" | "../graph/load/new_model_manager/tbe_handle_store.cc" | ||||
"../graph/load/new_model_manager/zero_copy_task.cc" | "../graph/load/new_model_manager/zero_copy_task.cc" | ||||
"../graph/load/output/output.cc" | "../graph/load/output/output.cc" | ||||
"../graph/manager/graph_caching_allocator.cc" | |||||
"../graph/manager/graph_manager_utils.cc" | "../graph/manager/graph_manager_utils.cc" | ||||
"../graph/manager/graph_mem_allocator.cc" | "../graph/manager/graph_mem_allocator.cc" | ||||
"../graph/manager/graph_var_manager.cc" | "../graph/manager/graph_var_manager.cc" | ||||
"../graph/manager/trans_var_data_utils.cc" | "../graph/manager/trans_var_data_utils.cc" | ||||
"../graph/manager/util/debug.cc" | "../graph/manager/util/debug.cc" | ||||
"../hybrid/hybrid_davinci_model_stub.cc" | |||||
"../model/ge_model.cc" | "../model/ge_model.cc" | ||||
"../model/ge_root_model.cc" | |||||
"../omm/csa_interact.cc" | "../omm/csa_interact.cc" | ||||
"../single_op/single_op.cc" | "../single_op/single_op.cc" | ||||
"../single_op/single_op_manager.cc" | "../single_op/single_op_manager.cc" | ||||
@@ -20,6 +20,7 @@ | |||||
#include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
#include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/utils/node_utils.h" | |||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "op/op_factory.h" | #include "op/op_factory.h" | ||||
@@ -68,6 +69,15 @@ Status GeLocalOpsKernelInfoStore::CalcOpRunningParam(Node &ge_node) { | |||||
GELOGE(FAILED, "CalcOpRunningParam failed, as op desc is null"); | GELOGE(FAILED, "CalcOpRunningParam failed, as op desc is null"); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
bool is_shape_unknown = false; | |||||
if (NodeUtils::GetNodeUnknownShapeStatus(ge_node, is_shape_unknown) == GRAPH_SUCCESS) { | |||||
if (is_shape_unknown) { | |||||
GELOGI("op:%s is unknown shape, does not need to calc output size.", ge_node.GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
} | |||||
const string node_name = ge_node.GetName(); | const string node_name = ge_node.GetName(); | ||||
const string node_type = ge_node.GetType(); | const string node_type = ge_node.GetType(); | ||||
size_t output_size = op_desc->GetOutputsSize(); | size_t output_size = op_desc->GetOutputsSize(); | ||||
@@ -157,6 +167,13 @@ Status GeLocalOpsKernelInfoStore::CalcConstantStrMemSize(const OpDescPtr &op_des | |||||
void GeLocalOpsKernelInfoStore::GetAllOpsKernelInfo(map<string, OpInfo> &infos) const { infos = op_info_map_; } | void GeLocalOpsKernelInfoStore::GetAllOpsKernelInfo(map<string, OpInfo> &infos) const { infos = op_info_map_; } | ||||
Status GeLocalOpsKernelInfoStore::GenerateTask(const Node &node, RunContext &context, vector<TaskDef> &tasks) { | Status GeLocalOpsKernelInfoStore::GenerateTask(const Node &node, RunContext &context, vector<TaskDef> &tasks) { | ||||
bool is_shape_unknown = false; | |||||
if (NodeUtils::GetNodeUnknownShapeStatus(node, is_shape_unknown) == GRAPH_SUCCESS) { | |||||
if (is_shape_unknown) { | |||||
GELOGI("op:%s is unknown shape, does not need to generate task", node.GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
} | |||||
string name = node.GetName(); | string name = node.GetName(); | ||||
string type = node.GetType(); | string type = node.GetType(); | ||||
GELOGD("Ge local generate task for node:%s(%s) begin, tasks.size()=%zu.", name.c_str(), type.c_str(), tasks.size()); | GELOGD("Ge local generate task for node:%s(%s) begin, tasks.size()=%zu.", name.c_str(), type.c_str(), tasks.size()); | ||||
@@ -128,17 +128,17 @@ bool HcclTask::Distribute() { | |||||
ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL); | ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL); | ||||
ge_task.stream = stream_; | ge_task.stream = stream_; | ||||
ge_task.kernelHcclInfo.hccl_type = task_info_->hccl_type(); | |||||
ge_task.kernelHcclInfo.inputDataAddr = task_info_->input_data_addr(); | |||||
ge_task.kernelHcclInfo.outputDataAddr = task_info_->output_data_addr(); | |||||
ge_task.kernelHcclInfo.workSpaceAddr = task_info_->workspace_addr(); | |||||
ge_task.kernelHcclInfo.workSpaceMemSize = task_info_->workspace_size(); | |||||
ge_task.kernelHcclInfo.count = task_info_->count(); | |||||
ge_task.kernelHcclInfo.dataType = static_cast<int32_t>(task_info_->data_type()); | |||||
ge_task.kernelHcclInfo.opType = static_cast<int32_t>(task_info_->op_type()); | |||||
ge_task.kernelHcclInfo.rootId = task_info_->root_id(); | |||||
ge_task.kernelHcclInfo.hcclStreamList = slave_stream_list_; | |||||
ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); | |||||
ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); | |||||
ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); | |||||
ge_task.kernelHcclInfo[0].workSpaceAddr = task_info_->workspace_addr(); | |||||
ge_task.kernelHcclInfo[0].workSpaceMemSize = task_info_->workspace_size(); | |||||
ge_task.kernelHcclInfo[0].count = task_info_->count(); | |||||
ge_task.kernelHcclInfo[0].dataType = static_cast<int32_t>(task_info_->data_type()); | |||||
ge_task.kernelHcclInfo[0].opType = static_cast<int32_t>(task_info_->op_type()); | |||||
ge_task.kernelHcclInfo[0].rootId = task_info_->root_id(); | |||||
ge_task.kernelHcclInfo[0].hcclStreamList = slave_stream_list_; | |||||
ge_task.privateDef = private_def; | ge_task.privateDef = private_def; | ||||
ge_task.privateDefLen = private_def_len; | ge_task.privateDefLen = private_def_len; | ||||
@@ -27,6 +27,7 @@ | |||||
#include "graph/opsproto_manager.h" | #include "graph/opsproto_manager.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "model/ge_model.h" | #include "model/ge_model.h" | ||||
#include "init/gelib.h" | |||||
using std::map; | using std::map; | ||||
using std::string; | using std::string; | ||||
@@ -34,9 +35,79 @@ using std::vector; | |||||
namespace { | namespace { | ||||
const char *const kAttrOpType = "op_type"; | const char *const kAttrOpType = "op_type"; | ||||
} | |||||
const char *const kEngineNameDefault = "default"; | |||||
const char *const kVectorEngine = "VectorEngine"; | |||||
const char *const kAIcoreEngine = "AIcoreEngine"; | |||||
const char *const kFileNameSuffix = "online"; | |||||
std::map<ge::OpEngineType, std::string> engine_type_map{ | |||||
{ge::ENGINE_SYS, kEngineNameDefault}, {ge::ENGINE_AICORE, kAIcoreEngine}, {ge::ENGINE_VECTOR, kVectorEngine}}; | |||||
} // namespace | |||||
namespace ge { | namespace ge { | ||||
static Status CheckEngineTypeSupport(const OpDescPtr &op_desc, OpEngineType engine_type) { | |||||
GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); | |||||
if (engine_type == ENGINE_SYS) { | |||||
GELOGI("CheckEngineType: use default engine."); | |||||
return SUCCESS; | |||||
} | |||||
// get op engine name | |||||
string op_engine_name; | |||||
auto iter = engine_type_map.find(engine_type); | |||||
if (iter != engine_type_map.end()) { | |||||
op_engine_name = iter->second; | |||||
GELOGI("CheckEngineType: engine type: %d", static_cast<int>(engine_type)); | |||||
} else { | |||||
GELOGE(FAILED, "CheckEngineType: engine type: %d not support", static_cast<int>(engine_type)); | |||||
return FAILED; | |||||
} | |||||
// set op engine name and opkernelLib. when engine support | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
if ((instance_ptr == nullptr) || (!instance_ptr->InitFlag())) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "CheckEngineType failed."); | |||||
return FAILED; | |||||
} | |||||
OpsKernelManager &ops_kernel_manager = instance_ptr->OpsKernelManagerObj(); | |||||
std::vector<OpInfo> op_infos = ops_kernel_manager.GetOpsKernelInfo(op_desc->GetType()); | |||||
if (op_infos.empty()) { | |||||
GELOGE(FAILED, "CheckEngineType: Can not get op info by op type %s", op_desc->GetType().c_str()); | |||||
return FAILED; | |||||
} | |||||
string kernel_name; | |||||
for (const auto &it : op_infos) { | |||||
if (it.engine == op_engine_name) { | |||||
kernel_name = it.opKernelLib; | |||||
break; | |||||
} | |||||
} | |||||
if (kernel_name.empty()) { | |||||
GELOGE(FAILED, "CheckEngineType:Can not find ops kernel,engine name: %s.", op_engine_name.c_str()); | |||||
return FAILED; | |||||
} | |||||
auto &kernel_map = ops_kernel_manager.GetAllOpsKernelInfoStores(); | |||||
auto kernel_info_store = kernel_map.find(kernel_name); | |||||
if (kernel_info_store != kernel_map.end()) { | |||||
std::string unsupported_reason; | |||||
if (kernel_info_store->second->CheckSupported(op_desc, unsupported_reason)) { | |||||
op_desc->SetOpEngineName(op_engine_name); | |||||
op_desc->SetOpKernelLibName(kernel_name); | |||||
GELOGI("CheckEngineType:Set OpKernelLibName %s and engine name %s into op_desc %s", kernel_name.c_str(), | |||||
op_engine_name.c_str(), op_desc->GetName().c_str()); | |||||
return SUCCESS; | |||||
} else { | |||||
GELOGE(FAILED, "CheckEngineType: check support failed, Op type %s of ops kernel %s is unsupported, reason:%s", | |||||
op_desc->GetType().c_str(), kernel_name.c_str(), unsupported_reason.c_str()); | |||||
return FAILED; | |||||
} | |||||
} else { | |||||
GELOGE(FAILED, | |||||
"CheckEngineType:Can not find any supported ops kernel info store by kernel_name %s," | |||||
"op type is %s, op name is %s", | |||||
kernel_name.c_str(), op_desc->GetType().c_str(), op_desc->GetName().c_str()); | |||||
} | |||||
return FAILED; | |||||
} | |||||
static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const GeTensorDesc &tensor, int32_t index, | static Status AddInputs(const ComputeGraphPtr &graph, const NodePtr &node, const GeTensorDesc &tensor, int32_t index, | ||||
bool attr) { | bool attr) { | ||||
GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID); | GE_CHECK_NOTNULL_EXEC(graph, return PARAM_INVALID); | ||||
@@ -96,7 +167,7 @@ static Status AddOutputs(const ComputeGraphPtr &graph, const NodePtr &node, cons | |||||
} | } | ||||
static void GetOpsProtoPath(string &opsproto_path) { | static void GetOpsProtoPath(string &opsproto_path) { | ||||
GELOGI("Start to get ops proto path schedule"); | |||||
GELOGI("Start to get ops proto path schedule."); | |||||
const char *path_env = std::getenv("ASCEND_OPP_PATH"); | const char *path_env = std::getenv("ASCEND_OPP_PATH"); | ||||
if (path_env != nullptr) { | if (path_env != nullptr) { | ||||
string path = path_env; | string path = path_env; | ||||
@@ -105,7 +176,7 @@ static void GetOpsProtoPath(string &opsproto_path) { | |||||
GELOGE(FAILED, "File path %s is invalid.", path.c_str()); | GELOGE(FAILED, "File path %s is invalid.", path.c_str()); | ||||
return; | return; | ||||
} | } | ||||
opsproto_path = (path + "/op_proto/built-in/" + ":") + (path + "/op_proto/custom/"); | |||||
opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/"); | |||||
GELOGI("Get opsproto so path from env : %s", path.c_str()); | GELOGI("Get opsproto so path from env : %s", path.c_str()); | ||||
return; | return; | ||||
} | } | ||||
@@ -113,15 +184,14 @@ static void GetOpsProtoPath(string &opsproto_path) { | |||||
GELOGI("path_base is %s", path_base.c_str()); | GELOGI("path_base is %s", path_base.c_str()); | ||||
path_base = path_base.substr(0, path_base.rfind('/')); | path_base = path_base.substr(0, path_base.rfind('/')); | ||||
path_base = path_base.substr(0, path_base.rfind('/') + 1); | path_base = path_base.substr(0, path_base.rfind('/') + 1); | ||||
opsproto_path = (path_base + "ops/op_proto/built-in/" + ":") + (path_base + "ops/op_proto/custom/"); | |||||
opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); | |||||
} | } | ||||
class GeGenerator::Impl { | class GeGenerator::Impl { | ||||
public: | public: | ||||
Status BuildModel(const Graph &graph, const vector<GeTensor> &inputs, GraphId &graph_id, | |||||
vector<GeModelPtr> &ge_models); | |||||
Status BuildModel(const Graph &graph, const vector<GeTensor> &inputs, GraphId &graph_id, GeRootModelPtr &ge_models); | |||||
Status SaveModel(const string &file_name_prefix, vector<GeModelPtr> &models, ModelBufferData &model); | |||||
Status SaveModel(const string &file_name_prefix, GeModelPtr &models, ModelBufferData &model); | |||||
Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs, | Status SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs, | ||||
const vector<GeTensor> &inputs, const vector<GeTensor> &outputs); | const vector<GeTensor> &inputs, const vector<GeTensor> &outputs); | ||||
@@ -141,7 +211,7 @@ Status GeGenerator::Initialize(const map<string, string> &options) { | |||||
} | } | ||||
string opsproto_path; | string opsproto_path; | ||||
GetOpsProtoPath(opsproto_path); | GetOpsProtoPath(opsproto_path); | ||||
GELOGI("opsproto_path is %s", opsproto_path.c_str()); | |||||
GELOGI("Get opsproto path is %s", opsproto_path.c_str()); | |||||
OpsProtoManager *manager = OpsProtoManager::Instance(); | OpsProtoManager *manager = OpsProtoManager::Instance(); | ||||
map<string, string> option_tmp; | map<string, string> option_tmp; | ||||
option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | ||||
@@ -149,7 +219,7 @@ Status GeGenerator::Initialize(const map<string, string> &options) { | |||||
Status ret = impl_->graph_manager_.Initialize(options); | Status ret = impl_->graph_manager_.Initialize(options); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, "Graph manager initialize failed"); | |||||
GELOGE(GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED, "Graph manager initialize failed."); | |||||
return GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED; | return GE_GENERATOR_GRAPH_MANAGER_INIT_FAILED; | ||||
} | } | ||||
// get ek file | // get ek file | ||||
@@ -179,7 +249,7 @@ Status GeGenerator::Finalize() { | |||||
GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); | GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); | ||||
Status ret = impl_->graph_manager_.Finalize(); | Status ret = impl_->graph_manager_.Finalize(); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "Graph manager finalize failed"); | |||||
GELOGE(GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED, "Graph manager finalize failed."); | |||||
return GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED; | return GE_GENERATOR_GRAPH_MANAGER_FINALIZE_FAILED; | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -187,7 +257,7 @@ Status GeGenerator::Finalize() { | |||||
Status GeGenerator::GenerateOfflineModel(const Graph &graph, const string &file_name_prefix, | Status GeGenerator::GenerateOfflineModel(const Graph &graph, const string &file_name_prefix, | ||||
const vector<GeTensor> &inputs) { | const vector<GeTensor> &inputs) { | ||||
GELOGI("Start to GenerateOfflineModel."); | |||||
GELOGI("Start to generate offline model."); | |||||
ModelBufferData model; | ModelBufferData model; | ||||
return GenerateModel(graph, file_name_prefix, inputs, model, true); | return GenerateModel(graph, file_name_prefix, inputs, model, true); | ||||
} | } | ||||
@@ -208,25 +278,22 @@ Status GeGenerator::GenerateInfershapeGraph(const Graph &graph) { | |||||
} | } | ||||
return ret; | return ret; | ||||
} | } | ||||
GELOGI("GenerateInfershapeJson success."); | |||||
GELOGI("GenerateInfershapeGraph success."); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_prefix, const vector<GeTensor> &inputs, | ||||
ModelBufferData &model, bool is_offline) { | ModelBufferData &model, bool is_offline) { | ||||
GraphId graph_id; | GraphId graph_id; | ||||
vector<GeModelPtr> ge_models; | |||||
GeRootModelPtr ge_root_model = nullptr; | |||||
GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); | GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); | ||||
string model_name; | |||||
auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||||
if (compute_graph == nullptr) { | |||||
GELOGW("Get compute graph fail."); | |||||
} else { | |||||
model_name = compute_graph->GetName(); | |||||
} | |||||
// using output as model_name (ignore ".om") | |||||
int start_position = file_name_prefix.find_last_of('/') + 1; | |||||
int end_position = file_name_prefix.length() - 3; | |||||
const string model_name = file_name_prefix.substr(start_position, end_position - start_position); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(model_name.empty(), return PARAM_INVALID, "om name is not valid!"); | |||||
impl_->is_offline_ = is_offline; | impl_->is_offline_ = is_offline; | ||||
Status ret = impl_->BuildModel(graph, inputs, graph_id, ge_models); | |||||
Status ret = impl_->BuildModel(graph, inputs, graph_id, ge_root_model); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Build model failed"); | GELOGE(ret, "Build model failed"); | ||||
if (impl_->graph_manager_.Finalize() != SUCCESS) { | if (impl_->graph_manager_.Finalize() != SUCCESS) { | ||||
@@ -234,11 +301,14 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||||
} | } | ||||
return ret; | return ret; | ||||
} | } | ||||
if (!model_name.empty() && !ge_models.empty()) { | |||||
ge_models[0]->SetName(model_name); | |||||
} | |||||
ret = impl_->SaveModel(file_name_prefix, ge_models, model); | |||||
GE_CHECK_NOTNULL(ge_root_model); | |||||
GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); | |||||
map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); | |||||
GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; | |||||
GE_RETURN_WITH_LOG_IF_FALSE(ge_model != nullptr, "ge_model can not be null"); | |||||
ge_model->SetName(model_name); | |||||
ret = impl_->SaveModel(file_name_prefix, ge_model, model); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Save model failed"); | GELOGE(ret, "Save model failed"); | ||||
if (impl_->graph_manager_.Finalize() != SUCCESS) { | if (impl_->graph_manager_.Finalize() != SUCCESS) { | ||||
@@ -250,17 +320,9 @@ Status GeGenerator::GenerateModel(const Graph &graph, const string &file_name_pr | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
/** | |||||
* @ingroup ge | |||||
* @brief Compiling a single operator into an offline model | |||||
* @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file | |||||
* @param [in] vector<GeTensor> &inputs: Operator input data description information. | |||||
* @param [in] vector<GeTensor> &outputs: Operator output data description information. | |||||
* @param [in] const string &model_file_name: Offline model filename. | |||||
* @return SUCCESS handle successfully / others handle failed | |||||
*/ | |||||
Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs, | |||||
const vector<GeTensor> &outputs, const string &model_file_name) { | |||||
Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs, | |||||
const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, | |||||
bool is_offline) { | |||||
GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); | GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); | ||||
if (!inputs.empty() && (inputs.size() != op_desc->GetInputsSize())) { | if (!inputs.empty() && (inputs.size() != op_desc->GetInputsSize())) { | ||||
GELOGE(PARAM_INVALID, "Tensor size: %zu, Inputs size:%zu", inputs.size(), op_desc->GetInputsSize()); | GELOGE(PARAM_INVALID, "Tensor size: %zu, Inputs size:%zu", inputs.size(), op_desc->GetInputsSize()); | ||||
@@ -275,7 +337,16 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor | |||||
OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc); | OpDescPtr op_desc_tmp = AttrUtils::CloneOpDesc(op_desc); | ||||
GE_CHECK_NOTNULL(op_desc_tmp); | GE_CHECK_NOTNULL(op_desc_tmp); | ||||
// 1. Create ComputeGraph. | |||||
// 1. check engine type when compile online | |||||
if (model_file_name == kFileNameSuffix) { | |||||
Status ret = CheckEngineTypeSupport(op_desc, engine_type); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "check engine type failed."); | |||||
return ret; | |||||
} | |||||
} | |||||
// 2. Create ComputeGraph. | |||||
string name = ge::CurrentTimeInStr() + "_" + model_file_name; | string name = ge::CurrentTimeInStr() + "_" + model_file_name; | ||||
ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(name); | ge::ComputeGraphPtr compute_graph = MakeShared<ComputeGraph>(name); | ||||
if (compute_graph == nullptr) { | if (compute_graph == nullptr) { | ||||
@@ -283,9 +354,11 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor | |||||
} | } | ||||
GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR); | GE_CHECK_NOTNULL_EXEC(compute_graph, return INTERNAL_ERROR); | ||||
// 2. Add Node to ComputeGraph. | |||||
// 3. Add Node to ComputeGraph. | |||||
NodePtr op_node = compute_graph->AddNode(op_desc); | NodePtr op_node = compute_graph->AddNode(op_desc); | ||||
GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR); | GE_CHECK_NOTNULL_EXEC(op_node, return INTERNAL_ERROR); | ||||
// 4. Create InputData node. | |||||
int32_t arg_index = 0; | int32_t arg_index = 0; | ||||
if (inputs.empty()) { | if (inputs.empty()) { | ||||
for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | for (const auto &input_desc : op_desc->GetAllInputsDescPtr()) { | ||||
@@ -301,7 +374,7 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor | |||||
} | } | ||||
} | } | ||||
// 4. Create Output node. | |||||
// 5. Create Output node. | |||||
if (!outputs.empty()) { | if (!outputs.empty()) { | ||||
GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs)); | GE_CHK_STATUS_RET_NOLOG(AddOutputs(compute_graph, op_node, outputs)); | ||||
} | } | ||||
@@ -312,41 +385,69 @@ Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor | |||||
GELOGI("ATC parser success in single op schedule."); | GELOGI("ATC parser success in single op schedule."); | ||||
GraphId graph_id; | GraphId graph_id; | ||||
vector<GeModelPtr> ge_models; | |||||
GeRootModelPtr ge_root_model = nullptr; | |||||
GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); | GE_CHECK_NOTNULL_EXEC(impl_, return PARAM_INVALID); | ||||
GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, inputs, graph_id, ge_models)); | |||||
impl_->is_offline_ = is_offline; | |||||
GE_CHK_STATUS_RET_NOLOG(impl_->BuildModel(graph, inputs, graph_id, ge_root_model)); | |||||
map<string, GeAttrValue> op_attrs = op_desc_tmp->GetAllAttrs(); | |||||
GE_CHECK_NOTNULL(ge_root_model); | |||||
GE_CHECK_NOTNULL(ge_root_model->GetRootGraph()); | |||||
map<string, GeModelPtr> name_to_ge_model = ge_root_model->GetSubgraphInstanceNameToModel(); | |||||
GeModelPtr &ge_model = name_to_ge_model[ge_root_model->GetRootGraph()->GetName()]; | |||||
GELOGD("The opType in op_desc_tmp is: %s", op_desc_tmp->GetType().c_str()); | |||||
GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_model, op_desc_tmp->GetType(), op_attrs, inputs, outputs)); | |||||
GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_model, model_buff)); | |||||
return SUCCESS; | |||||
} | |||||
if (!ge_models.empty()) { | |||||
map<string, GeAttrValue> op_attrs = op_desc_tmp->GetAllAttrs(); | |||||
GELOGI("The opType in op_desc_tmp is: %s", op_desc_tmp->GetType().c_str()); | |||||
GE_CHK_STATUS_RET_NOLOG(impl_->SaveParams(ge_models[0], op_desc_tmp->GetType(), op_attrs, inputs, outputs)); | |||||
} | |||||
/** | |||||
* @ingroup ge | |||||
* @brief Compiling a single operator into an offline model | |||||
* @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file | |||||
* @param [in] vector<GeTensor> &inputs: Operator input data description information. | |||||
* @param [in] vector<GeTensor> &outputs: Operator output data description information. | |||||
* @param [in] const string &model_file_name: Offline model filename. | |||||
* @return SUCCESS handle successfully / others handle failed | |||||
*/ | |||||
Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs, | |||||
const vector<GeTensor> &outputs, const string &model_file_name) { | |||||
GELOGI("Start to Build Single Op Offline Model."); | |||||
ModelBufferData model_buff; | ModelBufferData model_buff; | ||||
GE_CHK_STATUS_RET_NOLOG(impl_->SaveModel(model_file_name, ge_models, model_buff)); | |||||
return SUCCESS; | |||||
OpEngineType engine_type = ENGINE_SYS; | |||||
return BuildSingleOp(op_desc, inputs, outputs, model_file_name, engine_type, model_buff, true); | |||||
} | |||||
/** | |||||
* @ingroup ge | |||||
* @brief Compiling a single operator into online buffer | |||||
* @param [in] OpDescPtr &op_desc: Operator description info that needs to be compiled into an offline model file | |||||
* @param [in] vector<GeTensor> &inputs: Operator input data description information. | |||||
* @param [in] vector<GeTensor> &outputs: Operator output data description information. | |||||
* @param [in] engine_type: specific engine. | |||||
* @param [out] ModelBufferData &Model_buff: Model_buff: model buffer of the op. | |||||
* @return SUCCESS handle successfully / others handle failed | |||||
*/ | |||||
Status GeGenerator::BuildSingleOpModel(OpDescPtr &op_desc, const vector<GeTensor> &inputs, | |||||
const vector<GeTensor> &outputs, OpEngineType engine_type, | |||||
ModelBufferData &model_buff) { | |||||
GELOGI("Start to Build Single Op Online"); | |||||
return BuildSingleOp(op_desc, inputs, outputs, kFileNameSuffix, engine_type, model_buff, false); | |||||
} | } | ||||
Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs, | Status GeGenerator::Impl::SaveParams(GeModelPtr &ge_model, const string &type, const map<string, GeAttrValue> &attrs, | ||||
const vector<GeTensor> &inputs, const vector<GeTensor> &outputs) { | const vector<GeTensor> &inputs, const vector<GeTensor> &outputs) { | ||||
GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID); | GE_CHECK_NOTNULL_EXEC(ge_model, return PARAM_INVALID); | ||||
GE_CHK_BOOL_EXEC_NOLOG(graph_manager_.SaveParams(*ge_model, type, attrs, inputs, outputs) == SUCCESS, | GE_CHK_BOOL_EXEC_NOLOG(graph_manager_.SaveParams(*ge_model, type, attrs, inputs, outputs) == SUCCESS, | ||||
graph_manager_.Finalize(); | |||||
(void)graph_manager_.Finalize(); | |||||
return FAILED); | return FAILED); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, vector<GeModelPtr> &models, | |||||
ModelBufferData &model_buff) { | |||||
// to be change to ModelHelper interface | |||||
if (models.empty()) { | |||||
GELOGE(FAILED, "models are empty."); | |||||
return FAILED; | |||||
} | |||||
Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, GeModelPtr &model, ModelBufferData &model_buff) { | |||||
ModelHelper model_helper; | ModelHelper model_helper; | ||||
model_helper.SetSaveMode(is_offline_); | model_helper.SetSaveMode(is_offline_); | ||||
Status ret = model_helper.SaveToOmModel(models[0], save_param_, file_name_prefix, model_buff); | |||||
Status ret = model_helper.SaveToOmModel(model, save_param_, file_name_prefix, model_buff); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Save to Om model failed"); | GELOGE(ret, "Save to Om model failed"); | ||||
return ret; | return ret; | ||||
@@ -355,20 +456,21 @@ Status GeGenerator::Impl::SaveModel(const string &file_name_prefix, vector<GeMod | |||||
} | } | ||||
Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs, GraphId &graph_id, | Status GeGenerator::Impl::BuildModel(const Graph &graph, const vector<GeTensor> &inputs, GraphId &graph_id, | ||||
vector<GeModelPtr> &ge_models) { | |||||
GeRootModelPtr &ge_root_model) { | |||||
static GraphId id = 0; | static GraphId id = 0; | ||||
const std::map<std::string, std::string> options; | const std::map<std::string, std::string> options; | ||||
Status ret = graph_manager_.AddGraph(id, graph, options); | Status ret = graph_manager_.AddGraph(id, graph, options); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "graphManager AddGraph failed, id: %u", id); | |||||
graph_manager_.Finalize(); | |||||
GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "GraphManager add graph failed, id: %u", id); | |||||
(void)graph_manager_.Finalize(); | |||||
return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED; | return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED; | ||||
} | } | ||||
GELOGI("models' inputs.size()=%zu", inputs.size()); | |||||
ret = graph_manager_.BuildGraph(id, inputs, ge_models); | |||||
GELOGI("models inputs.size()=%zu", inputs.size()); | |||||
graph_manager_.SetOptionsRunGraphFlag(false); | |||||
ret = graph_manager_.BuildGraph(id, inputs, ge_root_model); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "graphManager BuildGraph failed, id: %u", id); | |||||
GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager build graph failed, id: %u", id); | |||||
return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; | return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; | ||||
} | } | ||||
@@ -383,14 +485,14 @@ Status GeGenerator::Impl::GenerateInfershapeGraph(const Graph &graph, GraphId &g | |||||
const std::map<std::string, std::string> options; | const std::map<std::string, std::string> options; | ||||
Status ret = graph_manager_.AddGraph(id, graph, options); | Status ret = graph_manager_.AddGraph(id, graph, options); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "graphManager AddGraph failed, id: %u", id); | |||||
graph_manager_.Finalize(); | |||||
GELOGE(GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED, "graphManager add graph failed, id: %u", id); | |||||
(void)graph_manager_.Finalize(); | |||||
return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED; | return GE_GENERATOR_GRAPH_MANAGER_ADD_GRAPH_FAILED; | ||||
} | } | ||||
ret = graph_manager_.GenerateInfershapeGraph(id); | ret = graph_manager_.GenerateInfershapeGraph(id); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "graphManager BuildGraph failed, id: %u", id); | |||||
GELOGE(GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED, "GraphManager BuildGraph failed, id: %u", id); | |||||
return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; | return GE_GENERATOR_GRAPH_MANAGER_BUILD_GRAPH_FAILED; | ||||
} | } | ||||
@@ -53,6 +53,7 @@ Status GraphBuilder::CalcOpParam(const ge::ComputeGraphPtr &graph) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GraphBuilder: GE is not initialized"); | GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GraphBuilder: GE is not initialized"); | ||||
return GE_CLI_GE_NOT_INITIALIZED; | return GE_CLI_GE_NOT_INITIALIZED; | ||||
} | } | ||||
for (const auto &node_ptr : graph->GetAllNodes()) { | for (const auto &node_ptr : graph->GetAllNodes()) { | ||||
GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); | GE_CHECK_NOTNULL(node_ptr->GetOpDesc()); | ||||
std::string kernel_lib_name = node_ptr->GetOpDesc()->GetOpKernelLibName(); | std::string kernel_lib_name = node_ptr->GetOpDesc()->GetOpKernelLibName(); | ||||
@@ -84,76 +85,229 @@ Status GraphBuilder::CalcOpParam(const ge::ComputeGraphPtr &graph) { | |||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
} | } | ||||
auto parent_node = graph->GetParentNode(); | |||||
if (parent_node == nullptr) { | |||||
GELOGI("Graph[%s] do not have parent node, no need update parent node output size.", graph->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
GE_CHK_STATUS_RET(UpdateParentNodeOutputSize(graph, parent_node)); | |||||
GELOGI("Success to calculate op running param."); | GELOGI("Success to calculate op running param."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphBuilder::UpdateParentNodeOutputSize(const ge::ComputeGraphPtr &graph, ge::NodePtr &parent_node_ptr) { | |||||
GELOGI("Begin to update parent node[%s] of graph[%s] output size.", parent_node_ptr->GetName().c_str(), | |||||
graph->GetName().c_str()); | |||||
auto parent_op_desc = parent_node_ptr->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(parent_op_desc); | |||||
bool is_unknown_shape = false; | |||||
if (!AttrUtils::GetBool(parent_op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape)) { | |||||
GELOGE(PARAM_INVALID, "Get op %s unknown shape attr failed.", parent_op_desc->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (is_unknown_shape) { | |||||
GELOGI("Current graph[%s] is unknown, no need to update parent node[%s] output size.", graph->GetName().c_str(), | |||||
parent_node_ptr->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
for (const auto &node_ptr : graph->GetDirectNode()) { | |||||
if (node_ptr->GetType() != NETOUTPUT) { | |||||
continue; | |||||
} | |||||
auto op_desc = node_ptr->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
for (const auto &in_data_anchor : node_ptr->GetAllInDataAnchors()) { | |||||
auto index = in_data_anchor->GetIdx(); | |||||
ge::GeTensorDesc desc_temp = op_desc->GetInputDesc(index); | |||||
int64_t size = 0; | |||||
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(desc_temp, size) != SUCCESS, GELOGI("Get size failed!")); | |||||
uint32_t parent_index = 0; | |||||
if (!AttrUtils::GetInt(desc_temp, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | |||||
GELOGE(INTERNAL_ERROR, "NetOutput input tensor %d, attr %s not found.", index, | |||||
ATTR_NAME_PARENT_NODE_INDEX.c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
ge::GeTensorDesc parent_desc_temp = parent_op_desc->GetOutputDesc(parent_index); | |||||
ge::TensorUtils::SetSize(parent_desc_temp, size); | |||||
GE_CHK_STATUS_RET(parent_op_desc->UpdateOutputDesc(parent_index, parent_desc_temp)); | |||||
GELOGI("Update parent node[%s] output index[%u] to size[%ld].", parent_node_ptr->GetName().c_str(), parent_index, | |||||
size); | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfoPtr> &subgraph_ptr_list, | Status GraphBuilder::Build(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfoPtr> &subgraph_ptr_list, | ||||
GeModelPtr &ge_model_ptr, uint64_t session_id) { | |||||
GeRootModelPtr &ge_root_model_ptr, uint64_t session_id) { | |||||
GELOGI("Start to build model."); | GELOGI("Start to build model."); | ||||
if (comp_graph == nullptr) { | if (comp_graph == nullptr) { | ||||
GELOGE(GE_GRAPH_PARAM_NULLPTR, "Graph build comp_graph is null."); | GELOGE(GE_GRAPH_PARAM_NULLPTR, "Graph build comp_graph is null."); | ||||
return GE_GRAPH_PARAM_NULLPTR; | return GE_GRAPH_PARAM_NULLPTR; | ||||
} | } | ||||
ge_root_model_ptr = MakeShared<ge::GeRootModel>(comp_graph); | |||||
if (ge_root_model_ptr == nullptr) { | |||||
return MEMALLOC_FAILED; | |||||
} | |||||
GeModelPtr ge_model_ptr = nullptr; | |||||
bool is_dynamic_shape = false; | |||||
// To be compatible with the old process, do not verify the return value temporarily. | |||||
(void)AttrUtils::GetBool(comp_graph, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, is_dynamic_shape); | |||||
if (is_dynamic_shape) { | |||||
GE_CHK_STATUS_RET( | |||||
BuildForDynamicShapeGraph(comp_graph, subgraph_ptr_list, ge_root_model_ptr, ge_model_ptr, session_id), | |||||
"Build for dynamic shape graph failed."); | |||||
return SUCCESS; | |||||
} | |||||
GE_CHK_STATUS_RET(BuildForKnownShapeGraph(comp_graph, subgraph_ptr_list, ge_model_ptr, session_id), | |||||
"Build for known shape graph failed."); | |||||
ge_root_model_ptr->SetSubgraphInstanceNameToModel(comp_graph->GetName(), ge_model_ptr); | |||||
return SUCCESS; | |||||
} | |||||
Status GraphBuilder::BuildForKnownShapeGraph(ComputeGraphPtr &comp_graph, | |||||
std::vector<SubGraphInfoPtr> &subgraph_ptr_list, GeModelPtr &ge_model_ptr, | |||||
uint64_t session_id) { | |||||
GELOGI("Begin to build known shape graph[%s].", comp_graph->GetName().c_str()); | |||||
Status ret = SecondPartition(comp_graph, subgraph_ptr_list); | Status ret = SecondPartition(comp_graph, subgraph_ptr_list); | ||||
GE_CHK_STATUS_RET(ret, "Graph second partition Failed."); | |||||
GE_CHK_STATUS_RET(ret, "Graph[%s] second partition Failed.", comp_graph->GetName().c_str()); | |||||
auto subgraph_map = graph_partitioner_.GetSubGraphMap(); | auto subgraph_map = graph_partitioner_.GetSubGraphMap(); | ||||
GE_TIMESTAMP_START(BuildSubgraph); | GE_TIMESTAMP_START(BuildSubgraph); | ||||
ge::ModelBuilder builder(comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); | ge::ModelBuilder builder(comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); | ||||
GELOGI("[Build] invoke the other opskernel to generate task."); | |||||
GraphUtils::DumpGEGraph(comp_graph, "BeforePreBuildModel"); | |||||
GraphUtils::DumpGEGraphToOnnx(*comp_graph, "BeforePreBuildModel"); | |||||
GE_DUMP(comp_graph, "BeforePreBuildModel"); | |||||
GE_TIMESTAMP_START(PreBuildModel); | GE_TIMESTAMP_START(PreBuildModel); | ||||
GE_CHK_STATUS_RET(builder.PreBuildModel(), "Builder PreBuildModel() return fail."); | |||||
GE_CHK_STATUS_RET(builder.PreBuildModel(), "Graph[%s] builder PreBuildModel() return fail.", | |||||
comp_graph->GetName().c_str()); | |||||
GE_TIMESTAMP_END(PreBuildModel, "GraphBuilder::PreBuildModel"); | GE_TIMESTAMP_END(PreBuildModel, "GraphBuilder::PreBuildModel"); | ||||
GraphUtils::DumpGEGraph(comp_graph, "AfterPrebuildmodel"); | |||||
GraphUtils::DumpGEGraphToOnnx(*comp_graph, "AfterPrebuildmodel"); | |||||
GE_DUMP(comp_graph, "AfterPreBuildModel"); | |||||
GE_TIMESTAMP_START(CalcOpParam); | GE_TIMESTAMP_START(CalcOpParam); | ||||
GE_CHK_STATUS_RET(CalcOpParam(comp_graph), "Builder CalcOpParam() return fail."); | |||||
GE_CHK_STATUS_RET(CalcOpParam(comp_graph), "Graph[%s] builder CalcOpParam() return fail.", | |||||
comp_graph->GetName().c_str()); | |||||
GE_TIMESTAMP_END(CalcOpParam, "GraphBuilder::CalcOpParam"); | GE_TIMESTAMP_END(CalcOpParam, "GraphBuilder::CalcOpParam"); | ||||
GraphUtils::DumpGEGraph(comp_graph, "AfterCalcOpParam"); | |||||
GraphUtils::DumpGEGraphToOnnx(*comp_graph, "AfterCalcOpParam"); | |||||
GE_DUMP(comp_graph, "AfterCalcOpParam"); | |||||
ModelPtr model_ptr = MakeShared<ge::Model>(); | ModelPtr model_ptr = MakeShared<ge::Model>(); | ||||
if (model_ptr == nullptr) { | if (model_ptr == nullptr) { | ||||
return MEMALLOC_FAILED; | return MEMALLOC_FAILED; | ||||
} | } | ||||
GE_TIMESTAMP_START(BuildModelForGetTask); | GE_TIMESTAMP_START(BuildModelForGetTask); | ||||
GE_CHK_STATUS_RET(builder.BuildModelForGetTask(*model_ptr), "Builder BuildModelForGetTask() return fail."); | |||||
GE_CHK_STATUS_RET(builder.BuildModelForGetTask(*model_ptr), "Graph[%s] builder BuildModelForGetTask() return fail.", | |||||
comp_graph->GetName().c_str()); | |||||
GE_TIMESTAMP_END(BuildModelForGetTask, "GraphBuilder::BuildModelForGetTask"); | GE_TIMESTAMP_END(BuildModelForGetTask, "GraphBuilder::BuildModelForGetTask"); | ||||
GraphUtils::DumpGEGraph(comp_graph, "AfterBuildModel"); | |||||
GraphUtils::DumpGEGraphToOnnx(*comp_graph, "AfterBuildModel"); | |||||
GE_DUMP(comp_graph, "AfterBuildModel"); | |||||
GE_TIMESTAMP_START(GetTaskInfo); | GE_TIMESTAMP_START(GetTaskInfo); | ||||
ret = GetTaskInfo(builder, model_ptr, comp_graph, subgraph_map, session_id); | ret = GetTaskInfo(builder, model_ptr, comp_graph, subgraph_map, session_id); | ||||
GE_TIMESTAMP_END(GetTaskInfo, "GraphBuilder::GetTaskInfo"); | GE_TIMESTAMP_END(GetTaskInfo, "GraphBuilder::GetTaskInfo"); | ||||
GraphUtils::DumpGEGraph(comp_graph, "AfterGetTask"); | |||||
GraphUtils::DumpGEGraphToOnnx(*comp_graph, "AfterGetTask"); | |||||
GE_DUMP(comp_graph, "AfterGetTask"); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Builder GetTaskInfo() return fail."); | |||||
GELOGE(ret, "Graph[%s] builder GetTaskInfo() return fail.", comp_graph->GetName().c_str()); | |||||
return ret; | return ret; | ||||
} | } | ||||
for (auto graph : comp_graph->GetAllSubgraphs()) { | |||||
GraphUtils::DumpGEGraphToOnnx(*graph, "SubgraphGetTask"); | |||||
ge_model_ptr = MakeShared<ge::GeModel>(); | |||||
if (ge_model_ptr == nullptr) { | |||||
return MEMALLOC_FAILED; | |||||
} | } | ||||
GE_CHK_STATUS_RET(builder.SaveDataToModel(*model_ptr, *ge_model_ptr), | |||||
"Graph[%s] builder SaveDataToModel() return fail.", comp_graph->GetName().c_str()); | |||||
GELOGI("Success to build graph[%s] model.", comp_graph->GetName().c_str()); | |||||
GE_TIMESTAMP_END(BuildSubgraph, "GraphBuilder::Build"); | |||||
return SUCCESS; | |||||
} | |||||
Status GraphBuilder::BuildForUnknownShapeGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, | |||||
uint64_t session_id) { | |||||
GELOGI("Begin to build unknown shape graph[%s].", comp_graph->GetName().c_str()); | |||||
GE_TIMESTAMP_START(CalcOpParam); | |||||
GE_CHK_STATUS_RET(CalcOpParam(comp_graph), "Graph[%s] builder CalcOpParam() return fail.", | |||||
comp_graph->GetName().c_str()); | |||||
GE_TIMESTAMP_END(CalcOpParam, "GraphBuilder::CalcOpParam"); | |||||
GE_DUMP(comp_graph, "AfterCalcOpParam"); | |||||
Graph2SubGraphInfoList subgraph_map; | |||||
ge::ModelBuilder builder(comp_graph, subgraph_map, stream_max_parallel_num_, hcom_parallel_, build_mode_); | |||||
ModelPtr model_ptr = MakeShared<ge::Model>(); | |||||
if (model_ptr == nullptr) { | |||||
return MEMALLOC_FAILED; | |||||
} | |||||
GE_TIMESTAMP_START(BuildModelForGetDynShapeTask); | |||||
GE_CHK_STATUS_RET(builder.BuildModelForGetDynShapeTask(*model_ptr), | |||||
"Graph[%s] builder BuildModelForGetDynShapeTask() return fail.", comp_graph->GetName().c_str()); | |||||
GE_TIMESTAMP_END(BuildModelForGetDynShapeTask, "GraphBuilder::BuildModelForGetDynShapeTask"); | |||||
GE_TIMESTAMP_START(GetTaskInfo); | |||||
Status ret = GetTaskInfo(builder, model_ptr, comp_graph, subgraph_map, session_id); | |||||
GE_TIMESTAMP_END(GetTaskInfo, "GraphBuilder::GetTaskInfo"); | |||||
GraphUtils::DumpGEGraph(comp_graph, "AfterGetTask"); | |||||
GraphUtils::DumpGEGraphToOnnx(*comp_graph, "AfterGetTask"); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Graph[%s] builder GetTaskInfo() return fail.", comp_graph->GetName().c_str()); | |||||
return ret; | |||||
} | |||||
ge_model_ptr = MakeShared<ge::GeModel>(); | ge_model_ptr = MakeShared<ge::GeModel>(); | ||||
if (ge_model_ptr == nullptr) { | if (ge_model_ptr == nullptr) { | ||||
return MEMALLOC_FAILED; | return MEMALLOC_FAILED; | ||||
} | } | ||||
GE_CHK_STATUS_RET(builder.SaveDataToModel(*model_ptr, *ge_model_ptr), "model builder SaveDataToModel() return fail."); | |||||
GELOGI("Success to build model."); | |||||
GE_TIMESTAMP_END(BuildSubgraph, "GraphBuilder::Build"); | |||||
GE_CHK_STATUS_RET(builder.SaveDataToModel(*model_ptr, *ge_model_ptr), | |||||
"Graph[%s] builder SaveDataToModel() return fail.", comp_graph->GetName().c_str()); | |||||
GELOGI("Success to build graph[%s] model.", comp_graph->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
Status GraphBuilder::BuildForDynamicShapeGraph(ComputeGraphPtr &comp_graph, | |||||
std::vector<SubGraphInfoPtr> &subgraph_ptr_list, | |||||
GeRootModelPtr &ge_root_model_ptr, GeModelPtr &ge_model_ptr, | |||||
uint64_t session_id) { | |||||
GELOGI("Start to build BuildForDynamicShape for dynamic shape."); | |||||
for (const auto &node : comp_graph->GetDirectNode()) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (node->GetType() == DATA) { | |||||
GE_CHK_STATUS_RET(CalcDynShapeRootGraphDataSize(op_desc), "Calc dynamic shape root graph data[%s] size failed.", | |||||
op_desc->GetName().c_str()); | |||||
} | |||||
// ATTR_NAME_IS_UNKNOWN_SHAPE is set on "graph partion" stage, but afer fusion , the graph may | |||||
// be changed so here need to renew. For example , the scene followed: | |||||
// (known)partioncall(known) (known)partioncall(known) | |||||
// After fusion | |||||
// | --> | |||||
// (known)Unique(unknown)--->(unknow)Shape(unknown) (known)FuncDef(known) | |||||
// if scene like this , it should be process as known shape graph | |||||
bool is_unknown_shape = false; | |||||
GE_CHK_STATUS_RET(ge::NodeUtils::GetNodeUnknownShapeStatus(*node, is_unknown_shape), | |||||
"Get node[%s] shape status failed!", node->GetName().c_str()); | |||||
if (!is_unknown_shape) { | |||||
GE_CHK_BOOL_EXEC(ge::AttrUtils::SetBool(op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape), return FAILED, | |||||
"Renew node [%s] attr[%s] failed!", node->GetName().c_str(), ATTR_NAME_IS_UNKNOWN_SHAPE.c_str()); | |||||
GELOGD("renew node [%s] attr[%s] success! value is %d", node->GetName().c_str(), | |||||
ATTR_NAME_IS_UNKNOWN_SHAPE.c_str(), is_unknown_shape); | |||||
} | |||||
vector<string> subgraph_names = op_desc->GetSubgraphInstanceNames(); | |||||
for (auto subgraph_name : subgraph_names) { | |||||
ComputeGraphPtr subgraph = comp_graph->GetSubgraph(subgraph_name); | |||||
bool is_unknown_shape = false; | |||||
if (!AttrUtils::GetBool(op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape)) { | |||||
GELOGE(PARAM_INVALID, "Get op %s unknown shape attr failed.", op_desc->GetName().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (is_unknown_shape) { | |||||
// unknown shape build flow | |||||
GE_CHK_STATUS_RET(BuildForUnknownShapeGraph(subgraph, ge_model_ptr, session_id), | |||||
"Build for unknown shape graph failed."); | |||||
} else { | |||||
// known shape build flow | |||||
GE_CHK_STATUS_RET(BuildForKnownShapeGraph(subgraph, subgraph_ptr_list, ge_model_ptr, session_id), | |||||
"Build for known shape graph failed."); | |||||
} | |||||
ge_root_model_ptr->SetSubgraphInstanceNameToModel(subgraph_name, ge_model_ptr); | |||||
} | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -199,10 +353,7 @@ Status GraphBuilder::GetTaskInfo(const ge::ModelBuilder &builder, const ModelPtr | |||||
GELOGE(ret, "Optimize streamed subGraph fail."); | GELOGE(ret, "Optimize streamed subGraph fail."); | ||||
return ret; | return ret; | ||||
} | } | ||||
GraphUtils::DumpGEGraph(comp_graph, "AfterOptimizeStreamedSubGraph"); | |||||
GraphUtils::DumpGEGraphToOnnx(*comp_graph, "AfterOptimizeStreamedSubGraph"); | |||||
GE_DUMP(comp_graph, "AfterOptimizeStreamedSubGraph"); | |||||
auto *get_var_mem_base = | auto *get_var_mem_base = | ||||
reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(ge::VarManager::Instance(0)->GetVarMemLogicBase())); | reinterpret_cast<uint8_t *>(reinterpret_cast<uintptr_t>(ge::VarManager::Instance(0)->GetVarMemLogicBase())); | ||||
uint64_t var_size = (ge::VarManager::Instance(session_id)->GetVarMemSize(RT_MEMORY_HBM) > 0) | uint64_t var_size = (ge::VarManager::Instance(session_id)->GetVarMemSize(RT_MEMORY_HBM) > 0) | ||||
@@ -289,6 +440,36 @@ Status GraphBuilder::UpdateDataInputSize(const ge::NodePtr &node_ptr) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphBuilder::CalcDynShapeRootGraphDataSize(const ge::OpDescPtr &op_desc) { | |||||
GELOGI("Begin to calc dynamic shape graph data[%s] size.", op_desc->GetName().c_str()); | |||||
// data op only has one output anchor | |||||
ge::GeTensorDesc output_desc = op_desc->GetOutputDesc(0); | |||||
int64_t output_size = 0; | |||||
if (ge::TensorUtils::GetSize(output_desc, output_size) != SUCCESS) { | |||||
GELOGW("Get size failed!"); | |||||
} | |||||
if (output_size > 0) { | |||||
GELOGI("No need to update dynamic shape graph data output size[%ld].", output_size); | |||||
return SUCCESS; | |||||
} else { | |||||
int64_t real_dim_size = 0; | |||||
ge::graphStatus graph_status = TensorUtils::GetTensorSizeInBytes(output_desc, real_dim_size); | |||||
if (graph_status != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Get tensor size in bytes failed."); | |||||
return FAILED; | |||||
} | |||||
ge::TensorUtils::SetSize(output_desc, real_dim_size); | |||||
GELOGI("Update dynamic shape graph data output size to [%ld].", real_dim_size); | |||||
if (op_desc->UpdateOutputDesc(0, output_desc) != GRAPH_SUCCESS) { | |||||
GELOGE(FAILED, "Update dynamic shape graph data output desc size failed."); | |||||
return FAILED; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status GraphBuilder::SecondPartition(ge::ComputeGraphPtr &comp_graph, vector<ge::SubGraphInfoPtr> &subgraph_ptr_list) { | Status GraphBuilder::SecondPartition(ge::ComputeGraphPtr &comp_graph, vector<ge::SubGraphInfoPtr> &subgraph_ptr_list) { | ||||
GELOGI("[SecondPartition] second partition."); | GELOGI("[SecondPartition] second partition."); | ||||
GE_TIMESTAMP_START(GraphPartition2); | GE_TIMESTAMP_START(GraphPartition2); | ||||
@@ -38,6 +38,7 @@ | |||||
#include "graph/partition/graph_partition.h" | #include "graph/partition/graph_partition.h" | ||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "model/ge_root_model.h" | |||||
namespace ge { | namespace ge { | ||||
class GraphBuilder { | class GraphBuilder { | ||||
@@ -46,8 +47,8 @@ class GraphBuilder { | |||||
GraphBuilder(const GraphBuilder &in) = delete; | GraphBuilder(const GraphBuilder &in) = delete; | ||||
GraphBuilder &operator=(const GraphBuilder &in) = delete; | GraphBuilder &operator=(const GraphBuilder &in) = delete; | ||||
virtual ~GraphBuilder() = default; | virtual ~GraphBuilder() = default; | ||||
Status Build(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfoPtr> &subgraph_ptr_list, GeModelPtr &ge_model_ptr, | |||||
uint64_t session_id = INVALID_SESSION_ID); | |||||
Status Build(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfoPtr> &subgraph_ptr_list, | |||||
GeRootModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); | |||||
void SetOptions(const GraphManagerOptions &options); | void SetOptions(const GraphManagerOptions &options); | ||||
private: | private: | ||||
@@ -56,8 +57,16 @@ class GraphBuilder { | |||||
Graph2SubGraphInfoList &subgraph_map, uint64_t session_id = INVALID_SESSION_ID); | Graph2SubGraphInfoList &subgraph_map, uint64_t session_id = INVALID_SESSION_ID); | ||||
Status SetInputSize(const ge::NodePtr &node_ptr); | Status SetInputSize(const ge::NodePtr &node_ptr); | ||||
Status UpdateDataInputSize(const ge::NodePtr &node_ptr); | Status UpdateDataInputSize(const ge::NodePtr &node_ptr); | ||||
Status UpdateParentNodeOutputSize(const ge::ComputeGraphPtr &graph, ge::NodePtr &parent_node_ptr); | |||||
Status CalcDynShapeRootGraphDataSize(const ge::OpDescPtr &op_desc); | |||||
Status SecondPartition(ge::ComputeGraphPtr &comp_graph, vector<ge::SubGraphInfoPtr> &subgraph_ptr_list); | Status SecondPartition(ge::ComputeGraphPtr &comp_graph, vector<ge::SubGraphInfoPtr> &subgraph_ptr_list); | ||||
Status BuildForDynamicShapeGraph(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfoPtr> &subgraph_ptr_list, | |||||
GeRootModelPtr &ge_root_model_ptr, GeModelPtr &ge_model_ptr, | |||||
uint64_t session_id = INVALID_SESSION_ID); | |||||
Status BuildForKnownShapeGraph(ComputeGraphPtr &comp_graph, std::vector<SubGraphInfoPtr> &subgraph_ptr_list, | |||||
GeModelPtr &ge_model_ptr, uint64_t session_id = INVALID_SESSION_ID); | |||||
Status BuildForUnknownShapeGraph(ComputeGraphPtr &comp_graph, GeModelPtr &ge_model_ptr, | |||||
uint64_t session_id = INVALID_SESSION_ID); | |||||
int build_mode_; | int build_mode_; | ||||
std::map<std::string, int> stream_max_parallel_num_; | std::map<std::string, int> stream_max_parallel_num_; | ||||
@@ -512,13 +512,14 @@ Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector<SubgraphPt | |||||
} | } | ||||
GELOGI("AllReduceParallelPass is enabled."); | GELOGI("AllReduceParallelPass is enabled."); | ||||
GraphUtils::DumpGEGraph(graph, "BeforeAllReduceParallel"); | |||||
GE_DUMP(graph, "BeforeAllReduceParallel"); | |||||
// All successors of HcomAllReduce. | // All successors of HcomAllReduce. | ||||
set<NodePtr> all_reduce_succs; | set<NodePtr> all_reduce_succs; | ||||
for (const NodePtr &node : graph->GetDirectNode()) { | for (const NodePtr &node : graph->GetDirectNode()) { | ||||
if (node->GetType() != HCOMALLREDUCE || node->GetInDataNodes().size() <= 1) { | |||||
if ((node->GetType() != HCOMALLREDUCE && node->GetType() != HVDCALLBACKALLREDUCE) || | |||||
node->GetInDataNodes().size() <= 1) { | |||||
continue; | continue; | ||||
} | } | ||||
@@ -534,7 +535,10 @@ Status AllReduceParallelPass::Run(ComputeGraphPtr graph, const vector<SubgraphPt | |||||
string out_stream_label; | string out_stream_label; | ||||
GE_CHECK_NOTNULL(out_node->GetOpDesc()); | GE_CHECK_NOTNULL(out_node->GetOpDesc()); | ||||
(void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, out_stream_label); | (void)AttrUtils::GetStr(out_node->GetOpDesc(), ATTR_NAME_STREAM_LABEL, out_stream_label); | ||||
if (out_stream_label == reduce_stream_label) { | |||||
// normally, Allreduce do not have streamLabel. when in horovod scenario Allreduce will have streamLabel | |||||
bool isSuccessorParallel = | |||||
(out_stream_label == reduce_stream_label) || (!reduce_stream_label.empty() && out_stream_label.empty()); | |||||
if (isSuccessorParallel) { | |||||
all_reduce_succs.emplace(out_node); | all_reduce_succs.emplace(out_node); | ||||
all_out_data_nodes.emplace(out_node); | all_out_data_nodes.emplace(out_node); | ||||
} | } | ||||
@@ -54,13 +54,42 @@ using std::unordered_map; | |||||
using std::unordered_set; | using std::unordered_set; | ||||
using std::vector; | using std::vector; | ||||
void MemoryBlock::SetHeadOffset(size_t offset) { | |||||
head_offset_ = offset; | |||||
size_t child_offset = head_offset_; | |||||
for (auto block : child_blocks_) { | |||||
if (block != nullptr) { | |||||
block->SetHeadOffset(child_offset); | |||||
child_offset += block->Size(); | |||||
} | |||||
} | |||||
} | |||||
void MemoryBlock::SetTailOffset(size_t offset) { | |||||
tail_offset_ = offset; | |||||
size_t child_offset = head_offset_; | |||||
for (auto block : child_blocks_) { | |||||
if (block != nullptr) { | |||||
child_offset += block->Size(); | |||||
block->SetTailOffset(child_offset - 1); | |||||
} | |||||
} | |||||
} | |||||
void MemoryBlock::Resize() { | void MemoryBlock::Resize() { | ||||
size_t child_block_size = 0; | |||||
for (auto block : child_blocks_) { | |||||
if (block != nullptr) { | |||||
block->Resize(); | |||||
child_block_size += block->Size(); | |||||
} | |||||
} | |||||
auto iter = std::max_element(real_size_list_.begin(), real_size_list_.end()); | auto iter = std::max_element(real_size_list_.begin(), real_size_list_.end()); | ||||
if (iter == real_size_list_.end()) { | if (iter == real_size_list_.end()) { | ||||
GELOGW("real_size_list_ is empty"); | GELOGW("real_size_list_ is empty"); | ||||
return; | return; | ||||
} else { | } else { | ||||
size_t block_size = *iter; | |||||
size_t block_size = (child_block_size > *iter) ? child_block_size : *iter; | |||||
if ((block_size > 0) && (block_size % MEM_ALIGN_SIZE != 0)) { | if ((block_size > 0) && (block_size % MEM_ALIGN_SIZE != 0)) { | ||||
block_size = (block_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE; | block_size = (block_size + MEM_ALIGN_SIZE - 1) / MEM_ALIGN_SIZE * MEM_ALIGN_SIZE; | ||||
} | } | ||||
@@ -102,6 +131,68 @@ bool MemoryBlock::IsSameLabel(std::string &first_batch_label) { | |||||
return all_same_label; | return all_same_label; | ||||
} | } | ||||
bool CanNotLifeReuse(MemoryBlock *block) { | |||||
if (block == nullptr || !block->reuse_mem_ || block->deleted_block_ || block->continuous_block_ || | |||||
block->GetLifeEnd() == kMaxLifeTime) { | |||||
return true; | |||||
} | |||||
return false; | |||||
} | |||||
void MemoryBlock::AddLifeReuseBlock(MemoryBlock *block) { | |||||
if (CanNotLifeReuse(this) || CanNotLifeReuse(block)) { | |||||
return; | |||||
} | |||||
MemoryBlock *parent = nullptr; | |||||
MemoryBlock *child = nullptr; | |||||
// merge small block to large block | |||||
if ((block->GetLifeBegin() > GetLifeEnd()) && (block->stream_id_ == stream_id_)) { | |||||
if ((child_offset_ + block->block_size_) <= block_size_) { | |||||
parent = this; | |||||
child = block; | |||||
} else if ((block->child_offset_ + block_size_) <= block->block_size_) { | |||||
parent = block; | |||||
child = this; | |||||
} | |||||
} | |||||
if ((parent != nullptr) && (child != nullptr) && child->child_blocks_.empty()) { | |||||
parent->child_blocks_.emplace_back(child); | |||||
parent->child_offset_ += child->block_size_; | |||||
child->deleted_block_ = true; | |||||
GELOGI( | |||||
"Add block stream id:%ld [size:%zu, life time[begin:%zu, end:%zu]] to" | |||||
" block[size:%zu, life time[begin:%zu, end:%zu]]", | |||||
stream_id_, child->block_size_, child->GetLifeBegin(), child->GetLifeEnd(), parent->block_size_, | |||||
parent->GetLifeBegin(), parent->GetLifeEnd()); | |||||
} | |||||
} | |||||
size_t MemoryBlock::GetLifeBegin() { | |||||
size_t life_time = 0; | |||||
if (!node_type_index_list_.empty()) { | |||||
if (node_type_index_list_.front().node != nullptr) { | |||||
auto node_op_desc = node_type_index_list_.front().node->GetOpDesc(); | |||||
if (node_op_desc != nullptr) { | |||||
life_time = node_op_desc->GetId(); | |||||
} | |||||
} | |||||
} | |||||
return life_time; | |||||
} | |||||
size_t MemoryBlock::GetLifeEnd() { | |||||
if (!node_type_index_list_.empty()) { | |||||
return node_type_index_list_.back().life_time_end; | |||||
} | |||||
return kMaxLifeTime; | |||||
} | |||||
void MemoryBlock::SetLifeTimeEnd(size_t time) { | |||||
if (!node_type_index_list_.empty()) { | |||||
node_type_index_list_.back().life_time_end = time; | |||||
} | |||||
} | |||||
void SetLastUsedInputMemAttr(NodePtr &node, int input_index) { | void SetLastUsedInputMemAttr(NodePtr &node, int input_index) { | ||||
if (node == nullptr) { | if (node == nullptr) { | ||||
return; | return; | ||||
@@ -122,6 +213,27 @@ void SetLastUsedInputMemAttr(NodePtr &node, int input_index) { | |||||
} | } | ||||
} | } | ||||
Status GetNoAlignSize(const ge::OpDesc &desc, uint32_t index, size_t &size) { | |||||
// calculate tensor real size | |||||
auto output_op_desc = desc.GetOutputDescPtr(index); | |||||
if (output_op_desc == nullptr) { | |||||
GELOGI("GetNoAlignSize failed. OpName: %s, OpType: %s, index: %d", desc.GetName().c_str(), desc.GetType().c_str(), | |||||
index); | |||||
return FAILED; | |||||
} | |||||
int64_t tensor_size = 0; | |||||
GeShape shape = output_op_desc->GetShape(); | |||||
Format format = output_op_desc->GetFormat(); | |||||
DataType data_type = output_op_desc->GetDataType(); | |||||
graphStatus graph_status = TensorUtils::CalcTensorMemSize(shape, format, data_type, tensor_size); | |||||
if (graph_status != GRAPH_SUCCESS) { | |||||
GELOGE(graph_status, "CalcTensorMemSize failed!"); | |||||
return FAILED; | |||||
} | |||||
size = static_cast<size_t>(tensor_size); | |||||
return SUCCESS; | |||||
} | |||||
string ToString(ge::NodeTypeIndex &x) { | string ToString(ge::NodeTypeIndex &x) { | ||||
stringstream ss; | stringstream ss; | ||||
ss << "[" << x.node->GetName() << "(" << x.node->GetType() << "), "; | ss << "[" << x.node->GetName() << "(" << x.node->GetType() << "), "; | ||||
@@ -150,7 +262,7 @@ string MemoryBlock::String() { | |||||
} | } | ||||
BlockMemAssigner::BlockMemAssigner(ge::ComputeGraphPtr compute_graph) | BlockMemAssigner::BlockMemAssigner(ge::ComputeGraphPtr compute_graph) | ||||
: mem_offset_(0), compute_graph_(std::move(compute_graph)) {} | |||||
: mem_offset_(0), compute_graph_(std::move(compute_graph)), life_time_(0) {} | |||||
BlockMemAssigner::~BlockMemAssigner() { | BlockMemAssigner::~BlockMemAssigner() { | ||||
for (MemoryBlock *memory_block : memory_blocks_) { | for (MemoryBlock *memory_block : memory_blocks_) { | ||||
@@ -290,8 +402,9 @@ bool CanReuseBySize(const map<string, uint64_t> &reusable_block_counts, const Me | |||||
// continuous memory case:only real_size is maximum can be reused and only one continuous memory in one block | // continuous memory case:only real_size is maximum can be reused and only one continuous memory in one block | ||||
if (continuous || reusable_block.continuous_block_) { | if (continuous || reusable_block.continuous_block_) { | ||||
auto it = std::max_element(std::begin(reusable_block.RealSizeList()), std::end(reusable_block.RealSizeList())); | |||||
if (it != std::end(reusable_block.RealSizeList())) { | |||||
auto it = | |||||
std::max_element(std::begin(reusable_block.NoAlignSizeList()), std::end(reusable_block.NoAlignSizeList())); | |||||
if (it != std::end(reusable_block.NoAlignSizeList())) { | |||||
GE_IF_BOOL_EXEC((continuous && reusable_block.continuous_block_) || (continuous && (real_size < *it)) || | GE_IF_BOOL_EXEC((continuous && reusable_block.continuous_block_) || (continuous && (real_size < *it)) || | ||||
(reusable_block.continuous_block_ && (real_size > *it)), | (reusable_block.continuous_block_ && (real_size > *it)), | ||||
GELOGD("Conflict current block size:%zu continuous:%d, reuse block max size:%zu continuous:%d", | GELOGD("Conflict current block size:%zu continuous:%d, reuse block max size:%zu continuous:%d", | ||||
@@ -498,25 +611,29 @@ void BlockMemAssigner::PrintSymbolMap() { | |||||
} | } | ||||
} | } | ||||
MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, MemoryType mem_type, const NodePtr &n, | |||||
uint32_t out_index, const vector<bool> &workspace_reuse_flag, | |||||
const bool is_op_reuse_mem, const bool continuous) { | |||||
MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, size_t no_align_size, | |||||
MemoryType mem_type, const NodePtr &n, uint32_t out_index, | |||||
const vector<bool> &workspace_reuse_flag, const bool is_op_reuse_mem, | |||||
const bool continuous) { | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "Input parameter n is null."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(n == nullptr, return nullptr, "Input parameter n is null."); | ||||
auto node_op_desc = n->GetOpDesc(); | auto node_op_desc = n->GetOpDesc(); | ||||
GE_IF_BOOL_EXEC(node_op_desc == nullptr, return nullptr); | GE_IF_BOOL_EXEC(node_op_desc == nullptr, return nullptr); | ||||
bool is_reuse_memory = false; | |||||
string ge_disable_reuse_mem_env = "0"; | string ge_disable_reuse_mem_env = "0"; | ||||
(void)ge::GetContext().GetOption(kDisableReuseMemory, ge_disable_reuse_mem_env); | (void)ge::GetContext().GetOption(kDisableReuseMemory, ge_disable_reuse_mem_env); | ||||
if (ge_disable_reuse_mem_env != "1") { | if (ge_disable_reuse_mem_env != "1") { | ||||
bool reuse_mem_flag = !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]); | bool reuse_mem_flag = !((workspace_reuse_flag.size() > out_index) && !workspace_reuse_flag[out_index]); | ||||
bool is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && reuse_mem_flag && is_op_reuse_mem && | |||||
(IsPreReuse(n, out_index)); | |||||
is_reuse_memory = !node_op_desc->HasAttr(kL2FusionDynamicConvergeOp) && reuse_mem_flag && is_op_reuse_mem && | |||||
(IsPreReuse(n, out_index)); | |||||
auto stream_id = node_op_desc->GetStreamId(); | auto stream_id = node_op_desc->GetStreamId(); | ||||
auto map_iter = reusable_streams_map_.find(stream_id); | auto map_iter = reusable_streams_map_.find(stream_id); | ||||
if (is_reuse_memory && map_iter != reusable_streams_map_.end()) { | if (is_reuse_memory && map_iter != reusable_streams_map_.end()) { | ||||
for (auto it = reusable_blocks_.begin(); it != reusable_blocks_.end(); ++it) { | for (auto it = reusable_blocks_.begin(); it != reusable_blocks_.end(); ++it) { | ||||
MemoryBlock *reusable_block = *it; | MemoryBlock *reusable_block = *it; | ||||
if (!IsPostReuse(reusable_block)) { | if (!IsPostReuse(reusable_block)) { | ||||
reusable_block->reuse_mem_ = false; | |||||
GELOGI("Unreusable block."); | |||||
continue; | continue; | ||||
} | } | ||||
@@ -526,7 +643,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, | |||||
CanReuseByStream(map_iter->second, *reusable_block)) { | CanReuseByStream(map_iter->second, *reusable_block)) { | ||||
GELOGD("Cross stream mem reuse, target stream:%ld, current stream:%ld", reusable_block->stream_id_, | GELOGD("Cross stream mem reuse, target stream:%ld, current stream:%ld", reusable_block->stream_id_, | ||||
stream_id); | stream_id); | ||||
reusable_block->AddNodeTypeIndex({n, mem_type, out_index}, real_size); | |||||
reusable_block->AddNodeTypeIndex({n, mem_type, out_index}, real_size, no_align_size); | |||||
if (mem_type == kOutput) { | if (mem_type == kOutput) { | ||||
auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString()); | auto iter = anchor_to_symbol_.find(NodeIndexIO(n, out_index, kOut).ToString()); | ||||
if (iter != anchor_to_symbol_.end()) { | if (iter != anchor_to_symbol_.end()) { | ||||
@@ -543,7 +660,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, | |||||
} | } | ||||
} | } | ||||
auto block = new (std::nothrow) MemoryBlock(block_size, is_op_reuse_mem); | |||||
auto block = new (std::nothrow) MemoryBlock(block_size, is_reuse_memory); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(block == nullptr, return nullptr, "new an object failed."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(block == nullptr, return nullptr, "new an object failed."); | ||||
// Data and netoutput need zero copy block | // Data and netoutput need zero copy block | ||||
@@ -551,7 +668,7 @@ MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, | |||||
block->is_zero_copy_ = true; | block->is_zero_copy_ = true; | ||||
} | } | ||||
block->Init(real_size, mem_type, n, out_index); | |||||
block->Init(real_size, mem_type, n, out_index, no_align_size); | |||||
block->stream_id_ = node_op_desc->GetStreamId(); | block->stream_id_ = node_op_desc->GetStreamId(); | ||||
block->ref_count_++; | block->ref_count_++; | ||||
block->continuous_block_ = continuous; | block->continuous_block_ = continuous; | ||||
@@ -577,11 +694,14 @@ MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, | |||||
if (output_op_desc != nullptr) { | if (output_op_desc != nullptr) { | ||||
GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS, GELOGI("Get size failed")); | GE_IF_BOOL_EXEC(ge::TensorUtils::GetSize(*output_op_desc, size) != SUCCESS, GELOGI("Get size failed")); | ||||
} | } | ||||
size_t no_align_size = 0; | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(GetNoAlignSize(*node_op_desc, index, no_align_size) != SUCCESS, return nullptr, | |||||
"Get no align size failed"); | |||||
if (IsSymbolExist(node_index_io)) { | if (IsSymbolExist(node_index_io)) { | ||||
std::string symbol = anchor_to_symbol_[node_index_io.ToString()]; | std::string symbol = anchor_to_symbol_[node_index_io.ToString()]; | ||||
block = symbol_blocks_[symbol]; | block = symbol_blocks_[symbol]; | ||||
block->AddNodeTypeIndex({n, kOutput, index}, size); | |||||
block->AddNodeTypeIndex({n, kOutput, index}, size, no_align_size); | |||||
block->ref_count_++; | block->ref_count_++; | ||||
} else { | } else { | ||||
int64_t max_size = size; | int64_t max_size = size; | ||||
@@ -594,7 +714,8 @@ MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, | |||||
} | } | ||||
auto block_size = GetBlockSize(max_size, ranges); | auto block_size = GetBlockSize(max_size, ranges); | ||||
vector<bool> workspace_reuse_flag; | vector<bool> workspace_reuse_flag; | ||||
block = ApplyMemory(block_size, size, kOutput, n, index, workspace_reuse_flag, is_op_reuse_mem, continuous); | |||||
block = ApplyMemory(block_size, size, no_align_size, kOutput, n, index, workspace_reuse_flag, is_op_reuse_mem, | |||||
continuous); | |||||
} | } | ||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(block == nullptr, return nullptr, "Block is nullptr."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(block == nullptr, return nullptr, "Block is nullptr."); | ||||
int out_count_reuse_input = block->ref_count_; | int out_count_reuse_input = block->ref_count_; | ||||
@@ -628,7 +749,7 @@ MemoryBlock *BlockMemAssigner::ApplyOutMemory(const NodePtr &n, uint32_t index, | |||||
GE_IF_BOOL_EXEC(ge::TensorUtils::GetReuseInputIndex(*owner_node_op_desc, dst_reuse_input_index) != SUCCESS, | GE_IF_BOOL_EXEC(ge::TensorUtils::GetReuseInputIndex(*owner_node_op_desc, dst_reuse_input_index) != SUCCESS, | ||||
GELOGI("Get dst_reuse_input_index failed")); | GELOGI("Get dst_reuse_input_index failed")); | ||||
if (dst_reuse_input && (dst_reuse_input_index == static_cast<uint32_t>(in_anchor->GetIdx()))) { | if (dst_reuse_input && (dst_reuse_input_index == static_cast<uint32_t>(in_anchor->GetIdx()))) { | ||||
block->AddNodeTypeIndex({owner_node, kOutput, i}, block->Size()); | |||||
block->AddNodeTypeIndex({owner_node, kOutput, i}, block->Size(), block->Size()); | |||||
out_count_reuse_input += 1; | out_count_reuse_input += 1; | ||||
reuse_input = true; | reuse_input = true; | ||||
} | } | ||||
@@ -710,6 +831,7 @@ void BlockMemAssigner::ReleaseMemory(MemoryBlock *to_release, vector<MemoryBlock | |||||
GE_CHK_TRUE_EXEC_INFO(!to_release->reuse_mem_, return, "doesn't reuse memory"); | GE_CHK_TRUE_EXEC_INFO(!to_release->reuse_mem_, return, "doesn't reuse memory"); | ||||
--to_release->ref_count_; | --to_release->ref_count_; | ||||
if (to_release->ref_count_ == 0) { | if (to_release->ref_count_ == 0) { | ||||
to_release->SetLifeTimeEnd(life_time_); | |||||
reusable_memory.emplace_back(to_release); | reusable_memory.emplace_back(to_release); | ||||
AddReusableBlockCount(*to_release, reusable_block_counts_); | AddReusableBlockCount(*to_release, reusable_block_counts_); | ||||
} | } | ||||
@@ -852,12 +974,11 @@ Status BlockMemAssigner::AssignOutputMemoryWithReuse(const NodePtr &node, vector | |||||
zero_memory_list_.emplace_back(node, kOutput, i); | zero_memory_list_.emplace_back(node, kOutput, i); | ||||
continue; | continue; | ||||
} | } | ||||
bool reuse_mem = is_op_reuse_mem_; | |||||
// atomic can't be reused | // atomic can't be reused | ||||
if (is_op_reuse_mem_ && out_node_set_continuous_input && is_atomic) { | if (is_op_reuse_mem_ && out_node_set_continuous_input && is_atomic) { | ||||
reuse_mem = false; | |||||
is_op_reuse_mem_ = false; | |||||
} | } | ||||
MemoryBlock *mem_block = ApplyOutMemory(node, i, ranges, reuse_mem, out_node_set_continuous_input); | |||||
MemoryBlock *mem_block = ApplyOutMemory(node, i, ranges, is_op_reuse_mem_, out_node_set_continuous_input); | |||||
if (mem_block != nullptr) { | if (mem_block != nullptr) { | ||||
node_out_blocks_[node->GetName()].emplace_back(mem_block); | node_out_blocks_[node->GetName()].emplace_back(mem_block); | ||||
if (out_node_set_continuous_input) { | if (out_node_set_continuous_input) { | ||||
@@ -894,6 +1015,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) { | |||||
for (NodePtr &n : compute_graph_->GetAllNodes()) { | for (NodePtr &n : compute_graph_->GetAllNodes()) { | ||||
auto node_op_desc = n->GetOpDesc(); | auto node_op_desc = n->GetOpDesc(); | ||||
GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | GE_IF_BOOL_EXEC(node_op_desc == nullptr, continue); | ||||
life_time_ = node_op_desc->GetId(); | |||||
int64_t stream_id = node_op_desc->GetStreamId(); | int64_t stream_id = node_op_desc->GetStreamId(); | ||||
if (AssignOutputMemoryWithReuse(n, ranges) != SUCCESS) { | if (AssignOutputMemoryWithReuse(n, ranges) != SUCCESS) { | ||||
return; | return; | ||||
@@ -930,9 +1052,9 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) { | |||||
zero_memory_list_.emplace_back(n, kWorkspace, static_cast<uint32_t>(i)); | zero_memory_list_.emplace_back(n, kWorkspace, static_cast<uint32_t>(i)); | ||||
continue; | continue; | ||||
} | } | ||||
MemoryBlock *mem_block = | |||||
ApplyMemory(GetBlockSize(static_cast<size_t>(temp[i]), ranges), static_cast<size_t>(temp[i]), kWorkspace, n, | |||||
static_cast<uint32_t>(i), workspace_reuse_flag, is_op_reuse_mem_, false); | |||||
MemoryBlock *mem_block = ApplyMemory(GetBlockSize(static_cast<size_t>(temp[i]), ranges), | |||||
static_cast<size_t>(temp[i]), static_cast<size_t>(temp[i]), kWorkspace, n, | |||||
static_cast<uint32_t>(i), workspace_reuse_flag, is_op_reuse_mem_, false); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mem_block == nullptr, continue, "failed to apply memory block."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(mem_block == nullptr, continue, "failed to apply memory block."); | ||||
CheckWorkspaceReuse(workspace_reuse_flag, i, stream_id, mem_block); | CheckWorkspaceReuse(workspace_reuse_flag, i, stream_id, mem_block); | ||||
} | } | ||||
@@ -1001,7 +1123,8 @@ void MergeBlocks(std::vector<MemoryBlock *> &dest, std::vector<MemoryBlock *> &s | |||||
dest[i]->AddSymbol(symbol); | dest[i]->AddSymbol(symbol); | ||||
} | } | ||||
for (size_t j = 0; j < src[i]->NodeTypeIndexList().size(); ++j) { | for (size_t j = 0; j < src[i]->NodeTypeIndexList().size(); ++j) { | ||||
dest[i]->AddNodeTypeIndex(src[i]->NodeTypeIndexList()[j], src[i]->RealSizeList()[j]); | |||||
dest[i]->AddNodeTypeIndex(src[i]->NodeTypeIndexList()[j], src[i]->RealSizeList()[j], | |||||
src[i]->NoAlignSizeList()[j]); | |||||
src[i]->deleted_block_ = true; | src[i]->deleted_block_ = true; | ||||
} | } | ||||
} | } | ||||
@@ -1115,6 +1238,21 @@ void BlockMemAssigner::AssignContinuousBlocks() { | |||||
} | } | ||||
} | } | ||||
void BlockMemAssigner::ReuseBlocksByLifeTime() { | |||||
for (size_t i = 0; i < memory_blocks_.size(); ++i) { | |||||
auto parent = memory_blocks_[i]; | |||||
if (parent == nullptr || parent->deleted_block_) { | |||||
continue; | |||||
} | |||||
if (parent->reuse_mem_ && !IsPostReuse(parent)) { | |||||
parent->reuse_mem_ = false; | |||||
} | |||||
for (size_t j = i + 1; j < memory_blocks_.size(); ++j) { | |||||
parent->AddLifeReuseBlock(memory_blocks_[j]); | |||||
} | |||||
} | |||||
} | |||||
/// | /// | ||||
/// @ingroup domi_omg | /// @ingroup domi_omg | ||||
/// @brief traverse memory size, resize, calculate offset | /// @brief traverse memory size, resize, calculate offset | ||||
@@ -1129,8 +1267,8 @@ void BlockMemAssigner::ResizeMemoryBlocks() { | |||||
memory_block->SetHeadOffset(mem_offset_); | memory_block->SetHeadOffset(mem_offset_); | ||||
mem_offset_ += memory_block->Size(); | mem_offset_ += memory_block->Size(); | ||||
memory_block->SetTailOffset(mem_offset_ - 1); | memory_block->SetTailOffset(mem_offset_ - 1); | ||||
GELOGI("mem_offset_ exclude zero_copy_memory is %zu.", mem_offset_); | |||||
} | } | ||||
GELOGI("mem_offset_ exclude zero_copy_memory is %zu.", mem_offset_); | |||||
} | } | ||||
/// | /// | ||||
@@ -1142,15 +1280,18 @@ void BlockMemAssigner::ResizeMemoryBlocks() { | |||||
/// @param [in] real_size memory size in need | /// @param [in] real_size memory size in need | ||||
/// @return Status result | /// @return Status result | ||||
/// | /// | ||||
void SetOffsetSize(const NodeTypeIndex &node_type_index, int64_t offset, size_t size, size_t real_size) { | |||||
ge::OpDescPtr op_desc = node_type_index.node->GetOpDesc(); | |||||
void SetOffsetSize(const NodeTypeIndex &node_type, const MemoryBlock *block, size_t real_size, size_t no_align_size, | |||||
bool child_block) { | |||||
ge::OpDescPtr op_desc = node_type.node->GetOpDesc(); | |||||
GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc == nullptr, return, "op_desc is null."); | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG(op_desc == nullptr, return, "op_desc is null."); | ||||
string graph_name = node_type_index.node->GetOwnerComputeGraph()->GetName(); | |||||
string graph_name = node_type.node->GetOwnerComputeGraph()->GetName(); | |||||
vector<int64_t> memorys_type; | vector<int64_t> memorys_type; | ||||
int64_t offset = block->HeadOffset(); | |||||
size_t end = node_type.life_time_end; | |||||
bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, memorys_type); | bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_OUTPUT_MEM_TYPE_LIST, memorys_type); | ||||
if (node_type_index.mem_type == kOutput) { | |||||
if (node_type.mem_type == kOutput) { | |||||
vector<int64_t> output_list = op_desc->GetOutputOffset(); | vector<int64_t> output_list = op_desc->GetOutputOffset(); | ||||
for (auto i = static_cast<uint32_t>(output_list.size()); i < node_type_index.index + 1; i++) { | |||||
for (auto i = static_cast<uint32_t>(output_list.size()); i < node_type.index + 1; i++) { | |||||
output_list.emplace_back(kInvalidOffset); | output_list.emplace_back(kInvalidOffset); | ||||
} | } | ||||
if (output_list.empty()) { | if (output_list.empty()) { | ||||
@@ -1160,39 +1301,56 @@ void SetOffsetSize(const NodeTypeIndex &node_type_index, int64_t offset, size_t | |||||
if ((op_desc->GetType() == DATA) || (op_desc->GetType() == AIPP_DATA_TYPE) || (op_desc->GetType() == MULTISHAPE) || | if ((op_desc->GetType() == DATA) || (op_desc->GetType() == AIPP_DATA_TYPE) || (op_desc->GetType() == MULTISHAPE) || | ||||
(op_desc->GetType() == NETOUTPUT)) { | (op_desc->GetType() == NETOUTPUT)) { | ||||
if ((output_list[node_type_index.index] == kInvalidOffset) || (output_list[node_type_index.index] < offset)) { | |||||
output_list.at(node_type_index.index) = offset; | |||||
if ((output_list[node_type.index] == kInvalidOffset) || (output_list[node_type.index] < offset)) { | |||||
output_list.at(node_type.index) = offset; | |||||
} | } | ||||
} else { | } else { | ||||
// fusion: keep the original other type offset value from op_desc | // fusion: keep the original other type offset value from op_desc | ||||
bool set_out_offset = (!has_mem_type_attr) || (memorys_type[node_type_index.index] != RT_MEMORY_L1); | |||||
bool set_out_offset = (!has_mem_type_attr) || | |||||
(memorys_type.size() > node_type.index && memorys_type[node_type.index] != RT_MEMORY_L1); | |||||
if (set_out_offset) { | if (set_out_offset) { | ||||
output_list.at(node_type_index.index) = offset; | |||||
output_list.at(node_type.index) = offset; | |||||
} | } | ||||
} | } | ||||
op_desc->SetOutputOffset(output_list); | op_desc->SetOutputOffset(output_list); | ||||
GELOGI("[IMAS]Set %s name[%s] output[%d] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu].", | |||||
graph_name.c_str(), op_desc->GetName().c_str(), node_type_index.index, offset, op_desc->GetStreamId(), size, | |||||
real_size); | |||||
} else if (node_type_index.mem_type == kWorkspace) { | |||||
} else if (node_type.mem_type == kWorkspace) { | |||||
vector<int64_t> workspace_list; | vector<int64_t> workspace_list; | ||||
workspace_list = op_desc->GetWorkspace(); | workspace_list = op_desc->GetWorkspace(); | ||||
for (auto i = static_cast<uint32_t>(workspace_list.size()); i < node_type_index.index + 1; i++) { | |||||
for (auto i = static_cast<uint32_t>(workspace_list.size()); i < node_type.index + 1; i++) { | |||||
workspace_list.emplace_back(kInvalidOffset); | workspace_list.emplace_back(kInvalidOffset); | ||||
} | } | ||||
vector<int64_t> workspace_memory_type; | |||||
bool has_workspace_mem_type_attr = | |||||
ge::AttrUtils::GetListInt(op_desc, TVM_ATTR_NAME_WORKSPACE_TYPE, workspace_memory_type); | |||||
vector<int64_t> workspace_mem_type; | |||||
bool has_workspace_mem_type = ge::AttrUtils::GetListInt(op_desc, TVM_ATTR_NAME_WORKSPACE_TYPE, workspace_mem_type); | |||||
// fusion: keep the original other type offset value from op_desc | // fusion: keep the original other type offset value from op_desc | ||||
bool set_workspace_offset = | |||||
(!has_workspace_mem_type_attr) || (workspace_memory_type[node_type_index.index] != RT_MEMORY_L1); | |||||
bool set_workspace_offset = (!has_workspace_mem_type) || (workspace_mem_type.size() > node_type.index && | |||||
workspace_mem_type[node_type.index] != RT_MEMORY_L1); | |||||
if (set_workspace_offset) { | if (set_workspace_offset) { | ||||
workspace_list.at(node_type_index.index) = offset; | |||||
workspace_list.at(node_type.index) = offset; | |||||
} | } | ||||
op_desc->SetWorkspace(workspace_list); | op_desc->SetWorkspace(workspace_list); | ||||
GELOGI("[IMAS]Set %s name[%s] workspace[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu].", | |||||
graph_name.c_str(), op_desc->GetName().c_str(), node_type_index.index, offset, op_desc->GetStreamId(), size, | |||||
real_size); | |||||
} | |||||
GELOGI( | |||||
"[IMAS]Set %s name[%s] %s[%u] offset to [%ld] streamid[%ld] size[%zu] realsize[%zu]" | |||||
" noalignsize[%zu] life time begin[%zu] life time end[%zu] child[%d].", | |||||
graph_name.c_str(), op_desc->GetName().c_str(), node_type.GetMemType().c_str(), node_type.index, offset, | |||||
op_desc->GetStreamId(), block->Size(), real_size, no_align_size, op_desc->GetId(), end, child_block); | |||||
} | |||||
void SetBlockOpMemOffset(MemoryBlock *block, bool child_block) { | |||||
if (block == nullptr) { | |||||
return; | |||||
} | |||||
size_t index = 0; | |||||
size_t real_size = 0; | |||||
size_t no_align_size = 0; | |||||
auto real_size_list_size = block->RealSizeList().size(); | |||||
for (const NodeTypeIndex &node_type_index : block->NodeTypeIndexList()) { | |||||
if (index < real_size_list_size) { | |||||
real_size = block->RealSizeList()[index]; | |||||
no_align_size = block->NoAlignSizeList()[index]; | |||||
} | |||||
SetOffsetSize(node_type_index, block, real_size, no_align_size, child_block); | |||||
index++; | |||||
} | } | ||||
} | } | ||||
@@ -1206,21 +1364,16 @@ void BlockMemAssigner::SetOpMemOffset(bool is_zero_copy) { | |||||
continue; | continue; | ||||
} | } | ||||
size_t index = 0; | |||||
size_t real_size = 0; | |||||
auto real_size_list_size = memory_block->RealSizeList().size(); | |||||
for (const NodeTypeIndex &node_type_index : memory_block->NodeTypeIndexList()) { | |||||
if (index < real_size_list_size) { | |||||
real_size = memory_block->RealSizeList()[index]; | |||||
} | |||||
SetOffsetSize(node_type_index, memory_block->HeadOffset(), memory_block->Size(), real_size); | |||||
index++; | |||||
SetBlockOpMemOffset(memory_block, false); | |||||
for (MemoryBlock *child_block : memory_block->ChildBlockList()) { | |||||
SetBlockOpMemOffset(child_block, true); | |||||
} | } | ||||
} | } | ||||
if (!is_zero_copy) { | if (!is_zero_copy) { | ||||
for (const NodeTypeIndex &node_type_index : zero_memory_list_) { | for (const NodeTypeIndex &node_type_index : zero_memory_list_) { | ||||
SetOffsetSize(node_type_index, 0, 0, 0); | |||||
MemoryBlock block(0, 0); | |||||
SetOffsetSize(node_type_index, &block, 0, 0, false); | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -1290,7 +1443,7 @@ void BlockMemAssigner::FindHeadAndTailNodesForStream(map<int64_t, pair<NodePtr, | |||||
for (size_t i = 0; i < n->GetOpDesc()->GetOutputsSize(); i++) { | for (size_t i = 0; i < n->GetOpDesc()->GetOutputsSize(); i++) { | ||||
int64_t size = 0; | int64_t size = 0; | ||||
if (ge::TensorUtils::GetSize(*n->GetOpDesc()->GetOutputDescPtr(static_cast<uint32_t>(i)), size) != SUCCESS) { | if (ge::TensorUtils::GetSize(*n->GetOpDesc()->GetOutputDescPtr(static_cast<uint32_t>(i)), size) != SUCCESS) { | ||||
GELOGW("Get output size failed!"); | |||||
GELOGW("Get output size failed!"); | |||||
continue; | continue; | ||||
} | } | ||||
stream_mem_map[stream_id] += size; | stream_mem_map[stream_id] += size; | ||||
@@ -1375,6 +1528,6 @@ void BlockMemAssigner::FindDependentStreamBetweenGraphs(const NodePtr &pre_node, | |||||
bool BlockMemAssigner::CheckIsZeroMemNodeType(const string &node_type) const { | bool BlockMemAssigner::CheckIsZeroMemNodeType(const string &node_type) const { | ||||
return (node_type == VARIABLE) || (node_type == CONSTANT) || (node_type == MULTISHAPE) || | return (node_type == VARIABLE) || (node_type == CONSTANT) || (node_type == MULTISHAPE) || | ||||
(node_type == HCOMBROADCAST) || (node_type == HCOMALLREDUCE) || (node_type == CONSTANTOP) || | (node_type == HCOMBROADCAST) || (node_type == HCOMALLREDUCE) || (node_type == CONSTANTOP) || | ||||
(node_type == ASSIGNADD) || (node_type == ASSIGNSUB) || (node_type == ASSIGN); | |||||
(node_type == ASSIGNADD) || (node_type == ASSIGNSUB) || (node_type == ASSIGN) || (node_type == HVDWAIT); | |||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -31,6 +31,8 @@ | |||||
#include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
namespace ge { | namespace ge { | ||||
const size_t kMaxLifeTime = 0xffffffff; | |||||
enum MemoryType { kOutput, kWorkspace }; | enum MemoryType { kOutput, kWorkspace }; | ||||
struct NodeTypeIndex { | struct NodeTypeIndex { | ||||
@@ -40,6 +42,15 @@ struct NodeTypeIndex { | |||||
ge::NodePtr node = nullptr; | ge::NodePtr node = nullptr; | ||||
MemoryType mem_type = kOutput; | MemoryType mem_type = kOutput; | ||||
uint32_t index = 0; | uint32_t index = 0; | ||||
size_t life_time_end = kMaxLifeTime; | |||||
const string GetMemType() const { | |||||
if (mem_type == kOutput) { | |||||
return "output"; | |||||
} else if (mem_type == kWorkspace) { | |||||
return "workspace"; | |||||
} | |||||
return "unknown"; | |||||
} | |||||
}; | }; | ||||
class MemoryBlock { | class MemoryBlock { | ||||
@@ -55,7 +66,8 @@ class MemoryBlock { | |||||
is_zero_copy_(false), | is_zero_copy_(false), | ||||
block_size_(block_size), | block_size_(block_size), | ||||
head_offset_(0), | head_offset_(0), | ||||
tail_offset_(0) {} | |||||
tail_offset_(0), | |||||
child_offset_(0) {} | |||||
MemoryBlock(const MemoryBlock &) = delete; | MemoryBlock(const MemoryBlock &) = delete; | ||||
@@ -66,23 +78,25 @@ class MemoryBlock { | |||||
symbol_list_.clear(); | symbol_list_.clear(); | ||||
} | } | ||||
void Init(size_t real_size, MemoryType type, const ge::NodePtr &node, uint32_t out_index) { | |||||
void Init(size_t real_size, MemoryType type, const ge::NodePtr &node, uint32_t out_index, size_t no_align_size) { | |||||
real_size_list_.emplace_back(real_size); | real_size_list_.emplace_back(real_size); | ||||
no_align_size_list_.emplace_back(no_align_size); | |||||
node_type_index_list_.emplace_back(node, type, out_index); | node_type_index_list_.emplace_back(node, type, out_index); | ||||
} | } | ||||
size_t Size() const { return block_size_; } | size_t Size() const { return block_size_; } | ||||
void SetHeadOffset(size_t offset) { head_offset_ = offset; } | |||||
void SetHeadOffset(size_t offset); | |||||
void SetTailOffset(size_t offset) { tail_offset_ = offset; } | |||||
void SetTailOffset(size_t offset); | |||||
size_t HeadOffset() const { return head_offset_; } | size_t HeadOffset() const { return head_offset_; } | ||||
size_t TailOffset() const { return tail_offset_; } | size_t TailOffset() const { return tail_offset_; } | ||||
void AddNodeTypeIndex(const NodeTypeIndex &node_type_index, size_t real_size) { | |||||
void AddNodeTypeIndex(const NodeTypeIndex &node_type_index, size_t real_size, size_t no_align_size) { | |||||
node_type_index_list_.emplace_back(node_type_index); | node_type_index_list_.emplace_back(node_type_index); | ||||
real_size_list_.emplace_back(real_size); | real_size_list_.emplace_back(real_size); | ||||
no_align_size_list_.emplace_back(no_align_size); | |||||
} | } | ||||
void AddSymbol(const std::string &symbol) { symbol_list_.emplace_back(symbol); } | void AddSymbol(const std::string &symbol) { symbol_list_.emplace_back(symbol); } | ||||
@@ -90,6 +104,8 @@ class MemoryBlock { | |||||
const std::vector<NodeTypeIndex> &NodeTypeIndexList() const { return node_type_index_list_; } | const std::vector<NodeTypeIndex> &NodeTypeIndexList() const { return node_type_index_list_; } | ||||
const std::vector<std::string> &SymbolList() const { return symbol_list_; } | const std::vector<std::string> &SymbolList() const { return symbol_list_; } | ||||
const std::vector<size_t> &RealSizeList() const { return real_size_list_; } | const std::vector<size_t> &RealSizeList() const { return real_size_list_; } | ||||
const std::vector<MemoryBlock *> &ChildBlockList() const { return child_blocks_; } | |||||
const std::vector<size_t> &NoAlignSizeList() const { return no_align_size_list_; } | |||||
void Resize(); | void Resize(); | ||||
@@ -97,6 +113,14 @@ class MemoryBlock { | |||||
bool IsSameLabel(std::string &first_batch_label); | bool IsSameLabel(std::string &first_batch_label); | ||||
void AddLifeReuseBlock(MemoryBlock *block); | |||||
void SetLifeTimeEnd(size_t time); | |||||
size_t GetLifeBegin(); | |||||
size_t GetLifeEnd(); | |||||
int ref_count_; | int ref_count_; | ||||
int64_t stream_id_; | int64_t stream_id_; | ||||
bool deleted_block_; | bool deleted_block_; | ||||
@@ -109,10 +133,13 @@ class MemoryBlock { | |||||
private: | private: | ||||
size_t block_size_; | size_t block_size_; | ||||
std::vector<size_t> real_size_list_; | std::vector<size_t> real_size_list_; | ||||
std::vector<size_t> no_align_size_list_; | |||||
size_t head_offset_; | size_t head_offset_; | ||||
size_t tail_offset_; | size_t tail_offset_; | ||||
size_t child_offset_; | |||||
std::vector<NodeTypeIndex> node_type_index_list_; | std::vector<NodeTypeIndex> node_type_index_list_; | ||||
std::vector<std::string> symbol_list_; | std::vector<std::string> symbol_list_; | ||||
std::vector<MemoryBlock *> child_blocks_; | |||||
}; | }; | ||||
class BlockMemAssigner : public MemAssigner { | class BlockMemAssigner : public MemAssigner { | ||||
@@ -292,8 +319,8 @@ class BlockMemAssigner : public MemAssigner { | |||||
/// @return MemoryBlock* | /// @return MemoryBlock* | ||||
/// @author | /// @author | ||||
/// | /// | ||||
MemoryBlock *ApplyMemory(size_t block_size, size_t real_size, MemoryType mem_type, const ge::NodePtr &n, | |||||
uint32_t out_index, const std::vector<bool> &workspace_reuse_flag, | |||||
MemoryBlock *ApplyMemory(size_t block_size, size_t real_size, size_t no_align_size, MemoryType mem_type, | |||||
const ge::NodePtr &n, uint32_t out_index, const std::vector<bool> &workspace_reuse_flag, | |||||
const bool is_op_reuse_mem, const bool continuous); | const bool is_op_reuse_mem, const bool continuous); | ||||
/// | /// | ||||
@@ -354,6 +381,17 @@ class BlockMemAssigner : public MemAssigner { | |||||
bool IsOutNodeSetContinuousInput(const NodePtr &n, uint32_t out_index, std::string &peer_name, | bool IsOutNodeSetContinuousInput(const NodePtr &n, uint32_t out_index, std::string &peer_name, | ||||
uint32_t &peer_input_index); | uint32_t &peer_input_index); | ||||
/// | |||||
/// @ingroup GE | |||||
/// @|+++++++++block1++++++++| |+++++++++block1++++++++| | |||||
/// @|+++++++++block1++++++++||++block2++| |+++++++++block1++++++++||++block2++| | |||||
/// @ |++block2++||++block3++| ==> |++block3++| |++block2++| | |||||
/// @ |++block3++| |++block3++| | |||||
/// @return void | |||||
/// @author | |||||
/// | |||||
void ReuseBlocksByLifeTime(); | |||||
std::vector<MemoryBlock *> reusable_blocks_; | std::vector<MemoryBlock *> reusable_blocks_; | ||||
std::map<std::string, uint64_t> reusable_block_counts_; | std::map<std::string, uint64_t> reusable_block_counts_; | ||||
@@ -380,6 +418,8 @@ class BlockMemAssigner : public MemAssigner { | |||||
bool is_op_reuse_mem_ = true; | bool is_op_reuse_mem_ = true; | ||||
size_t life_time_; | |||||
int64_t atomic_addr_clean_id_ = 0; | int64_t atomic_addr_clean_id_ = 0; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -245,13 +245,14 @@ Status GraphMemoryAssigner::AssignZeroCopyMemory(size_t &mem_offset, size_t &zer | |||||
memory_block->SetHeadOffset(mem_offset); | memory_block->SetHeadOffset(mem_offset); | ||||
mem_offset += memory_block->Size(); | mem_offset += memory_block->Size(); | ||||
memory_block->SetTailOffset(mem_offset - 1); | memory_block->SetTailOffset(mem_offset - 1); | ||||
GELOGI("mem_offset_ include zero_copy_memory is %zu.", mem_offset); | |||||
} | } | ||||
GELOGI("mem_offset_ include zero_copy_memory is %zu.", mem_offset); | |||||
// set offset for zero copy nodes | // set offset for zero copy nodes | ||||
priority_assigner->SetOpMemOffset(true); | priority_assigner->SetOpMemOffset(true); | ||||
zero_mem_copy_size = mem_offset - mem_offset_tmp; | zero_mem_copy_size = mem_offset - mem_offset_tmp; | ||||
memory_offset_[0].mem_offset_ = mem_offset; | |||||
GELOGI("max_mem_offset:%zu, mem_offset:%zu, zero_mem_copy_size:%zu.", mem_offset, mem_offset_tmp, zero_mem_copy_size); | GELOGI("max_mem_offset:%zu, mem_offset:%zu, zero_mem_copy_size:%zu.", mem_offset, mem_offset_tmp, zero_mem_copy_size); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -360,8 +361,11 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, | |||||
return PARAM_INVALID;); | return PARAM_INVALID;); | ||||
vector<int64_t> output_list = peer_op_desc->GetOutputOffset(); | vector<int64_t> output_list = peer_op_desc->GetOutputOffset(); | ||||
std::vector<int64_t> offsets_for_fusion = {}; | |||||
bool has_offset_attr = | |||||
AttrUtils::GetListInt(peer_op_desc, ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION, offsets_for_fusion); | |||||
if (peer_out_data_anchor->GetIdx() < static_cast<int>(output_list.size())) { | if (peer_out_data_anchor->GetIdx() < static_cast<int>(output_list.size())) { | ||||
if (continuous_input_alloc) { | |||||
if (continuous_input_alloc && !has_offset_attr) { | |||||
if (in_data_anchor->GetIdx() == 0) { | if (in_data_anchor->GetIdx() == 0) { | ||||
continuous_mem_start = output_list.at(peer_out_data_anchor->GetIdx()); | continuous_mem_start = output_list.at(peer_out_data_anchor->GetIdx()); | ||||
} | } | ||||
@@ -391,9 +395,7 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, | |||||
} | } | ||||
peer_op_desc->SetOutputOffset(output_list); | peer_op_desc->SetOutputOffset(output_list); | ||||
size_t pre_mem_offset = memory_offset_[0].mem_offset_; | size_t pre_mem_offset = memory_offset_[0].mem_offset_; | ||||
std::vector<int64_t> offsets_for_fusion = {}; | |||||
bool has_offset_attr = | |||||
AttrUtils::GetListInt(peer_op_desc, ATTR_NAME_OUTPUT_OFFSET_FOR_BUFFER_FUSION, offsets_for_fusion); | |||||
int64_t tensor_desc_size = 0; | int64_t tensor_desc_size = 0; | ||||
if (has_offset_attr) { | if (has_offset_attr) { | ||||
if (peer_out_data_anchor->GetIdx() < static_cast<int>(offsets_for_fusion.size())) { | if (peer_out_data_anchor->GetIdx() < static_cast<int>(offsets_for_fusion.size())) { | ||||
@@ -1232,7 +1234,7 @@ ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node, vector< | |||||
ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node) const { | ge::Status GraphMemoryAssigner::UpdateOpInputOffset(const NodePtr &node) const { | ||||
GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
vector<int64_t> input_list; | vector<int64_t> input_list; | ||||
if (node->GetType() == HCOMBROADCAST) { | |||||
if (node->GetType() == HCOMBROADCAST || node->GetType() == HVDCALLBACKBROADCAST) { | |||||
for (const auto &anchor : node->GetAllInDataAnchors()) { | for (const auto &anchor : node->GetAllInDataAnchors()) { | ||||
vector<int64_t> output_list; | vector<int64_t> output_list; | ||||
auto peer_out_anchor = anchor->GetPeerOutAnchor(); | auto peer_out_anchor = anchor->GetPeerOutAnchor(); | ||||
@@ -208,7 +208,7 @@ Status VarMemAssignUtil::DealVariableNode(uint32_t graph_id, const ge::NodePtr & | |||||
for (const ge::OutDataAnchorPtr &var_out_data_anchor : node->GetAllOutDataAnchors()) { | for (const ge::OutDataAnchorPtr &var_out_data_anchor : node->GetAllOutDataAnchors()) { | ||||
for (const ge::InDataAnchorPtr &dst_in_data_anchor : var_out_data_anchor->GetPeerInDataAnchors()) { | for (const ge::InDataAnchorPtr &dst_in_data_anchor : var_out_data_anchor->GetPeerInDataAnchors()) { | ||||
ge::NodePtr dst_node = dst_in_data_anchor->GetOwnerNode(); | ge::NodePtr dst_node = dst_in_data_anchor->GetOwnerNode(); | ||||
if (dst_node->GetType() == HCOMBROADCAST) { | |||||
if (dst_node->GetType() == HCOMBROADCAST || dst_node->GetType() == HVDCALLBACKBROADCAST) { | |||||
GE_CHK_STATUS_RET(DealBroadCastNode(graph_id, dst_node, dst_in_data_anchor, node, session_id)); | GE_CHK_STATUS_RET(DealBroadCastNode(graph_id, dst_node, dst_in_data_anchor, node, session_id)); | ||||
continue; | continue; | ||||
} | } | ||||
@@ -412,7 +412,9 @@ Status ModelBuilder::BuildModelDef(ge::Model &model) { | |||||
GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_ZERO_COPY_MEMORY_SIZE, zero_copy_mem_size_), | GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_ZERO_COPY_MEMORY_SIZE, zero_copy_mem_size_), | ||||
GELOGE(FAILED, "SetInt of ATTR_MODEL_ZERO_COPY_MEMORY_SIZE failed."); | GELOGE(FAILED, "SetInt of ATTR_MODEL_ZERO_COPY_MEMORY_SIZE failed."); | ||||
return FAILED); | return FAILED); | ||||
GE_CHK_BOOL_EXEC(ge::AttrUtils::SetListStr(&model, ATTR_MODEL_OUT_NODES_NAME, domi::GetContext().net_out_nodes), | |||||
GELOGE(FAILED, "SetListStr of ATTR_MODEL_OUT_NODES_NAME failed."); | |||||
return FAILED); | |||||
GELOGI("For model, max_mem_offset_: %zu, zero_copy_mem_size_: %zu", max_mem_offset_, zero_copy_mem_size_); | GELOGI("For model, max_mem_offset_: %zu, zero_copy_mem_size_: %zu", max_mem_offset_, zero_copy_mem_size_); | ||||
string ge_core_type; | string ge_core_type; | ||||
@@ -651,6 +653,14 @@ Status ModelBuilder::BuildModelForGetTask(ge::Model &model) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status ModelBuilder::BuildModelForGetDynShapeTask(ge::Model &model_def) { | |||||
GE_TIMESTAMP_START(BuildModelDef); | |||||
GE_CHK_STATUS_RET(BuildModelDef(model_def), "BuildModelDef failed!"); | |||||
GE_TIMESTAMP_END(BuildModelDef, "GraphBuilder::BuildModelDef"); | |||||
SetModelVersion(model_def); | |||||
return SUCCESS; | |||||
} | |||||
ge::Buffer ModelBuilder::GetWeightBuffer() const { return weight_buffer_; } | ge::Buffer ModelBuilder::GetWeightBuffer() const { return weight_buffer_; } | ||||
Status ModelBuilder::CompileSingleOp() { | Status ModelBuilder::CompileSingleOp() { | ||||
GELOGD("Begin to compile single op."); | GELOGD("Begin to compile single op."); | ||||
@@ -50,6 +50,7 @@ class ModelBuilder { | |||||
Status SaveDataToModel(ge::Model &model, ge::GeModel &ge_model); | Status SaveDataToModel(ge::Model &model, ge::GeModel &ge_model); | ||||
Status PreBuildModel(); | Status PreBuildModel(); | ||||
Status BuildModelForGetTask(ge::Model &model_def); | Status BuildModelForGetTask(ge::Model &model_def); | ||||
ge::Status BuildModelForGetDynShapeTask(ge::Model &model_def); | |||||
ge::Buffer GetWeightBuffer() const; | ge::Buffer GetWeightBuffer() const; | ||||
@@ -40,41 +40,47 @@ class StreamAllocator { | |||||
const vector<int64_t> &GetHugeStreams() const { return huge_streams_; } | const vector<int64_t> &GetHugeStreams() const { return huge_streams_; } | ||||
private: | private: | ||||
Status SplitStreams(std::vector<std::set<int64_t>> &split_streams); | |||||
Status AssignSingleStream(); | Status AssignSingleStream(); | ||||
Status SetActiveStreamsByLabel(); | Status SetActiveStreamsByLabel(); | ||||
Status UpdateActiveStreams(const std::vector<std::set<int64_t>> &splited_streams); | |||||
void UpdateLabelStreams(const std::vector<std::set<int64_t>> &split_streams); | |||||
Status SetActiveStreamsForSubgraph(); | |||||
Status SetActiveStreamsForLoop(); | |||||
Status CheckStreamActived() const; | |||||
Status GetMaxStreamAndTask(bool huge_stream, uint32_t &max_stream_count, uint32_t &max_task_count); | |||||
int64_t GetMaxNodeNumPerStream(const NodePtr &node, uint32_t max_node_num_one_stream); | |||||
Status SetActiveStreamsForSubgraphs(); | |||||
Status InsertSyncEvents(); | Status InsertSyncEvents(); | ||||
Status InsertOneEventInTwoNodes(const NodePtr &cur_node_ptr, const NodePtr &next_node_ptr); | Status InsertOneEventInTwoNodes(const NodePtr &cur_node_ptr, const NodePtr &next_node_ptr); | ||||
Status InsertEventsForSubgraph(); | |||||
Status OptimizeSyncEvents(); | Status OptimizeSyncEvents(); | ||||
Status OptimizeBySendEvents(const std::map<int64_t, std::vector<NodePtr>> &stream_nodes); | Status OptimizeBySendEvents(const std::map<int64_t, std::vector<NodePtr>> &stream_nodes); | ||||
Status OptimizeByRecvEvents(const std::map<int64_t, std::vector<NodePtr>> &stream_nodes); | Status OptimizeByRecvEvents(const std::map<int64_t, std::vector<NodePtr>> &stream_nodes); | ||||
Status OptimizeByStreamActivate(); | Status OptimizeByStreamActivate(); | ||||
// Determine if the successor node of RecvNode is directly or indirectly activated by the SendNode precursor node | |||||
bool IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr, const NodePtr &recv_node_ptr) const; | |||||
Status RefreshContinuousEvents(); | |||||
Status InsertSyncEventNodes(); | |||||
Status ReorderEventNodes() const; | |||||
Status SplitStreams(std::vector<std::set<int64_t>> &split_streams); | |||||
bool NeedSpiltNewStream(int64_t stream_node_num, int64_t max_node_num_one_stream, const OpDescPtr &op_desc) const; | |||||
Status UpdateActiveStreams(const std::vector<std::set<int64_t>> &splited_streams); | |||||
void UpdateLabelStreams(const std::vector<std::set<int64_t>> &split_streams); | |||||
Status InsertActiveNodesAfterSwitch(NodePtr &switch_node); | Status InsertActiveNodesAfterSwitch(NodePtr &switch_node); | ||||
Status InsertActiveNodesAfterSwitch(NodePtr &switch_nodes, std::vector<NodePtr> &switch_active_nodes); | Status InsertActiveNodesAfterSwitch(NodePtr &switch_nodes, std::vector<NodePtr> &switch_active_nodes); | ||||
Status SetActiveStreamList(NodePtr &active_node, const std::string &active_label); | |||||
Status AddActiveNodes(NodePtr &switch_node, const std::vector<std::string> &ori_active_label_list, | |||||
std::vector<std::string> &active_label_list, std::vector<NodePtr> &added_active_nodes); | |||||
Status UpdateActiveStreamsForSubgraphs() const; | |||||
Status SetActiveStreamsForLoop(); | |||||
Status CheckStreamActived() const; | |||||
Status AddActiveEntryStream(); | Status AddActiveEntryStream(); | ||||
Status CollectDeactiveStream(const OpDescPtr &op_desc, std::set<uint32_t> &deactive_streams) const; | Status CollectDeactiveStream(const OpDescPtr &op_desc, std::set<uint32_t> &deactive_streams) const; | ||||
Status InsertActiveEntryStream(const std::vector<uint32_t> &active_streams, int64_t stream_id); | Status InsertActiveEntryStream(const std::vector<uint32_t> &active_streams, int64_t stream_id); | ||||
Status AddEventId(const NodePtr &pre_node, const NodePtr ¬_cur, const NodePtr &cur_node, bool not_use_cur); | |||||
Status RefreshContinuousEvents(); | |||||
Status InsertSyncEventNodes(); | |||||
Status ReorderEventNodes() const; | |||||
void DumpEvents(); | |||||
Status GetMaxStreamAndTask(bool huge_stream, uint32_t &max_stream_count, uint32_t &max_task_count); | |||||
int64_t GetMaxNodeNumPerStream(const NodePtr &node, uint32_t max_node_num_one_stream); | |||||
void AddSendEventId(const NodePtr &node, uint32_t event_id); | void AddSendEventId(const NodePtr &node, uint32_t event_id); | ||||
void AddRecvEventId(const NodePtr &node, uint32_t event_id); | void AddRecvEventId(const NodePtr &node, uint32_t event_id); | ||||
void RmvSendEventId(const NodePtr &node, uint32_t event_id); | void RmvSendEventId(const NodePtr &node, uint32_t event_id); | ||||
@@ -83,10 +89,11 @@ class StreamAllocator { | |||||
void GetRecvEventIdList(const NodePtr &node, std::vector<uint32_t> &recv_list) const; | void GetRecvEventIdList(const NodePtr &node, std::vector<uint32_t> &recv_list) const; | ||||
NodePtr GetNodeFromSendEventId(uint32_t send_event_id) const; | NodePtr GetNodeFromSendEventId(uint32_t send_event_id) const; | ||||
NodePtr GetNodeFromRecvEventId(uint32_t recv_event_id) const; | NodePtr GetNodeFromRecvEventId(uint32_t recv_event_id) const; | ||||
Status AddEventId(const NodePtr &pre_node, const NodePtr ¬_cur, const NodePtr &cur_node, bool not_use_cur); | |||||
void DumpEvents(); | |||||
// Determine if the successor node of RecvNode is directly or indirectly activated by the SendNode precursor node | |||||
bool IsRecvNodeActivatedBySendNode(const NodePtr &send_node_ptr, const NodePtr &recv_node_ptr) const; | |||||
Status AddActiveNodes(NodePtr &switch_node, const std::vector<std::string> &ori_active_label_list, | |||||
std::vector<std::string> &active_label_list, std::vector<NodePtr> &added_active_nodes); | |||||
Status SetActiveStreamList(NodePtr &active_node, const std::string &active_label); | |||||
ComputeGraphPtr whole_graph_; | ComputeGraphPtr whole_graph_; | ||||
const Graph2SubGraphInfoList &subgraphs_; | const Graph2SubGraphInfoList &subgraphs_; | ||||
@@ -102,6 +109,9 @@ class StreamAllocator { | |||||
std::set<int64_t> specific_activated_streams_; | std::set<int64_t> specific_activated_streams_; | ||||
std::map<int64_t, std::set<NodePtr>> specific_activated_streams_nodes_map_; | std::map<int64_t, std::set<NodePtr>> specific_activated_streams_nodes_map_; | ||||
std::map<NodePtr, int64_t> node_split_stream_map_; | |||||
std::map<ComputeGraphPtr, NodePtr> subgraph_first_active_node_map_; | |||||
// send events corresponding to the node | // send events corresponding to the node | ||||
std::map<NodePtr, std::vector<uint32_t>> node_to_send_events_; | std::map<NodePtr, std::vector<uint32_t>> node_to_send_events_; | ||||
@@ -109,4 +119,4 @@ class StreamAllocator { | |||||
std::map<NodePtr, std::vector<uint32_t>> node_to_recv_events_; | std::map<NodePtr, std::vector<uint32_t>> node_to_recv_events_; | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_BUILD_STREAM_ALLOCATOR_H_ | |||||
#endif // GE_GRAPH_BUILD_STREAM_ALLOCATOR_H_ |
@@ -30,7 +30,7 @@ namespace ge { | |||||
StreamGraphOptimizer::~StreamGraphOptimizer() {} | StreamGraphOptimizer::~StreamGraphOptimizer() {} | ||||
void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map) { | void StreamGraphOptimizer::RefreshNodeId(const ComputeGraphPtr &comp_graph, Graph2SubGraphInfoList &subgraph_map) { | ||||
size_t node_size = comp_graph->GetDirectNodesSize(); | |||||
size_t node_size = comp_graph->GetAllNodesSize(); | |||||
GELOGI("Refresh placeholder and end nodeId start from node num: %zu", node_size); | GELOGI("Refresh placeholder and end nodeId start from node num: %zu", node_size); | ||||
for (const auto &subgraph_pair : subgraph_map) { | for (const auto &subgraph_pair : subgraph_map) { | ||||
for (const auto &subgraph_info : subgraph_pair.second) { | for (const auto &subgraph_info : subgraph_pair.second) { | ||||
@@ -17,9 +17,9 @@ | |||||
#include "graph/build/task_generator.h" | #include "graph/build/task_generator.h" | ||||
#include <string> | #include <string> | ||||
#include <utility> | #include <utility> | ||||
#include "common/profiling/profiling_manager.h" | |||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "common/util.h" | #include "common/util.h" | ||||
#include "common/profiling/profiling_manager.h" | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
@@ -73,12 +73,24 @@ Status TaskGenerator::GetTaskInfo(Model &model, ComputeGraphPtr &graph, uint64_t | |||||
std::vector<TaskDef> task_def_list; | std::vector<TaskDef> task_def_list; | ||||
std::map<uint32_t, string> op_name_map; | std::map<uint32_t, string> op_name_map; | ||||
GE_DUMP(graph, "GenerateTaskBefore"); | |||||
bool is_unknown_shape = false; | |||||
NodePtr parent_node = graph->GetParentNode(); | |||||
if (parent_node != nullptr) { | |||||
auto op_desc = parent_node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
(void)AttrUtils::GetBool(op_desc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_unknown_shape); | |||||
} | |||||
Status ret = SUCCESS; | |||||
if (is_unknown_shape) { | |||||
GELOGI("Beign to generate unknown shape task."); | |||||
ret = GenerateUnknownShapeTask(run_context, graph, task_def_list, op_name_map); | |||||
} else { | |||||
GELOGI("Beign to generate known shape task."); | |||||
ret = GenerateTask(run_context, graph, task_def_list, op_name_map); | |||||
} | |||||
GE_DUMP(graph, "GenerateTaskAfter"); | |||||
GraphUtils::DumpGEGraph(graph, "GenerateTaskBefore"); | |||||
GraphUtils::DumpGEGraphToOnnx(*graph, "GenerateTaskBefore"); | |||||
Status ret = GenerateTask(run_context, graph, task_def_list, op_name_map); | |||||
GraphUtils::DumpGEGraph(graph, "GenerateTaskAfter"); | |||||
GraphUtils::DumpGEGraphToOnnx(*graph, "GenerateTaskAfter"); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "GenerateTask failed. session_id=%lu", session_id); | GELOGE(ret, "GenerateTask failed. session_id=%lu", session_id); | ||||
return ret; | return ret; | ||||
@@ -251,8 +263,9 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra | |||||
GE_TIMESTAMP_CALLNUM_START(GenerateTask); | GE_TIMESTAMP_CALLNUM_START(GenerateTask); | ||||
// map store fusion nodes | // map store fusion nodes | ||||
map<int64_t, std::vector<NodePtr>> fusion_nodes; | map<int64_t, std::vector<NodePtr>> fusion_nodes; | ||||
const char *buffer_optimize_on = std::getenv("BUFFER_OPTIMIZE_ON"); | |||||
if (buffer_optimize_on != nullptr) { | |||||
string buffer_optimize = "off_optimize"; | |||||
(void)ge::GetContext().GetOption(BUFFER_OPTIMIZE, buffer_optimize); | |||||
if (buffer_optimize != "off_optimize") { | |||||
GE_CHK_STATUS_RET(SaveFusionNodes(fusion_nodes, graph)); | GE_CHK_STATUS_RET(SaveFusionNodes(fusion_nodes, graph)); | ||||
} | } | ||||
std::unordered_set<Node *> fusion_nodes_seen; | std::unordered_set<Node *> fusion_nodes_seen; | ||||
@@ -342,10 +355,125 @@ Status TaskGenerator::GenerateTask(RunContext &run_context, ComputeGraphPtr &gra | |||||
task_def_ptr->set_ops_kernel_store_ptr(reinterpret_cast<uintptr_t>(ops_kernel_info_store_ptr)); | task_def_ptr->set_ops_kernel_store_ptr(reinterpret_cast<uintptr_t>(ops_kernel_info_store_ptr)); | ||||
} | } | ||||
GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task finished, generate %lu task(s).", | |||||
GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task finished, generate %zu task(s).", | |||||
op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, | |||||
task_list_size_after - task_list_size_before); | |||||
} | |||||
GE_TIMESTAMP_CALLNUM_END(GenerateTask, "GraphBuild::GenerateTask"); | |||||
return SUCCESS; | |||||
} | |||||
Status TaskGenerator::GenerateUnknownShapeTask(RunContext &run_context, ComputeGraphPtr &graph, | |||||
vector<domi::TaskDef> &task_def_list, | |||||
map<uint32_t, string> &op_name_map) { | |||||
std::shared_ptr<GELib> ge_lib = GELib::GetInstance(); | |||||
if ((ge_lib == nullptr) || !ge_lib->InitFlag()) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "GenerateTask failed."); | |||||
return GE_CLI_GE_NOT_INITIALIZED; | |||||
} | |||||
GE_CHK_STATUS_RET(MarkNodeAndSetIndex(graph), "MarkNodeAndSetIndex failed."); | |||||
ProfilingPoint profiling_point; | |||||
vector<uint32_t> all_reduce_nodes; | |||||
GE_CHK_STATUS_RET(FindProfilingTaskIndex(graph, profiling_point, all_reduce_nodes)); | |||||
const OpsKernelManager &ops_kernel_manager = ge_lib->OpsKernelManagerObj(); | |||||
GE_TIMESTAMP_CALLNUM_START(GenerateTask); | |||||
// map store fusion nodes | |||||
map<int64_t, std::vector<NodePtr>> fusion_nodes; | |||||
string buffer_optimize = "off_optimize"; | |||||
(void)ge::GetContext().GetOption(BUFFER_OPTIMIZE, buffer_optimize); | |||||
if (buffer_optimize != "off_optimize") { | |||||
GE_CHK_STATUS_RET(SaveFusionNodes(fusion_nodes, graph)); | |||||
} | |||||
std::unordered_set<Node *> fusion_nodes_seen; | |||||
int64_t group_key; | |||||
uint32_t node_index = 0; | |||||
rtStream_t stream = nullptr; | |||||
GE_CHK_RT_RET(rtStreamCreate(&stream, 0)); | |||||
run_context.stream = stream; | |||||
GE_CHK_RT_RET(rtModelBindStream(run_context.model, stream, 0)); | |||||
for (auto &node : graph->GetAllNodes()) { | |||||
OpDescPtr op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
node_index++; | |||||
string name = node->GetName(); | |||||
string type = node->GetType(); | |||||
bool attr_notask = false; | |||||
bool get_attr_notask_flag = ge::AttrUtils::GetBool(op_desc, ATTR_NAME_NOTASK, attr_notask); | |||||
GE_IF_BOOL_EXEC(get_attr_notask_flag && attr_notask, | |||||
GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); | |||||
continue); | |||||
GE_CHK_STATUS_RET(UpdateOpIsVarAttr(op_desc, graph->GetSessionID())); | |||||
string op_kernel_lib_name = op_desc->GetOpKernelLibName(); | |||||
// For fusion ddb pass, task def must be continuous. | |||||
// Part2: Call | |||||
auto fusion_task_info = | |||||
FusionTaskInfo{run_context, graph, node, op_desc, node_index, ge_lib, | |||||
ops_kernel_manager, task_def_list, op_name_map, profiling_point, all_reduce_nodes}; | |||||
GE_CHK_STATUS_RET(GenerateTaskForFusionNode(fusion_task_info, fusion_nodes, fusion_nodes_seen), | |||||
"Call GenerateTaskForFusionNode node:%s(%s) failed", name.c_str(), type.c_str()); | |||||
// continue directly | |||||
if (ge::AttrUtils::GetInt(op_desc, ATTR_NAME_FUSION_GROUP_KEY, group_key)) { | |||||
GELOGI("Fusion node[name:%s, type:%s] do not need generate task again.", name.c_str(), type.c_str()); | |||||
continue; | |||||
} | |||||
if (op_kernel_lib_name.empty()) { | |||||
GELOGI("Node[name:%s, type:%s] does not need to generate task.", name.c_str(), type.c_str()); | |||||
continue; | |||||
} | |||||
OpsKernelInfoStorePtr kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); | |||||
if (kernel_info_store == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "No ops kernel store found. node:%s(%s), op_kernel_lib_name=%s.", name.c_str(), | |||||
type.c_str(), op_kernel_lib_name.c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
GE_CHK_STATUS_RET(UpdateAnchorStatus(node), "Call UpdateAnchorStatus node:%s(%s) failed", name.c_str(), | |||||
type.c_str()); | |||||
int64_t op_id = op_desc->GetId(); | |||||
int64_t stream_id = op_desc->GetStreamId(); | |||||
// Profiling task | |||||
size_t task_list_size_before = task_def_list.size(); | |||||
GE_CHK_STATUS_RET(InsertProfilingTaskBefore(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); | |||||
GELOGI("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task.", op_kernel_lib_name.c_str(), | |||||
name.c_str(), type.c_str(), op_id, stream_id); | |||||
GE_TIMESTAMP_RESTART(GenerateTask); | |||||
auto ret = kernel_info_store->GenerateTask(*node, run_context, task_def_list); | |||||
GE_TIMESTAMP_ADD(GenerateTask); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task failed.", | |||||
op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id); | |||||
return ret; | |||||
} | |||||
// Profiling task | |||||
GE_CHK_STATUS_RET(InsertProfilingTaskAfter(op_desc, profiling_point, all_reduce_nodes, node_index, task_def_list)); | |||||
size_t task_list_size_after = task_def_list.size(); | |||||
// If tasks is reduced | |||||
if (task_list_size_after < task_list_size_before) { | |||||
GELOGE(FAILED, "Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task. but task num from %zu to %zu.", | |||||
op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, task_list_size_before, | |||||
task_list_size_after); | |||||
return FAILED; | |||||
} | |||||
// Reset stream id to ge stream id, as graph load must use ge stream to reassign stream | |||||
void *ops_kernel_info_store_ptr = kernel_info_store.get(); | |||||
for (size_t idx = task_list_size_before; idx < task_list_size_after; ++idx) { | |||||
op_name_map[idx] = name; | |||||
// Set opsKernelInfoStorePtr and op_index, the two fields be use in DistributeTask and InitTaskInfo | |||||
TaskDef *task_def_ptr = &task_def_list[idx]; | |||||
GE_CHECK_NOTNULL(task_def_ptr); | |||||
task_def_ptr->set_ops_kernel_store_ptr(reinterpret_cast<uintptr_t>(ops_kernel_info_store_ptr)); | |||||
} | |||||
GELOGD("Call %s to generate node[name:%s(%s), id:%ld, stream_id:%ld] task finished, generate %zu task(s).", | |||||
op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, | op_kernel_lib_name.c_str(), name.c_str(), type.c_str(), op_id, stream_id, | ||||
task_list_size_after - task_list_size_before); | task_list_size_after - task_list_size_before); | ||||
} | } | ||||
GE_CHK_RT(rtModelUnbindStream(run_context.model, stream)); | |||||
GE_CHK_RT(rtStreamDestroy(stream)); | |||||
GE_TIMESTAMP_CALLNUM_END(GenerateTask, "GraphBuild::GenerateTask"); | GE_TIMESTAMP_CALLNUM_END(GenerateTask, "GraphBuild::GenerateTask"); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -381,6 +509,11 @@ Status TaskGenerator::GenerateTaskForFusionNode(FusionTaskInfo &fusion_task_info | |||||
fusion_node_type.c_str()); | fusion_node_type.c_str()); | ||||
continue; | continue; | ||||
} | } | ||||
bool attr_notask = false; | |||||
GE_IF_BOOL_EXEC(ge::AttrUtils::GetBool(op_desc, ATTR_NAME_NOTASK, attr_notask) && attr_notask, | |||||
GELOGI("Fusion: fusion_node[name:%s, type:%s] does not need to generate task.", | |||||
fusion_node_name.c_str(), fusion_node_type.c_str()); | |||||
continue); | |||||
size_t task_list_size_before = task_def_list.size(); | size_t task_list_size_before = task_def_list.size(); | ||||
OpsKernelInfoStorePtr kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); | OpsKernelInfoStorePtr kernel_info_store = ops_kernel_manager.GetOpsKernelInfoStore(op_kernel_lib_name); | ||||
@@ -528,6 +661,10 @@ Status TaskGenerator::MarkFirstAndLastOps(const vector<OpDescPtr> &ops, bool is_ | |||||
vector<vector<OpDescPtr>> continuous_op_lists(1); | vector<vector<OpDescPtr>> continuous_op_lists(1); | ||||
const set<string> label_op_types({LABELSET, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX}); | const set<string> label_op_types({LABELSET, LABELGOTO, LABELGOTOEX, LABELSWITCH, LABELSWITCHBYINDEX}); | ||||
for (auto &op_desc : ops) { | for (auto &op_desc : ops) { | ||||
bool attr_notask = false; | |||||
if (ge::AttrUtils::GetBool(op_desc, ATTR_NAME_NOTASK, attr_notask) && attr_notask) { | |||||
continue; | |||||
} | |||||
string op_type = op_desc->GetType(); | string op_type = op_desc->GetType(); | ||||
if (!is_single_stream && (!op_desc->GetSubgraphInstanceNames().empty() || label_op_types.count(op_type) != 0)) { | if (!is_single_stream && (!op_desc->GetSubgraphInstanceNames().empty() || label_op_types.count(op_type) != 0)) { | ||||
continuous_op_lists.emplace_back(vector<OpDescPtr>()); | continuous_op_lists.emplace_back(vector<OpDescPtr>()); | ||||
@@ -629,7 +766,7 @@ Status TaskGenerator::AutoFindBpOpIndex(const ComputeGraphPtr &graph, ProfilingP | |||||
continue; | continue; | ||||
} | } | ||||
if (op_desc->GetType() == HCOMALLREDUCE) { | |||||
if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HVDCALLBACKALLREDUCE) { | |||||
bp_node = node; | bp_node = node; | ||||
all_reduce_nodes.emplace_back(current_idx); | all_reduce_nodes.emplace_back(current_idx); | ||||
GELOGI("Allreduce name %s, idx %u", op_desc->GetName().c_str(), current_idx); | GELOGI("Allreduce name %s, idx %u", op_desc->GetName().c_str(), current_idx); | ||||
@@ -721,7 +858,7 @@ Status TaskGenerator::FindBpOfEnv(const ComputeGraphPtr &graph, const std::strin | |||||
iter_end = current_idx; | iter_end = current_idx; | ||||
GELOGI("Iter end name %s, idx %u", op_desc->GetName().c_str(), iter_end); | GELOGI("Iter end name %s, idx %u", op_desc->GetName().c_str(), iter_end); | ||||
} | } | ||||
if (op_desc->GetType() == HCOMALLREDUCE) { | |||||
if (op_desc->GetType() == HCOMALLREDUCE || op_desc->GetType() == HVDCALLBACKALLREDUCE) { | |||||
all_reduce_nodes.emplace_back(current_idx); | all_reduce_nodes.emplace_back(current_idx); | ||||
GELOGI("Allreduce name %s, idx %u", op_desc->GetName().c_str(), current_idx); | GELOGI("Allreduce name %s, idx %u", op_desc->GetName().c_str(), current_idx); | ||||
} | } | ||||
@@ -82,7 +82,7 @@ class TaskGenerator { | |||||
Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id); | Status UpdateOpIsVarAttr(const OpDescPtr &op_desc, uint64_t session_id); | ||||
/// | /// | ||||
/// call engine to generate task. | |||||
/// call engine to generate known shape task. | |||||
/// @param run_context run context | /// @param run_context run context | ||||
/// @param graph compute graph | /// @param graph compute graph | ||||
/// @param task_def_list task def list generate by engine | /// @param task_def_list task def list generate by engine | ||||
@@ -93,6 +93,18 @@ class TaskGenerator { | |||||
Status GenerateTask(RunContext &run_context, ComputeGraphPtr &graph, std::vector<domi::TaskDef> &task_def_list, | Status GenerateTask(RunContext &run_context, ComputeGraphPtr &graph, std::vector<domi::TaskDef> &task_def_list, | ||||
std::map<uint32_t, string> &op_name_map); | std::map<uint32_t, string> &op_name_map); | ||||
/// | |||||
/// call engine to generate unknown shape task. | |||||
/// @param run_context run context | |||||
/// @param graph compute graph | |||||
/// @param task_def_list task def list generate by engine | |||||
/// @param op_name_map relation of task index and op | |||||
/// @return SUCCESS:seccess | |||||
/// Other: failed | |||||
/// | |||||
Status GenerateUnknownShapeTask(RunContext &run_context, ComputeGraphPtr &graph, | |||||
std::vector<domi::TaskDef> &task_def_list, std::map<uint32_t, string> &op_name_map); | |||||
/// | /// | ||||
/// AddModelTaskToModel | /// AddModelTaskToModel | ||||
/// @param model_task_def model task | /// @param model_task_def model task | ||||
@@ -258,7 +258,7 @@ Status GraphExecutor::SyncExecuteModel(uint32_t model_id, const std::vector<GeTe | |||||
// Run graph return | // Run graph return | ||||
uint32_t result_code = graph_run_listener_->GetResultCode(); | uint32_t result_code = graph_run_listener_->GetResultCode(); | ||||
if (result_code != SUCCESS) { | |||||
if (result_code != SUCCESS && result_code != END_OF_SEQUENCE) { | |||||
GELOGE(GE_GRAPH_EXECUTE_FAILED, "[GraphExecutor] execute model failed, ret=%u, modelId=%u.", result_code, | GELOGE(GE_GRAPH_EXECUTE_FAILED, "[GraphExecutor] execute model failed, ret=%u, modelId=%u.", result_code, | ||||
model_id); | model_id); | ||||
return GE_GRAPH_EXECUTE_FAILED; | return GE_GRAPH_EXECUTE_FAILED; | ||||
@@ -319,7 +319,7 @@ Status GraphExecutor::FreeExecuteMemory() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeModelPtr &ge_model, | |||||
Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_root_model, | |||||
const std::vector<GeTensor> &input_tensor, std::vector<GeTensor> &output_tensor) { | const std::vector<GeTensor> &input_tensor, std::vector<GeTensor> &output_tensor) { | ||||
if (graph_id != last_graph_id_) { | if (graph_id != last_graph_id_) { | ||||
auto ret = FreeExecuteMemory(); | auto ret = FreeExecuteMemory(); | ||||
@@ -333,8 +333,8 @@ Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeModelPtr &ge_model, | |||||
GELOGE(GE_GRAPH_EXECUTE_NOT_INIT, "[GraphExecutor] AI Core Engine without calling SetCondition!"); | GELOGE(GE_GRAPH_EXECUTE_NOT_INIT, "[GraphExecutor] AI Core Engine without calling SetCondition!"); | ||||
return GE_GRAPH_EXECUTE_NOT_INIT; | return GE_GRAPH_EXECUTE_NOT_INIT; | ||||
} | } | ||||
GE_CHECK_NOTNULL_EXEC(ge_model, return FAILED); | |||||
Status ret = SyncExecuteModel(ge_model->GetModelId(), input_tensor, output_tensor); | |||||
GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED); | |||||
Status ret = SyncExecuteModel(ge_root_model->GetModelId(), input_tensor, output_tensor); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] SyncExecuteModel Error!"); | GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] SyncExecuteModel Error!"); | ||||
return GE_GRAPH_SYNC_MODEL_FAILED; | return GE_GRAPH_SYNC_MODEL_FAILED; | ||||
@@ -343,7 +343,7 @@ Status GraphExecutor::ExecuteGraph(GraphId graph_id, const GeModelPtr &ge_model, | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeModelPtr &ge_model, | |||||
Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, | |||||
const std::vector<InputTensorInfo> &input_tensor) { | const std::vector<InputTensorInfo> &input_tensor) { | ||||
GELOGI("[GraphExecutor] Start to async execute graph, graph_id=%u", graph_id); | GELOGI("[GraphExecutor] Start to async execute graph, graph_id=%u", graph_id); | ||||
if (graph_id != last_graph_id_) { | if (graph_id != last_graph_id_) { | ||||
@@ -353,8 +353,8 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeModelPtr &ge_m | |||||
} | } | ||||
} | } | ||||
last_graph_id_ = graph_id; | last_graph_id_ = graph_id; | ||||
GE_CHECK_NOTNULL_EXEC(ge_model, return FAILED); | |||||
Status ret = AsyncExecuteModel(ge_model->GetModelId(), input_tensor); | |||||
GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED); | |||||
Status ret = AsyncExecuteModel(ge_root_model->GetModelId(), input_tensor); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] AsyncExecuteModel Error!"); | GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[GraphExecutor] AsyncExecuteModel Error!"); | ||||
return GE_GRAPH_SYNC_MODEL_FAILED; | return GE_GRAPH_SYNC_MODEL_FAILED; | ||||
@@ -46,11 +46,11 @@ class GraphExecutor { | |||||
virtual ~GraphExecutor(); | virtual ~GraphExecutor(); | ||||
Status ExecuteGraph(GraphId graph_id, const GeModelPtr &ge_model, const std::vector<GeTensor> &input_tensor, | |||||
Status ExecuteGraph(GraphId graph_id, const GeRootModelPtr &ge_root_model, const std::vector<GeTensor> &input_tensor, | |||||
std::vector<GeTensor> &output_tensor); | std::vector<GeTensor> &output_tensor); | ||||
Status ExecuteGraphAsync(GraphId graph_id, const GeModelPtr &ge_model, | |||||
const std::vector<InputTensorInfo> &input_tensor); | |||||
ge::Status ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr &ge_root_model, | |||||
const std::vector<InputTensorInfo> &input_tensor); | |||||
Status SetCondition(std::mutex *mutex, std::condition_variable *cond, std::shared_ptr<GraphModelListener> listener); | Status SetCondition(std::mutex *mutex, std::condition_variable *cond, std::shared_ptr<GraphModelListener> listener); | ||||
@@ -94,42 +94,6 @@ void LabelMaker::SetStreamIdOwner(const ComputeGraphPtr &graph, const OpDescPtr | |||||
op_desc->SetStreamId(stream_id); | op_desc->SetStreamId(stream_id); | ||||
} | } | ||||
/** | |||||
* @ingroup ge | |||||
* @brief Link Node to Graph head. | |||||
* @param [in] graph: graph for add node. | |||||
* @param [in] lb_node: Node for set link to head. | |||||
* @return: SUCCESS / FAILED | |||||
*/ | |||||
Status LabelMaker::AddCtrlLink2Data(const ComputeGraphPtr &graph, const NodePtr &node) { | |||||
GE_CHECK_NOTNULL(graph); | |||||
GE_CHECK_NOTNULL(node); | |||||
std::set<NodePtr> linked_nodes; | |||||
for (const NodePtr &n : graph->GetDirectNode()) { | |||||
GE_CHECK_NOTNULL(n); | |||||
if (n->GetType() != DATA) { | |||||
continue; | |||||
} | |||||
// Link control edge to graph head. | |||||
for (const NodePtr &out_node : n->GetOutAllNodes()) { | |||||
if (linked_nodes.count(out_node) > 0) { | |||||
continue; | |||||
} | |||||
(void)linked_nodes.insert(out_node); | |||||
if (GraphUtils::AddEdge(node->GetOutControlAnchor(), out_node->GetInControlAnchor()) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Add ctrl edge from %s to %s failed.", node->GetName().c_str(), | |||||
out_node->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
/** | /** | ||||
* @ingroup ge | * @ingroup ge | ||||
* @brief Add StreamActive node at graph front. | * @brief Add StreamActive node at graph front. | ||||
@@ -154,15 +118,10 @@ NodePtr LabelMaker::AddStreamActive(const ComputeGraphPtr &graph, const std::str | |||||
vector<uint32_t> active_streams; | vector<uint32_t> active_streams; | ||||
(void)AttrUtils::SetStr(op_desc, ATTR_NAME_SWITCH_BRANCH_NODE_LABEL, op_desc->GetName()); | (void)AttrUtils::SetStr(op_desc, ATTR_NAME_SWITCH_BRANCH_NODE_LABEL, op_desc->GetName()); | ||||
(void)AttrUtils::SetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams); | (void)AttrUtils::SetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams); | ||||
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_SUBGRAPH_FIRST_ACTIVE, true); | |||||
NodePtr stream_active = graph->AddNodeFront(op_desc); | NodePtr stream_active = graph->AddNodeFront(op_desc); | ||||
GE_CHECK_NOTNULL_EXEC(stream_active, return nullptr); | GE_CHECK_NOTNULL_EXEC(stream_active, return nullptr); | ||||
// Link control edge to graph head. | |||||
if (AddCtrlLink2Data(graph, stream_active) != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Add ctrl edge for graph %s failed.", graph->GetName().c_str()); | |||||
return nullptr; | |||||
} | |||||
return stream_active; | return stream_active; | ||||
} | } | ||||
@@ -230,6 +189,7 @@ NodePtr LabelMaker::AddLabelSetLeave(const ComputeGraphPtr &graph, const std::st | |||||
GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); | GELOGI("LabelSet: Create node %s.", op_desc->GetName().c_str()); | ||||
(void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | (void)AttrUtils::SetInt(op_desc, ATTR_NAME_LABEL_SWITCH_INDEX, index); | ||||
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_SUBGRAPH_END_NODE, true); | |||||
NodePtr label_set = graph->AddNode(op_desc); | NodePtr label_set = graph->AddNode(op_desc); | ||||
GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); | GE_CHECK_NOTNULL_EXEC(label_set, return nullptr); | ||||
@@ -60,7 +60,6 @@ class LabelMaker { | |||||
ComputeGraphPtr parent_graph_; | ComputeGraphPtr parent_graph_; | ||||
private: | private: | ||||
Status AddCtrlLink2Data(const ComputeGraphPtr &graph, const NodePtr &node); | |||||
void SetStreamIdEnter(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | void SetStreamIdEnter(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | ||||
void SetStreamIdLeave(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | void SetStreamIdLeave(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | ||||
void SetStreamIdOwner(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | void SetStreamIdOwner(const ComputeGraphPtr &graph, const OpDescPtr &op_desc); | ||||
@@ -50,6 +50,21 @@ Status PartitionedCallLabelMaker::Run(uint32_t &label_index) { | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
const std::string stream_active_name = parent_node_->GetName() + "/StreamActive"; // rtStreamActive | |||||
NodePtr stream_active = AddStreamActive(sub_graph, stream_active_name); | |||||
if (stream_active == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "Subgraph: %s add stream active node failed.", sub_graph->GetName().c_str()); | |||||
return FAILED; | |||||
} | |||||
for (auto &node : sub_graph->GetDirectNode()) { | |||||
if (node->GetType() == NETOUTPUT) { | |||||
auto op_desc = node->GetOpDesc(); | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
(void)AttrUtils::SetBool(op_desc, ATTR_NAME_SUBGRAPH_END_NODE, true); | |||||
} | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -29,6 +29,8 @@ | |||||
+---------------+ | +---------------+ | ||||
+---------------+ | +---------------+ | ||||
| Node | +---------------+ | | Node | +---------------+ | ||||
+---------------+ | StreamActive | | |||||
| Node | +---------------+ | |||||
+---------------+ | f | | +---------------+ | f | | ||||
| Node | +---------------+ | | Node | +---------------+ | ||||
+---------------+ | u | | +---------------+ | u | | ||||
@@ -53,7 +53,7 @@ Status GraphLoader::UnloadModel(uint32_t model_id) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeModel> &ge_model_ptr, | |||||
Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeRootModel> &ge_root_model_ptr, | |||||
const std::shared_ptr<ModelListener> &listener) { | const std::shared_ptr<ModelListener> &listener) { | ||||
GELOGI("Load model online begin."); | GELOGI("Load model online begin."); | ||||
rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); | rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); | ||||
@@ -62,15 +62,15 @@ Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge | |||||
CsaInteract::GetInstance().WriteErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_LOAD); | CsaInteract::GetInstance().WriteErrorCode(rt_ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_LOAD); | ||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
if (ge_model_ptr == nullptr) { | |||||
if (ge_root_model_ptr == nullptr) { | |||||
GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph] GE load graph model_ptr is nullptr."); | GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph] GE load graph model_ptr is nullptr."); | ||||
return GE_GRAPH_PARAM_NULLPTR; | return GE_GRAPH_PARAM_NULLPTR; | ||||
} | } | ||||
model_id = ge_model_ptr->GetModelId(); | |||||
model_id = ge_root_model_ptr->GetModelId(); | |||||
auto model_manager = ModelManager::GetInstance(); | auto model_manager = ModelManager::GetInstance(); | ||||
GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
Status ret = model_manager->LoadModelOnline(model_id, ge_model_ptr, listener); | |||||
Status ret = model_manager->LoadModelOnline(model_id, ge_root_model_ptr, listener); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "LoadModel: Load failed. ret = %u", ret); | GELOGE(ret, "LoadModel: Load failed. ret = %u", ret); | ||||
CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_LOAD); | CsaInteract::GetInstance().WriteErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_LOAD); | ||||
@@ -71,7 +71,7 @@ class GraphLoader { | |||||
static Status DestroyAicpuSessionForInfer(uint32_t model_id); | static Status DestroyAicpuSessionForInfer(uint32_t model_id); | ||||
static Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeModel> &model, | |||||
static Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeRootModel> &ge_root_model, | |||||
const std::shared_ptr<ModelListener> &listener); | const std::shared_ptr<ModelListener> &listener); | ||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
@@ -15,9 +15,12 @@ | |||||
*/ | */ | ||||
#include "graph/load/new_model_manager/data_dumper.h" | #include "graph/load/new_model_manager/data_dumper.h" | ||||
#include <ctime> | |||||
#include <map> | #include <map> | ||||
#include <utility> | #include <utility> | ||||
#include <vector> | #include <vector> | ||||
#include "common/properties_manager.h" | #include "common/properties_manager.h" | ||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "framework/common/util.h" | #include "framework/common/util.h" | ||||
@@ -32,6 +35,7 @@ | |||||
namespace { | namespace { | ||||
const uint32_t kAicpuLoadFlag = 1; | const uint32_t kAicpuLoadFlag = 1; | ||||
const uint32_t kAicpuUnloadFlag = 0; | const uint32_t kAicpuUnloadFlag = 0; | ||||
const uint32_t kTimeBufferLen = 80; | |||||
const char *const kDumpOutput = "output"; | const char *const kDumpOutput = "output"; | ||||
const char *const kDumpInput = "input"; | const char *const kDumpInput = "input"; | ||||
const char *const kDumpAll = "all"; | const char *const kDumpAll = "all"; | ||||
@@ -156,10 +160,8 @@ void DataDumper::SaveDumpTask(uint32_t task_id, uint32_t stream_id, const std::s | |||||
return; | return; | ||||
} | } | ||||
uintptr_t data_addr = args - sizeof(void *) * op_desc->GetInputOffset().size() + | |||||
sizeof(void *) * static_cast<uint32_t>(inner_input_mapping.input_anchor_index); | |||||
GELOGI("Save input dump task %s, id: %u.", data_op->GetName().c_str(), task_id); | GELOGI("Save input dump task %s, id: %u.", data_op->GetName().c_str(), task_id); | ||||
op_list_.push_back({task_id, stream_id, data_op, data_addr, false, inner_input_mapping.input_anchor_index, | |||||
op_list_.push_back({task_id, stream_id, data_op, args, false, inner_input_mapping.input_anchor_index, | |||||
inner_input_mapping.output_anchor_index, input_tensor->GetShape().GetDims()}); | inner_input_mapping.output_anchor_index, input_tensor->GetShape().GetDims()}); | ||||
} | } | ||||
} | } | ||||
@@ -188,11 +190,24 @@ static void SetOpMappingLoopAddr(uintptr_t step_id, uintptr_t loop_per_iter, uin | |||||
} | } | ||||
} | } | ||||
static std::string GetCurrentTime() { | |||||
std::time_t now = std::time(nullptr); | |||||
std::tm *ptm = std::localtime(&now); | |||||
if (ptm == nullptr) { | |||||
return ""; | |||||
} | |||||
char buffer[kTimeBufferLen] = {0}; | |||||
// format: 20171122042550 | |||||
std::strftime(buffer, kTimeBufferLen, "%Y%m%d%H%M%S", ptm); | |||||
return std::string(buffer); | |||||
} | |||||
Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { | Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump::Task &task) { | ||||
GELOGI("Start dump output"); | GELOGI("Start dump output"); | ||||
if (inner_dump_info.is_task) { | if (inner_dump_info.is_task) { | ||||
// tbe or aicpu op | // tbe or aicpu op | ||||
const auto &output_descs = inner_dump_info.op->GetAllOutputsDesc(); | const auto &output_descs = inner_dump_info.op->GetAllOutputsDesc(); | ||||
const auto input_size = inner_dump_info.op->GetAllInputsDesc().size(); | |||||
const std::vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op, false); | const std::vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(runtime_param_, inner_dump_info.op, false); | ||||
if (output_descs.size() != output_addrs.size()) { | if (output_descs.size() != output_addrs.size()) { | ||||
GELOGE(PARAM_INVALID, "Invalid output desc addrs size %zu, op %s has %zu output desc.", output_addrs.size(), | GELOGE(PARAM_INVALID, "Invalid output desc addrs size %zu, op %s has %zu output desc.", output_addrs.size(), | ||||
@@ -217,8 +232,7 @@ Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump: | |||||
output.set_original_output_index(origin_output_index); | output.set_original_output_index(origin_output_index); | ||||
output.set_original_output_format(static_cast<int32_t>(output_descs.at(i).GetOriginFormat())); | output.set_original_output_format(static_cast<int32_t>(output_descs.at(i).GetOriginFormat())); | ||||
output.set_original_output_data_type(static_cast<int32_t>(output_descs.at(i).GetOriginDataType())); | output.set_original_output_data_type(static_cast<int32_t>(output_descs.at(i).GetOriginDataType())); | ||||
// due to lhisi virtual addr bug, cannot use args now | |||||
output.set_address(static_cast<uint64_t>(reinterpret_cast<uintptr_t>(output_addrs[i]))); | |||||
output.set_address(static_cast<uint64_t>(inner_dump_info.args + (i + input_size) * sizeof(void *))); | |||||
task.mutable_output()->Add(std::move(output)); | task.mutable_output()->Add(std::move(output)); | ||||
} | } | ||||
@@ -255,8 +269,8 @@ Status DataDumper::DumpOutput(const InnerDumpInfo &inner_dump_info, aicpu::dump: | |||||
GELOGE(FAILED, "Index is out of range."); | GELOGE(FAILED, "Index is out of range."); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
output.set_address( | |||||
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(output_addrs[inner_dump_info.output_anchor_index]))); | |||||
auto data_addr = inner_dump_info.args + sizeof(void *) * static_cast<uint32_t>(inner_dump_info.input_anchor_index); | |||||
output.set_address(static_cast<uint64_t>(data_addr)); | |||||
task.mutable_output()->Add(std::move(output)); | task.mutable_output()->Add(std::move(output)); | ||||
@@ -282,7 +296,7 @@ Status DataDumper::DumpInput(const InnerDumpInfo &inner_dump_info, aicpu::dump:: | |||||
input.mutable_shape()->add_dim(dim); | input.mutable_shape()->add_dim(dim); | ||||
} | } | ||||
input.set_address(static_cast<uint64_t>(reinterpret_cast<uintptr_t>(input_addrs[i]))); | |||||
input.set_address(static_cast<uint64_t>(inner_dump_info.args + sizeof(void *) * i)); | |||||
task.mutable_input()->Add(std::move(input)); | task.mutable_input()->Add(std::move(input)); | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -370,7 +384,10 @@ Status DataDumper::LoadDumpInfo() { | |||||
} | } | ||||
aicpu::dump::OpMappingInfo op_mapping_info; | aicpu::dump::OpMappingInfo op_mapping_info; | ||||
op_mapping_info.set_dump_path(PropertiesManager::Instance().GetDumpOutputPath() + std::to_string(device_id_) + "/"); | |||||
std::string time_now = GetCurrentTime(); | |||||
GELOGI("Time is %s now", time_now.c_str()); | |||||
op_mapping_info.set_dump_path(PropertiesManager::Instance().GetDumpOutputPath() + time_now + "/" + | |||||
std::to_string(device_id_) + "/"); | |||||
op_mapping_info.set_model_name(model_name_); | op_mapping_info.set_model_name(model_name_); | ||||
op_mapping_info.set_model_id(model_id_); | op_mapping_info.set_model_id(model_id_); | ||||
op_mapping_info.set_flag(kAicpuLoadFlag); | op_mapping_info.set_flag(kAicpuLoadFlag); | ||||
@@ -45,6 +45,7 @@ | |||||
#include "graph/load/output/output.h" | #include "graph/load/output/output.h" | ||||
#include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
#include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
#include "graph/manager/trans_var_data_utils.h" | |||||
#include "graph/manager/util/debug.h" | #include "graph/manager/util/debug.h" | ||||
#include "graph/model_serialize.h" | #include "graph/model_serialize.h" | ||||
#include "graph/node.h" | #include "graph/node.h" | ||||
@@ -75,6 +76,7 @@ | |||||
namespace ge { | namespace ge { | ||||
namespace { | namespace { | ||||
const uint32_t kDataIndex = 0; | const uint32_t kDataIndex = 0; | ||||
const uint32_t kOutputNum = 1; | |||||
const uint32_t kTrueBranchStreamNum = 1; | const uint32_t kTrueBranchStreamNum = 1; | ||||
const uint32_t kThreadNum = 16; | const uint32_t kThreadNum = 16; | ||||
const uint32_t kAddrLen = sizeof(void *); | const uint32_t kAddrLen = sizeof(void *); | ||||
@@ -83,275 +85,6 @@ const int kBytes = 8; | |||||
const uint32_t kDataMemAlignSizeCompare = 64; | const uint32_t kDataMemAlignSizeCompare = 64; | ||||
const char *const kDefaultBatchLable = "Batch_default"; | const char *const kDefaultBatchLable = "Batch_default"; | ||||
class RtContextSwitchGuard { | |||||
public: | |||||
RtContextSwitchGuard(rtCtxMode_t mode, uint32_t device_id) : last_(nullptr), current_(nullptr) { | |||||
auto ret = rtCtxGetCurrent(&last_); | |||||
if (ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Failed to get current context from rt, error-code %d", ret); | |||||
return; | |||||
} | |||||
ret = rtCtxCreate(¤t_, mode, static_cast<int32_t>(device_id)); | |||||
if (ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Failed to create new context for device %u, error-code %d", device_id, ret); | |||||
return; | |||||
} | |||||
ret = rtCtxSetCurrent(current_); | |||||
if (ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Failed to switch context to normal, context %p, device %u", current_, device_id); | |||||
return; | |||||
} | |||||
GELOGD("Create and switch rt context %p type %d for device %u, backup last %p.", current_, mode, device_id, last_); | |||||
} | |||||
~RtContextSwitchGuard() { | |||||
if (current_ != nullptr) { | |||||
auto ret = rtCtxDestroy(current_); | |||||
GELOGD("Destory current context %p result %d", current_, ret); | |||||
} | |||||
if (last_ != nullptr) { | |||||
auto ret = rtCtxSetCurrent(last_); | |||||
GELOGD("Recovery last context %p result %d.", last_, ret); | |||||
} | |||||
} | |||||
private: | |||||
rtContext_t last_; | |||||
rtContext_t current_; | |||||
}; | |||||
int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) { | |||||
int64_t var_size = GetSizeByDataType(desc.GetDataType()); | |||||
if (var_size <= 0) { | |||||
GELOGE(PARAM_INVALID, "Failed to calc var data size from data type %s", | |||||
TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str()); | |||||
return -1; | |||||
} | |||||
auto shape = desc.GetShape(); | |||||
auto dim_num = shape.GetDimNum(); | |||||
for (size_t dim_index = 0; dim_index < dim_num; ++dim_index) { | |||||
var_size *= shape.GetDim(dim_index); | |||||
} | |||||
return var_size; | |||||
} | |||||
Status CopyVarFromDevice(uint64_t session_id, const NodePtr &var, std::unique_ptr<uint8_t[]> &var_data, | |||||
const GeTensorDesc &input_desc) { | |||||
uint8_t *var_logic = nullptr; | |||||
GE_CHECK_NOTNULL(var); | |||||
auto ret = VarManager::Instance(session_id)->GetVarAddr(var->GetName(), input_desc, &var_logic); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, | |||||
"Failed to copy var %s from device, can not find it" | |||||
" from var manager %u", | |||||
var->GetName().c_str(), ret); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM); | |||||
if (var_addr == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, | |||||
"Failed to copy var %s from device, cant not get " | |||||
"var addr from logic addr %p", | |||||
var->GetName().c_str(), var_logic); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
int64_t var_size_bytes = CalcVarSizeInBytes(input_desc); | |||||
if (var_size_bytes <= 0) { | |||||
return INTERNAL_ERROR; | |||||
} | |||||
std::unique_ptr<uint8_t[]> var_host(new (std::nothrow) uint8_t[var_size_bytes]); | |||||
if (var_host == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Failed to malloc rt-host memory, size %ld", var_size_bytes); | |||||
return OUT_OF_MEMORY; | |||||
} | |||||
ret = rtMemcpy(reinterpret_cast<void *>(var_host.get()), var_size_bytes, reinterpret_cast<void *>(var_addr), | |||||
var_size_bytes, RT_MEMCPY_DEVICE_TO_HOST); | |||||
if (ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, | |||||
"Failed to copy var memory from device, var %s, size %ld," | |||||
" rt-error-code %u", | |||||
var->GetName().c_str(), var_size_bytes, ret); | |||||
return RT_FAILED; | |||||
} | |||||
GELOGD("Copy var %s from device to host, size %ld", var->GetName().c_str(), var_size_bytes); | |||||
var_data.swap(var_host); | |||||
GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr); | |||||
return SUCCESS; | |||||
} | |||||
Status CopyVarToDevice(const NodePtr &var, const formats::TransResult &trans_result, void *var_addr) { | |||||
GELOGD("Copy var %s from host to device, size %zu", var->GetName().c_str(), trans_result.length); | |||||
auto ret = rtMemcpy(var_addr, trans_result.length, reinterpret_cast<void *>(trans_result.data.get()), | |||||
trans_result.length, RT_MEMCPY_HOST_TO_DEVICE); | |||||
if (ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Failed to copy memory to device, size %zu", trans_result.length); | |||||
return RT_FAILED; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats::TransResult &result) { | |||||
formats::TransResult result_last_time{}; | |||||
bool use_init_data = true; | |||||
for (const auto &trans_info : trans_road) { | |||||
if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) { | |||||
GELOGD("Skip to trans variable data on the reshape/reformat node"); | |||||
continue; | |||||
} | |||||
uint8_t *src_data = nullptr; | |||||
if (use_init_data) { | |||||
src_data = var_data; | |||||
use_init_data = false; | |||||
} else { | |||||
src_data = result_last_time.data.get(); | |||||
} | |||||
formats::TransResult tmp_result{}; | |||||
if (trans_info.node_type == TRANSDATA) { | |||||
auto src_format = trans_info.input.GetFormat(); | |||||
auto src_shape = trans_info.input.GetShape().GetDims(); | |||||
auto dst_format = trans_info.output.GetFormat(); | |||||
auto dst_shape = trans_info.output.GetShape().GetDims(); | |||||
auto data_type = trans_info.input.GetDataType(); | |||||
GELOGD("Trans format from %s to %s, shape %s to %s, data-type %s", | |||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | |||||
formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(), | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
auto ret = formats::TransFormat({src_data, src_format, dst_format, src_shape, dst_shape, data_type}, tmp_result); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, | |||||
"Failed to trans format from %s to %s, shape %s to %s, " | |||||
"data type %s error code %u", | |||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | |||||
formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(), | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str(), ret); | |||||
return ret; | |||||
} | |||||
} else if (trans_info.node_type == CAST) { | |||||
auto input_shape = trans_info.input.GetShape(); | |||||
auto src_data_size = input_shape.GetShapeSize() == 0 ? 1 : input_shape.GetShapeSize(); | |||||
auto src_data_type = trans_info.input.GetDataType(); | |||||
auto dst_data_type = trans_info.output.GetDataType(); | |||||
GELOGD("Trans data type from %s to %s, input shape %s, data size %ld", | |||||
TypeUtils::DataTypeToSerialString(src_data_type).c_str(), | |||||
TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(), | |||||
src_data_size); | |||||
auto ret = formats::TransDataType({src_data, static_cast<size_t>(src_data_size), src_data_type, dst_data_type}, | |||||
tmp_result); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to trans data type from %s to %s, input shape %s, data size %ld, error code %u", | |||||
TypeUtils::DataTypeToSerialString(src_data_type).c_str(), | |||||
TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(), | |||||
src_data_size, ret); | |||||
return ret; | |||||
} | |||||
} else { | |||||
GELOGE(UNSUPPORTED, "Failed to trans var data, the trans type %s does not supported", | |||||
trans_info.node_type.c_str()); | |||||
return UNSUPPORTED; | |||||
} | |||||
result_last_time = tmp_result; | |||||
} | |||||
result = result_last_time; | |||||
return SUCCESS; | |||||
} | |||||
/// re-alloc var memory on device using var-manager | |||||
/// free origin var memory(var manager does not support now) | |||||
/// @param session_id | |||||
/// @param var | |||||
/// @param var_size_bytes | |||||
/// @param var_device | |||||
/// @return | |||||
Status ReAssignVarAddr(uint64_t session_id, const std::string &var_name, const GeTensorDesc &tensor_desc, | |||||
void **var_device) { | |||||
uint8_t *var_logic = nullptr; | |||||
Status ret = VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, &var_logic); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, | |||||
"Failed to get var %s device addr, can not find it" | |||||
" from var manager %u", | |||||
var_name.c_str(), ret); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM); | |||||
if (var_addr == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to convert var %s logic addr to real addr", var_name.c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
*var_device = var_addr; | |||||
GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr); | |||||
return SUCCESS; | |||||
} | |||||
Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t session_id) { | |||||
// do not need to do anything if only all reshape/reformat node on the trans_road | |||||
GE_CHECK_NOTNULL(var); | |||||
bool need_trans = false; | |||||
for (auto &road : trans_road) { | |||||
if (road.node_type != RESHAPE && road.node_type != REFORMAT) { | |||||
need_trans = true; | |||||
break; | |||||
} | |||||
} | |||||
if (!need_trans) { | |||||
return SUCCESS; | |||||
} | |||||
// Sync var data from device | |||||
std::unique_ptr<uint8_t[]> var_data; | |||||
if (trans_road.size() == 0) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to get trans_road, trans_road is empty."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
const GeTensorDesc &input_desc = trans_road.begin()->input; | |||||
auto ret = CopyVarFromDevice(session_id, var, var_data, input_desc); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | |||||
formats::TransResult trans_result{}; | |||||
ret = TransVarOnHost(var_data.get(), trans_road, trans_result); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to trans var data on host, error code %u", ret); | |||||
return ret; | |||||
} | |||||
void *var_device = nullptr; | |||||
/// It is a temporary solution to use the last GeTensorDesc to assign variable memory because the variable manager | |||||
/// depends on TensorDesc and it is difficult to be modified. The correct solution is to assign memory based on the | |||||
/// size of the converted variable. To complete the final solution, the dependency of the variable manager on | |||||
/// TensorDesc needs to be removed. This change is large and needs to be performed step by step. | |||||
ret = ReAssignVarAddr(session_id, var->GetName(), trans_road.rbegin()->output, &var_device); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to re-assign memory on device, size %zu", trans_result.length); | |||||
return ret; | |||||
} | |||||
// sync new data to device | |||||
ret = CopyVarToDevice(var, trans_result, var_device); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to send var data to device"); | |||||
return ret; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
inline bool IsDataOp(const std::string &node_type) { | inline bool IsDataOp(const std::string &node_type) { | ||||
return node_type == DATA_TYPE || node_type == AIPP_DATA_TYPE || node_type == ANN_DATA_TYPE; | return node_type == DATA_TYPE || node_type == AIPP_DATA_TYPE || node_type == ANN_DATA_TYPE; | ||||
} | } | ||||
@@ -474,6 +207,14 @@ DavinciModel::~DavinciModel() { | |||||
CleanTbeHandle(); | CleanTbeHandle(); | ||||
var_mem_base_ = nullptr; | var_mem_base_ = nullptr; | ||||
if (known_node_) { | |||||
if (args_ != nullptr) { | |||||
GE_CHK_RT(rtFree(args_)); | |||||
} | |||||
if (args_host_ != nullptr) { | |||||
GE_CHK_RT(rtFreeHost(args_host_)); | |||||
} | |||||
} | |||||
} catch (...) { | } catch (...) { | ||||
GELOGW("DavinciModel::~DavinciModel: clear op_list catch exception."); | GELOGW("DavinciModel::~DavinciModel: clear op_list catch exception."); | ||||
} | } | ||||
@@ -574,6 +315,14 @@ Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_p | |||||
GELOGI("copy weights data to device"); | GELOGI("copy weights data to device"); | ||||
} | } | ||||
GE_CHK_STATUS_RET(InitVariableMem(), "init variable mem failed."); | |||||
runtime_param_.mem_base = mem_base_; | |||||
runtime_param_.weight_base = weights_mem_base_; | |||||
return SUCCESS; | |||||
} | |||||
Status DavinciModel::InitVariableMem() { | |||||
// malloc variable memory base | |||||
var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM); | var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM); | ||||
if (TotalVarMemSize() && var_mem_base_ == nullptr) { | if (TotalVarMemSize() && var_mem_base_ == nullptr) { | ||||
Status ret = VarManager::Instance(session_id_)->MallocVarMemory(TotalVarMemSize()); | Status ret = VarManager::Instance(session_id_)->MallocVarMemory(TotalVarMemSize()); | ||||
@@ -582,12 +331,9 @@ Status DavinciModel::InitModelMem(void *dev_ptr, size_t mem_size, void *weight_p | |||||
return ret; | return ret; | ||||
} | } | ||||
var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM); | var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM); | ||||
GELOGI("[IMAS]InitModelMem graph_%u MallocMemory type[V] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, | |||||
GELOGI("[IMAS]InitVariableMem graph_%u MallocMemory type[V] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, | |||||
var_mem_base_, TotalVarMemSize()); | var_mem_base_, TotalVarMemSize()); | ||||
} | } | ||||
runtime_param_.mem_base = mem_base_; | |||||
runtime_param_.weight_base = weights_mem_base_; | |||||
runtime_param_.var_base = var_mem_base_; | runtime_param_.var_base = var_mem_base_; | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -618,11 +364,15 @@ void DavinciModel::InitRuntimeParams() { | |||||
ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_VAR_SIZE, value); | ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_VAR_SIZE, value); | ||||
runtime_param_.var_size = ret ? (uint64_t)value : 0; | runtime_param_.var_size = ret ? (uint64_t)value : 0; | ||||
session_id_ = runtime_param_.session_id; | session_id_ = runtime_param_.session_id; | ||||
GELOGI("InitRuntimeParams(), memory_size:%lu, weight_size:%lu, stream_num:%u, session_id:%u, var_size:%lu.", | |||||
runtime_param_.mem_size, runtime_param_.weight_size, runtime_param_.stream_num, runtime_param_.session_id, | |||||
runtime_param_.var_size); | |||||
GELOGI("InitRuntimeParams(), event_num:%u, label_num:%u", runtime_param_.event_num, runtime_param_.label_num); | |||||
GELOGI( | |||||
"InitRuntimeParams(), memory_size:%lu, weight_size:%lu, session_id:%u, var_size:%lu, logic_var_base:%lu, " | |||||
"logic_mem_base:%lu.", | |||||
runtime_param_.mem_size, runtime_param_.weight_size, runtime_param_.session_id, runtime_param_.var_size, | |||||
runtime_param_.logic_var_base, runtime_param_.logic_mem_base); | |||||
GELOGI("InitRuntimeParams(), stream_num:%lu, event_num:%u, label_num:%u", runtime_param_.stream_num, | |||||
runtime_param_.event_num, runtime_param_.label_num); | |||||
} | } | ||||
void DavinciModel::CheckHasHcomOp() { | void DavinciModel::CheckHasHcomOp() { | ||||
@@ -639,7 +389,9 @@ void DavinciModel::CheckHasHcomOp() { | |||||
GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGW("Node OpDesc is nullptr"); continue); | GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGW("Node OpDesc is nullptr"); continue); | ||||
GE_IF_BOOL_EXEC(((op_desc->GetType() == HCOMBROADCAST) || (op_desc->GetType() == HCOMALLGATHER) || | GE_IF_BOOL_EXEC(((op_desc->GetType() == HCOMBROADCAST) || (op_desc->GetType() == HCOMALLGATHER) || | ||||
(op_desc->GetType() == HCOMALLREDUCE) || (op_desc->GetType() == HCOMSEND) || | (op_desc->GetType() == HCOMALLREDUCE) || (op_desc->GetType() == HCOMSEND) || | ||||
(op_desc->GetType() == HCOMRECEIVE) || (op_desc->GetType() == HCOMREDUCESCATTER)), | |||||
(op_desc->GetType() == HCOMRECEIVE) || (op_desc->GetType() == HCOMREDUCESCATTER) || | |||||
(op_desc->GetType() == HVDCALLBACKALLREDUCE) || (op_desc->GetType() == HVDCALLBACKALLGATHER) || | |||||
(op_desc->GetType() == HVDCALLBACKBROADCAST) || (op_desc->GetType() == HVDWAIT)), | |||||
uint32_t stream_id = static_cast<uint32_t>(op_desc->GetStreamId()); | uint32_t stream_id = static_cast<uint32_t>(op_desc->GetStreamId()); | ||||
(void)hcom_streams_.emplace(stream_id); GELOGD("hcom stream: %u.", stream_id); continue); | (void)hcom_streams_.emplace(stream_id); GELOGD("hcom stream: %u.", stream_id); continue); | ||||
@@ -692,6 +444,10 @@ Status DavinciModel::DoTaskSink() { | |||||
GELOGI("do task_sink."); | GELOGI("do task_sink."); | ||||
GE_CHK_STATUS_RET(BindModelStream(), "Bind model stream failed."); | GE_CHK_STATUS_RET(BindModelStream(), "Bind model stream failed."); | ||||
if (known_node_) { | |||||
GE_CHK_STATUS_RET(MallocKnownArgs(), "Mallloc known node args failed."); | |||||
} | |||||
GE_CHK_STATUS_RET(InitTaskInfo(*model_task_def_.get()), "InitTaskInfo failed."); | GE_CHK_STATUS_RET(InitTaskInfo(*model_task_def_.get()), "InitTaskInfo failed."); | ||||
GE_CHK_STATUS_RET(LoadWithQueue(), "LoadWithQueue failed."); | GE_CHK_STATUS_RET(LoadWithQueue(), "LoadWithQueue failed."); | ||||
@@ -787,12 +543,14 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size | |||||
GE_CHK_STATUS_RET(CopyVarData(compute_graph_), "copy var data failed."); | GE_CHK_STATUS_RET(CopyVarData(compute_graph_), "copy var data failed."); | ||||
GE_TIMESTAMP_START(InitModelMem); | GE_TIMESTAMP_START(InitModelMem); | ||||
GE_CHK_STATUS_RET_NOLOG(InitModelMem(dev_ptr, mem_size, weight_ptr, weight_size)); | |||||
GELOGI("known_node is %d", known_node_); | |||||
if (!known_node_) { | |||||
GE_CHK_STATUS_RET_NOLOG(InitModelMem(dev_ptr, mem_size, weight_ptr, weight_size)); | |||||
data_inputer_ = new (std::nothrow) DataInputer(); | |||||
GE_CHK_BOOL_RET_STATUS(data_inputer_ != nullptr, INTERNAL_ERROR, "data_inputer_ is nullptr."); | |||||
} | |||||
GE_TIMESTAMP_END(InitModelMem, "GraphLoader::InitModelMem"); | GE_TIMESTAMP_END(InitModelMem, "GraphLoader::InitModelMem"); | ||||
data_inputer_ = new (std::nothrow) DataInputer(); | |||||
GE_CHK_BOOL_RET_STATUS(data_inputer_ != nullptr, INTERNAL_ERROR, "data_inputer_ is nullptr."); | |||||
for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { | for (const ge::NodePtr &node : compute_graph_->GetDirectNode()) { | ||||
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | ||||
GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); | GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); | ||||
@@ -817,7 +575,6 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size | |||||
} | } | ||||
SetDataDumperArgs(); | SetDataDumperArgs(); | ||||
GE_TIMESTAMP_START(DoTaskSink); | GE_TIMESTAMP_START(DoTaskSink); | ||||
auto ret = DoTaskSink(); | auto ret = DoTaskSink(); | ||||
GE_TIMESTAMP_END(DoTaskSink, "GraphLoader::DoTaskSink"); | GE_TIMESTAMP_END(DoTaskSink, "GraphLoader::DoTaskSink"); | ||||
@@ -832,6 +589,7 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size | |||||
} | } | ||||
ProfilingManager::Instance().ReportProfilingData(GetTaskDescInfo(), compute_graph_desc_info); | ProfilingManager::Instance().ReportProfilingData(GetTaskDescInfo(), compute_graph_desc_info); | ||||
} | } | ||||
GELOGI("davinci model init success."); | |||||
return ret; | return ret; | ||||
} | } | ||||
@@ -935,6 +693,10 @@ Status DavinciModel::InitNodes(const ComputeGraphPtr &compute_graph) { | |||||
Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { | Status DavinciModel::InitDataOp(const NodePtr &node, uint32_t &data_op_index) { | ||||
// op_desc Checked by Init: Data, valid. | // op_desc Checked by Init: Data, valid. | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
if (known_node_) { | |||||
data_op_list_.push_back(op_desc); | |||||
return SUCCESS; | |||||
} | |||||
uint32_t parent_index = 0; // Ignore subgraph Data Node. | uint32_t parent_index = 0; // Ignore subgraph Data Node. | ||||
if (AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | if (AttrUtils::GetInt(op_desc, ATTR_NAME_PARENT_NODE_INDEX, parent_index)) { | ||||
GELOGI("Skip subgraph Data node: %s.", op_desc->GetName().c_str()); | GELOGI("Skip subgraph Data node: %s.", op_desc->GetName().c_str()); | ||||
@@ -1015,6 +777,10 @@ Status DavinciModel::InitInputZeroCopyNodes(const NodePtr &node) { | |||||
Status DavinciModel::InitNetOutput(const NodePtr &node) { | Status DavinciModel::InitNetOutput(const NodePtr &node) { | ||||
// node->GetOpDesc Checked by Init: NetOutput, valid. | // node->GetOpDesc Checked by Init: NetOutput, valid. | ||||
auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
if (known_node_) { | |||||
output_op_list_.push_back(op_desc); | |||||
return SUCCESS; | |||||
} | |||||
ComputeGraphPtr owner_graph = node->GetOwnerComputeGraph(); | ComputeGraphPtr owner_graph = node->GetOwnerComputeGraph(); | ||||
GE_CHECK_NOTNULL(owner_graph); | GE_CHECK_NOTNULL(owner_graph); | ||||
if (owner_graph->GetParentGraph() != nullptr) { | if (owner_graph->GetParentGraph() != nullptr) { | ||||
@@ -1024,7 +790,6 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { | |||||
} | } | ||||
output_op_list_.push_back(op_desc); | output_op_list_.push_back(op_desc); | ||||
// Make information for copy output data. | // Make information for copy output data. | ||||
const vector<int64_t> input_size_list = ModelUtils::GetInputSize(op_desc); | const vector<int64_t> input_size_list = ModelUtils::GetInputSize(op_desc); | ||||
const vector<void *> virtual_addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, op_desc, false); | const vector<void *> virtual_addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, op_desc, false); | ||||
@@ -1048,6 +813,7 @@ Status DavinciModel::InitNetOutput(const NodePtr &node) { | |||||
GELOGE(PARAM_INVALID, "Output zero copy nodes init failed!"); | GELOGE(PARAM_INVALID, "Output zero copy nodes init failed!"); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
GELOGI("DavinciModel::InitNetoutput success."); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -1605,7 +1371,9 @@ Status DavinciModel::GetOutputDescInfo(vector<InputOutputDescInfo> &output_desc, | |||||
for (size_t i = 0; i < output_op_list_.size(); i++) { | for (size_t i = 0; i < output_op_list_.size(); i++) { | ||||
auto &op_desc = output_op_list_[i]; | auto &op_desc = output_op_list_[i]; | ||||
uint32_t out_size = static_cast<uint32_t>(op_desc->GetInputsSize()); | uint32_t out_size = static_cast<uint32_t>(op_desc->GetInputsSize()); | ||||
// get real out nodes from model | |||||
vector<std::string> out_node_name; | |||||
(void)ge::AttrUtils::GetListStr(ge_model_, ATTR_MODEL_OUT_NODES_NAME, out_node_name); | |||||
for (uint32_t index = 0; index < out_size; index++) { | for (uint32_t index = 0; index < out_size; index++) { | ||||
string output_name; | string output_name; | ||||
InputOutputDescInfo output; | InputOutputDescInfo output; | ||||
@@ -1616,10 +1384,14 @@ Status DavinciModel::GetOutputDescInfo(vector<InputOutputDescInfo> &output_desc, | |||||
std::vector<int64_t> src_index = op_desc->GetSrcIndex(); | std::vector<int64_t> src_index = op_desc->GetSrcIndex(); | ||||
GE_CHK_BOOL_RET_STATUS(src_name.size() > index && src_index.size() > index, INTERNAL_ERROR, | GE_CHK_BOOL_RET_STATUS(src_name.size() > index && src_index.size() > index, INTERNAL_ERROR, | ||||
"construct output_name failed."); | "construct output_name failed."); | ||||
output_name = | |||||
std::string("output_") + std::to_string(index) + "_" + src_name[index] + "_" + std::to_string(src_index[index]); | |||||
// forward compatbility, if old om has no out_node_name, need to return output follow origin way | |||||
if (out_size == out_node_name.size()) { | |||||
output_name = out_node_name[index] + ":" + std::to_string(src_index[index]); | |||||
} else { | |||||
output_name = std::string("output_") + std::to_string(index) + "_" + src_name[index] + "_" + | |||||
std::to_string(src_index[index]); | |||||
} | |||||
output.name = output_name; | output.name = output_name; | ||||
output_desc.push_back(output); | output_desc.push_back(output); | ||||
formats.push_back(format_result); | formats.push_back(format_result); | ||||
} | } | ||||
@@ -1653,8 +1425,8 @@ Status DavinciModel::CopyInputData(const InputData &input_data, bool device_data | |||||
"input data size(%u) does not match model required size(%u), ret failed.", data_buf.length, | "input data size(%u) does not match model required size(%u), ret failed.", data_buf.length, | ||||
mem_size); | mem_size); | ||||
GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] input[%u] memaddr[%p] mem_size[%u] datasize[%u]", | |||||
runtime_param_.graph_id, data.first, mem_addr, mem_size, data_buf.length); | |||||
GELOGI("[IMAS]CopyPlainData memcpy graph_%u type[F] input[%u] dst[%p] src[%p] mem_size[%u] datasize[%u]", | |||||
runtime_param_.graph_id, data.first, mem_addr, data_buf.data, mem_size, data_buf.length); | |||||
if (data_buf.length == 0) { | if (data_buf.length == 0) { | ||||
GELOGW("No data need to memcpy!"); | GELOGW("No data need to memcpy!"); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -2000,7 +1772,7 @@ Status DavinciModel::CopyOutputData(uint32_t data_id, OutputData &output_data) { | |||||
uint32_t output_data_index = 0; | uint32_t output_data_index = 0; | ||||
for (auto &op_desc : output_op_list_) { | for (auto &op_desc : output_op_list_) { | ||||
ret = CopyOutputDataToUser(op_desc, output_data.blobs, output_data_index); | ret = CopyOutputDataToUser(op_desc, output_data.blobs, output_data_index); | ||||
GE_CHK_BOOL_EXEC(ret == SUCCESS, break, "Copy input data to model ret failed, index:%u, model id:%u", | |||||
GE_CHK_BOOL_EXEC(ret == SUCCESS, break, "Copy output data to model ret failed, index:%u, model id:%u", | |||||
output_data.index, output_data.model_id); | output_data.index, output_data.model_id); | ||||
} | } | ||||
} | } | ||||
@@ -2032,8 +1804,10 @@ Status DavinciModel::CopyOutputDataToUser(OpDescPtr &op_desc, std::vector<DataBu | |||||
"Model output data size(%u) does not match required size(%u).", v_output_size[i], | "Model output data size(%u) does not match required size(%u).", v_output_size[i], | ||||
data_buf.length); | data_buf.length); | ||||
GELOGI("CopyOutputDataToUser memcpy graph_%u type[F] name[%s] output[%lu] memaddr[%p] mem_size[%u] datasize[%u]", | |||||
runtime_param_.graph_id, op_desc->GetName().c_str(), i, data_buf.data, data_buf.length, v_output_size[i]); | |||||
GELOGI( | |||||
"CopyOutputDataToUser memcpy graph_%u type[F] name[%s] output[%lu] dst[%p] src[%p] mem_size[%u] datasize[%u]", | |||||
runtime_param_.graph_id, op_desc->GetName().c_str(), i, data_buf.data, v_output_data_addr[i], data_buf.length, | |||||
v_output_size[i]); | |||||
GE_CHK_RT_RET(rtMemcpy(data_buf.data, size, v_output_data_addr[i], size, RT_MEMCPY_DEVICE_TO_DEVICE)); | GE_CHK_RT_RET(rtMemcpy(data_buf.data, size, v_output_data_addr[i], size, RT_MEMCPY_DEVICE_TO_DEVICE)); | ||||
} | } | ||||
@@ -2104,14 +1878,9 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b | |||||
OutputData *output_data) { | OutputData *output_data) { | ||||
GE_CHK_BOOL_EXEC(listener_ != nullptr, return PARAM_INVALID, "listener_ is null."); | GE_CHK_BOOL_EXEC(listener_ != nullptr, return PARAM_INVALID, "listener_ is null."); | ||||
std::vector<ge::OutputTensorInfo> outputs; | std::vector<ge::OutputTensorInfo> outputs; | ||||
if (seq_end_flag) { | |||||
GELOGW("End of sequence, model id: %u", model_id_); | |||||
GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, END_OF_SEQUENCE, outputs), "OnComputeDone failed"); | |||||
return END_OF_SEQUENCE; | |||||
} | |||||
// return result is not required | // return result is not required | ||||
if (!rslt_flg) { | |||||
if (!rslt_flg && !seq_end_flag) { | |||||
GELOGW("Compute failed, model id: %u", model_id_); | GELOGW("Compute failed, model id: %u", model_id_); | ||||
GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "OnComputeDone failed."); | GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, INTERNAL_ERROR, outputs), "OnComputeDone failed."); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
@@ -2146,7 +1915,11 @@ Status DavinciModel::ReturnResult(uint32_t data_id, const bool rslt_flg, const b | |||||
} | } | ||||
GE_IF_BOOL_EXEC((DumpOpInputOutput() != SUCCESS), GELOGW("dump op failed, model_id: %u", model_id_);); | GE_IF_BOOL_EXEC((DumpOpInputOutput() != SUCCESS), GELOGW("dump op failed, model_id: %u", model_id_);); | ||||
if (seq_end_flag) { | |||||
GELOGW("End of sequence, model id: %u", model_id_); | |||||
GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, END_OF_SEQUENCE, outputs), "OnCompute Done failed."); | |||||
return END_OF_SEQUENCE; | |||||
} | |||||
GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, SUCCESS, outputs), "OnComputeDone failed"); | GE_CHK_STATUS(listener_->OnComputeDone(model_id_, data_id, SUCCESS, outputs), "OnComputeDone failed"); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -2493,6 +2266,87 @@ void DavinciModel::UnbindTaskSinkStream() { | |||||
return; | return; | ||||
} | } | ||||
Status DavinciModel::CreateKnownZeroCopyMap(const vector<void *> &inputs, const vector<void *> &outputs) { | |||||
GELOGI("DavinciModel::CreateKnownZeroCopyMap in."); | |||||
if (inputs.size() != data_op_list_.size()) { | |||||
GELOGE(FAILED, "input data addr %u is not equal to input op number %u.", inputs.size(), data_op_list_.size()); | |||||
return FAILED; | |||||
} | |||||
for (size_t i = 0; i < data_op_list_.size(); ++i) { | |||||
const vector<void *> addr_list = ModelUtils::GetOutputDataAddrs(runtime_param_, data_op_list_[i]); | |||||
knonw_input_data_info_[addr_list[kDataIndex]] = inputs[i]; | |||||
GELOGI("DavinciModel::CreateKnownZeroCopyMap input %d,v addr %p,p addr %p .", i, addr_list[kDataIndex], inputs[i]); | |||||
} | |||||
if (output_op_list_.size() != kOutputNum) { | |||||
GELOGE(FAILED, "output op num is %u, not equal %u.", outputs.size(), kOutputNum); | |||||
return FAILED; | |||||
} | |||||
const vector<void *> addr_list = ModelUtils::GetInputDataAddrs(runtime_param_, output_op_list_[kDataIndex]); | |||||
if (outputs.size() != addr_list.size()) { | |||||
GELOGE(FAILED, "output data addr %u is not equal to output op number %u.", outputs.size(), addr_list.size()); | |||||
return FAILED; | |||||
} | |||||
for (size_t i = 0; i < addr_list.size(); ++i) { | |||||
knonw_output_data_info_[addr_list[i]] = outputs[i]; | |||||
GELOGI("DavinciModel::CreateKnownZeroCopyMap output %d,v addr %p,p addr %p .", i, addr_list[i], outputs[i]); | |||||
} | |||||
GELOGI("DavinciModel::CreateKnownZeroCopyMap success."); | |||||
return SUCCESS; | |||||
} | |||||
Status DavinciModel::UpdateKnownZeroCopyAddr(vector<void *> &io_addrs, uint32_t args_offset) { | |||||
for (size_t i = 0; i < io_addrs.size(); ++i) { | |||||
auto it_in = knonw_input_data_info_.find(io_addrs[i]); | |||||
if (it_in != knonw_input_data_info_.end()) { | |||||
GELOGI("DavinciModel::UpdateKnownZeroCopyAddr input %d,v addr %p,p addr %p .", i, io_addrs[i], | |||||
knonw_input_data_info_.at(io_addrs[i])); | |||||
io_addrs[i] = knonw_input_data_info_.at(io_addrs[i]); | |||||
} | |||||
auto it_out = knonw_output_data_info_.find(io_addrs[i]); | |||||
if (it_out != knonw_output_data_info_.end()) { | |||||
GELOGI("DavinciModel::UpdateKnownZeroCopyAddr output %d,v addr %p,p addr %p .", i, io_addrs[i], | |||||
knonw_output_data_info_.at(io_addrs[i])); | |||||
io_addrs[i] = knonw_output_data_info_.at(io_addrs[i]); | |||||
} | |||||
} | |||||
// may args_size is equal to src_args_size? | |||||
uint32_t src_args_size = io_addrs.size() * sizeof(uint64_t); | |||||
GELOGI("DavinciModel::UpdateKnownZeroCopyAddr args host %p, src_args_size %u, args_offset %u", args_host_, | |||||
src_args_size, args_offset); | |||||
errno_t sec_ret = | |||||
memcpy_s(static_cast<char *>(args_host_) + args_offset, src_args_size, io_addrs.data(), src_args_size); | |||||
if (sec_ret != EOK) { | |||||
GELOGE(FAILED, "Call memcpy_s failed, ret: %d", sec_ret); | |||||
return FAILED; | |||||
} | |||||
GELOGI("DavinciModel::UpdateKnownZeroCopyAddr success."); | |||||
return SUCCESS; | |||||
} | |||||
Status DavinciModel::UpdateKnownNodeArgs(const vector<void *> &inputs, const vector<void *> &outputs) { | |||||
GELOGI("DavinciModel::UpdateKnownNodeArgs in"); | |||||
GE_CHK_STATUS_RET(CreateKnownZeroCopyMap(inputs, outputs), | |||||
"DavinciModel::UpdateKnownNodeArgs create map for input/output zero copy."); | |||||
for (size_t task_index = 0; task_index < task_list_.size(); ++task_index) { | |||||
auto &task = task_list_[task_index]; | |||||
if (task != nullptr) { | |||||
Status ret = task->UpdateArgs(); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(FAILED, "task %d created by davinci model is nullptr.", task_index); | |||||
return FAILED; | |||||
} | |||||
} | |||||
} | |||||
GELOGI("DavinciModel::UpdateKnownNodeArgs device args %p, size %u, host args %p, size %u", args_, total_args_size_, | |||||
args_host_, total_args_size_); | |||||
// copy continuous args from host to device | |||||
Status rt_ret = rtMemcpy(args_, total_args_size_, args_host_, total_args_size_, RT_MEMCPY_HOST_TO_DEVICE); | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) | |||||
GELOGI("DavinciModel::UpdateKnownNodeArgs success"); | |||||
return SUCCESS; | |||||
} | |||||
Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { | Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { | ||||
GELOGI("InitTaskInfo in,task size %zu", model_task_def.task().size()); | GELOGI("InitTaskInfo in,task size %zu", model_task_def.task().size()); | ||||
task_list_.resize(model_task_def.task_size()); | task_list_.resize(model_task_def.task_size()); | ||||
@@ -2513,13 +2367,13 @@ Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { | |||||
GELOGE(RT_FAILED, "Failed to set context from rt, error-code 0x%X.", rt_ret); | GELOGE(RT_FAILED, "Failed to set context from rt, error-code 0x%X.", rt_ret); | ||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
model->task_list_[idx] = TaskInfoFactory::Instance().Create(static_cast<rtModelTaskType_t>(task.type())); | |||||
Status ret = FAILED; | Status ret = FAILED; | ||||
if (model->task_list_[idx] != nullptr) { | |||||
ret = model->task_list_[idx]->Init(task, model); | |||||
// dynamic shape will create task_list_ before | |||||
if (model->task_list_[idx] == nullptr) { | |||||
model->task_list_[idx] = TaskInfoFactory::Instance().Create(static_cast<rtModelTaskType_t>(task.type())); | |||||
GE_CHECK_NOTNULL(model->task_list_[idx]); | |||||
} | } | ||||
ret = model->task_list_[idx]->Init(task, model); | |||||
return ret; | return ret; | ||||
}, | }, | ||||
model_task_def.task(i), this, ctx, i); | model_task_def.task(i), this, ctx, i); | ||||
@@ -2543,6 +2397,39 @@ Status DavinciModel::InitTaskInfo(domi::ModelTaskDef &model_task_def) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status DavinciModel::MallocKnownArgs() { | |||||
GELOGI("DavinciModel::MallocKnownArgs in"); | |||||
if (model_task_def_->task_size() == 0) { | |||||
GELOGW("DavinciModel::MallocKnownArgs davincimodel has no task info."); | |||||
return SUCCESS; | |||||
} | |||||
task_list_.resize(model_task_def_->task_size()); | |||||
for (int32_t i = 0; i < model_task_def_->task_size(); ++i) { | |||||
const domi::TaskDef &taskdef = model_task_def_->task(i); | |||||
task_list_[i] = TaskInfoFactory::Instance().Create(static_cast<rtModelTaskType_t>(taskdef.type())); | |||||
GE_CHECK_NOTNULL(task_list_[i]); | |||||
Status ret = task_list_[i]->CalculateArgs(taskdef, this); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "TaskInfo CalculateArgs failed."); | |||||
return ret; | |||||
} | |||||
} | |||||
// malloc args memory | |||||
rtError_t rt_ret = rtMalloc(&args_, total_args_size_, RT_MEMORY_HBM); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); | |||||
return RT_FAILED; | |||||
} | |||||
// malloc args host memory | |||||
rt_ret = rtMallocHost(&args_host_, total_args_size_); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Call rtMallocHost failed, ret: 0x%X", rt_ret); | |||||
return RT_FAILED; | |||||
} | |||||
GELOGI("DavinciModel::MallocKnownArgs success, total args size %u.", total_args_size_); | |||||
return SUCCESS; | |||||
} | |||||
Status DavinciModel::DistributeTask() { | Status DavinciModel::DistributeTask() { | ||||
GELOGI("do Distribute."); | GELOGI("do Distribute."); | ||||
for (auto &task : cpu_task_list_) { | for (auto &task : cpu_task_list_) { | ||||
@@ -3117,7 +3004,7 @@ bool DavinciModel::IsBroadCastOpData(const ge::NodePtr &var_node) { | |||||
GE_RT_FALSE_CHECK_NOTNULL(in_anchor); | GE_RT_FALSE_CHECK_NOTNULL(in_anchor); | ||||
ge::NodePtr dst_node = in_anchor->GetOwnerNode(); | ge::NodePtr dst_node = in_anchor->GetOwnerNode(); | ||||
GE_RT_FALSE_CHECK_NOTNULL(dst_node); | GE_RT_FALSE_CHECK_NOTNULL(dst_node); | ||||
if (dst_node->GetType() == HCOMBROADCAST) { | |||||
if (dst_node->GetType() == HCOMBROADCAST || dst_node->GetType() == HVDCALLBACKBROADCAST) { | |||||
return true; | return true; | ||||
} | } | ||||
} | } | ||||
@@ -3126,32 +3013,15 @@ bool DavinciModel::IsBroadCastOpData(const ge::NodePtr &var_node) { | |||||
} | } | ||||
void DavinciModel::InitZeroCopyUtil(bool is_dynamic_batch, bool &input_zero_copy, bool &output_zero_copy) { | void DavinciModel::InitZeroCopyUtil(bool is_dynamic_batch, bool &input_zero_copy, bool &output_zero_copy) { | ||||
auto dump_path = PropertiesManager::Instance().GetDumpOutputPath(); | |||||
auto enable_dump = !dump_path.empty(); | |||||
auto dump_op_env = std::getenv("DUMP_OP"); | |||||
if (dump_op_env != nullptr) { | |||||
string dump_op_flag(dump_op_env); | |||||
if (dump_op_flag == "1") { | |||||
enable_dump = true; | |||||
} | |||||
} | |||||
GELOGI("dump path: %s, dump_op_env: %s", dump_path.c_str(), dump_op_env); | |||||
if (!is_dynamic_batch) { | if (!is_dynamic_batch) { | ||||
zero_copy_batch_label_addrs_.clear(); | zero_copy_batch_label_addrs_.clear(); | ||||
} | } | ||||
if (enable_dump) { | |||||
input_zero_copy = false; | |||||
output_zero_copy = false; | |||||
} else { | |||||
for (const auto &addrs : output_outside_addrs_) { | |||||
const auto &used_list = addrs.second; | |||||
if (used_list.empty()) { | |||||
output_zero_copy = false; | |||||
break; | |||||
} | |||||
for (const auto &addrs : output_outside_addrs_) { | |||||
const auto &used_list = addrs.second; | |||||
if (used_list.empty()) { | |||||
output_zero_copy = false; | |||||
break; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -3244,11 +3114,11 @@ Status DavinciModel::NnExecute(rtStream_t stream, bool async_mode, const InputDa | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
uint8_t *DavinciModel::MallocFeatureMapMem(uint64_t data_size) { | |||||
uint8_t *DavinciModel::MallocFeatureMapMem(size_t data_size) { | |||||
uint8_t *mem_base = nullptr; | uint8_t *mem_base = nullptr; | ||||
const string purpose("feature map,used for op input and output."); | const string purpose("feature map,used for op input and output."); | ||||
if (std::getenv(kEnvGeuseStaticMemory) != nullptr) { | if (std::getenv(kEnvGeuseStaticMemory) != nullptr) { | ||||
data_size = static_cast<uint64_t>(VarManager::Instance(0)->GetGraphMemoryMaxSize()); | |||||
data_size = static_cast<size_t>(VarManager::Instance(0)->GetGraphMemoryMaxSize()); | |||||
string memory_key = std::to_string(0) + "_f"; | string memory_key = std::to_string(0) + "_f"; | ||||
mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, memory_key, data_size, GetDeviceId()); | mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, memory_key, data_size, GetDeviceId()); | ||||
} else { | } else { | ||||
@@ -3261,7 +3131,7 @@ uint8_t *DavinciModel::MallocFeatureMapMem(uint64_t data_size) { | |||||
return mem_base; | return mem_base; | ||||
} | } | ||||
uint8_t *DavinciModel::MallocWeightsMem(uint32_t weights_size) { | |||||
uint8_t *DavinciModel::MallocWeightsMem(size_t weights_size) { | |||||
uint8_t *weights_mem_base = nullptr; | uint8_t *weights_mem_base = nullptr; | ||||
const string purpose("weights memory in inference network."); | const string purpose("weights memory in inference network."); | ||||
if (std::getenv(kEnvGeuseStaticMemory) != nullptr) { | if (std::getenv(kEnvGeuseStaticMemory) != nullptr) { | ||||
@@ -3319,10 +3189,6 @@ uint32_t DavinciModel::GetGraphID(const std::string &session_graph_id) { | |||||
Status DavinciModel::TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id) { | Status DavinciModel::TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id) { | ||||
GELOGI("TransAllVarData start: session_id:%lu, graph_id: %u.", session_id_, graph_id); | GELOGI("TransAllVarData start: session_id:%lu, graph_id: %u.", session_id_, graph_id); | ||||
ThreadPool executor(kThreadNum); | |||||
std::vector<std::future<Status>> vector_future; | |||||
rtContext_t ctx = nullptr; | rtContext_t ctx = nullptr; | ||||
rtError_t rt_ret = rtCtxGetCurrent(&ctx); | rtError_t rt_ret = rtCtxGetCurrent(&ctx); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
@@ -3330,6 +3196,7 @@ Status DavinciModel::TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id) | |||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
std::vector<NodePtr> variable_node_list; | |||||
for (ge::NodePtr &node : graph->GetDirectNode()) { | for (ge::NodePtr &node : graph->GetDirectNode()) { | ||||
if (node == nullptr) { | if (node == nullptr) { | ||||
continue; | continue; | ||||
@@ -3337,63 +3204,13 @@ Status DavinciModel::TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id) | |||||
if (node->GetType() != VARIABLE) { | if (node->GetType() != VARIABLE) { | ||||
continue; | continue; | ||||
} | } | ||||
std::future<Status> f = executor.commit( | |||||
[](ge::NodePtr &node, DavinciModel *model, rtContext_t ctx, uint32_t graph_id) -> Status { | |||||
if (model == nullptr) { | |||||
GELOGE(FAILED, "DavinciModel is NULL!"); | |||||
return FAILED; | |||||
} | |||||
rtError_t rt_ret = rtCtxSetCurrent(ctx); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Failed to set context, error_code is: 0x%X.", rt_ret); | |||||
return RT_FAILED; | |||||
} | |||||
uint32_t allocated_graph_id = 0; | |||||
Status ret = VarManager::Instance(model->session_id_)->GetAllocatedGraphId(node->GetName(), allocated_graph_id); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "var has not been allocated, node:%s, graph_id:%u.", node->GetName().c_str(), | |||||
graph_id); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
uint32_t changed_graph_id = 0; | |||||
ret = VarManager::Instance(model->session_id_)->GetChangedGraphId(node->GetName(), changed_graph_id); | |||||
bool call_trans_var = | |||||
(ret == SUCCESS && changed_graph_id == graph_id && changed_graph_id != allocated_graph_id); | |||||
if (call_trans_var) { | |||||
GELOGI("VarManager::GetChangedGraphId() success, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id); | |||||
VarTransRoad *trans_road = VarManager::Instance(model->session_id_)->GetTransRoad(node->GetName()); | |||||
if (trans_road == nullptr) { | |||||
GELOGI("The variable %s does not have any trans road", node->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
ret = TransVarData(node, *trans_road, model->session_id_); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "TransVarData failed, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
VarManager::Instance(model->session_id_)->RemoveChangedGraphId(node->GetName()); | |||||
} | |||||
return SUCCESS; | |||||
}, | |||||
node, this, ctx, graph_id); | |||||
if (!f.valid()) { | |||||
GELOGE(FAILED, "Future is invalid"); | |||||
return FAILED; | |||||
} | |||||
vector_future.push_back(std::move(f)); | |||||
variable_node_list.emplace_back(node); | |||||
} | } | ||||
Status ret_status; | |||||
for (size_t i = 0; i < vector_future.size(); ++i) { | |||||
ret_status = vector_future[i].get(); | |||||
if (ret_status != SUCCESS) { | |||||
GELOGE(ret_status, "TransAllVarData:: trans %zu vardata failed", i); | |||||
return ret_status; | |||||
} | |||||
} | |||||
GE_CHK_STATUS_RET_NOLOG( | |||||
TransVarDataUtils::TransAllVarData(variable_node_list, session_id_, ctx, graph_id, kThreadNum)); | |||||
GELOGI("TransAllVarData success."); | GELOGI("TransAllVarData success."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -3457,96 +3274,8 @@ void DavinciModel::ReuseHcclFollowStream(int64_t remain_cap, int64_t &index) { | |||||
} | } | ||||
} | } | ||||
Status TransTensor(uint8_t *var_data, const NodePtr &var_src, const NodePtr &var_dst, formats::TransResult &result) { | |||||
GE_CHECK_NOTNULL(var_src); | |||||
GE_CHECK_NOTNULL(var_src->GetOpDesc()); | |||||
GE_CHECK_NOTNULL(var_dst); | |||||
GE_CHECK_NOTNULL(var_dst->GetOpDesc()); | |||||
auto src_data_shape_size = var_src->GetOpDesc()->GetOutputDesc(0).GetShape().GetShapeSize(); | |||||
auto src_data_datatype = var_src->GetOpDesc()->GetOutputDesc(0).GetDataType(); | |||||
auto dst_data_datatype = var_dst->GetOpDesc()->GetOutputDesc(0).GetDataType(); | |||||
GE_IF_BOOL_EXEC( | |||||
src_data_datatype != dst_data_datatype, | |||||
auto ret = formats::TransDataType( | |||||
{var_data, static_cast<size_t>(src_data_shape_size), src_data_datatype, dst_data_datatype}, result); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "trans var data on host failed"); | |||||
return ret; | |||||
}); | |||||
return SUCCESS; | |||||
} | |||||
Status DavinciModel::CopyTensorFromSrcVarNode(const NodePtr &var_src, const NodePtr &var_dst) { | |||||
/// after FE fusion pass, input num of applymomentum op was changed, 0th input is var_fp32, 6th input is | |||||
/// var_fp16(new). | |||||
/// unlink edges between var_fp32 and "dst_node" (need fp16) of var_fp32, add edge between var_fp16 and dst_node. | |||||
/// need copy value from var_fp32 to var_fp16. | |||||
/// [opdesc of var_src and var_dst are checked before passed in, no need to check if they are nullptr] | |||||
GE_IF_BOOL_EXEC(var_src == nullptr || var_dst == nullptr, GELOGE(FAILED, "node var is nullptr"); return FAILED); | |||||
// src_node output_desc (fp32) | |||||
GeTensorDesc output_desc = var_src->GetOpDesc()->GetOutputDesc(0); | |||||
auto src_data_type = output_desc.GetDataType(); | |||||
auto src_shape = output_desc.GetShape(); | |||||
auto src_format = output_desc.GetFormat(); | |||||
GELOGI("src_node %s, src_format %s, src_shape %s, src_type %s", var_src->GetName().c_str(), | |||||
TypeUtils::FormatToSerialString(src_format).c_str(), formats::ShapeToString(src_shape).c_str(), | |||||
TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
// dst_node output_desc (fp16) | |||||
GeTensorDesc dst_tensor_desc = var_dst->GetOpDesc()->GetOutputDesc(0); | |||||
auto data_type = dst_tensor_desc.GetDataType(); | |||||
auto data_shape = dst_tensor_desc.GetShape(); | |||||
auto data_format = dst_tensor_desc.GetFormat(); | |||||
GELOGI("dst_node %s, src_format %s, src_shape %s, src_type %s", var_dst->GetName().c_str(), | |||||
TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(), | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
// Sync var data from device | |||||
std::unique_ptr<uint8_t[]> var_src_data; | |||||
RtContextSwitchGuard switch_context(RT_CTX_NORMAL_MODE, device_id_); | |||||
// copy from src_node | |||||
auto ret = CopyVarFromDevice(session_id_, var_src, var_src_data, output_desc); | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "Copy Var From Device failed"); return ret); | |||||
// trans dtype | |||||
formats::TransResult trans_result; | |||||
ret = TransTensor(var_src_data.get(), var_src, var_dst, trans_result); | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "trans var data on host failed"); return ret); | |||||
// reset src value. | |||||
void *var_device = nullptr; | |||||
ret = ReAssignVarAddr(session_id_, var_dst->GetName(), dst_tensor_desc, &var_device); | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "assign mem failed"); return ret); | |||||
// copy to device | |||||
ret = CopyVarToDevice(var_dst, trans_result, var_device); | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Failed to send var data to device"); return ret); | |||||
return SUCCESS; | |||||
} | |||||
Status DavinciModel::CopyVarData(ComputeGraphPtr &compute_graph) { | Status DavinciModel::CopyVarData(ComputeGraphPtr &compute_graph) { | ||||
GELOGI("CopyVarData start: session_id:%lu.", session_id_); | |||||
if (compute_graph == nullptr) { | |||||
GELOGE(FAILED, "compute_graph is nullptr"); | |||||
return FAILED; | |||||
} | |||||
string cp_from_node; | |||||
bool copy_value = false; | |||||
for (auto &node : compute_graph->GetAllNodes()) { | |||||
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() != VARIABLE, continue); | |||||
GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), "_copy_from_var_node", cp_from_node), | |||||
GELOGI("Get original type of cp_from_node")); | |||||
if (cp_from_node.length() != 0) { | |||||
(void)ge::AttrUtils::GetBool(node->GetOpDesc(), "_copy_value", copy_value); // no need to check value | |||||
if (!copy_value) { | |||||
auto src_node = compute_graph->FindNode(cp_from_node); | |||||
GE_CHECK_NOTNULL(src_node); | |||||
GELOGI("current_var_node__: [%s] copy_from_var_node__: [%s].", node->GetName().c_str(), | |||||
src_node->GetName().c_str()); | |||||
auto ret = CopyTensorFromSrcVarNode(src_node, node); | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "copy tensor failed!"); return FAILED); | |||||
// only copy once | |||||
(void)ge::AttrUtils::SetBool(node->GetOpDesc(), "_copy_value", true); // no need to check value | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
return TransVarDataUtils::CopyVarData(compute_graph, session_id_, device_id_); | |||||
} | } | ||||
Status DavinciModel::GetComputeGraphInfo(std::vector<ComputeGraphDescInfo> &compute_graph_desc_info) { | Status DavinciModel::GetComputeGraphInfo(std::vector<ComputeGraphDescInfo> &compute_graph_desc_info) { | ||||
@@ -250,6 +250,11 @@ class DavinciModel { | |||||
return res; | return res; | ||||
} | } | ||||
rtStream_t GetRtModelStream() { | |||||
rtModel_t res = rt_model_stream_; | |||||
return res; | |||||
} | |||||
uint64_t GetRtBaseAddr() const { return runtime_param_.logic_mem_base; } | uint64_t GetRtBaseAddr() const { return runtime_param_.logic_mem_base; } | ||||
uint64_t GetRtWeightAddr() const { return runtime_param_.logic_weight_base; } | uint64_t GetRtWeightAddr() const { return runtime_param_.logic_weight_base; } | ||||
@@ -427,6 +432,26 @@ class DavinciModel { | |||||
void CreateHcclFollowStream(rtStream_t stream, int64_t remain_cap); | void CreateHcclFollowStream(rtStream_t stream, int64_t remain_cap); | ||||
void ReuseHcclFollowStream(int64_t remain_cap, int64_t &index); | void ReuseHcclFollowStream(int64_t remain_cap, int64_t &index); | ||||
void InitRuntimeParams(); | |||||
Status InitVariableMem(); | |||||
void UpdateMemBase(uint8_t *mem_base) { | |||||
runtime_param_.mem_base = mem_base; | |||||
mem_base_ = mem_base; | |||||
} | |||||
void SetTotalArgsSize(uint32_t args_size) { total_args_size_ += args_size; } | |||||
uint32_t GetTotalArgsSize() { return total_args_size_; } | |||||
void *GetCurrentArgsAddr(uint32_t offset) { | |||||
void *cur_args = static_cast<char *>(args_) + offset; | |||||
return cur_args; | |||||
} | |||||
void SetKnownNode(bool known_node) { known_node_ = known_node; } | |||||
bool IsKnownNode() { return known_node_; } | |||||
Status MallocKnownArgs(); | |||||
Status UpdateKnownNodeArgs(const vector<void *> &inputs, const vector<void *> &outputs); | |||||
Status CreateKnownZeroCopyMap(const vector<void *> &inputs, const vector<void *> &outputs); | |||||
Status UpdateKnownZeroCopyAddr(vector<void *> &io_addrs, uint32_t args_offset); | |||||
private: | private: | ||||
// memory address of weights | // memory address of weights | ||||
uint8_t *weights_mem_base_; | uint8_t *weights_mem_base_; | ||||
@@ -523,9 +548,9 @@ class DavinciModel { | |||||
Status DistributeTask(); | Status DistributeTask(); | ||||
uint8_t *MallocFeatureMapMem(uint64_t data_size); | |||||
uint8_t *MallocFeatureMapMem(size_t data_size); | |||||
uint8_t *MallocWeightsMem(uint32_t weights_size); | |||||
uint8_t *MallocWeightsMem(size_t weights_size); | |||||
void FreeFeatureMapMem(); | void FreeFeatureMapMem(); | ||||
@@ -690,8 +715,6 @@ class DavinciModel { | |||||
/// | /// | ||||
Status CpuModelRepeat(); | Status CpuModelRepeat(); | ||||
void InitRuntimeParams(); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief set ts device. | /// @brief set ts device. | ||||
@@ -709,7 +732,6 @@ class DavinciModel { | |||||
Status TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id); | Status TransAllVarData(ComputeGraphPtr &graph, uint32_t graph_id); | ||||
Status CopyVarData(ComputeGraphPtr &graph); | Status CopyVarData(ComputeGraphPtr &graph); | ||||
Status CopyTensorFromSrcVarNode(const NodePtr &var_src, const NodePtr &var_dst); | |||||
// get desc info of graph for profiling | // get desc info of graph for profiling | ||||
Status GetComputeGraphInfo(vector<ComputeGraphDescInfo> &compute_graph_desc_info); | Status GetComputeGraphInfo(vector<ComputeGraphDescInfo> &compute_graph_desc_info); | ||||
@@ -827,6 +849,13 @@ class DavinciModel { | |||||
DataDumper data_dumper_; | DataDumper data_dumper_; | ||||
uint64_t iterator_count_; | uint64_t iterator_count_; | ||||
bool is_l1_fusion_enable_; | bool is_l1_fusion_enable_; | ||||
bool known_node_ = false; | |||||
uint32_t total_args_size_ = 0; | |||||
void *args_ = nullptr; | |||||
void *args_host_ = nullptr; | |||||
std::map<const void *, void *> knonw_input_data_info_; | |||||
std::map<const void *, void *> knonw_output_data_info_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_ | #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_DAVINCI_MODEL_H_ |
@@ -24,6 +24,7 @@ | |||||
#include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
#include "graph/load/new_model_manager/davinci_model.h" | #include "graph/load/new_model_manager/davinci_model.h" | ||||
#include "graph/load/new_model_manager/davinci_model_parser.h" | #include "graph/load/new_model_manager/davinci_model_parser.h" | ||||
#include "model/ge_root_model.h" | |||||
namespace ge { | namespace ge { | ||||
thread_local uint32_t device_count = 0; | thread_local uint32_t device_count = 0; | ||||
@@ -68,8 +69,6 @@ Status ModelManager::KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, u | |||||
GE_CHK_RT(rtFree(aicpu_kernel_addr)); return FAILED;) | GE_CHK_RT(rtFree(aicpu_kernel_addr)); return FAILED;) | ||||
uint64_t kernel_id_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(aicpu_kernel_addr)); | uint64_t kernel_id_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(aicpu_kernel_addr)); | ||||
param_base.fwkKernelBase.fwk_kernel.kernelID = kernel_id_addr; | param_base.fwkKernelBase.fwk_kernel.kernelID = kernel_id_addr; | ||||
// Remove model key from map | |||||
model_aicpu_kernel_.erase(iter); | |||||
} | } | ||||
} | } | ||||
@@ -214,18 +213,38 @@ Status ModelManager::SetDevice(int32_t deviceId) const { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
ge::Status ModelManager::DoLoadHybridModelOnline(uint32_t model_id, const shared_ptr<ge::GeRootModel> &ge_root_model, | |||||
const shared_ptr<ModelListener> &listener) { | |||||
auto hybrid_model = hybrid::HybridDavinciModel::Create(ge_root_model); | |||||
GE_CHECK_NOTNULL(hybrid_model); | |||||
hybrid_model->SetListener(listener); | |||||
hybrid_model->SetModelId(model_id); | |||||
hybrid_model->SetDeviceId(GetContext().DeviceId()); | |||||
GE_CHK_STATUS_RET(hybrid_model->Init(), "Failed to init hybrid model. model_id = %u", model_id); | |||||
auto shared_model = std::shared_ptr<hybrid::HybridDavinciModel>(hybrid_model.release()); | |||||
InsertModel(model_id, shared_model); | |||||
return SUCCESS; | |||||
} | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief load model online | /// @brief load model online | ||||
/// @return Status run result | /// @return Status run result | ||||
/// | /// | ||||
Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::GeModel> &ge_model, | |||||
Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::GeRootModel> &ge_root_model, | |||||
std::shared_ptr<ModelListener> listener) { | std::shared_ptr<ModelListener> listener) { | ||||
GE_CHK_BOOL_RET_STATUS(listener.get() != nullptr, PARAM_INVALID, "Param incorrect, listener is null"); | GE_CHK_BOOL_RET_STATUS(listener.get() != nullptr, PARAM_INVALID, "Param incorrect, listener is null"); | ||||
if (model_id == INVALID_MODEL_ID) { | if (model_id == INVALID_MODEL_ID) { | ||||
GenModelId(&model_id); | GenModelId(&model_id); | ||||
} | } | ||||
bool is_shape_unknown = false; | |||||
GE_CHK_STATUS_RET(ge_root_model->CheckIsUnknownShape(is_shape_unknown), "CheckIsUnknownShape failed, model id:%u", | |||||
model_id); | |||||
if (is_shape_unknown) { | |||||
return DoLoadHybridModelOnline(model_id, ge_root_model, listener); | |||||
} | |||||
GE_CHK_STATUS_RET(SetDevice(static_cast<int32_t>(GetContext().DeviceId())), "Set device failed, model id:%u.", | GE_CHK_STATUS_RET(SetDevice(static_cast<int32_t>(GetContext().DeviceId())), "Set device failed, model id:%u.", | ||||
model_id); | model_id); | ||||
mmTimespec timespec = mmGetTickCount(); | mmTimespec timespec = mmGetTickCount(); | ||||
@@ -238,6 +257,11 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||||
davinci_model->SetId(model_id); | davinci_model->SetId(model_id); | ||||
davinci_model->SetDeviceId(GetContext().DeviceId()); | davinci_model->SetDeviceId(GetContext().DeviceId()); | ||||
auto root_graph = ge_root_model->GetRootGraph(); | |||||
GE_CHECK_NOTNULL(root_graph); | |||||
string root_model_name = root_graph->GetName(); | |||||
auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); | |||||
GeModelPtr ge_model = name_to_model[root_model_name]; | |||||
Status ret = SUCCESS; | Status ret = SUCCESS; | ||||
do { | do { | ||||
GE_TIMESTAMP_START(Assign); | GE_TIMESTAMP_START(Assign); | ||||
@@ -274,16 +298,26 @@ void ModelManager::InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davin | |||||
model_map_[id] = davinci_model; | model_map_[id] = davinci_model; | ||||
} | } | ||||
void ModelManager::InsertModel(uint32_t id, shared_ptr<hybrid::HybridDavinciModel> &hybrid_model) { | |||||
GE_CHK_BOOL_EXEC(hybrid_model != nullptr, return, "hybrid_model ptr is null, id: %u", id); | |||||
std::lock_guard<std::mutex> lock(map_mutex_); | |||||
hybrid_model_map_[id] = hybrid_model; | |||||
} | |||||
Status ModelManager::DeleteModel(uint32_t id) { | Status ModelManager::DeleteModel(uint32_t id) { | ||||
std::lock_guard<std::mutex> lock(map_mutex_); | std::lock_guard<std::mutex> lock(map_mutex_); | ||||
auto it = model_map_.find(id); | auto it = model_map_.find(id); | ||||
if (it == model_map_.end()) { | |||||
auto hybrid_model_it = hybrid_model_map_.find(id); | |||||
if (it != model_map_.end()) { | |||||
(void)model_map_.erase(it); | |||||
} else if (hybrid_model_it != hybrid_model_map_.end()) { | |||||
(void)hybrid_model_map_.erase(hybrid_model_it); | |||||
} else { | |||||
GELOGE(PARAM_INVALID, "model id %u does not exists.", id); | GELOGE(PARAM_INVALID, "model id %u does not exists.", id); | ||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
(void)model_map_.erase(it); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -294,6 +328,13 @@ std::shared_ptr<DavinciModel> ModelManager::GetModel(uint32_t id) { | |||||
return (it == model_map_.end()) ? nullptr : it->second; | return (it == model_map_.end()) ? nullptr : it->second; | ||||
} | } | ||||
std::shared_ptr<hybrid::HybridDavinciModel> ModelManager::GetHybridModel(uint32_t id) { | |||||
std::lock_guard<std::mutex> lock(map_mutex_); | |||||
auto it = hybrid_model_map_.find(id); | |||||
return (it == hybrid_model_map_.end()) ? nullptr : it->second; | |||||
} | |||||
Status ModelManager::Unload(uint32_t model_id) { | Status ModelManager::Unload(uint32_t model_id) { | ||||
GE_CHK_STATUS_RET(DeleteModel(model_id), "failed to unload model id: %u", model_id); | GE_CHK_STATUS_RET(DeleteModel(model_id), "failed to unload model id: %u", model_id); | ||||
if (device_count > 0) { | if (device_count > 0) { | ||||
@@ -349,7 +390,10 @@ Status ModelManager::DataInput(const InputData &input_data, OutputData &output_d | |||||
/// | /// | ||||
Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<InputTensorInfo> &inputs) { | Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<InputTensorInfo> &inputs) { | ||||
std::shared_ptr<DavinciModel> model = GetModel(model_id); | std::shared_ptr<DavinciModel> model = GetModel(model_id); | ||||
GE_CHECK_NOTNULL(model); | |||||
auto hybrid_model = GetHybridModel(model_id); | |||||
if (hybrid_model == nullptr) { | |||||
GE_CHECK_NOTNULL(model); | |||||
} | |||||
InputData input_data; | InputData input_data; | ||||
input_data.model_id = model_id; | input_data.model_id = model_id; | ||||
@@ -374,6 +418,12 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<InputT | |||||
GE_CHK_STATUS_EXEC(data_wrap->Init(input_data, output_data), return domi::PUSH_DATA_FAILED, | GE_CHK_STATUS_EXEC(data_wrap->Init(input_data, output_data), return domi::PUSH_DATA_FAILED, | ||||
"Init InputDataWrapper failed,input data model_id is : %u.", model_id); | "Init InputDataWrapper failed,input data model_id is : %u.", model_id); | ||||
if (hybrid_model != nullptr) { | |||||
GE_CHK_STATUS_RET(hybrid_model->EnqueueData(data_wrap), "Data queue is full, please call again later, model_id %u ", | |||||
model_id); | |||||
return SUCCESS; | |||||
} | |||||
GE_CHK_BOOL_RET_STATUS(model != nullptr, PARAM_INVALID, "Invalid Model ID %u in InputData! ", model_id); | GE_CHK_BOOL_RET_STATUS(model != nullptr, PARAM_INVALID, "Invalid Model ID %u in InputData! ", model_id); | ||||
DataInputer *inputer = model->GetDataInputer(); | DataInputer *inputer = model->GetDataInputer(); | ||||
@@ -395,6 +445,13 @@ Status ModelManager::DataInputTensor(uint32_t model_id, const std::vector<InputT | |||||
/// @author | /// @author | ||||
/// | /// | ||||
Status ModelManager::Start(uint32_t model_id) { | Status ModelManager::Start(uint32_t model_id) { | ||||
auto hybrid_model = GetHybridModel(model_id); | |||||
if (hybrid_model != nullptr) { | |||||
GE_CHK_STATUS_RET_NOLOG(hybrid_model->ModelRunStart()); | |||||
GELOGI("Start hybrid model %u success.", model_id); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | ||||
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid Model ID %u to start! ", model_id); | GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid Model ID %u to start! ", model_id); | ||||
@@ -416,6 +473,13 @@ Status ModelManager::Start(uint32_t model_id) { | |||||
/// @author | /// @author | ||||
/// | /// | ||||
Status ModelManager::Stop(uint32_t model_id) { | Status ModelManager::Stop(uint32_t model_id) { | ||||
auto hybrid_model = GetHybridModel(model_id); | |||||
if (hybrid_model != nullptr) { | |||||
GE_CHK_STATUS_RET_NOLOG(hybrid_model->ModelRunStop()); | |||||
GELOGI("Stop hybrid model %u success.", model_id); | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | ||||
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid Model ID %u to stop!", model_id); | GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "Invalid Model ID %u to stop!", model_id); | ||||
@@ -581,6 +645,13 @@ Status ModelManager::HandleDumpCommand(const Command &command) { | |||||
} | } | ||||
Status ModelManager::GetMaxUsedMemory(const uint32_t model_id, uint64_t &max_size) { | Status ModelManager::GetMaxUsedMemory(const uint32_t model_id, uint64_t &max_size) { | ||||
auto hybrid_model = GetHybridModel(model_id); | |||||
if (hybrid_model != nullptr) { | |||||
// TODO hybrid use dynamic memory allocation | |||||
max_size = 0; | |||||
return SUCCESS; | |||||
} | |||||
std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | std::shared_ptr<DavinciModel> davinci_model = GetModel(model_id); | ||||
GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "GetMaxUsedMemory Failed, Invalid Model ID %u !", | GE_CHK_BOOL_RET_STATUS(davinci_model != nullptr, PARAM_INVALID, "GetMaxUsedMemory Failed, Invalid Model ID %u !", | ||||
model_id); | model_id); | ||||
@@ -25,6 +25,7 @@ | |||||
#include <set> | #include <set> | ||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include <model/ge_root_model.h> | |||||
#include "cce/aicpu_engine_struct.h" | #include "cce/aicpu_engine_struct.h" | ||||
#include "common/ge_inner_error_codes.h" | #include "common/ge_inner_error_codes.h" | ||||
#include "common/ge_types.h" | #include "common/ge_types.h" | ||||
@@ -34,10 +35,10 @@ | |||||
#include "ge/ge_api_types.h" | #include "ge/ge_api_types.h" | ||||
#include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
#include "graph/model.h" | #include "graph/model.h" | ||||
#include "hybrid/hybrid_davinci_model.h" | |||||
#include "runtime/base.h" | #include "runtime/base.h" | ||||
namespace ge { | namespace ge { | ||||
class DavinciModel; | class DavinciModel; | ||||
class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | ||||
@@ -69,9 +70,12 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
/// @return Status run result | /// @return Status run result | ||||
/// @author @ | /// @author @ | ||||
/// | /// | ||||
ge::Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeModel> &model, | |||||
ge::Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeRootModel> &ge_root_model, | |||||
std::shared_ptr<ModelListener> listener); | std::shared_ptr<ModelListener> listener); | ||||
ge::Status DoLoadHybridModelOnline(uint32_t model_id, const shared_ptr<ge::GeRootModel> &ge_root_model, | |||||
const std::shared_ptr<ModelListener> &listener); | |||||
/// | /// | ||||
/// @ingroup ge | /// @ingroup ge | ||||
/// @brief ACL case, Load task list with queue. | /// @brief ACL case, Load task list with queue. | ||||
@@ -206,6 +210,8 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
/// | /// | ||||
std::shared_ptr<DavinciModel> GetModel(uint32_t id); | std::shared_ptr<DavinciModel> GetModel(uint32_t id); | ||||
std::shared_ptr<hybrid::HybridDavinciModel> GetHybridModel(uint32_t id); | |||||
ge::Status KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, uint64_t session_id, uint32_t model_id); | ge::Status KernelLaunchEx(aicpu::FWKAdapter::FWKOperateType op_type, uint64_t session_id, uint32_t model_id); | ||||
ge::Status CreateAicpuSession(uint64_t session_id); | ge::Status CreateAicpuSession(uint64_t session_id); | ||||
@@ -238,6 +244,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
/// @brief insert new model into model manager set | /// @brief insert new model into model manager set | ||||
/// | /// | ||||
void InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davinci_model); | void InsertModel(uint32_t id, std::shared_ptr<DavinciModel> &davinci_model); | ||||
void InsertModel(uint32_t id, std::shared_ptr<hybrid::HybridDavinciModel> &hybrid_model); | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
@@ -248,6 +255,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
void GenModelId(uint32_t *id); | void GenModelId(uint32_t *id); | ||||
std::map<uint32_t, std::shared_ptr<DavinciModel>> model_map_; | std::map<uint32_t, std::shared_ptr<DavinciModel>> model_map_; | ||||
std::map<uint32_t, std::shared_ptr<hybrid::HybridDavinciModel>> hybrid_model_map_; | |||||
std::map<std::string, std::vector<uint64_t>> model_aicpu_kernel_; | std::map<std::string, std::vector<uint64_t>> model_aicpu_kernel_; | ||||
uint32_t max_model_id_; | uint32_t max_model_id_; | ||||
std::mutex map_mutex_; | std::mutex map_mutex_; | ||||
@@ -474,7 +474,7 @@ vector<void *> ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param | |||||
int64_t workspace_bytes = v_workspace_bytes[i]; | int64_t workspace_bytes = v_workspace_bytes[i]; | ||||
uint8_t *mem_addr = workspace_bytes == 0 ? nullptr : mem_base + workspace_offset; | uint8_t *mem_addr = workspace_bytes == 0 ? nullptr : mem_base + workspace_offset; | ||||
v_workspace_data_addr.push_back(mem_addr); | v_workspace_data_addr.push_back(mem_addr); | ||||
GELOGI("[IMAS]GetWorkspaceDataAddrs graph_%u type[F] name[%s] output[%zu] offset[%ld] bytes[%ld] memaddr[%p]", | |||||
GELOGI("[IMAS]GetWorkspaceDataAddrs graph_%u type[F] name[%s] workspace[%zu] offset[%ld] bytes[%ld] memaddr[%p]", | |||||
model_param.graph_id, op_desc->GetName().c_str(), i, workspace_offset, workspace_bytes, mem_addr); | model_param.graph_id, op_desc->GetName().c_str(), i, workspace_offset, workspace_bytes, mem_addr); | ||||
} | } | ||||
} | } | ||||
@@ -37,17 +37,12 @@ HcclTaskInfo::~HcclTaskInfo() { | |||||
if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rtFree Fail, ret = 0x%X.", ret); | GELOGE(RT_FAILED, "Call rtFree Fail, ret = 0x%X.", ret); | ||||
} | } | ||||
private_def_ = nullptr; | private_def_ = nullptr; | ||||
} | } | ||||
input_data_addr_ = nullptr; | |||||
davinci_model_ = nullptr; | davinci_model_ = nullptr; | ||||
ops_kernel_store_ = nullptr; | ops_kernel_store_ = nullptr; | ||||
output_data_addr_ = nullptr; | |||||
workspace_addr_ = nullptr; | |||||
max_node_of_hccl_stream_ = 0; | max_node_of_hccl_stream_ = 0; | ||||
} | } | ||||
Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | ||||
GELOGI("HcclTaskInfo Init Start."); | GELOGI("HcclTaskInfo Init Start."); | ||||
if (davinci_model == nullptr) { | if (davinci_model == nullptr) { | ||||
@@ -55,63 +50,75 @@ Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_m | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
davinci_model_ = davinci_model; | davinci_model_ = davinci_model; | ||||
Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); | Status ret = SetStream(task_def.stream_id(), davinci_model->GetStreamList()); | ||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
return ret; | return ret; | ||||
} | } | ||||
GetPrivateDefByTaskDef(task_def); | GetPrivateDefByTaskDef(task_def); | ||||
auto hccl_def = task_def.kernel_hccl(); | auto hccl_def = task_def.kernel_hccl(); | ||||
hcclDataType_t data_type; | |||||
int32_t count; | |||||
uint32_t op_index = hccl_def.op_index(); | uint32_t op_index = hccl_def.op_index(); | ||||
GELOGI("HcclTaskInfo Init, op_index is: %u", op_index); | GELOGI("HcclTaskInfo Init, op_index is: %u", op_index); | ||||
std::string hccl_type = hccl_def.hccl_type(); | |||||
// Get HCCL op | // Get HCCL op | ||||
OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); | OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
Status dmrt = HcomOmeUtil::GetHcomDataType(op_desc, data_type); | |||||
// Create the kernel hccl infos | |||||
CreateKernelHcclInfo(op_desc); | |||||
// Initialize the hccl_type of all kernel hccl info | |||||
HcomOmeUtil::GetHcclType(task_def, kernel_hccl_infos_); | |||||
// Only in Horovod scenario should get the inputName and GeShape | |||||
ret = HcomOmeUtil::GetHorovodInputs(op_desc, kernel_hccl_infos_); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(FAILED, "davinci_model: GetHorovodInputs fail! domi error: %u", ret); | |||||
return FAILED; | |||||
} | |||||
Status dmrt = HcomOmeUtil::GetHcclDataType(op_desc, kernel_hccl_infos_); | |||||
if (dmrt != SUCCESS) { | if (dmrt != SUCCESS) { | ||||
GELOGE(FAILED, "davinci_model: GetHcomDataType fail! domi error: %u", dmrt); | GELOGE(FAILED, "davinci_model: GetHcomDataType fail! domi error: %u", dmrt); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
dmrt = HcomOmeUtil::GetHcomCount(op_desc, data_type, (hccl_type == HCOMALLGATHER), count); | |||||
dmrt = HcomOmeUtil::GetHcclCount(op_desc, kernel_hccl_infos_); | |||||
if (dmrt != SUCCESS) { | if (dmrt != SUCCESS) { | ||||
GELOGE(FAILED, "davinci_model: GetHcomCount fail! domi error: %u", dmrt); | GELOGE(FAILED, "davinci_model: GetHcomCount fail! domi error: %u", dmrt); | ||||
return FAILED; | return FAILED; | ||||
} | } | ||||
ret = SetAddrs(hccl_type, op_desc); | |||||
// Only HCOMBROADCAST and HVDCALLBACKBROADCAST need to get the rootId | |||||
dmrt = HcomOmeUtil::GetAllRootId(op_desc, kernel_hccl_infos_); | |||||
if (dmrt != SUCCESS) { | |||||
GELOGE(FAILED, "davinci_model: Get rootId fail! domi error: %u", dmrt); | |||||
return FAILED; | |||||
} | |||||
ret = SetAddrs(op_desc, kernel_hccl_infos_); | |||||
if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
GELOGE(ret, "Setaddrs Fail."); | GELOGE(ret, "Setaddrs Fail."); | ||||
return ret; | return ret; | ||||
} | } | ||||
count_ = count; | |||||
hccl_type_ = hccl_type; | |||||
data_type_ = data_type; | |||||
// GE's new process: hccl declares the need for Workspace size, and GE allocates Workspace | // GE's new process: hccl declares the need for Workspace size, and GE allocates Workspace | ||||
auto workspace_bytes = op_desc->GetWorkspaceBytes(); | |||||
if (!workspace_bytes.empty()) { | |||||
uint64_t workspace_mem_size_tmp = workspace_bytes[0]; | |||||
GELOGI("hccl need workSpaceMemSize=%lu", workspace_mem_size_tmp); | |||||
if (workspace_mem_size_tmp != 0) { | |||||
workspace_mem_size_ = workspace_mem_size_tmp; | |||||
vector<void *> workspace_data_addrs = | |||||
ModelUtils::GetWorkspaceDataAddrs(davinci_model->GetRuntimeParam(), op_desc); | |||||
if (!workspace_data_addrs.empty()) { | |||||
GELOGI("Get workSpaceAddr"); | |||||
workspace_addr_ = workspace_data_addrs[0]; | |||||
} | |||||
} | |||||
ret = SetWorkspace(op_desc, kernel_hccl_infos_); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "SetWorkspace Fail."); | |||||
return ret; | |||||
} | } | ||||
// GE's new process: hccl declares the number of streams required, creates a stream by GE, and sends it to hccl | // GE's new process: hccl declares the number of streams required, creates a stream by GE, and sends it to hccl | ||||
ret = SetFollowStream(op_desc, davinci_model); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "SetStream Fail."); | |||||
return ret; | |||||
} | |||||
GELOGI("HcclTaskInfo Init Success"); | |||||
return SUCCESS; | |||||
} | |||||
Status HcclTaskInfo::SetFollowStream(const ge::ConstOpDescPtr &op_desc, DavinciModel *davinci_model) { | |||||
if (!HcomOmeUtil::IsHCOMOp(op_desc->GetType())) { | |||||
GELOGI("Node %s Optye %s no need to create slave streams.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
Status ret; | |||||
int64_t hccl_stream_num = 0; | int64_t hccl_stream_num = 0; | ||||
if (!ge::AttrUtils::GetInt(op_desc, "used_stream_num", hccl_stream_num)) { | if (!ge::AttrUtils::GetInt(op_desc, "used_stream_num", hccl_stream_num)) { | ||||
GELOGI("op_desc has no attr used_stream_num!"); | GELOGI("op_desc has no attr used_stream_num!"); | ||||
@@ -142,8 +149,7 @@ Status HcclTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci_m | |||||
return FAILED; | return FAILED; | ||||
} | } | ||||
} | } | ||||
GELOGI("HcclTaskInfo Init Success, hcclStreamNum =%ld", hccl_stream_num); | |||||
GELOGI("Initialize hccl slave stream success, hcclStreamNum =%ld", hccl_stream_num); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -167,14 +173,12 @@ Status HcclTaskInfo::CreateStream(int64_t stream_num, DavinciModel *davinci_mode | |||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
// Create slave stream, inactive by default, activated by hccl | // Create slave stream, inactive by default, activated by hccl | ||||
rt_ret = rtModelBindStream(davinci_model->GetRtModelHandle(), stream, RT_MODEL_WAIT_ACTIVE_STREAM); | rt_ret = rtModelBindStream(davinci_model->GetRtModelHandle(), stream, RT_MODEL_WAIT_ACTIVE_STREAM); | ||||
if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | ||||
return RT_FAILED; | return RT_FAILED; | ||||
} | } | ||||
GELOGD("hccl_stream addr is=%p", stream); | GELOGD("hccl_stream addr is=%p", stream); | ||||
int64_t remain_cap = max_node_of_hccl_stream_ - 1; | int64_t remain_cap = max_node_of_hccl_stream_ - 1; | ||||
davinci_model->CreateHcclFollowStream(stream, remain_cap); | davinci_model->CreateHcclFollowStream(stream, remain_cap); | ||||
@@ -192,7 +196,6 @@ Status HcclTaskInfo::Distribute() { | |||||
GELOGE(INTERNAL_ERROR, "ops kernel store is null."); | GELOGE(INTERNAL_ERROR, "ops kernel store is null."); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
OpsKernelInfoStore *ops_kernel_info_store = reinterpret_cast<OpsKernelInfoStore *>(ops_kernel_store_); | OpsKernelInfoStore *ops_kernel_info_store = reinterpret_cast<OpsKernelInfoStore *>(ops_kernel_store_); | ||||
GE_CHECK_NOTNULL(ops_kernel_info_store); | GE_CHECK_NOTNULL(ops_kernel_info_store); | ||||
GETaskInfo ge_task; | GETaskInfo ge_task; | ||||
@@ -202,81 +205,62 @@ Status HcclTaskInfo::Distribute() { | |||||
GELOGE(INTERNAL_ERROR, "davinci_model : load task fail, return ret: %u", result); | GELOGE(INTERNAL_ERROR, "davinci_model : load task fail, return ret: %u", result); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
GELOGI("HcclTaskInfo Distribute Success."); | GELOGI("HcclTaskInfo Distribute Success."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HcclTaskInfo::SetAddrs(const std::string &hccl_type, const std::shared_ptr<OpDesc> &op_desc) { | |||||
Status HcclTaskInfo::SetAddrs(const std::shared_ptr<OpDesc> &op_desc, | |||||
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (HcomOmeUtil::CheckKernelHcclInfo(op_desc, kernel_hccl_infos) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "HcomOmeUtil:: the number of GETaskKernelHcclInfo is invalid."); | |||||
return PARAM_INVALID; | |||||
} | |||||
GELOGI("Set hccl task input output address, node[%s}, type[%s] kernel_hccl_infos.size[%zu].", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), kernel_hccl_infos.size()); | |||||
if (op_desc->GetType() == HVDWAIT) { | |||||
return SUCCESS; | |||||
} | |||||
domi::Status dmrt; | domi::Status dmrt; | ||||
hcclRedOp_t op_type; | |||||
hcclRedOp_t op_type = HCCL_REP_OP_SUM; | |||||
GE_CHECK_NOTNULL(davinci_model_); | GE_CHECK_NOTNULL(davinci_model_); | ||||
GELOGI("Calc opType[%s] input address before. Node name[%s]", op_desc->GetType().c_str(), op_desc->GetName().c_str()); | |||||
auto input_data_addr_list = ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); | auto input_data_addr_list = ModelUtils::GetInputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); | ||||
if (!input_data_addr_list.empty()) { | |||||
input_data_addr_ = input_data_addr_list[0]; | |||||
} | |||||
void *output_data_addr = nullptr; | |||||
auto output_data_addr_list = ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); | auto output_data_addr_list = ModelUtils::GetOutputDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); | ||||
if (!output_data_addr_list.empty()) { | |||||
output_data_addr = output_data_addr_list[0]; | |||||
} | |||||
if (hccl_type == HCOMBROADCAST) { | |||||
int64_t root_id; | |||||
dmrt = HcomOmeUtil::GetHcomRootId(op_desc, root_id); | |||||
if (dmrt != SUCCESS) { | |||||
GELOGE(FAILED, "davinci_model: GetHcomRootId fail! domi error: %u", dmrt); | |||||
return FAILED; | |||||
} | |||||
root_id_ = root_id; | |||||
} else if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE) { | |||||
output_data_addr_ = output_data_addr; | |||||
} else if (hccl_type == HCOMALLREDUCE) { | |||||
dmrt = HcomOmeUtil::GetHcomOperationType(op_desc, op_type); | |||||
if (dmrt != SUCCESS) { | |||||
GELOGE(FAILED, "davinci_model: GetHcomOperationType fail! domi error: %u", dmrt); | |||||
return FAILED; | |||||
} | |||||
output_data_addr_ = output_data_addr; | |||||
op_type_ = op_type; | |||||
} else if (hccl_type == HCOMREDUCESCATTER) { | |||||
dmrt = HcomOmeUtil::GetHcomOperationType(op_desc, op_type); | |||||
if (dmrt != SUCCESS) { | |||||
GELOGE(FAILED, "davinci_model: GetHcomOperationType fail! domi error: %u", dmrt); | |||||
return FAILED; | |||||
// initialize every kernel_hccl_info inputDataAddr | |||||
for (size_t i = 0; i < kernel_hccl_infos.size(); i++) { | |||||
std::string hccl_type = kernel_hccl_infos[i].hccl_type; | |||||
void *input_data_addr = input_data_addr_list.empty() ? nullptr : input_data_addr_list[i]; | |||||
kernel_hccl_infos[i].inputDataAddr = input_data_addr; | |||||
void *output_data_addr = output_data_addr_list.empty() ? nullptr : output_data_addr_list[i]; | |||||
if (hccl_type == HCOMALLGATHER || hccl_type == HCOMRECEIVE || hccl_type == HVDCALLBACKALLGATHER) { | |||||
kernel_hccl_infos[i].outputDataAddr = output_data_addr; | |||||
} else if (hccl_type == HCOMALLREDUCE || hccl_type == HCOMREDUCESCATTER || hccl_type == HVDCALLBACKALLREDUCE) { | |||||
dmrt = HcomOmeUtil::GetHcclOperationType(op_desc, op_type); | |||||
if (dmrt != SUCCESS) { | |||||
GELOGE(FAILED, "davinci_model: GetHcomOperationType fail! domi error: %u", dmrt); | |||||
return FAILED; | |||||
} | |||||
kernel_hccl_infos[i].outputDataAddr = output_data_addr; | |||||
kernel_hccl_infos[i].opType = op_type; | |||||
} | } | ||||
output_data_addr_ = output_data_addr; | |||||
op_type_ = op_type; | |||||
davinci_model_->DisableZeroCopy(input_data_addr); | |||||
} | } | ||||
davinci_model_->DisableZeroCopy(input_data_addr_); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
void HcclTaskInfo::TransToGETaskInfo(GETaskInfo &ge_task) { | void HcclTaskInfo::TransToGETaskInfo(GETaskInfo &ge_task) { | ||||
ge_task.id = id_; | ge_task.id = id_; | ||||
ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL); | ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL); | ||||
ge_task.stream = stream_; | ge_task.stream = stream_; | ||||
ge_task.kernelHcclInfo.hccl_type = hccl_type_; | |||||
ge_task.kernelHcclInfo.inputDataAddr = input_data_addr_; | |||||
ge_task.kernelHcclInfo.outputDataAddr = output_data_addr_; | |||||
ge_task.kernelHcclInfo.workSpaceAddr = workspace_addr_; | |||||
ge_task.kernelHcclInfo.count = count_; | |||||
ge_task.kernelHcclInfo.dataType = static_cast<int32_t>(data_type_); | |||||
ge_task.kernelHcclInfo.opType = static_cast<int32_t>(op_type_); | |||||
ge_task.kernelHcclInfo.rootId = root_id_; | |||||
ge_task.kernelHcclInfo.workSpaceMemSize = workspace_mem_size_; | |||||
ge_task.kernelHcclInfo.hcclStreamList = hccl_stream_list_; | |||||
ge_task.kernelHcclInfo = kernel_hccl_infos_; | |||||
ge_task.privateDef = private_def_; | ge_task.privateDef = private_def_; | ||||
ge_task.privateDefLen = private_def_len_; | ge_task.privateDefLen = private_def_len_; | ||||
ge_task.opsKernelStorePtr = ops_kernel_store_; | ge_task.opsKernelStorePtr = ops_kernel_store_; | ||||
for (size_t i = 0; i < ge_task.kernelHcclInfo.size(); i++) { | |||||
ge_task.kernelHcclInfo[i].hcclStreamList = hccl_stream_list_; | |||||
} | |||||
} | } | ||||
void HcclTaskInfo::GetPrivateDefByTaskDef(const domi::TaskDef &task) { | void HcclTaskInfo::GetPrivateDefByTaskDef(const domi::TaskDef &task) { | ||||
// Get privateDef and opsKernelStorePtr from taskDef and save them in taskInfo | // Get privateDef and opsKernelStorePtr from taskDef and save them in taskInfo | ||||
GELOGI("get custom info in modelTaskDef."); | GELOGI("get custom info in modelTaskDef."); | ||||
@@ -299,11 +283,54 @@ void HcclTaskInfo::GetPrivateDefByTaskDef(const domi::TaskDef &task) { | |||||
GELOGE(RT_FAILED, "Call rtMemcpy Fail, ret = 0x%X.", ret); | GELOGE(RT_FAILED, "Call rtMemcpy Fail, ret = 0x%X.", ret); | ||||
return; | return; | ||||
} | } | ||||
GELOGI("The first address of the custom info, privateDef=%p.", private_def_); | GELOGI("The first address of the custom info, privateDef=%p.", private_def_); | ||||
} | } | ||||
} | } | ||||
} | } | ||||
void HcclTaskInfo::CreateKernelHcclInfo(const ge::ConstOpDescPtr &op_desc) { | |||||
GE_CHECK_NOTNULL_JUST_RETURN(op_desc); | |||||
if (HcomOmeUtil::IsHCOMOp(op_desc->GetType())) { | |||||
GETaskKernelHcclInfo kernel_hccl_info; | |||||
kernel_hccl_infos_.emplace_back(kernel_hccl_info); | |||||
} else if (HcomOmeUtil::IsHorovodOp(op_desc->GetType())) { | |||||
// Horovod wait do not have any input, but create a GETaskKernelHcclInfo to record hccl_type. | |||||
// Other Operator need to check that the number of GETaskKernelHcclInfo must equals to number of inputs | |||||
if (op_desc->GetType() == HVDWAIT) { | |||||
GETaskKernelHcclInfo kernel_hccl_info; | |||||
kernel_hccl_infos_.emplace_back(kernel_hccl_info); | |||||
return; | |||||
} | |||||
for (size_t i = 0; i < op_desc->GetInputsSize(); i++) { | |||||
GETaskKernelHcclInfo kernel_hccl_info; | |||||
kernel_hccl_infos_.emplace_back(kernel_hccl_info); | |||||
} | |||||
} | |||||
} | |||||
Status HcclTaskInfo::SetWorkspace(const std::shared_ptr<OpDesc> &op_desc, | |||||
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
GELOGI("SetWorkspace Node[%s] opType[%s] set workspace.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
uint64_t workspace_mem_size = 0; | |||||
void *workspace_addr = nullptr; | |||||
auto workspace_bytes = op_desc->GetWorkspaceBytes(); | |||||
if (!workspace_bytes.empty()) { | |||||
uint64_t workspace_mem_size_tmp = workspace_bytes[0]; | |||||
GELOGI("hccl need workSpaceMemSize=%lu", workspace_mem_size_tmp); | |||||
if (workspace_mem_size_tmp != 0) { | |||||
workspace_mem_size = workspace_mem_size_tmp; | |||||
vector<void *> workspace_data_addrs = | |||||
ModelUtils::GetWorkspaceDataAddrs(davinci_model_->GetRuntimeParam(), op_desc); | |||||
if (!workspace_data_addrs.empty()) { | |||||
GELOGI("Get workSpaceAddr"); | |||||
workspace_addr = workspace_data_addrs[0]; | |||||
} | |||||
} | |||||
} | |||||
for (size_t i = 0; i < kernel_hccl_infos.size(); i++) { | |||||
kernel_hccl_infos[i].workSpaceMemSize = workspace_mem_size; | |||||
kernel_hccl_infos[i].workSpaceAddr = workspace_addr; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
REGISTER_TASK_INFO(RT_MODEL_TASK_HCCL, HcclTaskInfo); | REGISTER_TASK_INFO(RT_MODEL_TASK_HCCL, HcclTaskInfo); | ||||
} // namespace ge | } // namespace ge |
@@ -18,9 +18,9 @@ | |||||
#define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_HCCL_TASK_INFO_H_ | #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_HCCL_TASK_INFO_H_ | ||||
#include <memory> | #include <memory> | ||||
#include <mutex> | |||||
#include <string> | #include <string> | ||||
#include <vector> | #include <vector> | ||||
#include <mutex> | |||||
#include "common/opskernel/ge_task_info.h" | #include "common/opskernel/ge_task_info.h" | ||||
#include "graph/load/new_model_manager/task_info/task_info.h" | #include "graph/load/new_model_manager/task_info/task_info.h" | ||||
@@ -30,16 +30,7 @@ class HcclTaskInfo : public TaskInfo { | |||||
public: | public: | ||||
HcclTaskInfo() | HcclTaskInfo() | ||||
: davinci_model_(nullptr), | : davinci_model_(nullptr), | ||||
hccl_type_(""), | |||||
input_data_addr_(nullptr), | |||||
output_data_addr_(nullptr), | |||||
count_(0), | |||||
data_type_(HCCL_DATA_TYPE_INT8), | |||||
op_type_(HCCL_REP_OP_SUM), | |||||
root_id_(0), | |||||
id_(0), | id_(0), | ||||
workspace_addr_(nullptr), | |||||
workspace_mem_size_(0), | |||||
hccl_stream_list_(), | hccl_stream_list_(), | ||||
ops_kernel_store_(nullptr), | ops_kernel_store_(nullptr), | ||||
private_def_(nullptr), | private_def_(nullptr), | ||||
@@ -56,6 +47,8 @@ class HcclTaskInfo : public TaskInfo { | |||||
private: | private: | ||||
ge::Status SetAddrs(const std::string &hccl_type, const std::shared_ptr<OpDesc> &op); | ge::Status SetAddrs(const std::string &hccl_type, const std::shared_ptr<OpDesc> &op); | ||||
Status SetAddrs(const std::shared_ptr<OpDesc> &op_desc, std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos); | |||||
void TransToGETaskInfo(GETaskInfo &ge_task); | void TransToGETaskInfo(GETaskInfo &ge_task); | ||||
void GetPrivateDefByTaskDef(const domi::TaskDef &task); | void GetPrivateDefByTaskDef(const domi::TaskDef &task); | ||||
@@ -64,23 +57,21 @@ class HcclTaskInfo : public TaskInfo { | |||||
ge::Status CreateStream(int64_t stream_num, DavinciModel *davinci_model); | ge::Status CreateStream(int64_t stream_num, DavinciModel *davinci_model); | ||||
Status SetFollowStream(const ge::ConstOpDescPtr &op_desc, DavinciModel *davinci_model); | |||||
void CreateKernelHcclInfo(const ge::ConstOpDescPtr &op_desc); | |||||
Status SetWorkspace(const std::shared_ptr<OpDesc> &op_desc, std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos); | |||||
DavinciModel *davinci_model_; | DavinciModel *davinci_model_; | ||||
string hccl_type_; | |||||
void *input_data_addr_; | |||||
void *output_data_addr_; | |||||
int32_t count_; | |||||
hcclDataType_t data_type_; | |||||
hcclRedOp_t op_type_; | |||||
int64_t root_id_; | |||||
uint32_t id_; | uint32_t id_; | ||||
void *workspace_addr_; | |||||
uint64_t workspace_mem_size_; | |||||
vector<rtStream_t> hccl_stream_list_; | vector<rtStream_t> hccl_stream_list_; | ||||
void *ops_kernel_store_; | void *ops_kernel_store_; | ||||
void *private_def_; | void *private_def_; | ||||
uint32_t private_def_len_; | uint32_t private_def_len_; | ||||
static std::mutex hccl_follow_stream_mutex_; | static std::mutex hccl_follow_stream_mutex_; | ||||
static uint32_t max_node_of_hccl_stream_; | static uint32_t max_node_of_hccl_stream_; | ||||
vector<GETaskKernelHcclInfo> kernel_hccl_infos_; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_HCCL_TASK_INFO_H_ | #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_HCCL_TASK_INFO_H_ |
@@ -51,17 +51,7 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
GELOGE(INTERNAL_ERROR, "Init aicpu task info error, index is out of range!"); | GELOGE(INTERNAL_ERROR, "Init aicpu task info error, index is out of range!"); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
if (CopyTaskInfo(kernel_ex_def, rts_param, op_desc) != SUCCESS) { | |||||
GELOGE(FAILED, "copy task info to workspace failed."); | |||||
return FAILED; | |||||
} | |||||
const vector<void *> workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); | |||||
if (workspace_data_addrs.empty()) { | |||||
GELOGE(FAILED, "workspace_data_addrs is empty."); | |||||
return FAILED; | |||||
} | |||||
op_desc_ = op_desc; | |||||
// 2. Reconstruct kernelExDef.args to STR_FWK_OP_KERNEL | // 2. Reconstruct kernelExDef.args to STR_FWK_OP_KERNEL | ||||
STR_FWK_OP_KERNEL fwk_op_kernel = {0}; | STR_FWK_OP_KERNEL fwk_op_kernel = {0}; | ||||
@@ -87,7 +77,52 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
} | } | ||||
} | } | ||||
auto session_id = fwk_op_kernel.fwkKernelBase.fwk_kernel.sessionID; | |||||
// 2.2 Collect aicpu kernel | |||||
uint64_t kernel_id = fwk_op_kernel.fwkKernelBase.fwk_kernel.kernelID; | |||||
GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuKernel(session_id, davinci_model->Id(), kernel_id) != SUCCESS, | |||||
GELOGE(FAILED, "CreateAicpuKernel error."); | |||||
return FAILED;) | |||||
kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); | |||||
if (davinci_model_->IsKnownNode()) { | |||||
void *input_output_addr = davinci_model_->GetCurrentArgsAddr(args_offset_); | |||||
fwk_op_kernel.fwkKernelBase.fwk_kernel.inputOutputAddr = | |||||
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(input_output_addr)); | |||||
void *workspace_base_addr = nullptr; | |||||
rtError_t rt_ret = rtMalloc(&workspace_base_addr, kernel_ex_def.task_info_size(), RT_MEMORY_HBM); | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc error, ret: Ox%X", rt_ret); return FAILED;); | |||||
rt_ret = rtMemcpy(workspace_base_addr, kernel_ex_def.task_info_size(), kernel_ex_def.task_info().data(), | |||||
kernel_ex_def.task_info_size(), RT_MEMCPY_HOST_TO_DEVICE); | |||||
fwk_op_kernel.fwkKernelBase.fwk_kernel.workspaceBaseAddr = | |||||
static_cast<uint64_t>(reinterpret_cast<uintptr_t>(workspace_base_addr)); | |||||
fwk_op_kernel.fwkKernelBase.fwk_kernel.stepIDAddr = step_id_addr; | |||||
fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoNum = 0; | |||||
fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = 0; | |||||
rt_ret = rtMalloc(&kernel_buf_, kernel_buf_size_, RT_MEMORY_HBM); | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc error: 0x%X", rt_ret); return FAILED;) | |||||
rt_ret = rtMemcpy(kernel_buf_, kernel_buf_size_, static_cast<void *>(&fwk_op_kernel), kernel_buf_size_, | |||||
RT_MEMCPY_HOST_TO_DEVICE); | |||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMemcpy error, ret: Ox%X", rt_ret); return FAILED;) | |||||
GELOGI("KernelExTaskInfo knonw node Init Success."); | |||||
return SUCCESS; | |||||
} | |||||
// 3. Set workspaceaddr, inputOutputDataAddr | // 3. Set workspaceaddr, inputOutputDataAddr | ||||
if (CopyTaskInfo(kernel_ex_def, rts_param, op_desc) != SUCCESS) { | |||||
GELOGE(FAILED, "copy task info to workspace failed."); | |||||
return FAILED; | |||||
} | |||||
const vector<void *> workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc); | |||||
if (workspace_data_addrs.empty()) { | |||||
GELOGE(FAILED, "workspace_data_addrs is empty."); | |||||
return FAILED; | |||||
} | |||||
uint64_t workspace_base_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(workspace_data_addrs[0])); | uint64_t workspace_base_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(workspace_data_addrs[0])); | ||||
const vector<void *> input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); | const vector<void *> input_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc); | ||||
const vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); | const vector<void *> output_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc); | ||||
@@ -106,8 +141,7 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), op_desc->GetName())) { | if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), op_desc->GetName())) { | ||||
dump_flag_ = RT_KERNEL_DUMPFLAG; | dump_flag_ = RT_KERNEL_DUMPFLAG; | ||||
dump_args_ = | |||||
reinterpret_cast<void *>(reinterpret_cast<uintptr_t>(input_output_addr_) + sizeof(void *) * input_addrs.size()); | |||||
dump_args_ = input_output_addr_; | |||||
} | } | ||||
} | } | ||||
@@ -119,16 +153,10 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = 0; | fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = 0; | ||||
// 4. Create session | // 4. Create session | ||||
auto session_id = fwk_op_kernel.fwkKernelBase.fwk_kernel.sessionID; | |||||
GE_CHECK_NOTNULL(ModelManager::GetInstance()); | GE_CHECK_NOTNULL(ModelManager::GetInstance()); | ||||
GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuSession(session_id) != SUCCESS, | GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuSession(session_id) != SUCCESS, | ||||
GELOGE(FAILED, "CreateAicpuSession error. session id: %lu", session_id); | GELOGE(FAILED, "CreateAicpuSession error. session id: %lu", session_id); | ||||
return FAILED;) | return FAILED;) | ||||
// 4.1 Collect aicpu kernel | |||||
uint64_t kernel_id = fwk_op_kernel.fwkKernelBase.fwk_kernel.kernelID; | |||||
GE_IF_BOOL_EXEC(ModelManager::GetInstance()->CreateAicpuKernel(session_id, davinci_model->Id(), kernel_id) != SUCCESS, | |||||
GELOGE(FAILED, "CreateAicpuKernel error."); | |||||
return FAILED;) | |||||
// 5. Return result | // 5. Return result | ||||
rtError_t rt_ret = rtMalloc(&kernel_buf_, sizeof(STR_FWK_OP_KERNEL), RT_MEMORY_HBM); | rtError_t rt_ret = rtMalloc(&kernel_buf_, sizeof(STR_FWK_OP_KERNEL), RT_MEMORY_HBM); | ||||
GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc error: 0x%X", rt_ret); return FAILED;) | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, GELOGE(rt_ret, "rtMalloc error: 0x%X", rt_ret); return FAILED;) | ||||
@@ -144,12 +172,46 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); | virtual_io_addrs.insert(virtual_io_addrs.end(), virtual_out_addrs.begin(), virtual_out_addrs.end()); | ||||
davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, io_addrs.data(), input_output_addr_, addrs_size, 0); | davinci_model_->SetZeroCopyAddr(op_desc, virtual_io_addrs, io_addrs.data(), input_output_addr_, addrs_size, 0); | ||||
kernel_buf_size_ = sizeof(STR_FWK_OP_KERNEL); | |||||
GELOGI("KernelExTaskInfo Init Success. session id: %lu", session_id); | GELOGI("KernelExTaskInfo Init Success. session id: %lu", session_id); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status KernelExTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||||
auto kernel_ex_def = task_def.kernel_ex(); | |||||
uint32_t op_index = kernel_ex_def.op_index(); | |||||
OpDescPtr op_desc = davinci_model->GetOpByIndex(op_index); | |||||
if (op_desc == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "Init aicpu task info error, index is out of range!"); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
args_offset_ = davinci_model->GetTotalArgsSize(); | |||||
const size_t inputs_size = op_desc->GetInputsSize(); | |||||
const size_t outputs_size = op_desc->GetOutputsSize(); | |||||
// aicpu kernel input/output size | |||||
size_t mem_length = inputs_size + outputs_size; | |||||
uint32_t mem_size = sizeof(uint64_t) * mem_length; | |||||
davinci_model->SetTotalArgsSize(mem_size); | |||||
GELOGI("kernel task name %s, args_size %u, args_offset %u", op_desc->GetName().c_str(), mem_size, args_offset_); | |||||
return SUCCESS; | |||||
} | |||||
Status KernelExTaskInfo::UpdateArgs() { | |||||
GELOGI("KernelExTaskInfo::UpdateArgs in."); | |||||
const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); | |||||
vector<void *> io_addrs; | |||||
vector<void *> input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc_); | |||||
vector<void *> output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc_); | |||||
io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); | |||||
io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); | |||||
GE_CHK_STATUS_RET(davinci_model_->UpdateKnownZeroCopyAddr(io_addrs, args_offset_), | |||||
"update known node %s zero copy addr failed.", op_desc_->GetName().c_str()); | |||||
GELOGI("KernelExTaskInfo::UpdateArgs success."); | |||||
return SUCCESS; | |||||
} | |||||
Status KernelExTaskInfo::CopyTaskInfo(const domi::KernelExDef &kernel_def, const RuntimeParam &rts_param, | Status KernelExTaskInfo::CopyTaskInfo(const domi::KernelExDef &kernel_def, const RuntimeParam &rts_param, | ||||
const OpDescPtr &op_desc) { | const OpDescPtr &op_desc) { | ||||
// Userspace copy need virtual address. | // Userspace copy need virtual address. | ||||
@@ -41,6 +41,10 @@ class KernelExTaskInfo : public TaskInfo { | |||||
Status Release() override; | Status Release() override; | ||||
Status UpdateArgs() override; | |||||
Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||||
uint32_t GetTaskID() override { return task_id_; } | uint32_t GetTaskID() override { return task_id_; } | ||||
uint32_t GetStreamId() override { return stream_id_; } | uint32_t GetStreamId() override { return stream_id_; } | ||||
@@ -61,6 +65,8 @@ class KernelExTaskInfo : public TaskInfo { | |||||
void *kernel_buf_; | void *kernel_buf_; | ||||
void *input_output_addr_; | void *input_output_addr_; | ||||
void *dump_args_; | void *dump_args_; | ||||
OpDescPtr op_desc_ = nullptr; | |||||
uint32_t args_offset_ = 0; | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_KERNEL_EX_TASK_INFO_H_ | #endif // GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_KERNEL_EX_TASK_INFO_H_ |
@@ -343,6 +343,10 @@ Status KernelTaskInfo::SuperKernelDistribute() { | |||||
Status KernelTaskInfo::Distribute() { | Status KernelTaskInfo::Distribute() { | ||||
GELOGD("KernelTaskInfo Distribute Start."); | GELOGD("KernelTaskInfo Distribute Start."); | ||||
if (davinci_model_->IsKnownNode()) { | |||||
args_ = davinci_model_->GetCurrentArgsAddr(args_offset_); | |||||
GELOGI("Known node %s args addr %p, offset %u.", op_desc_->GetName().c_str(), args_, args_offset_); | |||||
} | |||||
rtError_t rt_ret = RT_ERROR_NONE; | rtError_t rt_ret = RT_ERROR_NONE; | ||||
char *skt_enable_env = getenv("SKT_ENABLE"); | char *skt_enable_env = getenv("SKT_ENABLE"); | ||||
int64_t env_flag = (skt_enable_env != nullptr) ? strtol(skt_enable_env, nullptr, 10) : 0; | int64_t env_flag = (skt_enable_env != nullptr) ? strtol(skt_enable_env, nullptr, 10) : 0; | ||||
@@ -380,7 +384,29 @@ Status KernelTaskInfo::Distribute() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status KernelTaskInfo::UpdateArgs() { | |||||
GELOGI("KernelTaskInfo::UpdateArgs in."); | |||||
const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); | |||||
vector<void *> input_data_addrs = ModelUtils::GetInputDataAddrs(rts_param, op_desc_); | |||||
vector<void *> output_data_addrs = ModelUtils::GetOutputDataAddrs(rts_param, op_desc_); | |||||
vector<void *> workspace_data_addrs = ModelUtils::GetWorkspaceDataAddrs(rts_param, op_desc_); | |||||
vector<void *> io_addrs; | |||||
io_addrs.insert(io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); | |||||
io_addrs.insert(io_addrs.end(), output_data_addrs.begin(), output_data_addrs.end()); | |||||
io_addrs.insert(io_addrs.end(), workspace_data_addrs.begin(), workspace_data_addrs.end()); | |||||
GE_CHK_STATUS_RET(davinci_model_->UpdateKnownZeroCopyAddr(io_addrs, args_offset_), | |||||
"update known node %s zero copy addr failed.", op_desc_->GetName().c_str()); | |||||
GELOGI("KernelTaskInfo::UpdateArgs success."); | |||||
return SUCCESS; | |||||
} | |||||
Status KernelTaskInfo::Release() { | Status KernelTaskInfo::Release() { | ||||
if (davinci_model_ != nullptr && davinci_model_->IsKnownNode()) { | |||||
return SUCCESS; | |||||
} | |||||
FreeRtMem(&args_); | FreeRtMem(&args_); | ||||
FreeRtMem(&flowtable_); | FreeRtMem(&flowtable_); | ||||
FreeRtMem(&custom_info_.input_descs); | FreeRtMem(&custom_info_.input_descs); | ||||
@@ -439,6 +465,15 @@ Status KernelTaskInfo::UpdateL2Data(const domi::KernelDef &kernel_def) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status KernelTaskInfo::CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { | |||||
domi::KernelDef kernel_def = task_def.kernel(); | |||||
uint32_t args_size = kernel_def.args_size(); | |||||
args_offset_ = davinci_model->GetTotalArgsSize(); | |||||
davinci_model->SetTotalArgsSize(args_size); | |||||
GELOGI("kernel task name , args_size %u, args_offset %u", args_size, args_offset_); | |||||
return SUCCESS; | |||||
} | |||||
Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kernel_def) { | Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kernel_def) { | ||||
GELOGD("Do InitTVMTask."); | GELOGD("Do InitTVMTask."); | ||||
GE_CHECK_NOTNULL(davinci_model_); | GE_CHECK_NOTNULL(davinci_model_); | ||||
@@ -448,6 +483,9 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne | |||||
GELOGE(INTERNAL_ERROR, "InitTVMTaskInfo error, index:%u out of range!", ctx_.opIndex); | GELOGE(INTERNAL_ERROR, "InitTVMTaskInfo error, index:%u out of range!", ctx_.opIndex); | ||||
return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
} | } | ||||
if (davinci_model_->IsKnownNode()) { | |||||
return SUCCESS; | |||||
} | |||||
// Update Stub | // Update Stub | ||||
// When training, when the the second call to DavinciModel::init() comes here, stub_func_ is already valid, | // When training, when the the second call to DavinciModel::init() comes here, stub_func_ is already valid, | ||||
@@ -512,7 +550,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne | |||||
if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), op_desc->GetName())) { | if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), op_desc->GetName())) { | ||||
dump_flag_ = RT_KERNEL_DUMPFLAG; | dump_flag_ = RT_KERNEL_DUMPFLAG; | ||||
dump_args_ = static_cast<char *>(args_) + offset + kAddrLen * input_data_addrs.size(); | |||||
dump_args_ = static_cast<char *>(args_) + offset; | |||||
} | } | ||||
// update origin l2 data | // update origin l2 data | ||||
@@ -771,7 +809,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), op_desc->GetName())) { | if (PropertiesManager::Instance().IsLayerNeedDump(davinci_model_->Name(), op_desc->GetName())) { | ||||
dump_flag_ = RT_KERNEL_DUMPFLAG; | dump_flag_ = RT_KERNEL_DUMPFLAG; | ||||
dump_args_ = static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead) + kAddrLen * input_addrs.size(); | |||||
dump_args_ = static_cast<char *>(args_) + sizeof(aicpu::AicpuParamHead); | |||||
} | } | ||||
vector<void *> virtual_io_addrs; // use virtual address for zero copy key. | vector<void *> virtual_io_addrs; // use virtual address for zero copy key. | ||||
@@ -67,6 +67,10 @@ class KernelTaskInfo : public TaskInfo { | |||||
Status Distribute() override; | Status Distribute() override; | ||||
Status UpdateArgs() override; | |||||
Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) override; | |||||
Status Release() override; | Status Release() override; | ||||
cce::ccOpContext *GetCtx() override { return &ctx_; } | cce::ccOpContext *GetCtx() override { return &ctx_; } | ||||
@@ -146,6 +150,7 @@ class KernelTaskInfo : public TaskInfo { | |||||
void *dump_args_; | void *dump_args_; | ||||
OpDescPtr op_desc_; | OpDescPtr op_desc_; | ||||
DavinciModel *davinci_model_; | DavinciModel *davinci_model_; | ||||
uint32_t args_offset_ = 0; | |||||
// For super kernel | // For super kernel | ||||
uint32_t skt_id_; | uint32_t skt_id_; | ||||
@@ -62,6 +62,10 @@ class TaskInfo { | |||||
virtual Status Distribute() = 0; | virtual Status Distribute() = 0; | ||||
virtual Status UpdateArgs() { return SUCCESS; } | |||||
virtual Status CalculateArgs(const domi::TaskDef &task_def, DavinciModel *davinci_model) { return SUCCESS; } | |||||
virtual Status Release() { return SUCCESS; } | virtual Status Release() { return SUCCESS; } | ||||
virtual cce::ccOpContext *GetCtx() { return nullptr; } | virtual cce::ccOpContext *GetCtx() { return nullptr; } | ||||
@@ -0,0 +1,343 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#include "graph/manager/graph_caching_allocator.h" | |||||
#include <set> | |||||
#include <string> | |||||
#include <utility> | |||||
#include "framework/common/debug/ge_log.h" | |||||
#include "graph/manager/graph_mem_allocator.h" | |||||
namespace ge { | |||||
const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, | |||||
8 * kMByteSize, | |||||
32 * kMByteSize, | |||||
128 * kMByteSize, | |||||
kGByteSize, | |||||
4 * kGByteSize, | |||||
16 * kGByteSize, | |||||
26 * kGByteSize}; | |||||
static bool BlockComparator(const Block *left, const Block *right) { | |||||
if (left->device_id != right->device_id) { | |||||
return left->device_id < right->device_id; | |||||
} | |||||
if (left->size != right->size) { | |||||
return left->size < right->size; | |||||
} | |||||
return reinterpret_cast<uintptr_t>(left->ptr) < reinterpret_cast<uintptr_t>(right->ptr); | |||||
} | |||||
bool CanMerge(Block *block) { | |||||
if (block == nullptr || block->allocated || !block->IsSplit()) { | |||||
return false; | |||||
} | |||||
return true; | |||||
} | |||||
size_t GetBinIndex(size_t size) { | |||||
size_t index = 0; | |||||
for (auto range : bin_ranges) { | |||||
if (size <= range) { | |||||
break; | |||||
} | |||||
++index; | |||||
} | |||||
if (index > kNumBins - 1) { | |||||
index = kNumBins - 1; | |||||
} | |||||
return index; | |||||
} | |||||
size_t GetAllocationSize(size_t size) { | |||||
size_t index = GetBinIndex(size); | |||||
return bin_ranges[index]; | |||||
} | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief block size based on alignment | |||||
/// @param [in] original malloc size | |||||
/// @return allocation size | |||||
/// | |||||
size_t GetBlockSize(size_t size) { | |||||
if (size == 0) { | |||||
return kRoundBlockSize; | |||||
} | |||||
return kRoundBlockSize * ((size + kRoundBlockSize - 1) / kRoundBlockSize); | |||||
} | |||||
bool ShouldSplit(const Block *block, size_t size) { | |||||
return static_cast<double>(size) <= (static_cast<double>(block->size) * kSplitThreshold); | |||||
} | |||||
CachingAllocator::CachingAllocator(rtMemType_t memory_type) : memory_type_(memory_type), memory_allocator_(nullptr) { | |||||
for (uint32_t i = 0; i < kNumBins; ++i) { | |||||
free_block_bins_[i] = nullptr; | |||||
} | |||||
} | |||||
Status CachingAllocator::Initialize(uint32_t device_id) { | |||||
GELOGI("Device id %u", device_id); | |||||
// when redo Initialize free old memory | |||||
FreeBlocks(); | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
for (uint32_t i = 0; i < kNumBins; ++i) { | |||||
if (free_block_bins_[i] != nullptr) { | |||||
continue; | |||||
} | |||||
auto bin_ptr = new (std::nothrow) BlockBin(BlockComparator); | |||||
if (bin_ptr == nullptr) { | |||||
GELOGE(ge::FAILED, "Alloc BlockBin failed."); | |||||
return ge::FAILED; | |||||
} | |||||
free_block_bins_[i] = bin_ptr; | |||||
} | |||||
memory_allocator_ = MemManager::Instance(memory_type_); | |||||
if (memory_allocator_ == nullptr) { | |||||
return ge::FAILED; | |||||
} | |||||
return ge::SUCCESS; | |||||
} | |||||
void CachingAllocator::Finalize(uint32_t device_id) { | |||||
GELOGI("Device id %u", device_id); | |||||
FreeBlocks(); | |||||
FreeBlockBins(); | |||||
} | |||||
uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device_id) { | |||||
uint8_t *ptr = nullptr; | |||||
size = GetBlockSize(size); | |||||
Block *block = FindFreeBlock(size, org_ptr, device_id); | |||||
if (block != nullptr) { | |||||
ptr = block->ptr; | |||||
} else { | |||||
if (ge::SUCCESS == TryExtendCache(size, device_id)) { | |||||
block = FindFreeBlock(size, org_ptr, device_id); | |||||
if (block != nullptr) { | |||||
ptr = block->ptr; | |||||
} | |||||
} | |||||
} | |||||
if (ptr == nullptr) { | |||||
GELOGE(FAILED, "Malloc failed device id = %u, size= %zu", device_id, size); | |||||
} else { | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
block->allocated = true; | |||||
allocated_blocks_[block->ptr] = block; | |||||
GELOGI("Malloc device id = %u, size= %zu", device_id, size); | |||||
} | |||||
return ptr; | |||||
} | |||||
Status CachingAllocator::Free(uint8_t *ptr, uint32_t device_id) { | |||||
GELOGI("Free device id = %u", device_id); | |||||
if (ptr == nullptr) { | |||||
GELOGE(PARAM_INVALID, "Invalid memory pointer"); | |||||
return ge::PARAM_INVALID; | |||||
} | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
auto it = allocated_blocks_.find(ptr); | |||||
if (it == allocated_blocks_.end()) { | |||||
GELOGE(PARAM_INVALID, "Invalid memory pointer"); | |||||
return ge::PARAM_INVALID; | |||||
} | |||||
Block *block = it->second; | |||||
allocated_blocks_.erase(it); | |||||
FreeBlock(block); | |||||
return ge::SUCCESS; | |||||
} | |||||
void CachingAllocator::FreeBlock(Block *block) { | |||||
if (block == nullptr || !block->allocated) { | |||||
return; | |||||
} | |||||
GELOGI("Free block size = %zu", block->size); | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
block->allocated = false; | |||||
auto &bin = *block->bin; | |||||
Block *merge_blocks[] = {block->prev, block->next}; | |||||
for (Block *merge_block : merge_blocks) { | |||||
MergeBlocks(block, merge_block, bin); | |||||
} | |||||
bin.insert(block); | |||||
} | |||||
void CachingAllocator::MergeBlocks(Block *dst, Block *src, BlockBin &bin) { | |||||
if (!CanMerge(dst) || !CanMerge(src)) { | |||||
return; | |||||
} | |||||
if (dst->prev == src) { | |||||
dst->ptr = src->ptr; | |||||
dst->prev = src->prev; | |||||
if (dst->prev != nullptr) { | |||||
dst->prev->next = dst; | |||||
} | |||||
} else { | |||||
dst->next = src->next; | |||||
if (dst->next != nullptr) { | |||||
dst->next->prev = dst; | |||||
} | |||||
} | |||||
dst->size += src->size; | |||||
bin.erase(src); | |||||
delete src; | |||||
} | |||||
BlockBin *CachingAllocator::GetBlockBin(size_t size) { | |||||
size_t index = GetBinIndex(size); | |||||
return free_block_bins_[index]; | |||||
} | |||||
Block *CachingAllocator::FindFreeBlock(size_t size, uint8_t *org_ptr, uint32_t device_id) { | |||||
// org_ptr - 1, try to find ptr same as org_ptr | |||||
Block key(device_id, size, (org_ptr == nullptr ? nullptr : org_ptr - 1)); | |||||
BlockBin *bin = GetBlockBin(size); | |||||
if (bin == nullptr) { | |||||
GELOGE(ge::FAILED, "Get block bin failed size = %zu", size); | |||||
return nullptr; | |||||
} | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
auto it = bin->lower_bound(&key); | |||||
if (it != bin->end()) { | |||||
Block *block = *it; | |||||
bin->erase(it); | |||||
if (block != nullptr) { | |||||
GELOGI("Find block size = %zu", block->size); | |||||
if (ShouldSplit(block, size)) { | |||||
return SplitBlock(block, size, *bin, device_id); | |||||
} | |||||
} | |||||
return block; | |||||
} | |||||
return nullptr; | |||||
} | |||||
Block *CachingAllocator::SplitBlock(Block *block, size_t size, BlockBin &bin, uint32_t device_id) { | |||||
// block has been checked, should not be nullptr | |||||
Block *remaining = block; | |||||
Block *new_block = new (std::nothrow) Block(device_id, size, &bin, block->ptr); | |||||
if (new_block == nullptr) { | |||||
GELOGE(ge::FAILED, "Alloc block failed size = %zu", size); | |||||
return block; | |||||
} | |||||
new_block->prev = remaining->prev; | |||||
if (new_block->prev != nullptr) { | |||||
new_block->prev->next = new_block; | |||||
} | |||||
new_block->next = remaining; | |||||
remaining->prev = new_block; | |||||
remaining->ptr = remaining->ptr + size; | |||||
remaining->size -= size; | |||||
bin.insert(remaining); | |||||
return new_block; | |||||
} | |||||
Status CachingAllocator::TryExtendCache(size_t size, uint32_t device_id) { | |||||
auto memory_size = GetAllocationSize(size); | |||||
const std::string purpose = "Memory for caching."; | |||||
auto memory_addr = memory_allocator_->MallocMemory(purpose, memory_size, device_id); | |||||
// try to free caches and malloc again when malloc memory failed | |||||
if (memory_addr == nullptr) { | |||||
FreeCachedBlocks(); | |||||
memory_addr = memory_allocator_->MallocMemory(purpose, memory_size, device_id); | |||||
if (memory_addr == nullptr) { | |||||
GELOGE(ge::FAILED, "TryExtendCache failed, no enough memory for size = %zu, device_id = %u", memory_size, | |||||
device_id); | |||||
return ge::FAILED; | |||||
} | |||||
} | |||||
if (AddToBlockBin(memory_addr, memory_size) != ge::SUCCESS) { | |||||
(void)memory_allocator_->FreeMemory(memory_addr); | |||||
return ge::FAILED; | |||||
} | |||||
return ge::SUCCESS; | |||||
} | |||||
Status CachingAllocator::AddToBlockBin(uint8_t *ptr, size_t size) { | |||||
BlockBin *bin = GetBlockBin(size); | |||||
if (bin == nullptr) { | |||||
GELOGE(ge::FAILED, "Get block bin failed size = %zu", size); | |||||
return ge::FAILED; | |||||
} | |||||
Block *block = new (std::nothrow) Block(0, size, bin, nullptr); | |||||
if (block == nullptr) { | |||||
GELOGE(ge::FAILED, "Alloc block failed size = %zu", size); | |||||
return ge::FAILED; | |||||
} | |||||
GELOGI("Block size = %zu", size); | |||||
block->ptr = ptr; | |||||
block->size = size; | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
bin->insert(block); | |||||
return ge::SUCCESS; | |||||
} | |||||
void CachingAllocator::FreeCachedBlocks() { | |||||
GELOGI("Free cached blocks"); | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
for (uint32_t i = 0; i < kNumBins; ++i) { | |||||
auto pool = free_block_bins_[i]; | |||||
if (pool == nullptr) { | |||||
continue; | |||||
} | |||||
for (auto it = pool->begin(); it != pool->end();) { | |||||
Block *block = *it; | |||||
// free block memory that has not been split | |||||
if ((block != nullptr) && (block->ptr != nullptr) && (block->prev == nullptr) && (block->next == nullptr) && | |||||
(memory_allocator_->FreeMemory(block->ptr) == ge::SUCCESS)) { | |||||
pool->erase(it++); | |||||
delete block; | |||||
continue; | |||||
} | |||||
++it; | |||||
} | |||||
} | |||||
} | |||||
void CachingAllocator::FreeBlocks() { | |||||
GELOGI("Free blocks"); | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
// free allocated blocks and put to cache | |||||
for (auto &it : allocated_blocks_) { | |||||
FreeBlock(it.second); | |||||
} | |||||
allocated_blocks_.clear(); | |||||
FreeCachedBlocks(); | |||||
} | |||||
void CachingAllocator::FreeBlockBins() { | |||||
GELOGI("Free block bins"); | |||||
std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
for (uint32_t i = 0; i < kNumBins; ++i) { | |||||
if (free_block_bins_[i] != nullptr) { | |||||
delete free_block_bins_[i]; | |||||
free_block_bins_[i] = nullptr; | |||||
} | |||||
} | |||||
} | |||||
} // namespace ge |
@@ -0,0 +1,212 @@ | |||||
/** | |||||
* Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
* | |||||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||||
* you may not use this file except in compliance with the License. | |||||
* You may obtain a copy of the License at | |||||
* | |||||
* http://www.apache.org/licenses/LICENSE-2.0 | |||||
* | |||||
* Unless required by applicable law or agreed to in writing, software | |||||
* distributed under the License is distributed on an "AS IS" BASIS, | |||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
* See the License for the specific language governing permissions and | |||||
* limitations under the License. | |||||
*/ | |||||
#ifndef GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_ | |||||
#define GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_ | |||||
#include <iostream> | |||||
#include <map> | |||||
#include <memory> | |||||
#include <mutex> | |||||
#include <string> | |||||
#include <vector> | |||||
#include <set> | |||||
#include <unordered_map> | |||||
#include <unordered_set> | |||||
#include "framework/common/ge_inner_error_codes.h" | |||||
#include "graph/node.h" | |||||
#include "runtime/mem.h" | |||||
namespace ge { | |||||
constexpr size_t kRoundBlockSize = 512; // all block sizes are rounded to at least 512 bytes | |||||
constexpr double kSplitThreshold = 0.75; // split when malloc size <= small block size * kSpliThreshold | |||||
constexpr size_t kKByteSize = 1024; | |||||
constexpr size_t kMByteSize = 1024 * 1024; | |||||
constexpr size_t kGByteSize = 1024 * 1024 * 1024; | |||||
struct Block; | |||||
typedef bool (*Comparison)(const Block *, const Block *); | |||||
using BlockBin = std::set<Block *, Comparison>; | |||||
static const uint32_t kNumBins = 8; | |||||
struct Block { | |||||
uint32_t device_id; // npu device id | |||||
size_t size; // block size in bytes | |||||
BlockBin *bin; // owning block bin | |||||
uint8_t *ptr; // memory address | |||||
bool allocated; // in-use flag | |||||
Block *prev; // prev block if split from a larger allocation | |||||
Block *next; // next block if split from a larger allocation | |||||
Block(uint32_t device, size_t size, BlockBin *bin, uint8_t *ptr) | |||||
: device_id(device), size(size), bin(bin), ptr(ptr), allocated(0), prev(nullptr), next(nullptr) {} | |||||
// constructor for search key | |||||
Block(uint32_t device, size_t size, uint8_t *ptr) | |||||
: device_id(device), size(size), bin(nullptr), ptr(ptr), allocated(0), prev(nullptr), next(nullptr) {} | |||||
bool IsSplit() const { return (prev != nullptr) || (next != nullptr); } | |||||
}; | |||||
class MemoryAllocator; | |||||
class CachingAllocator { | |||||
public: | |||||
explicit CachingAllocator(rtMemType_t memory_type); | |||||
virtual ~CachingAllocator() = default; | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief caching allocator init | |||||
/// @param [in] device id | |||||
/// @return Status of init | |||||
/// | |||||
Status Initialize(uint32_t device_id = 0); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief memory allocator finalize, release cached memory | |||||
/// @return void | |||||
/// | |||||
void Finalize(uint32_t device_id = 0); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief malloc memory | |||||
/// @param [in] size memory size | |||||
/// @param [in] try to reuse the same memory | |||||
/// @param [in] device id | |||||
/// @return memory address | |||||
/// | |||||
uint8_t *Malloc(size_t size, uint8_t *org_ptr = nullptr, uint32_t device_id = 0); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief free memory | |||||
/// @param [in] device_id device id | |||||
/// @param [out] memory_ptr memory address ptr | |||||
/// @return Status result of function | |||||
/// | |||||
Status Free(uint8_t *memory_addr, uint32_t device_id = 0); | |||||
private: | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief extend cache by size | |||||
/// @param [in] memory size | |||||
/// @param [in] device id | |||||
/// @return Status result of function | |||||
/// | |||||
Status TryExtendCache(size_t size, uint32_t device_id); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief find free block by size | |||||
/// @param [in] memory size | |||||
/// @param [in] device_id device id | |||||
/// @return block ptr | |||||
/// | |||||
Block *FindFreeBlock(size_t size, uint8_t *org_ptr, uint32_t device_id); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief get the right bin based on size | |||||
/// @param [in] original malloc size | |||||
/// @return block bin | |||||
/// | |||||
BlockBin *GetBlockBin(size_t size); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief add memory to right bin based on size | |||||
/// @param [in] memory ptr | |||||
/// @param [in] memory size | |||||
/// @return Status result of function | |||||
/// | |||||
Status AddToBlockBin(uint8_t *ptr, size_t size); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief free block to right bin | |||||
/// @param [in] block ptr | |||||
/// @return void | |||||
/// | |||||
void FreeBlock(Block *block); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief free all cached blocks to right bin and release the memory when memory is not enough | |||||
/// @return void | |||||
/// | |||||
void FreeCachedBlocks(); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief free allocated and cached blocks and release the memory when process exit | |||||
/// @return void | |||||
/// | |||||
void FreeBlocks(); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief free block bins when process exit | |||||
/// @return void | |||||
/// | |||||
void FreeBlockBins(); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief If a split block is freed, try merging with the original block | |||||
/// @param [inout] dest block ptr | |||||
/// @param [in] src block ptr | |||||
/// @param [out] block bin | |||||
/// @return void | |||||
/// | |||||
void MergeBlocks(Block *dst, Block *src, BlockBin &bin); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief If the allocated memory size is too much smaller than the memory block, try to split the memory block | |||||
/// @param [in] original block ptr | |||||
/// @param [in] allocated memory size | |||||
/// @param [in] block bin | |||||
/// @param [in] device id | |||||
/// @return splited block ptr | |||||
/// | |||||
Block *SplitBlock(Block *block, size_t size, BlockBin &bin, uint32_t device_id); | |||||
private: | |||||
rtMemType_t memory_type_; | |||||
// device memory allocator | |||||
MemoryAllocator *memory_allocator_; | |||||
// lock around all operations | |||||
mutable std::recursive_mutex mutex_; | |||||
// allocated blocks by memory pointer | |||||
std::unordered_map<uint8_t *, Block *> allocated_blocks_; | |||||
// block bins by different block size | |||||
BlockBin *free_block_bins_[kNumBins]; | |||||
}; | |||||
}; // namespace ge | |||||
#endif // GE_GRAPH_MANAGER_GRAPH_CACHING_ALLOCATOR_H_ |
@@ -99,7 +99,7 @@ class GraphManager { | |||||
/// @param [out] models build result | /// @param [out] models build result | ||||
/// @return Status result of function | /// @return Status result of function | ||||
/// | /// | ||||
Status BuildGraph(const GraphId &graph_id, const std::vector<GeTensor> &inputs, vector<GeModelPtr> &models); | |||||
ge::Status BuildGraph(const GraphId &graph_id, const std::vector<GeTensor> &inputs, GeRootModelPtr &models); | |||||
/// | /// | ||||
/// @ingroup ge_graph | /// @ingroup ge_graph | ||||
@@ -153,6 +153,8 @@ class GraphManager { | |||||
const std::map<std::string, std::string> *GetGraphOptions(uint32_t graph_id); | const std::map<std::string, std::string> *GetGraphOptions(uint32_t graph_id); | ||||
void SetOptionsRunGraphFlag(bool run_graph_flag); | |||||
private: | private: | ||||
struct PreRunArgs { | struct PreRunArgs { | ||||
GraphId graph_id; | GraphId graph_id; | ||||
@@ -166,7 +168,7 @@ class GraphManager { | |||||
GraphNodePtr graph_node; | GraphNodePtr graph_node; | ||||
GraphId graph_id; | GraphId graph_id; | ||||
std::vector<ge::InputTensorInfo> input_tensor; | std::vector<ge::InputTensorInfo> input_tensor; | ||||
GeModelPtr ge_model; | |||||
GeRootModelPtr ge_root_model; | |||||
GEThreadLocalContext context; | GEThreadLocalContext context; | ||||
RunAsyncCallback callback; | RunAsyncCallback callback; | ||||
}; | }; | ||||
@@ -177,19 +179,16 @@ class GraphManager { | |||||
static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, const SubGraphInfoPtr &sub_graph_info_ptr, | static Status ProcessSubGraphWithMultiThreads(GraphManager *graph_manager, const SubGraphInfoPtr &sub_graph_info_ptr, | ||||
uint64_t session_id, const GEThreadLocalContext &ge_context); | uint64_t session_id, const GEThreadLocalContext &ge_context); | ||||
Status PreRun(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, vector<GeModelPtr> &ge_models, | |||||
GeModelPtr &ge_model, uint64_t session_id = INVALID_SESSION_ID); | |||||
Status PreRunDynShape(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, | |||||
vector<GeModelPtr> &ge_models, GeModelPtr &ge_model, uint64_t session_id = INVALID_SESSION_ID); | |||||
Status PreRun(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, GeRootModelPtr &ge_root_model, | |||||
uint64_t session_id = INVALID_SESSION_ID); | |||||
Status OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, uint64_t session_id); | Status OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, uint64_t session_id); | ||||
Status Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, vector<GeModelPtr> &ge_models, | |||||
GeModelPtr &ge_model, uint64_t session_id); | |||||
Status Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, GeRootModelPtr &ge_root_model, | |||||
uint64_t session_id); | |||||
Status StartForRunGraph(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, | Status StartForRunGraph(const GraphNodePtr &graph_node, const std::vector<GeTensor> &inputs, | ||||
vector<GeModelPtr> &ge_models, uint64_t session_id = INVALID_SESSION_ID); | |||||
GeRootModelPtr &ge_root_model, uint64_t session_id = INVALID_SESSION_ID); | |||||
Status InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, const std::vector<GeTensor> &inputs, | Status InnerRunGraph(GraphNodePtr &graph_node, const GraphId &graph_id, const std::vector<GeTensor> &inputs, | ||||
std::vector<GeTensor> &outputs); | std::vector<GeTensor> &outputs); | ||||
@@ -240,6 +239,8 @@ class GraphManager { | |||||
Status SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph); | Status SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph); | ||||
void SetAttrForHcomBroadCastOp(ge::ComputeGraphPtr &compute_graph); | |||||
bool IsBroadCastOpData(const ge::NodePtr &var_node); | bool IsBroadCastOpData(const ge::NodePtr &var_node); | ||||
void AdjustBroadCastOpData(const ge::NodePtr &var_node); | void AdjustBroadCastOpData(const ge::NodePtr &var_node); | ||||
@@ -258,6 +259,7 @@ class GraphManager { | |||||
std::shared_ptr<GraphContext> GetGraphContext() const { return graph_context_; } | std::shared_ptr<GraphContext> GetGraphContext() const { return graph_context_; } | ||||
Status RemoveIsolatedConst(ge::ComputeGraphPtr &compute_graph); | Status RemoveIsolatedConst(ge::ComputeGraphPtr &compute_graph); | ||||
Status RemoveIsolatedConstInThisGraph(ge::ComputeGraphPtr &compute_graph); | |||||
Status OptimizeStage1(ComputeGraphPtr &compute_graph); | Status OptimizeStage1(ComputeGraphPtr &compute_graph); | ||||
Status OptimizeStage2(ComputeGraphPtr &compute_graph); | Status OptimizeStage2(ComputeGraphPtr &compute_graph); | ||||
@@ -265,13 +267,13 @@ class GraphManager { | |||||
Status NewOptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph); | Status NewOptimizeAfterMergeSubGraph(ge::ComputeGraphPtr &compute_graph); | ||||
Status LoadGraphAsync(const GeModelPtr &ge_model, const GraphNodePtr &graph_node); | |||||
Status LoadGraphAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); | |||||
Status CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node); | Status CheckAndReleaseMemory(const GeModelPtr &ge_model, const GraphNodePtr &graph_node); | ||||
bool CheckModelLoad(const GeModelPtr &ge_model, bool load_flag); | |||||
bool CheckModelLoad(const GeRootModelPtr &ge_model, bool load_flag); | |||||
Status LoadGraph(const GeModelPtr &ge_model, const GraphNodePtr &graph_node); | |||||
Status LoadGraph(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); | |||||
bool IsGraphNeedBuild(const GraphNodePtr &graph_node); | bool IsGraphNeedBuild(const GraphNodePtr &graph_node); | ||||
@@ -287,6 +289,8 @@ class GraphManager { | |||||
static void StopQueue(GraphManager *graph_manager); | static void StopQueue(GraphManager *graph_manager); | ||||
static void ReturnError(GraphManager *graph_manager, RunAsyncCallback callback, Status ret, const string &log); | static void ReturnError(GraphManager *graph_manager, RunAsyncCallback callback, Status ret, const string &log); | ||||
void ChangeConstTypeWhenTraining(const ComputeGraphPtr &compute_graph); | |||||
std::atomic_bool thread_run_flag_; | std::atomic_bool thread_run_flag_; | ||||
BlockingQueue<PreRunArgs> prerun_args_q_{}; | BlockingQueue<PreRunArgs> prerun_args_q_{}; | ||||
BlockingQueue<RunArgs> run_args_q_{}; | BlockingQueue<RunArgs> run_args_q_{}; | ||||
@@ -36,6 +36,7 @@ | |||||
#include "graph/graph.h" | #include "graph/graph.h" | ||||
#include "graph/model.h" | #include "graph/model.h" | ||||
#include "model/ge_model.h" | #include "model/ge_model.h" | ||||
#include "model/ge_root_model.h" | |||||
#include "register/register_fmk_types.h" | #include "register/register_fmk_types.h" | ||||
#include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
@@ -160,6 +161,8 @@ class GraphNode { | |||||
void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; } | void SetLoadFlag(bool load_flag) { load_flag_ = load_flag; } | ||||
void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; } | void SetGeModel(const GeModelPtr &ge_model) { ge_model_ = ge_model; } | ||||
GeModelPtr GetGeModel() const { return ge_model_; } | GeModelPtr GetGeModel() const { return ge_model_; } | ||||
void SetGeRootModel(const GeRootModelPtr &ge_root_model) { ge_root_model_ = ge_root_model; } | |||||
GeRootModelPtr GetGeRootModel() const { return ge_root_model_; } | |||||
const std::map<std::string, std::string> &GetOptions() const { return options_; } | const std::map<std::string, std::string> &GetOptions() const { return options_; } | ||||
void SetOptions(const std::map<std::string, std::string> &options) { options_ = options; } | void SetOptions(const std::map<std::string, std::string> &options) { options_ = options; } | ||||
void Lock(); | void Lock(); | ||||
@@ -179,6 +182,7 @@ class GraphNode { | |||||
bool build_flag_; | bool build_flag_; | ||||
bool load_flag_; | bool load_flag_; | ||||
GeModelPtr ge_model_; | GeModelPtr ge_model_; | ||||
GeRootModelPtr ge_root_model_; | |||||
BlockingQueue<uint8_t> sem_; | BlockingQueue<uint8_t> sem_; | ||||
}; | }; | ||||
@@ -15,6 +15,7 @@ | |||||
*/ | */ | ||||
#include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
#include "graph/manager/graph_caching_allocator.h" | |||||
#include <set> | #include <set> | ||||
#include <string> | #include <string> | ||||
@@ -47,7 +48,7 @@ void MemoryAllocator::Finalize(uint32_t device_id) { | |||||
memory_base_map_.clear(); | memory_base_map_.clear(); | ||||
} | } | ||||
uint8_t *MemoryAllocator::MallocMemory(const string &purpose, uint64_t memory_size, uint32_t device_id) const { | |||||
uint8_t *MemoryAllocator::MallocMemory(const string &purpose, size_t memory_size, uint32_t device_id) const { | |||||
uint8_t *memory_addr = nullptr; | uint8_t *memory_addr = nullptr; | ||||
if (rtMalloc(reinterpret_cast<void **>(&memory_addr), memory_size, memory_type_) != RT_ERROR_NONE) { | if (rtMalloc(reinterpret_cast<void **>(&memory_addr), memory_size, memory_type_) != RT_ERROR_NONE) { | ||||
@@ -74,7 +75,7 @@ Status MemoryAllocator::FreeMemory(uint8_t *memory_addr, uint32_t device_id) con | |||||
return ge::SUCCESS; | return ge::SUCCESS; | ||||
} | } | ||||
uint8_t *MemoryAllocator::MallocMemory(const string &purpose, const string &memory_key, uint64_t memory_size, | |||||
uint8_t *MemoryAllocator::MallocMemory(const string &purpose, const string &memory_key, size_t memory_size, | |||||
uint32_t device_id) { | uint32_t device_id) { | ||||
auto it = memory_base_map_.find(memory_key); | auto it = memory_base_map_.find(memory_key); | ||||
if (it != memory_base_map_.end()) { | if (it != memory_base_map_.end()) { | ||||
@@ -147,7 +148,7 @@ uint8_t *MemoryAllocator::GetMemoryAddr(const string &memory_key, uint32_t devic | |||||
return it->second.memory_addr_; | return it->second.memory_addr_; | ||||
} | } | ||||
MemManager::MemManager() : default_memory_allocator_(nullptr) {} | |||||
MemManager::MemManager() {} | |||||
MemManager::~MemManager() { Finalize(); } | MemManager::~MemManager() { Finalize(); } | ||||
@@ -159,7 +160,7 @@ MemManager &MemManager::Instance() { | |||||
MemoryAllocator *MemManager::Instance(rtMemType_t memory_type) { return Instance().GetMemoryAllocator(memory_type); } | MemoryAllocator *MemManager::Instance(rtMemType_t memory_type) { return Instance().GetMemoryAllocator(memory_type); } | ||||
Status MemManager::Initialize(const std::vector<rtMemType_t> &memory_type) { | Status MemManager::Initialize(const std::vector<rtMemType_t> &memory_type) { | ||||
std::lock_guard<std::mutex> lock(allocator_mutex_); | |||||
std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
MemoryAllocator *memory_allocator = nullptr; | MemoryAllocator *memory_allocator = nullptr; | ||||
for (unsigned int index : memory_type) { | for (unsigned int index : memory_type) { | ||||
auto it = memory_allocator_map_.find(index); | auto it = memory_allocator_map_.find(index); | ||||
@@ -184,34 +185,34 @@ Status MemManager::Initialize(const std::vector<rtMemType_t> &memory_type) { | |||||
} | } | ||||
} | } | ||||
default_memory_allocator_ = new (std::nothrow) MemoryAllocator(RT_MEMORY_RESERVED); | |||||
if (default_memory_allocator_ == nullptr) { | |||||
GELOGE(ge::INTERNAL_ERROR, "Create MemoryAllocator failed."); | |||||
return ge::INTERNAL_ERROR; | |||||
} | |||||
return ge::SUCCESS; | |||||
return InitCachingAllocator(memory_type); | |||||
} | } | ||||
void MemManager::Finalize() noexcept { | void MemManager::Finalize() noexcept { | ||||
GELOGI("Finalize."); | GELOGI("Finalize."); | ||||
std::lock_guard<std::mutex> lock(allocator_mutex_); | |||||
std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
// caching allocator use memory allocator, so finalize it first | |||||
for (auto &caching_allocator : caching_allocator_map_) { | |||||
if (caching_allocator.second != nullptr) { | |||||
caching_allocator.second->Finalize(); | |||||
delete caching_allocator.second; | |||||
caching_allocator.second = nullptr; | |||||
} | |||||
} | |||||
caching_allocator_map_.clear(); | |||||
for (auto &memory_allocator : memory_allocator_map_) { | for (auto &memory_allocator : memory_allocator_map_) { | ||||
if (memory_allocator.second != nullptr) { | if (memory_allocator.second != nullptr) { | ||||
memory_allocator.second->Finalize(0); | |||||
memory_allocator.second->Finalize(); | |||||
delete memory_allocator.second; | delete memory_allocator.second; | ||||
memory_allocator.second = nullptr; | memory_allocator.second = nullptr; | ||||
} | } | ||||
} | } | ||||
if (default_memory_allocator_ != nullptr) { | |||||
delete default_memory_allocator_; | |||||
default_memory_allocator_ = nullptr; | |||||
} | |||||
memory_allocator_map_.clear(); | memory_allocator_map_.clear(); | ||||
} | } | ||||
MemoryAllocator *MemManager::GetMemoryAllocator(rtMemType_t memory_type) { | MemoryAllocator *MemManager::GetMemoryAllocator(rtMemType_t memory_type) { | ||||
std::lock_guard<std::mutex> lock(allocator_mutex_); | |||||
std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
MemoryAllocator *memory_allocator = nullptr; | MemoryAllocator *memory_allocator = nullptr; | ||||
auto it = memory_allocator_map_.find(memory_type); | auto it = memory_allocator_map_.find(memory_type); | ||||
if (it != memory_allocator_map_.end()) { | if (it != memory_allocator_map_.end()) { | ||||
@@ -221,9 +222,60 @@ MemoryAllocator *MemManager::GetMemoryAllocator(rtMemType_t memory_type) { | |||||
// Usually impossible | // Usually impossible | ||||
if (memory_allocator == nullptr) { | if (memory_allocator == nullptr) { | ||||
GELOGE(ge::INTERNAL_ERROR, "GetMemoryAllocator failed, memory type is %u.", memory_type); | GELOGE(ge::INTERNAL_ERROR, "GetMemoryAllocator failed, memory type is %u.", memory_type); | ||||
return default_memory_allocator_; | |||||
static MemoryAllocator default_memory_allocator(RT_MEMORY_RESERVED); | |||||
return &default_memory_allocator; | |||||
} | } | ||||
return memory_allocator; | return memory_allocator; | ||||
} | } | ||||
Status MemManager::InitCachingAllocator(const std::vector<rtMemType_t> &memory_type) { | |||||
CachingAllocator *caching_allocator = nullptr; | |||||
for (unsigned int index : memory_type) { | |||||
auto it = caching_allocator_map_.find(index); | |||||
if (it == caching_allocator_map_.end()) { | |||||
caching_allocator = new (std::nothrow) CachingAllocator(index); | |||||
if (caching_allocator != nullptr) { | |||||
caching_allocator_map_[index] = caching_allocator; | |||||
GELOGI("Create CachingAllocator memory type[%u] success.", index); | |||||
} else { | |||||
GELOGE(ge::INTERNAL_ERROR, "Alloc CachingAllocator failed."); | |||||
} | |||||
} else { | |||||
caching_allocator = it->second; | |||||
} | |||||
if (caching_allocator == nullptr) { | |||||
GELOGE(ge::INTERNAL_ERROR, "Create CachingAllocator failed."); | |||||
return ge::INTERNAL_ERROR; | |||||
} else { | |||||
if (caching_allocator->Initialize() != ge::SUCCESS) { | |||||
return ge::INTERNAL_ERROR; | |||||
} | |||||
} | |||||
} | |||||
return ge::SUCCESS; | |||||
} | |||||
CachingAllocator &MemManager::GetCachingAllocator(rtMemType_t memory_type) { | |||||
std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
CachingAllocator *caching_allocator = nullptr; | |||||
auto it = caching_allocator_map_.find(memory_type); | |||||
if (it != caching_allocator_map_.end()) { | |||||
caching_allocator = it->second; | |||||
} | |||||
// Usually impossible | |||||
if (caching_allocator == nullptr) { | |||||
GELOGE(ge::INTERNAL_ERROR, "GetCachingAllocator failed, memory type is %u.", memory_type); | |||||
static CachingAllocator default_caching_allocator(RT_MEMORY_RESERVED); | |||||
return default_caching_allocator; | |||||
; | |||||
} | |||||
return *caching_allocator; | |||||
} | |||||
CachingAllocator &MemManager::CachingInstance(rtMemType_t memory_type) { | |||||
return Instance().GetCachingAllocator(memory_type); | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -88,7 +88,7 @@ class MemoryAllocator { | |||||
/// @param [in] device_id device id | /// @param [in] device_id device id | ||||
/// @return memory address | /// @return memory address | ||||
/// | /// | ||||
uint8_t *MallocMemory(const string &purpose, uint64_t memory_size, uint32_t device_id = 0) const; | |||||
uint8_t *MallocMemory(const string &purpose, size_t memory_size, uint32_t device_id = 0) const; | |||||
/// | /// | ||||
/// @ingroup ge_graph | /// @ingroup ge_graph | ||||
@@ -108,7 +108,7 @@ class MemoryAllocator { | |||||
/// @param [in] device_id device id | /// @param [in] device_id device id | ||||
/// @return memory address | /// @return memory address | ||||
/// | /// | ||||
uint8_t *MallocMemory(const string &purpose, const string &memory_key, uint64_t memory_size, uint32_t device_id = 0); | |||||
uint8_t *MallocMemory(const string &purpose, const string &memory_key, size_t memory_size, uint32_t device_id = 0); | |||||
/// | /// | ||||
/// @ingroup ge_graph | /// @ingroup ge_graph | ||||
@@ -135,6 +135,7 @@ class MemoryAllocator { | |||||
}; | }; | ||||
using MemoryAllocatorPtr = std::shared_ptr<MemoryAllocator>; | using MemoryAllocatorPtr = std::shared_ptr<MemoryAllocator>; | ||||
class CachingAllocator; | |||||
class MemManager { | class MemManager { | ||||
public: | public: | ||||
@@ -142,6 +143,7 @@ class MemManager { | |||||
virtual ~MemManager(); | virtual ~MemManager(); | ||||
static MemManager &Instance(); | static MemManager &Instance(); | ||||
static MemoryAllocator *Instance(rtMemType_t memory_type); | static MemoryAllocator *Instance(rtMemType_t memory_type); | ||||
static CachingAllocator &CachingInstance(rtMemType_t memory_type); | |||||
MemManager(const MemManager &) = delete; | MemManager(const MemManager &) = delete; | ||||
MemManager &operator=(const MemManager &) = delete; | MemManager &operator=(const MemManager &) = delete; | ||||
/// | /// | ||||
@@ -164,13 +166,29 @@ class MemManager { | |||||
/// @ingroup ge_graph | /// @ingroup ge_graph | ||||
/// @brief ge memory allocator | /// @brief ge memory allocator | ||||
/// @param [in] memory_type memory type | /// @param [in] memory_type memory type | ||||
/// @return Status result of function | |||||
/// @return MemoryAllocator ptr | |||||
/// | /// | ||||
MemoryAllocator *GetMemoryAllocator(rtMemType_t memory_type); | MemoryAllocator *GetMemoryAllocator(rtMemType_t memory_type); | ||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief ge caching allocator | |||||
/// @param [in] memory_type memory type | |||||
/// @return CachingAllocator ptr | |||||
/// | |||||
CachingAllocator &GetCachingAllocator(rtMemType_t memory_type); | |||||
/// | |||||
/// @ingroup ge_graph | |||||
/// @brief ge create caching allocator | |||||
/// @param [in] memory_type memory type | |||||
/// @return Status result of function | |||||
/// | |||||
Status InitCachingAllocator(const std::vector<rtMemType_t> &memory_type); | |||||
std::map<rtMemType_t, MemoryAllocator *> memory_allocator_map_; | std::map<rtMemType_t, MemoryAllocator *> memory_allocator_map_; | ||||
MemoryAllocator *default_memory_allocator_; | |||||
std::mutex allocator_mutex_; | |||||
std::map<rtMemType_t, CachingAllocator *> caching_allocator_map_; | |||||
std::recursive_mutex allocator_mutex_; | |||||
}; | }; | ||||
}; // namespace ge | }; // namespace ge | ||||
@@ -25,8 +25,343 @@ | |||||
#include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
#include "graph/types.h" | #include "graph/types.h" | ||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
#include "common/thread_pool.h" | |||||
#include <algorithm> | |||||
namespace ge { | namespace ge { | ||||
namespace { | |||||
class RtContextSwitchGuard { | |||||
public: | |||||
RtContextSwitchGuard(rtCtxMode_t mode, uint32_t device_id) : last_(nullptr), current_(nullptr) { | |||||
auto ret = rtCtxGetCurrent(&last_); | |||||
if (ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Failed to get current context from rt, error-code %d", ret); | |||||
return; | |||||
} | |||||
ret = rtCtxCreate(¤t_, mode, static_cast<int32_t>(device_id)); | |||||
if (ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Failed to create new context for device %u, error-code %d", device_id, ret); | |||||
return; | |||||
} | |||||
ret = rtCtxSetCurrent(current_); | |||||
if (ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Failed to switch context to normal, context %p, device %u", current_, device_id); | |||||
return; | |||||
} | |||||
GELOGD("Create and switch rt context %p type %d for device %u, backup last %p.", current_, mode, device_id, last_); | |||||
} | |||||
~RtContextSwitchGuard() { | |||||
if (current_ != nullptr) { | |||||
auto ret = rtCtxDestroy(current_); | |||||
GELOGD("Destory current context %p result %d", current_, ret); | |||||
} | |||||
if (last_ != nullptr) { | |||||
auto ret = rtCtxSetCurrent(last_); | |||||
GELOGD("Recovery last context %p result %d.", last_, ret); | |||||
} | |||||
} | |||||
private: | |||||
rtContext_t last_; | |||||
rtContext_t current_; | |||||
}; | |||||
int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) { | |||||
int64_t var_size = GetSizeByDataType(desc.GetDataType()); | |||||
if (var_size <= 0) { | |||||
GELOGE(PARAM_INVALID, "Failed to calc var data size from data type %s", | |||||
TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str()); | |||||
return -1; | |||||
} | |||||
auto shape = desc.GetShape(); | |||||
auto dim_num = shape.GetDimNum(); | |||||
for (size_t dim_index = 0; dim_index < dim_num; ++dim_index) { | |||||
var_size *= shape.GetDim(dim_index); | |||||
} | |||||
return var_size; | |||||
} | |||||
Status CopyVarToDevice(const NodePtr &var, const formats::TransResult &trans_result, void *var_addr) { | |||||
GELOGD("Copy var %s from host to device, size %zu", var->GetName().c_str(), trans_result.length); | |||||
auto ret = rtMemcpy(var_addr, trans_result.length, reinterpret_cast<void *>(trans_result.data.get()), | |||||
trans_result.length, RT_MEMCPY_HOST_TO_DEVICE); | |||||
if (ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Failed to copy memory to device, size %zu", trans_result.length); | |||||
return RT_FAILED; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status CopyVarFromDevice(uint64_t session_id, const NodePtr &var, std::unique_ptr<uint8_t[]> &var_data, | |||||
const GeTensorDesc &input_desc) { | |||||
uint8_t *var_logic = nullptr; | |||||
GE_CHECK_NOTNULL(var); | |||||
auto ret = VarManager::Instance(session_id)->GetVarAddr(var->GetName(), input_desc, &var_logic); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, | |||||
"Failed to copy var %s from device, can not find it" | |||||
" from var manager %u", | |||||
var->GetName().c_str(), ret); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM); | |||||
if (var_addr == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, | |||||
"Failed to copy var %s from device, cant not get " | |||||
"var addr from logic addr %p", | |||||
var->GetName().c_str(), var_logic); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
int64_t var_size_bytes = CalcVarSizeInBytes(input_desc); | |||||
if (var_size_bytes <= 0) { | |||||
return INTERNAL_ERROR; | |||||
} | |||||
std::unique_ptr<uint8_t[]> var_host(new (std::nothrow) uint8_t[var_size_bytes]); | |||||
if (var_host == nullptr) { | |||||
GELOGE(OUT_OF_MEMORY, "Failed to malloc rt-host memory, size %ld", var_size_bytes); | |||||
return OUT_OF_MEMORY; | |||||
} | |||||
ret = rtMemcpy(reinterpret_cast<void *>(var_host.get()), var_size_bytes, reinterpret_cast<void *>(var_addr), | |||||
var_size_bytes, RT_MEMCPY_DEVICE_TO_HOST); | |||||
if (ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, | |||||
"Failed to copy var memory from device, var %s, size %ld," | |||||
" rt-error-code %u", | |||||
var->GetName().c_str(), var_size_bytes, ret); | |||||
return RT_FAILED; | |||||
} | |||||
GELOGD("Copy var %s from device to host, size %ld", var->GetName().c_str(), var_size_bytes); | |||||
var_data.swap(var_host); | |||||
GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr); | |||||
return SUCCESS; | |||||
} | |||||
Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats::TransResult &result) { | |||||
formats::TransResult result_last_time{}; | |||||
bool use_init_data = true; | |||||
for (const auto &trans_info : trans_road) { | |||||
if (trans_info.node_type == RESHAPE || trans_info.node_type == REFORMAT) { | |||||
GELOGD("Skip to trans variable data on the reshape/reformat node"); | |||||
continue; | |||||
} | |||||
uint8_t *src_data = nullptr; | |||||
if (use_init_data) { | |||||
src_data = var_data; | |||||
use_init_data = false; | |||||
} else { | |||||
src_data = result_last_time.data.get(); | |||||
} | |||||
formats::TransResult tmp_result{}; | |||||
if (trans_info.node_type == TRANSDATA || trans_info.node_type == TRANSPOSED) { | |||||
auto src_format = trans_info.input.GetFormat(); | |||||
auto src_shape = trans_info.input.GetShape().GetDims(); | |||||
auto dst_format = trans_info.output.GetFormat(); | |||||
auto dst_shape = trans_info.output.GetShape().GetDims(); | |||||
auto data_type = trans_info.input.GetDataType(); | |||||
GELOGD("Trans format from %s to %s, shape %s to %s, data-type %s", | |||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | |||||
formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(), | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
auto ret = formats::TransFormat({src_data, src_format, dst_format, src_shape, dst_shape, data_type}, tmp_result); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, | |||||
"Failed to trans format from %s to %s, shape %s to %s, " | |||||
"data type %s error code %u", | |||||
TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | |||||
formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(), | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str(), ret); | |||||
return ret; | |||||
} | |||||
} else if (trans_info.node_type == CAST) { | |||||
auto input_shape = trans_info.input.GetShape(); | |||||
auto src_data_size = input_shape.GetShapeSize() == 0 ? 1 : input_shape.GetShapeSize(); | |||||
auto src_data_type = trans_info.input.GetDataType(); | |||||
auto dst_data_type = trans_info.output.GetDataType(); | |||||
GELOGD("Trans data type from %s to %s, input shape %s, data size %ld", | |||||
TypeUtils::DataTypeToSerialString(src_data_type).c_str(), | |||||
TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(), | |||||
src_data_size); | |||||
auto ret = formats::TransDataType({src_data, static_cast<size_t>(src_data_size), src_data_type, dst_data_type}, | |||||
tmp_result); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to trans data type from %s to %s, input shape %s, data size %ld, error code %u", | |||||
TypeUtils::DataTypeToSerialString(src_data_type).c_str(), | |||||
TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(), | |||||
src_data_size, ret); | |||||
return ret; | |||||
} | |||||
} else { | |||||
GELOGE(UNSUPPORTED, "Failed to trans var data, the trans type %s does not supported", | |||||
trans_info.node_type.c_str()); | |||||
return UNSUPPORTED; | |||||
} | |||||
result_last_time = tmp_result; | |||||
} | |||||
result = result_last_time; | |||||
return SUCCESS; | |||||
} | |||||
/// re-alloc var memory on device using var-manager | |||||
/// free origin var memory(var manager does not support now) | |||||
/// @param session_id | |||||
/// @param var | |||||
/// @param var_size_bytes | |||||
/// @param var_device | |||||
/// @return | |||||
Status ReAssignVarAddr(uint64_t session_id, const std::string &var_name, const GeTensorDesc &tensor_desc, | |||||
void **var_device) { | |||||
uint8_t *var_logic = nullptr; | |||||
Status ret = VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, &var_logic); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, | |||||
"Failed to get var %s device addr, can not find it" | |||||
" from var manager %u", | |||||
var_name.c_str(), ret); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM); | |||||
if (var_addr == nullptr) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to convert var %s logic addr to real addr", var_name.c_str()); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
*var_device = var_addr; | |||||
GELOGI("var_logic:%p, var_addr:%p", var_logic, var_addr); | |||||
return SUCCESS; | |||||
} | |||||
Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t session_id) { | |||||
// do not need to do anything if only all reshape/reformat node on the trans_road | |||||
GE_CHECK_NOTNULL(var); | |||||
bool need_trans = false; | |||||
for (auto &road : trans_road) { | |||||
if (road.node_type != RESHAPE && road.node_type != REFORMAT) { | |||||
need_trans = true; | |||||
break; | |||||
} | |||||
} | |||||
if (!need_trans) { | |||||
return SUCCESS; | |||||
} | |||||
// Sync var data from device | |||||
std::unique_ptr<uint8_t[]> var_data; | |||||
if (trans_road.empty()) { | |||||
GELOGE(INTERNAL_ERROR, "Failed to get trans_road, trans_road is empty."); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
const GeTensorDesc &input_desc = trans_road.begin()->input; | |||||
auto ret = CopyVarFromDevice(session_id, var, var_data, input_desc); | |||||
if (ret != SUCCESS) { | |||||
return ret; | |||||
} | |||||
formats::TransResult trans_result{}; | |||||
ret = TransVarOnHost(var_data.get(), trans_road, trans_result); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to trans var data on host, error code %u", ret); | |||||
return ret; | |||||
} | |||||
void *var_device = nullptr; | |||||
/// It is a temporary solution to use the last GeTensorDesc to assign variable memory because the variable manager | |||||
/// depends on TensorDesc and it is difficult to be modified. The correct solution is to assign memory based on the | |||||
/// size of the converted variable. To complete the final solution, the dependency of the variable manager on | |||||
/// TensorDesc needs to be removed. This change is large and needs to be performed step by step. | |||||
ret = ReAssignVarAddr(session_id, var->GetName(), trans_road.rbegin()->output, &var_device); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to re-assign memory on device, size %zu", trans_result.length); | |||||
return ret; | |||||
} | |||||
// sync new data to device | |||||
ret = CopyVarToDevice(var, trans_result, var_device); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "Failed to send var data to device"); | |||||
return ret; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status TransTensor(uint8_t *var_data, const NodePtr &var_src, const NodePtr &var_dst, formats::TransResult &result) { | |||||
GE_CHECK_NOTNULL(var_src); | |||||
GE_CHECK_NOTNULL(var_src->GetOpDesc()); | |||||
GE_CHECK_NOTNULL(var_dst); | |||||
GE_CHECK_NOTNULL(var_dst->GetOpDesc()); | |||||
auto src_data_shape_size = var_src->GetOpDesc()->GetOutputDesc(0).GetShape().GetShapeSize(); | |||||
auto src_data_datatype = var_src->GetOpDesc()->GetOutputDesc(0).GetDataType(); | |||||
auto dst_data_datatype = var_dst->GetOpDesc()->GetOutputDesc(0).GetDataType(); | |||||
GE_IF_BOOL_EXEC( | |||||
src_data_datatype != dst_data_datatype, | |||||
auto ret = formats::TransDataType( | |||||
{var_data, static_cast<size_t>(src_data_shape_size), src_data_datatype, dst_data_datatype}, result); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "trans var data on host failed"); | |||||
return ret; | |||||
}); | |||||
return SUCCESS; | |||||
} | |||||
Status CopyTensorFromSrcVarNode(const NodePtr &var_src, const NodePtr &var_dst, uint64_t session_id, | |||||
uint32_t device_id) { | |||||
/// after FE fusion pass, input num of applymomentum op was changed, 0th input is var_fp32, 6th input is | |||||
/// var_fp16(new). | |||||
/// unlink edges between var_fp32 and "dst_node" (need fp16) of var_fp32, add edge between var_fp16 and dst_node. | |||||
/// need copy value from var_fp32 to var_fp16. | |||||
/// [opdesc of var_src and var_dst are checked before passed in, no need to check if they are nullptr] | |||||
GE_IF_BOOL_EXEC(var_src == nullptr || var_dst == nullptr, GELOGE(FAILED, "node var is nullptr"); return FAILED); | |||||
// src_node output_desc (fp32) | |||||
GeTensorDesc output_desc = var_src->GetOpDesc()->GetOutputDesc(0); | |||||
auto src_data_type = output_desc.GetDataType(); | |||||
auto src_shape = output_desc.GetShape(); | |||||
auto src_format = output_desc.GetFormat(); | |||||
GELOGI("src_node %s, src_format %s, src_shape %s, src_type %s", var_src->GetName().c_str(), | |||||
TypeUtils::FormatToSerialString(src_format).c_str(), formats::ShapeToString(src_shape).c_str(), | |||||
TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
// dst_node output_desc (fp16) | |||||
GeTensorDesc dst_tensor_desc = var_dst->GetOpDesc()->GetOutputDesc(0); | |||||
auto data_type = dst_tensor_desc.GetDataType(); | |||||
auto data_shape = dst_tensor_desc.GetShape(); | |||||
auto data_format = dst_tensor_desc.GetFormat(); | |||||
GELOGI("dst_node %s, src_format %s, src_shape %s, src_type %s", var_dst->GetName().c_str(), | |||||
TypeUtils::FormatToSerialString(data_format).c_str(), formats::ShapeToString(data_shape).c_str(), | |||||
TypeUtils::DataTypeToSerialString(data_type).c_str()); | |||||
// Sync var data from device | |||||
std::unique_ptr<uint8_t[]> var_src_data; | |||||
RtContextSwitchGuard switch_context(RT_CTX_NORMAL_MODE, device_id); | |||||
// copy from src_node | |||||
auto ret = CopyVarFromDevice(session_id, var_src, var_src_data, output_desc); | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "Copy Var From Device failed"); return ret); | |||||
// trans dtype | |||||
formats::TransResult trans_result{}; | |||||
ret = TransTensor(var_src_data.get(), var_src, var_dst, trans_result); | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "trans var data on host failed"); return ret); | |||||
// reset src value. | |||||
void *var_device = nullptr; | |||||
ret = ReAssignVarAddr(session_id, var_dst->GetName(), dst_tensor_desc, &var_device); | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "assign mem failed"); return ret); | |||||
// copy to device | |||||
ret = CopyVarToDevice(var_dst, trans_result, var_device); | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Failed to send var data to device"); return ret); | |||||
return SUCCESS; | |||||
} | |||||
} // namespace | |||||
Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | ||||
uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) { | uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) { | ||||
GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "dst addr is null. "); | GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "dst addr is null. "); | ||||
@@ -88,4 +423,101 @@ Status TransVarDataUtils::SyncTensorToDevice(const string &var_name, const uint8 | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes, uint64_t session_id, | |||||
rtContext_t context, uint32_t graph_id, uint32_t thread_num) { | |||||
ThreadPool executor(thread_num); | |||||
std::vector<std::future<Status>> vector_future; | |||||
for (auto &node : variable_nodes) { | |||||
if (node == nullptr) { | |||||
continue; | |||||
} | |||||
if (node->GetType() != VARIABLE) { | |||||
continue; | |||||
} | |||||
std::future<Status> f = executor.commit( | |||||
[](const ge::NodePtr &node, uint64_t session_id, rtContext_t ctx, uint32_t graph_id) -> Status { | |||||
rtError_t rt_ret = rtCtxSetCurrent(ctx); | |||||
if (rt_ret != RT_ERROR_NONE) { | |||||
GELOGE(RT_FAILED, "Failed to set context, error_code is: 0x%X.", rt_ret); | |||||
return RT_FAILED; | |||||
} | |||||
uint32_t allocated_graph_id = 0; | |||||
Status ret = VarManager::Instance(session_id)->GetAllocatedGraphId(node->GetName(), allocated_graph_id); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "var has not been allocated, node:%s, graph_id:%u.", node->GetName().c_str(), | |||||
graph_id); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
uint32_t changed_graph_id = 0; | |||||
ret = VarManager::Instance(session_id)->GetChangedGraphId(node->GetName(), changed_graph_id); | |||||
bool call_trans_var = | |||||
(ret == SUCCESS && changed_graph_id == graph_id && changed_graph_id != allocated_graph_id); | |||||
if (call_trans_var) { | |||||
GELOGI("VarManager::GetChangedGraphId() success, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id); | |||||
VarTransRoad *trans_road = VarManager::Instance(session_id)->GetTransRoad(node->GetName()); | |||||
if (trans_road == nullptr) { | |||||
GELOGI("The variable %s does not have any trans road", node->GetName().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
ret = TransVarData(node, *trans_road, session_id); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(INTERNAL_ERROR, "TransVarData failed, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id); | |||||
return INTERNAL_ERROR; | |||||
} | |||||
VarManager::Instance(session_id)->RemoveChangedGraphId(node->GetName()); | |||||
} | |||||
return SUCCESS; | |||||
}, | |||||
node, session_id, context, graph_id); | |||||
if (!f.valid()) { | |||||
GELOGE(FAILED, "Future is invalid"); | |||||
return FAILED; | |||||
} | |||||
vector_future.push_back(std::move(f)); | |||||
} | |||||
Status ret_status; | |||||
for (size_t i = 0; i < vector_future.size(); ++i) { | |||||
ret_status = vector_future[i].get(); | |||||
if (ret_status != SUCCESS) { | |||||
GELOGE(ret_status, "TransAllVarData:: trans %zu vardata failed", i); | |||||
return ret_status; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status TransVarDataUtils::CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id) { | |||||
GELOGI("CopyVarData start: session_id:%lu.", session_id); | |||||
if (compute_graph == nullptr) { | |||||
GELOGE(FAILED, "compute_graph is nullptr"); | |||||
return FAILED; | |||||
} | |||||
string cp_from_node; | |||||
bool copy_value = false; | |||||
for (auto &node : compute_graph->GetAllNodes()) { | |||||
GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr || node->GetOpDesc()->GetType() != VARIABLE, continue); | |||||
GE_IF_BOOL_EXEC(ge::AttrUtils::GetStr(node->GetOpDesc(), "_copy_from_var_node", cp_from_node), | |||||
GELOGI("Get original type of cp_from_node")); | |||||
if (cp_from_node.length() != 0) { | |||||
(void)ge::AttrUtils::GetBool(node->GetOpDesc(), "_copy_value", copy_value); // no need to check value | |||||
if (!copy_value) { | |||||
auto src_node = compute_graph->FindNode(cp_from_node); | |||||
GE_CHECK_NOTNULL(src_node); | |||||
GELOGI("current_var_node__: [%s] copy_from_var_node__: [%s].", node->GetName().c_str(), | |||||
src_node->GetName().c_str()); | |||||
auto ret = CopyTensorFromSrcVarNode(src_node, node, session_id, device_id); | |||||
GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "copy tensor failed!"); return FAILED); | |||||
// only copy once | |||||
(void)ge::AttrUtils::SetBool(node->GetOpDesc(), "_copy_value", true); // no need to check value | |||||
} | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
} // namespace ge | } // namespace ge |
@@ -22,6 +22,9 @@ | |||||
#include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
#include "framework/common/ge_types.h" | #include "framework/common/ge_types.h" | ||||
#include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
#include "graph/node.h" | |||||
#include "runtime/context.h" | |||||
#include "graph_var_manager.h" | |||||
namespace ge { | namespace ge { | ||||
class TransVarDataUtils { | class TransVarDataUtils { | ||||
@@ -31,6 +34,11 @@ class TransVarDataUtils { | |||||
static ge::Status SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, | static ge::Status SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, | ||||
const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); | const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id_); | ||||
static ge::Status TransAllVarData(const std::vector<NodePtr> &variable_nodes, uint64_t session_id, | |||||
rtContext_t context, uint32_t graph_id, uint32_t thread_num = 16); | |||||
static ge::Status CopyVarData(const ComputeGraphPtr &compute_graph, uint64_t session_id, uint32_t device_id); | |||||
private: | private: | ||||
static ge::Status SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | static ge::Status SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | ||||
uint8_t **host_addr, int64_t &addr_size, uint64_t session_id_); | uint8_t **host_addr, int64_t &addr_size, uint64_t session_id_); | ||||
@@ -24,35 +24,49 @@ | |||||
#include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
namespace ge { | namespace ge { | ||||
Status HcomOmeUtil::GetHcomDataType(const ge::ConstOpDescPtr &op_desc, hcclDataType_t &data_type) { | |||||
Status HcomOmeUtil::GetHcclDataType(const ge::ConstOpDescPtr &op_desc, | |||||
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | |||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
if (CheckKernelHcclInfo(op_desc, kernel_hccl_infos) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "HcomOmeUtil:: the number of GETaskKernelHcclInfo is invalid."); | |||||
return PARAM_INVALID; | |||||
} | |||||
GELOGI("GetHcclDataType start, node[%s], opType[%s].", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
if (op_desc->GetType() == HVDWAIT) { | |||||
return SUCCESS; | |||||
} | |||||
ge::DataType src_data_type = ge::DT_FLOAT; | ge::DataType src_data_type = ge::DT_FLOAT; | ||||
if (op_desc->GetType() == HCOMRECEIVE) { | |||||
bool ret = ge::AttrUtils::GetDataType(op_desc, HCOM_ATTR_DATA_TYPE, src_data_type); | |||||
if (ret == false) { | |||||
GELOGE(PARAM_INVALID, "op:HcomReceive, op desc no attr: dtype."); | |||||
for (size_t i = 0; i < kernel_hccl_infos.size(); i++) { | |||||
if (op_desc->GetType() == HCOMRECEIVE) { | |||||
bool ret = ge::AttrUtils::GetDataType(op_desc, HCOM_ATTR_DATA_TYPE, src_data_type); | |||||
if (ret == false) { | |||||
GELOGE(PARAM_INVALID, "op:HcomReceive, op desc no attr: dtype."); | |||||
return PARAM_INVALID; | |||||
} | |||||
} else { | |||||
auto input_desc_ptr = op_desc->GetInputDescPtr(i); | |||||
GE_CHECK_NOTNULL(input_desc_ptr); | |||||
src_data_type = input_desc_ptr->GetDataType(); | |||||
} | |||||
auto iter = kConstOpHcclDataType.find(static_cast<int64_t>(src_data_type)); | |||||
if (iter == kConstOpHcclDataType.end()) { | |||||
GELOGE(PARAM_INVALID, | |||||
"HcomOmeUtil:: Node: %s Optype: %s HcomDataType cann't support! Current Davinci Data Type : %s", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||||
ge::TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
} else { | |||||
auto input_desc_ptr = op_desc->GetInputDescPtr(0); | |||||
GE_CHECK_NOTNULL(input_desc_ptr); | |||||
src_data_type = input_desc_ptr->GetDataType(); | |||||
} | |||||
auto iter = kConstOpHcomDataType.find(static_cast<int64_t>(src_data_type)); | |||||
if (iter == kConstOpHcomDataType.end()) { | |||||
GELOGE(PARAM_INVALID, "HcomOmeUtil:: HcomDataType cann't support! Current Davinci Data Type : %s", | |||||
ge::TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
return PARAM_INVALID; | |||||
kernel_hccl_infos[i].dataType = iter->second; | |||||
} | } | ||||
data_type = iter->second; | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HcomOmeUtil::GetHcomTypeSize(hcclDataType_t data_type, int32_t &size) { | |||||
auto iter = kConstOpHcomDataTypeSize.find(data_type); | |||||
GE_CHK_BOOL_EXEC(iter != kConstOpHcomDataTypeSize.end(), return PARAM_INVALID, | |||||
Status HcomOmeUtil::GetHcclTypeSize(hcclDataType_t data_type, int32_t &size) { | |||||
auto iter = kConstOpHcclDataTypeSize.find(data_type); | |||||
GE_CHK_BOOL_EXEC(iter != kConstOpHcclDataTypeSize.end(), return PARAM_INVALID, | |||||
"HcomOmeUtil::HcomDataTypeSize , No DataTypeSize!"); | "HcomOmeUtil::HcomDataTypeSize , No DataTypeSize!"); | ||||
size = iter->second; | size = iter->second; | ||||
@@ -62,10 +76,14 @@ Status HcomOmeUtil::GetHcomTypeSize(hcclDataType_t data_type, int32_t &size) { | |||||
Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType_t data_type, bool is_allgather, | Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType_t data_type, bool is_allgather, | ||||
int &count) { | int &count) { | ||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
if (!IsHCOMOp(op_desc->GetType())) { | |||||
GELOGE(PARAM_INVALID, "HcomOmeUtil:: operator is not Hcom operator."); | |||||
return PARAM_INVALID; | |||||
} | |||||
int64_t total_size = 0; | int64_t total_size = 0; | ||||
int64_t align_size = 512; | int64_t align_size = 512; | ||||
int32_t size = 0; | int32_t size = 0; | ||||
GE_CHK_STATUS_RET(HcomOmeUtil::GetHcomTypeSize(data_type, size), "GetHcomCount: GetHcomTypeSize fail!"); | |||||
GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclTypeSize(data_type, size), "GetHcomCount: GetHcclTypeSize fail!"); | |||||
if (op_desc->GetType() == HCOMRECEIVE) { | if (op_desc->GetType() == HCOMRECEIVE) { | ||||
vector<int64_t> shape_dims; | vector<int64_t> shape_dims; | ||||
bool ret = ge::AttrUtils::GetListInt(op_desc, HCOM_ATTR_SHAPE, shape_dims); | bool ret = ge::AttrUtils::GetListInt(op_desc, HCOM_ATTR_SHAPE, shape_dims); | ||||
@@ -114,34 +132,207 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HcomOmeUtil::GetHcomOperationType(const ge::ConstOpDescPtr &op_desc, hcclRedOp_t &op_type) { | |||||
Status HcomOmeUtil::GetHorovodCount(const ge::ConstOpDescPtr &op_desc, | |||||
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | |||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
if (!IsHorovodOp(op_desc->GetType())) { | |||||
GELOGE(PARAM_INVALID, "HcomOmeUtil:: operator is not Horovod operator."); | |||||
return PARAM_INVALID; | |||||
} | |||||
int64_t align_size = 512; | |||||
int32_t size = 0; | |||||
for (size_t i = 0; i < op_desc->GetInputsSize(); i++) { | |||||
GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclTypeSize(static_cast<tagHcclDataType>(kernel_hccl_infos[i].dataType), size), | |||||
"GetHorovodCount: GetHcclTypeSize fail!"); | |||||
int64_t input_size = 0; | |||||
int64_t block_size = 0; | |||||
GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(i)); | |||||
GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetInputDescPtr(i), input_size), | |||||
"get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); | |||||
std::string hcom_op_type; | |||||
GE_CHK_BOOL_EXEC(ge::AttrUtils::GetStr(op_desc, HCOM_ATTR_REDUCE_TYPE, hcom_op_type), return PARAM_INVALID, | |||||
"HcomOmeUtil::Get HCOM_ATTR_REDUCE_TYPE fail, not support!"); | |||||
if (hcom_op_type == "min") { | |||||
op_type = HCCL_REP_OP_MIN; | |||||
} else if (hcom_op_type == "max") { | |||||
op_type = HCCL_REP_OP_MAX; | |||||
} else if (hcom_op_type == "prod") { | |||||
op_type = HCCL_REP_OP_PROD; | |||||
} else if (hcom_op_type == "sum") { | |||||
op_type = HCCL_REP_OP_SUM; | |||||
} else { | |||||
GELOGE(PARAM_INVALID, "HcomOmeUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [%s] not support!", hcom_op_type.c_str()); | |||||
int64_t shape_size = op_desc->GetInputDescPtr(i)->GetShape().GetShapeSize(); | |||||
GE_CHK_STATUS_RET(ge::CheckInt64Int32MulOverflow(shape_size, size), | |||||
"Product of shape size and size beyond INT64_MAX"); | |||||
if (kernel_hccl_infos[0].hccl_type == HVDCALLBACKALLGATHER) { | |||||
block_size = shape_size * size; | |||||
} else { | |||||
block_size = (input_size + align_size - 1) / align_size * align_size; | |||||
} | |||||
GE_CHK_BOOL_RET_STATUS(size != 0, PARAM_INVALID, "Size is zero"); | |||||
GE_CHK_BOOL_EXEC(block_size % size == 0, return PARAM_INVALID, "block_size:%ld is not divisiable by size:%d.", | |||||
block_size, size); | |||||
kernel_hccl_infos[i].count = static_cast<int>(block_size / size); | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status HcomOmeUtil::GetHcclCount(const ge::ConstOpDescPtr &op_desc, | |||||
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
Status ret; | |||||
ret = CheckKernelHcclInfo(op_desc, kernel_hccl_infos); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "HcomOmeUtil:: the number of GETaskKernelHcclInfo is invalid."); | |||||
return PARAM_INVALID; | return PARAM_INVALID; | ||||
} | } | ||||
GELOGI("GetHcclCount start, node[%s], opType[%s].", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
if (IsHCOMOp(op_desc->GetType())) { | |||||
int32_t count = 0; | |||||
ret = GetHcomCount(op_desc, static_cast<tagHcclDataType>(kernel_hccl_infos[0].dataType), | |||||
kernel_hccl_infos[0].hccl_type == HCOMALLGATHER, count); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "HcomOmeUtil:: Node: %s Optype: %s get the Hcom operator hccl count fail.", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
kernel_hccl_infos[0].count = count; | |||||
} | |||||
if (IsHorovodOp(op_desc->GetType())) { | |||||
ret = GetHorovodCount(op_desc, kernel_hccl_infos); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "HcomOmeUtil:: Node: %s Optype: %s get the Horovod hccl operator count fail.", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status HcomOmeUtil::GetHcomRootId(const ge::ConstOpDescPtr &op_desc, int64_t &root_id) { | |||||
Status HcomOmeUtil::GetHcclOperationType(const ge::ConstOpDescPtr &op_desc, hcclRedOp_t &op_type) { | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (IsHCOMOp(op_desc->GetType())) { | |||||
std::string hcom_op_type; | |||||
GE_CHK_BOOL_EXEC(ge::AttrUtils::GetStr(op_desc, HCOM_ATTR_REDUCE_TYPE, hcom_op_type), return PARAM_INVALID, | |||||
"HcomOmeUtil:: Node: %s Optype: %s Get HCOM_ATTR_REDUCE_TYPE fail, not support!", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
if (hcom_op_type == "min") { | |||||
op_type = HCCL_REP_OP_MIN; | |||||
} else if (hcom_op_type == "max") { | |||||
op_type = HCCL_REP_OP_MAX; | |||||
} else if (hcom_op_type == "prod") { | |||||
op_type = HCCL_REP_OP_PROD; | |||||
} else if (hcom_op_type == "sum") { | |||||
op_type = HCCL_REP_OP_SUM; | |||||
} else { | |||||
GELOGE(PARAM_INVALID, "HcomOmeUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [%s] not support!", hcom_op_type.c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
if (IsHorovodOp(op_desc->GetType())) { | |||||
int64_t horovod_op_type; | |||||
GE_CHK_BOOL_EXEC(ge::AttrUtils::GetInt(op_desc, ATTR_HOROVOD_ATTR_REDUCE_TYPE, horovod_op_type), | |||||
return PARAM_INVALID, | |||||
"HcomOmeUtil:: Node: %s Optype: %s Get ATTR_HOROVOD_ATTR_REDUCE_TYPE fail, not support!", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
auto iter = kHorovodRedOpToHcclRedOp.find(static_cast<horovodRedOp_t>(horovod_op_type)); | |||||
if (iter == kHorovodRedOpToHcclRedOp.end()) { | |||||
GELOGE(PARAM_INVALID, "HcomOmeUtil:: Node: %s Optype: %s HcomOpType cann't support! Current HcomOpType : %ld", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str(), horovod_op_type); | |||||
return PARAM_INVALID; | |||||
} | |||||
op_type = iter->second; | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
Status HcomOmeUtil::GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &root_id) { | |||||
GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
GE_CHK_BOOL_EXEC(ge::AttrUtils::GetInt(op_desc, HCOM_ATTR_ROOT_RANK, root_id), return PARAM_INVALID, | GE_CHK_BOOL_EXEC(ge::AttrUtils::GetInt(op_desc, HCOM_ATTR_ROOT_RANK, root_id), return PARAM_INVALID, | ||||
"HcomOmeUtil::Get HCOM_ATTR_ROOT_INDEX fail, not support!"); | |||||
"HcomOmeUtil::Node %s Optype: %s Get HCOM_ATTR_ROOT_INDEX fail, not support!", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
return SUCCESS; | |||||
} | |||||
Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, | |||||
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (op_desc->GetType() == HCOMBROADCAST || op_desc->GetType() == HVDCALLBACKBROADCAST) { | |||||
GELOGI("GetAllRootId Node[%s] opType[%s] get hccl rootId.", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
int64_t root_id = 0; | |||||
Status dmrt = GetHcclRootId(op_desc, root_id); | |||||
if (dmrt != SUCCESS) { | |||||
GELOGE(FAILED, "davinci_model: GetHcomRootId fail! domi error: %u", dmrt); | |||||
return FAILED; | |||||
} | |||||
for (size_t i = 0; i < kernel_hccl_infos.size(); i++) { | |||||
kernel_hccl_infos[i].rootId = root_id; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
bool HcomOmeUtil::IsHCOMOp(const string &op_type) { | |||||
return (op_type == HCOMALLREDUCE) || (op_type == HCOMALLGATHER) || (op_type == HCOMBROADCAST) || | |||||
(op_type == HCOMSEND) || (op_type == HCOMRECEIVE) || (op_type == HCOMREDUCESCATTER); | |||||
} | |||||
bool HcomOmeUtil::IsHorovodOp(const string &op_type) { | |||||
return (op_type == HVDCALLBACKALLREDUCE) || (op_type == HVDCALLBACKALLGATHER) || (op_type == HVDCALLBACKBROADCAST) || | |||||
(op_type == HVDWAIT); | |||||
} | |||||
Status HcomOmeUtil::CheckKernelHcclInfo(const ge::ConstOpDescPtr &op_desc, | |||||
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (IsHCOMOp(op_desc->GetType()) && kernel_hccl_infos.size() != 1) { | |||||
GELOGE(PARAM_INVALID, "HcomOmeUtil:: in Hcom scenario, the number of GETaskKernelHcclInfo is invalid."); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (IsHorovodOp(op_desc->GetType())) { | |||||
if (op_desc->GetType() == HVDWAIT) { | |||||
return SUCCESS; | |||||
} | |||||
if (kernel_hccl_infos.empty() || op_desc->GetInputsSize() != kernel_hccl_infos.size()) { | |||||
GELOGE(PARAM_INVALID, "HcomOmeUtil:: in Horovod scenario, the number of GETaskKernelHcclInfo is invalid."); | |||||
return PARAM_INVALID; | |||||
} | |||||
} | |||||
return SUCCESS; | |||||
} | |||||
void HcomOmeUtil::GetHcclType(const domi::TaskDef &task_def, std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | |||||
auto hccl_def = task_def.kernel_hccl(); | |||||
std::string hccl_type = hccl_def.hccl_type(); | |||||
for (size_t i = 0; i < kernel_hccl_infos.size(); i++) { | |||||
kernel_hccl_infos[i].hccl_type = hccl_type; | |||||
} | |||||
} | |||||
Status HcomOmeUtil::GetHorovodInputs(const ge::ConstOpDescPtr &op_desc, | |||||
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | |||||
GE_CHECK_NOTNULL(op_desc); | |||||
if (!IsHorovodOp(op_desc->GetType())) { | |||||
return SUCCESS; | |||||
} | |||||
if (CheckKernelHcclInfo(op_desc, kernel_hccl_infos) != SUCCESS) { | |||||
GELOGE(PARAM_INVALID, "HcomOmeUtil:: Node: %s Optype: %s the number of GETaskKernelHcclInfo is invalid.", | |||||
op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
return PARAM_INVALID; | |||||
} | |||||
if (op_desc->GetType() == HVDWAIT) { | |||||
return SUCCESS; | |||||
} | |||||
for (size_t i = 0; i < op_desc->GetInputsSize(); i++) { | |||||
ConstGeTensorDescPtr input_desc = op_desc->GetInputDescPtr(i); | |||||
GETaskKernelHcclInfo &kernel_hccl_info = kernel_hccl_infos.at(i); | |||||
kernel_hccl_info.input_name = op_desc->GetInputNameByIndex(i); | |||||
kernel_hccl_info.dims = input_desc->GetShape().GetDims(); | |||||
} | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} // namespace ge | } // namespace ge |
@@ -22,72 +22,146 @@ | |||||
#include <vector> | #include <vector> | ||||
#include "common/debug/log.h" | #include "common/debug/log.h" | ||||
#include "common/opskernel/ge_task_info.h" | |||||
#include "common/string_util.h" | #include "common/string_util.h" | ||||
#include "common/types.h" | #include "common/types.h" | ||||
#include "common/util.h" | #include "common/util.h" | ||||
#include "graph/op_desc.h" | #include "graph/op_desc.h" | ||||
#include "hccl/hcom.h" | #include "hccl/hcom.h" | ||||
#include "proto/task.pb.h" | |||||
namespace ge { | namespace ge { | ||||
using std::string; | using std::string; | ||||
using std::vector; | using std::vector; | ||||
static std::map<int64_t, hcclDataType_t> kConstOpHcomDataType = { | |||||
{ge::DT_FLOAT, HCCL_DATA_TYPE_FLOAT}, | |||||
{ge::DT_FLOAT16, HCCL_DATA_TYPE_HALF}, | |||||
{ge::DT_INT8, HCCL_DATA_TYPE_INT8}, | |||||
{ge::DT_INT32, HCCL_DATA_TYPE_INT}, | |||||
static std::map<int64_t, hcclDataType_t> kConstOpHcclDataType = { | |||||
{ge::DT_FLOAT, HCCL_DATA_TYPE_FLOAT}, | |||||
{ge::DT_FLOAT16, HCCL_DATA_TYPE_HALF}, | |||||
{ge::DT_INT8, HCCL_DATA_TYPE_INT8}, | |||||
{ge::DT_INT32, HCCL_DATA_TYPE_INT}, | |||||
}; | }; | ||||
static std::map<hcclDataType_t, int32_t> kConstOpHcomDataTypeSize = { | |||||
{HCCL_DATA_TYPE_FLOAT, sizeof(float)}, | |||||
{HCCL_DATA_TYPE_HALF, sizeof(float) / 2}, | |||||
{HCCL_DATA_TYPE_INT8, sizeof(int8_t)}, | |||||
{HCCL_DATA_TYPE_INT, sizeof(int32_t)}, | |||||
static std::map<hcclDataType_t, int32_t> kConstOpHcclDataTypeSize = { | |||||
{HCCL_DATA_TYPE_FLOAT, sizeof(float)}, | |||||
{HCCL_DATA_TYPE_HALF, sizeof(float) / 2}, | |||||
{HCCL_DATA_TYPE_INT8, sizeof(int8_t)}, | |||||
{HCCL_DATA_TYPE_INT, sizeof(int32_t)}, | |||||
}; | |||||
static std::map<horovodRedOp_t, hcclRedOp_t> kHorovodRedOpToHcclRedOp = { | |||||
{HOROVOD_REP_OP_SUM, HCCL_REP_OP_SUM}, {HOROVOD_REP_OP_MIN, HCCL_REP_OP_MIN}, | |||||
{HOROVOD_REP_OP_MAX, HCCL_REP_OP_MAX}, {HOROVOD_REP_OP_PROD, HCCL_REP_OP_PROD}, | |||||
{HOROVOD_REP_OP_RESERVED, HCCL_REP_OP_RESERVED}, | |||||
}; | }; | ||||
class HcomOmeUtil { | class HcomOmeUtil { | ||||
public: | public: | ||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief GetHcomDataType | |||||
/// @brief GetHcclDataType | |||||
/// @return SUCCESS | /// @return SUCCESS | ||||
/// @return FAIL | /// @return FAIL | ||||
/// | /// | ||||
static Status GetHcomDataType(const ge::ConstOpDescPtr &op_desc, hcclDataType_t &data_type); | |||||
static Status GetHcclDataType(const ge::ConstOpDescPtr &op_desc, | |||||
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos); | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief GetHcomTypeSize | |||||
/// @brief GetHcclTypeSize | |||||
/// @return SUCCESS | /// @return SUCCESS | ||||
/// @return FAIL | /// @return FAIL | ||||
/// | /// | ||||
static Status GetHcomTypeSize(hcclDataType_t data_type, int32_t &size); | |||||
static Status GetHcclTypeSize(hcclDataType_t data_type, int32_t &size); | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief GetHcomCount | |||||
/// @brief GetHcclCount | |||||
/// @return SUCCESS | /// @return SUCCESS | ||||
/// @return FAIL | /// @return FAIL | ||||
/// | /// | ||||
static Status GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType_t data_type, bool is_allgather, | |||||
int &count); | |||||
static Status GetHcclCount(const ge::ConstOpDescPtr &op_desc, std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos); | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief GetHcclOperationType | |||||
/// @return SUCCESS | |||||
/// @return FAIL | |||||
/// | |||||
static Status GetHcclOperationType(const ge::ConstOpDescPtr &op_desc, hcclRedOp_t &op_type); | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief GetHcclRootId | |||||
/// @return SUCCESS | |||||
/// @return FAIL | |||||
/// | |||||
static Status GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &root_id); | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief GetAllRootId | |||||
/// @return SUCCESS | |||||
/// @return FAIL | |||||
/// | |||||
static Status GetAllRootId(const ge::ConstOpDescPtr &op_desc, std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos); | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief check the op_type whether is hcom operator or not | |||||
/// @return true | |||||
/// @return false | |||||
/// | |||||
static bool IsHCOMOp(const string &op_type); | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief check the op_type whether is horovod operator or not | |||||
/// @return true | |||||
/// @return false | |||||
/// | |||||
static bool IsHorovodOp(const string &op_type); | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief GetHcclType | |||||
/// @return void | |||||
/// | |||||
static void GetHcclType(const domi::TaskDef &task_def, std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos); | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief GetHcomOperationType | |||||
/// @brief CheckKernelHcclInfo | |||||
/// @return SUCCESS | |||||
/// @return FAIL | |||||
/// | |||||
static Status CheckKernelHcclInfo(const ge::ConstOpDescPtr &op_desc, | |||||
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos); | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief GetHorovodInputs | |||||
/// @return SUCCESS | /// @return SUCCESS | ||||
/// @return FAIL | /// @return FAIL | ||||
/// | /// | ||||
static Status GetHcomOperationType(const ge::ConstOpDescPtr &op_desc, hcclRedOp_t &op_type); | |||||
static Status GetHorovodInputs(const ge::ConstOpDescPtr &op_desc, | |||||
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos); | |||||
private: | |||||
/// | |||||
/// @ingroup domi_ome | |||||
/// @brief GetHcomCount | |||||
/// @return SUCCESS | |||||
/// @return FAIL | |||||
/// | |||||
static Status GetHcomCount(const ge::ConstOpDescPtr &op_desc, hcclDataType_t data_type, bool is_allgather, | |||||
int &count); | |||||
/// | /// | ||||
/// @ingroup domi_ome | /// @ingroup domi_ome | ||||
/// @brief GetHcomRootId | |||||
/// @brief GetHorovodCount | |||||
/// @return SUCCESS | /// @return SUCCESS | ||||
/// @return FAIL | /// @return FAIL | ||||
/// | /// | ||||
static Status GetHcomRootId(const ge::ConstOpDescPtr &op_desc, int64_t &root_id); | |||||
static Status GetHorovodCount(const ge::ConstOpDescPtr &op_desc, | |||||
std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos); | |||||
}; | }; | ||||
} // namespace ge | } // namespace ge | ||||
#endif // GE_GRAPH_MANAGER_UTIL_HCOM_UTIL_H_ | #endif // GE_GRAPH_MANAGER_UTIL_HCOM_UTIL_H_ |
@@ -134,7 +134,7 @@ Status GraphOptimize::OptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { | |||||
return GE_CLI_GE_NOT_INITIALIZED; | return GE_CLI_GE_NOT_INITIALIZED; | ||||
} | } | ||||
std::map<string, GraphOptimizerPtr> graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjs(); | |||||
auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); | |||||
GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %lu.", | GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %lu.", | ||||
graph_optimizer.size()); | graph_optimizer.size()); | ||||
string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; | string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; | ||||
@@ -154,6 +154,37 @@ Status GraphOptimize::OptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { | |||||
return ret; | return ret; | ||||
} | } | ||||
Status GraphOptimize::OptimizeOriginalGraphJudgeInsert(ComputeGraphPtr &compute_graph) { | |||||
GELOGD("OptimizeOriginalGraphJudgeInsert in"); | |||||
GE_CHECK_NOTNULL(compute_graph); | |||||
Status ret = SUCCESS; | |||||
std::shared_ptr<GELib> instance_ptr = ge::GELib::GetInstance(); | |||||
if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
GELOGE(GE_CLI_GE_NOT_INITIALIZED, "OptimizeOriginalGraph failed."); | |||||
return GE_CLI_GE_NOT_INITIALIZED; | |||||
} | |||||
auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); | |||||
GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %lu.", | |||||
graph_optimizer.size()); | |||||
string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; | |||||
if (graph_optimizer.size() != 0) { | |||||
for (auto iter = graph_optimizer.begin(); iter != graph_optimizer.end(); ++iter) { | |||||
if (iter->first == exclude_core_Type) { | |||||
GELOGI("[OptimizeOriginalGraphJudgeInsert]: engine type will exclude: %s", exclude_core_Type.c_str()); | |||||
continue; | |||||
} | |||||
GELOGI("Begin to refine running format by engine %s", iter->first.c_str()); | |||||
ret = (iter->second)->OptimizeOriginalGraphJudgeInsert(*compute_graph); | |||||
if (ret != SUCCESS) { | |||||
GELOGE(ret, "[OptimizeOriginalGraphJudgeInsert]: graph optimize failed, ret:%d", ret); | |||||
return ret; | |||||
} | |||||
} | |||||
} | |||||
return ret; | |||||
} | |||||
Status GraphOptimize::NewOptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { | Status GraphOptimize::NewOptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { | ||||
GELOGD("NewOptimizeOriginalGraph in"); | GELOGD("NewOptimizeOriginalGraph in"); | ||||
if (compute_graph == nullptr) { | if (compute_graph == nullptr) { | ||||
@@ -168,7 +199,7 @@ Status GraphOptimize::NewOptimizeOriginalGraph(ComputeGraphPtr &compute_graph) { | |||||
return GE_CLI_GE_NOT_INITIALIZED; | return GE_CLI_GE_NOT_INITIALIZED; | ||||
} | } | ||||
std::map<string, GraphOptimizerPtr> graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjs(); | |||||
auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); | |||||
GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %lu.", | GELOGI("optimize by opskernel in original graph optimize phase. num of graph_optimizer is %lu.", | ||||
graph_optimizer.size()); | graph_optimizer.size()); | ||||
string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; | string exclude_core_Type = (core_type_ == kVectorCore) ? kAicoreEngine : kVectorEngine; | ||||
@@ -207,7 +238,7 @@ Status GraphOptimize::OptimizeOriginalGraphForQuantize(ComputeGraphPtr &compute_ | |||||
return GE_CLI_GE_NOT_INITIALIZED; | return GE_CLI_GE_NOT_INITIALIZED; | ||||
} | } | ||||
std::map<string, GraphOptimizerPtr> graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjs(); | |||||
auto graph_optimizer = instance_ptr->OpsKernelManagerObj().GetAllGraphOptimizerObjsByPriority(); | |||||
GELOGI("optimize by opskernel in original graph optimize quantize phase. num of graph_optimizer is %zu.", | GELOGI("optimize by opskernel in original graph optimize quantize phase. num of graph_optimizer is %zu.", | ||||
graph_optimizer.size()); | graph_optimizer.size()); | ||||
Status ret = SUCCESS; | Status ret = SUCCESS; | ||||
@@ -47,6 +47,8 @@ class GraphOptimize { | |||||
// original graph optimize | // original graph optimize | ||||
Status OptimizeOriginalGraph(ComputeGraphPtr &compute_graph); | Status OptimizeOriginalGraph(ComputeGraphPtr &compute_graph); | ||||
Status OptimizeOriginalGraphJudgeInsert(ComputeGraphPtr &compute_graph); | |||||
// new original graph optimize | // new original graph optimize | ||||
Status NewOptimizeOriginalGraph(ComputeGraphPtr &compute_graph); | Status NewOptimizeOriginalGraph(ComputeGraphPtr &compute_graph); | ||||
@@ -43,39 +43,44 @@ | |||||
#define REQUIRE_SUCCESS(cond, ...) REQUIRE(((cond) == SUCCESS), __VA_ARGS__) | #define REQUIRE_SUCCESS(cond, ...) REQUIRE(((cond) == SUCCESS), __VA_ARGS__) | ||||
#define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) | #define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) | ||||
namespace { | |||||
const bool kDebugging = (std::getenv("DEBUG_DYNAMIC_PARTITION") != nullptr); | |||||
} // namespace | |||||
bool IsExperimental() { | |||||
const static bool kIsExperimental = (std::getenv("EXPERIMENTAL_DYNAMIC_PARTITION") != nullptr); | |||||
return kIsExperimental; | |||||
} | |||||
#define DLOG() \ | |||||
if (kDebugging) std::cerr | |||||
namespace ge { | namespace ge { | ||||
using Cluster = DynamicShapePartitioner::Cluster; | using Cluster = DynamicShapePartitioner::Cluster; | ||||
using ClusterPtr = std::shared_ptr<Cluster>; | using ClusterPtr = std::shared_ptr<Cluster>; | ||||
Status DynamicShapePartitioner::Partition() { | Status DynamicShapePartitioner::Partition() { | ||||
REQUIRE_NOT_NULL(root_graph_, "Graph is nullptr."); | REQUIRE_NOT_NULL(root_graph_, "Graph is nullptr."); | ||||
DLOG() << "Start dynamic shape partition graph " << root_graph_->GetName() << std::endl; | |||||
REQUIRE_SUCCESS(MarkUnknowShapeNodes(), "Failed mark unknow shape nodes."); | |||||
if (!IsExperimental()) { | |||||
GELOGD("Skip dynamic shape partition as not in experimental mode."); | |||||
REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, false), | |||||
"Failed set dynamic shape partitioned flag on root graph."); | |||||
return SUCCESS; | |||||
} | |||||
GELOGD("Start dynamic shape partition graph %s.", root_graph_->GetName().c_str()); | |||||
REQUIRE_SUCCESS(MarkUnknownShapeNodes(), "Failed mark unknown shape nodes."); | |||||
if (unknown_shape_nodes_.empty()) { | if (unknown_shape_nodes_.empty()) { | ||||
DLOG() << "Skip dynamic shape partition of graph " << root_graph_->GetName() << " as all nodes are known shape." | |||||
<< std::endl; | |||||
REQUIRE(AttrUtils::SetBool(*root_graph_, "_dynamic_shape_partitioned", false), | |||||
GELOGD("Skip dynamic shape partition of graph %s as all nodes are known shape.", root_graph_->GetName().c_str()); | |||||
REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, false), | |||||
"Failed set dynamic shape partitioned flag on root graph."); | "Failed set dynamic shape partitioned flag on root graph."); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
REQUIRE(AttrUtils::SetBool(*root_graph_, "_dynamic_shape_partitioned", true), | |||||
REQUIRE(AttrUtils::SetBool(*root_graph_, ATTR_NAME_DYNAMIC_SHAPE_PARTITIONED, true), | |||||
"Failed set dynamic shape partitioned flag on root graph."); | "Failed set dynamic shape partitioned flag on root graph."); | ||||
DumpGraph("_Before_DSP"); | DumpGraph("_Before_DSP"); | ||||
auto status = PartitionImpl(); | auto status = PartitionImpl(); | ||||
DLOG() << DebugString() << std::endl; | |||||
GELOGD("%s.", DebugString().c_str()); | |||||
if (status != SUCCESS) { | if (status != SUCCESS) { | ||||
GELOGE(status, "Failed dynamic shape partition graph: %s, status:\n %s", root_graph_->GetName().c_str(), | GELOGE(status, "Failed dynamic shape partition graph: %s, status:\n %s", root_graph_->GetName().c_str(), | ||||
DebugString().c_str()); | DebugString().c_str()); | ||||
} | } | ||||
DumpGraph("_After_DSP"); | DumpGraph("_After_DSP"); | ||||
DLOG() << (status == SUCCESS ? "Succeed" : "Failed") << " dynamic shape partition graph " << root_graph_->GetName() | |||||
<< std::endl; | |||||
GELOGD("Finish dynamic shape partition graph %s.", root_graph_->GetName().c_str()); | |||||
ClearResource(); | ClearResource(); | ||||
return status; | return status; | ||||
} | } | ||||
@@ -122,36 +127,36 @@ Status DynamicShapePartitioner::BuildPartitionSubgraph() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
std::string DynamicShapePartitioner::DebugString() { | |||||
size_t unknow = 0; | |||||
size_t know = 0; | |||||
std::string DynamicShapePartitioner::DebugString() const { | |||||
size_t unknown = 0; | |||||
size_t known = 0; | |||||
size_t data = 0; | size_t data = 0; | ||||
size_t netoutput = 0; | size_t netoutput = 0; | ||||
std::stringstream ss; | std::stringstream ss; | ||||
ss << "All unknow shape nodes:" << std::endl; | |||||
ss << "All unknown shape nodes:" << std::endl; | |||||
for (auto node : unknown_shape_nodes_) { | for (auto node : unknown_shape_nodes_) { | ||||
ss << " [" << node->GetName() << "](" << node->GetType() << ")" << std::endl; | ss << " [" << node->GetName() << "](" << node->GetType() << ")" << std::endl; | ||||
} | } | ||||
for (auto cluster : unique_clusters_) { | for (auto cluster : unique_clusters_) { | ||||
if (cluster->IsUnknowShape()) { | |||||
unknow++; | |||||
} else if (cluster->IsKnowShape()) { | |||||
know++; | |||||
if (cluster->IsUnknownShape()) { | |||||
unknown++; | |||||
} else if (cluster->IsKnownShape()) { | |||||
known++; | |||||
} else if (cluster->IsData()) { | } else if (cluster->IsData()) { | ||||
data++; | data++; | ||||
} else if (cluster->IsNetOutput()) { | } else if (cluster->IsNetOutput()) { | ||||
netoutput++; | netoutput++; | ||||
} | } | ||||
} | } | ||||
ss << "All clusters:" << unique_clusters_.size() << ", data:" << data << ", know:" << know << ", unknow:" << unknow | |||||
<< ", netoutput:" << netoutput << std::endl; | |||||
ss << "All clusters:" << unique_clusters_.size() << ", data:" << data << ", known:" << known | |||||
<< ", unknown:" << unknown << ", netoutput:" << netoutput << std::endl; | |||||
for (auto cluster : unique_clusters_) { | for (auto cluster : unique_clusters_) { | ||||
ss << " " << cluster->DebugString() << std::endl; | ss << " " << cluster->DebugString() << std::endl; | ||||
} | } | ||||
return ss.str(); | return ss.str(); | ||||
} | } | ||||
void DynamicShapePartitioner::DumpGraph(std::string suffix) { | |||||
void DynamicShapePartitioner::DumpGraph(const std::string &suffix) { | |||||
GraphUtils::DumpGEGraphToOnnx(*root_graph_, root_graph_->GetName() + suffix); | GraphUtils::DumpGEGraphToOnnx(*root_graph_, root_graph_->GetName() + suffix); | ||||
for (auto sub_graph : root_graph_->GetAllSubgraphs()) { | for (auto sub_graph : root_graph_->GetAllSubgraphs()) { | ||||
GraphUtils::DumpGEGraphToOnnx(*sub_graph, sub_graph->GetName() + suffix); | GraphUtils::DumpGEGraphToOnnx(*sub_graph, sub_graph->GetName() + suffix); | ||||
@@ -169,10 +174,10 @@ void DynamicShapePartitioner::ClearResource() { | |||||
root_graph_.reset(); | root_graph_.reset(); | ||||
} | } | ||||
Status DynamicShapePartitioner::MarkUnknowShapeNodes() { | |||||
Status DynamicShapePartitioner::MarkUnknownShapeNodes() { | |||||
auto graph = root_graph_; | auto graph = root_graph_; | ||||
for (auto &node : graph->GetDirectNode()) { | for (auto &node : graph->GetDirectNode()) { | ||||
REQUIRE_SUCCESS(CollectSpreadUnknowShapeNodes(node), "Failed collect spread unknow shape nodes %s.", | |||||
REQUIRE_SUCCESS(CollectSpreadUnknownShapeNodes(node), "Failed collect spread unknown shape nodes %s.", | |||||
node->GetName().c_str()); | node->GetName().c_str()); | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -188,14 +193,14 @@ Status DynamicShapePartitioner::InitClusters() { | |||||
} else if (node->GetType() == NETOUTPUT) { | } else if (node->GetType() == NETOUTPUT) { | ||||
type = Cluster::NETOUTPUT; | type = Cluster::NETOUTPUT; | ||||
} else if (unknown_shape_nodes_.count(node) > 0) { | } else if (unknown_shape_nodes_.count(node) > 0) { | ||||
type = Cluster::UNKNOW_SHAPE; | |||||
type = Cluster::UNKNOWN_SHAPE; | |||||
} else { | } else { | ||||
type = Cluster::KNOW_SHAPE; | |||||
type = Cluster::KNOWN_SHAPE; | |||||
} | } | ||||
auto cluster = MakeShared<Cluster>(rank++, type, node, this); | auto cluster = MakeShared<Cluster>(rank++, type, node, this); | ||||
REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster."); | REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster."); | ||||
node_2_cluster_[node] = cluster; | node_2_cluster_[node] = cluster; | ||||
if (cluster->IsUnknowShape()) { | |||||
if (cluster->IsUnknownShape()) { | |||||
ordered_cluster_.push_back(cluster); | ordered_cluster_.push_back(cluster); | ||||
} | } | ||||
// Already sorted topologically, so access to the parent cluster is safe | // Already sorted topologically, so access to the parent cluster is safe | ||||
@@ -203,18 +208,15 @@ Status DynamicShapePartitioner::InitClusters() { | |||||
cluster->AddInput(node_2_cluster_[parent]); | cluster->AddInput(node_2_cluster_[parent]); | ||||
} | } | ||||
} | } | ||||
if (kDebugging) { | |||||
for (const auto node : graph->GetDirectNode()) { | |||||
DLOG() << "Make cluster for node :" << node->GetName() << ":" << node_2_cluster_[node]->DebugString() | |||||
<< std::endl; | |||||
} | |||||
for (const auto node : graph->GetDirectNode()) { | |||||
GELOGD("Make cluster for node %s : %s.", node->GetName().c_str(), node_2_cluster_[node]->DebugString().c_str()); | |||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status DynamicShapePartitioner::TopologicalSortClusters() { | Status DynamicShapePartitioner::TopologicalSortClusters() { | ||||
ordered_cluster_.clear(); | ordered_cluster_.clear(); | ||||
// BFS topological sort clusters for know shape cluster | |||||
// BFS topological sort clusters for known shape cluster | |||||
std::queue<ClusterPtr> ready_clusters; | std::queue<ClusterPtr> ready_clusters; | ||||
std::unordered_map<ClusterPtr, size_t> cluster_pending_count; | std::unordered_map<ClusterPtr, size_t> cluster_pending_count; | ||||
std::unordered_set<ClusterPtr> seen_clusters; | std::unordered_set<ClusterPtr> seen_clusters; | ||||
@@ -231,16 +233,17 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { | |||||
cluster_pending_count[cluster] = pending_count; | cluster_pending_count[cluster] = pending_count; | ||||
} | } | ||||
} | } | ||||
size_t rank = 0; | size_t rank = 0; | ||||
while (!ready_clusters.empty()) { | while (!ready_clusters.empty()) { | ||||
auto cluster = ready_clusters.front(); | auto cluster = ready_clusters.front(); | ||||
ready_clusters.pop(); | ready_clusters.pop(); | ||||
cluster->UpdateRank(rank++); | cluster->UpdateRank(rank++); | ||||
if (cluster->IsKnowShape()) { | |||||
if (cluster->IsKnownShape()) { | |||||
ordered_cluster_.push_back(cluster); | ordered_cluster_.push_back(cluster); | ||||
} | } | ||||
for (auto out_cluster : cluster->Outputs()) { | for (auto out_cluster : cluster->Outputs()) { | ||||
if (--cluster_pending_count[out_cluster] == 0) { | |||||
if (cluster_pending_count[out_cluster] > 0 && --cluster_pending_count[out_cluster] == 0) { | |||||
ready_clusters.push(out_cluster); | ready_clusters.push(out_cluster); | ||||
} | } | ||||
} | } | ||||
@@ -252,49 +255,58 @@ Status DynamicShapePartitioner::TopologicalSortClusters() { | |||||
} | } | ||||
namespace { | namespace { | ||||
template <typename T> | |||||
static std::string ToString(T vec) { | |||||
if (vec.empty()) { | |||||
static std::string ToString(const std::vector<ClusterPtr> &clusters) { | |||||
if (clusters.empty()) { | |||||
return "()"; | return "()"; | ||||
} | } | ||||
std::stringstream ss; | std::stringstream ss; | ||||
ss << "("; | ss << "("; | ||||
auto iter = vec.begin(); | |||||
for (size_t i = 0; i < vec.size() - 1; i++) { | |||||
ss << (*iter++)->Id() << ","; | |||||
auto iter = clusters.begin(); | |||||
for (size_t i = 0; i < clusters.size() - 1; i++) { | |||||
ss << (*iter)->Id() << ","; | |||||
iter++; | |||||
} | } | ||||
ss << (*iter++)->Id() << ")."; | |||||
ss << (*iter)->Id() << ")."; | |||||
return ss.str(); | return ss.str(); | ||||
} | } | ||||
} // namespace | } // namespace | ||||
Status DynamicShapePartitioner::MergeClusters() { | Status DynamicShapePartitioner::MergeClusters() { | ||||
// Merge unknow shape clusters | |||||
// Merge unknown shape clusters | |||||
for (auto cluster : ordered_cluster_) { | for (auto cluster : ordered_cluster_) { | ||||
for (auto in_cluster : cluster->Inputs()) { | for (auto in_cluster : cluster->Inputs()) { | ||||
if (in_cluster->IsUnknowShape()) { | |||||
auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | |||||
DLOG() << "Merge all path cluster from " << in_cluster->Id() << " to " << cluster->Id() | |||||
<< ToString(merged_clusters) << std::endl; | |||||
for (auto merged_cluster : merged_clusters) { | |||||
for (auto node : merged_cluster->Nodes()) { | |||||
node_2_cluster_[node] = cluster; | |||||
} | |||||
if (!in_cluster->IsUnknownShape()) { | |||||
continue; | |||||
} | |||||
auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | |||||
GELOGD("Merge all path cluster from %lu to %lu %s.", in_cluster->Id(), cluster->Id(), | |||||
ToString(merged_clusters).c_str()); | |||||
for (auto merged_cluster : merged_clusters) { | |||||
for (auto node : merged_cluster->Nodes()) { | |||||
node_2_cluster_[node] = cluster; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
} | } | ||||
REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknow shape clusters."); | |||||
// Merge know shape clusters | |||||
REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters."); | |||||
// Merge known shape clusters | |||||
for (auto cluster : ordered_cluster_) { | for (auto cluster : ordered_cluster_) { | ||||
if (cluster->IsRefVariable() && cluster->Inputs().size() == 1) { | |||||
auto in_cluster = *(cluster->Inputs().begin()); | |||||
in_cluster->Merge(cluster); | |||||
node_2_cluster_[*(cluster->Nodes().begin())] = in_cluster; | |||||
continue; | |||||
} | |||||
for (auto in_cluster : cluster->Inputs()) { | for (auto in_cluster : cluster->Inputs()) { | ||||
if (in_cluster->IsKnowShape()) { | |||||
if (cluster->TryMerge(in_cluster)) { | |||||
DLOG() << "Success merge known shape cluster " << in_cluster->Id() << " to " << cluster->Id() << "." | |||||
<< std::endl; | |||||
for (auto node : in_cluster->Nodes()) { | |||||
node_2_cluster_[node] = cluster; | |||||
} | |||||
if (!in_cluster->IsKnownShape()) { | |||||
continue; | |||||
} | |||||
if (cluster->TryMerge(in_cluster)) { | |||||
GELOGD("Success merge known shape cluster from %lu to %lu.", in_cluster->Id(), cluster->Id()); | |||||
for (auto node : in_cluster->Nodes()) { | |||||
node_2_cluster_[node] = cluster; | |||||
} | } | ||||
} | } | ||||
} | } | ||||
@@ -302,23 +314,30 @@ Status DynamicShapePartitioner::MergeClusters() { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status DynamicShapePartitioner::CollectSpreadUnknowShapeNodes(NodePtr node) { | |||||
Status DynamicShapePartitioner::CollectSpreadUnknownShapeNodes(NodePtr node) { | |||||
if (unknown_shape_nodes_.count(node) > 0) { | if (unknown_shape_nodes_.count(node) > 0) { | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
auto opdesc = node->GetOpDesc(); | auto opdesc = node->GetOpDesc(); | ||||
// One can set 'ATTR_NAME_IS_UNKNOWN_SHAPE=true' on node so as to forcing the node flow into the unknown subgraph, | |||||
// ignore the actual shape. | |||||
bool is_forced_unknown = false; | |||||
if (AttrUtils::GetBool(opdesc, ATTR_NAME_IS_UNKNOWN_SHAPE, is_forced_unknown) && is_forced_unknown) { | |||||
GELOGD("Collect node %s as unknown as it was marked unknown forcibly.", node->GetName().c_str()); | |||||
unknown_shape_nodes_.insert(node); | |||||
return SUCCESS; | |||||
} | |||||
size_t anchor_index = 0; | size_t anchor_index = 0; | ||||
bool is_unknow = false; | |||||
bool is_unknown = false; | |||||
for (auto &out_tensor : opdesc->GetAllOutputsDesc()) { | for (auto &out_tensor : opdesc->GetAllOutputsDesc()) { | ||||
if (IsUnknowShapeTensor(out_tensor)) { | |||||
DLOG() << "Collect node " << node->GetName() << " as unknown as output " << anchor_index << " is unknown" | |||||
<< std::endl; | |||||
is_unknow = true; | |||||
if (IsUnknownShapeTensor(out_tensor)) { | |||||
GELOGD("Collect node %s as unknown as output %lu is unknown.", node->GetName().c_str(), anchor_index); | |||||
is_unknown = true; | |||||
auto anchor = node->GetOutDataAnchor(anchor_index); | auto anchor = node->GetOutDataAnchor(anchor_index); | ||||
for (const auto peer_anchor : anchor->GetPeerInDataAnchors()) { | for (const auto peer_anchor : anchor->GetPeerInDataAnchors()) { | ||||
if (peer_anchor != nullptr) { | if (peer_anchor != nullptr) { | ||||
DLOG() << "Collect node " << peer_anchor->GetOwnerNode()->GetName() << " as has unknown input from " | |||||
<< node->GetName() << ":" << anchor_index << std::endl; | |||||
GELOGD("Collect node %s as has unknown input from %s:%lu.", peer_anchor->GetOwnerNode()->GetName().c_str(), | |||||
node->GetName().c_str(), anchor_index); | |||||
unknown_shape_nodes_.insert(peer_anchor->GetOwnerNode()); | unknown_shape_nodes_.insert(peer_anchor->GetOwnerNode()); | ||||
} | } | ||||
} | } | ||||
@@ -327,21 +346,20 @@ Status DynamicShapePartitioner::CollectSpreadUnknowShapeNodes(NodePtr node) { | |||||
} | } | ||||
anchor_index = 0; | anchor_index = 0; | ||||
for (auto &in_tensor : opdesc->GetAllInputsDesc()) { | for (auto &in_tensor : opdesc->GetAllInputsDesc()) { | ||||
if (IsUnknowShapeTensor(in_tensor)) { | |||||
DLOG() << "Collect node " << node->GetName() << " as unknown as input " << anchor_index << " is unknown" | |||||
<< std::endl; | |||||
is_unknow = true; | |||||
if (IsUnknownShapeTensor(in_tensor)) { | |||||
GELOGD("Collect node %s as unknown as input %lu is unknown.", node->GetName().c_str(), anchor_index); | |||||
is_unknown = true; | |||||
auto anchor = node->GetInDataAnchor(anchor_index); | auto anchor = node->GetInDataAnchor(anchor_index); | ||||
const auto peer_anchor = anchor->GetPeerOutAnchor(); | const auto peer_anchor = anchor->GetPeerOutAnchor(); | ||||
if (peer_anchor != nullptr) { | if (peer_anchor != nullptr) { | ||||
DLOG() << "Collect node " << peer_anchor->GetOwnerNode()->GetName() << " as has unknown output to " | |||||
<< node->GetName() << ":" << anchor_index << std::endl; | |||||
GELOGD("Collect node %s as has unknown output to %s:%lu.", peer_anchor->GetOwnerNode()->GetName().c_str(), | |||||
node->GetName().c_str(), anchor_index); | |||||
unknown_shape_nodes_.insert(peer_anchor->GetOwnerNode()); | unknown_shape_nodes_.insert(peer_anchor->GetOwnerNode()); | ||||
} | } | ||||
} | } | ||||
anchor_index++; | anchor_index++; | ||||
} | } | ||||
if (is_unknow) { | |||||
if (is_unknown) { | |||||
unknown_shape_nodes_.insert(node); | unknown_shape_nodes_.insert(node); | ||||
} else { | } else { | ||||
auto graph = root_graph_; | auto graph = root_graph_; | ||||
@@ -350,11 +368,10 @@ Status DynamicShapePartitioner::CollectSpreadUnknowShapeNodes(NodePtr node) { | |||||
REQUIRE_NOT_NULL(subgraph, "Failed get subgraph %s of node %s on root graph.", subgraph_name.c_str(), | REQUIRE_NOT_NULL(subgraph, "Failed get subgraph %s of node %s on root graph.", subgraph_name.c_str(), | ||||
node->GetName().c_str()); | node->GetName().c_str()); | ||||
bool is_graph_unknow = false; | bool is_graph_unknow = false; | ||||
REQUIRE_SUCCESS(IsUnknowShapeGraph(subgraph, is_graph_unknow), "Failed check subgraph %s shape of node %s.", | |||||
REQUIRE_SUCCESS(IsUnknownShapeGraph(subgraph, is_graph_unknow), "Failed check subgraph %s shape of node %s.", | |||||
subgraph_name.c_str(), node->GetName().c_str()); | subgraph_name.c_str(), node->GetName().c_str()); | ||||
if (is_graph_unknow) { | if (is_graph_unknow) { | ||||
DLOG() << "Collect node " << node->GetName() << " as its subgraph " << subgraph->GetName() << " is unknown." | |||||
<< std::endl; | |||||
GELOGD("Collect node %s as its subgraph %s is unknown.", node->GetName().c_str(), subgraph->GetName().c_str()); | |||||
unknown_shape_nodes_.insert(node); | unknown_shape_nodes_.insert(node); | ||||
break; | break; | ||||
} | } | ||||
@@ -363,20 +380,20 @@ Status DynamicShapePartitioner::CollectSpreadUnknowShapeNodes(NodePtr node) { | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status DynamicShapePartitioner::IsUnknowShapeNode(NodePtr node, bool &is_unknow) { | |||||
Status DynamicShapePartitioner::IsUnknownShapeNode(NodePtr node, bool &is_unknown) { | |||||
auto opdesc = node->GetOpDesc(); | auto opdesc = node->GetOpDesc(); | ||||
auto graph = root_graph_; | auto graph = root_graph_; | ||||
for (auto &out_tensor : opdesc->GetAllOutputsDesc()) { | for (auto &out_tensor : opdesc->GetAllOutputsDesc()) { | ||||
if (IsUnknowShapeTensor(out_tensor)) { | |||||
DLOG() << "Mark node " << node->GetName() << " unknown because unknown output " << std::endl; | |||||
is_unknow = true; | |||||
if (IsUnknownShapeTensor(out_tensor)) { | |||||
GELOGD("Mark node %s unknown as unknown output.", node->GetName().c_str()); | |||||
is_unknown = true; | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} | } | ||||
for (auto &in_tensor : opdesc->GetAllInputsDesc()) { | for (auto &in_tensor : opdesc->GetAllInputsDesc()) { | ||||
if (IsUnknowShapeTensor(in_tensor)) { | |||||
DLOG() << "Mark node " << node->GetName() << " unknown because unknown intput " << std::endl; | |||||
is_unknow = true; | |||||
if (IsUnknownShapeTensor(in_tensor)) { | |||||
GELOGD("Mark node %s unknown as unknown intput.", node->GetName().c_str()); | |||||
is_unknown = true; | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} | } | ||||
@@ -384,30 +401,30 @@ Status DynamicShapePartitioner::IsUnknowShapeNode(NodePtr node, bool &is_unknow) | |||||
auto subgraph = graph->GetSubgraph(subgraph_name); | auto subgraph = graph->GetSubgraph(subgraph_name); | ||||
REQUIRE_NOT_NULL(subgraph, "Failed get subgraph %s of node %s on root graph.", subgraph_name.c_str(), | REQUIRE_NOT_NULL(subgraph, "Failed get subgraph %s of node %s on root graph.", subgraph_name.c_str(), | ||||
node->GetName().c_str()); | node->GetName().c_str()); | ||||
REQUIRE_SUCCESS(IsUnknowShapeGraph(subgraph, is_unknow), "Failed check subgraph %s shape of node %s.", | |||||
REQUIRE_SUCCESS(IsUnknownShapeGraph(subgraph, is_unknown), "Failed check subgraph %s shape of node %s.", | |||||
subgraph_name.c_str(), node->GetName().c_str()); | subgraph_name.c_str(), node->GetName().c_str()); | ||||
if (is_unknow) { | |||||
DLOG() << "Mark node " << node->GetName() << " unknown because unknown subgraph " << std::endl; | |||||
if (is_unknown) { | |||||
GELOGD("Mark node %s unknown as unknown subgraph.", node->GetName().c_str()); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} | } | ||||
is_unknow = false; | |||||
is_unknown = false; | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status DynamicShapePartitioner::IsUnknowShapeGraph(ComputeGraphPtr graph, bool &is_unknow) { | |||||
Status DynamicShapePartitioner::IsUnknownShapeGraph(ComputeGraphPtr graph, bool &is_unknown) { | |||||
for (auto &node : graph->GetDirectNode()) { | for (auto &node : graph->GetDirectNode()) { | ||||
REQUIRE_SUCCESS(IsUnknowShapeNode(node, is_unknow), "Failed check node %s shape on graph %s.", | |||||
REQUIRE_SUCCESS(IsUnknownShapeNode(node, is_unknown), "Failed check node %s shape on graph %s.", | |||||
node->GetName().c_str(), graph->GetName().c_str()); | node->GetName().c_str(), graph->GetName().c_str()); | ||||
if (is_unknow) { | |||||
DLOG() << "Mark graph " << graph->GetName() << " unknown because unknown node " << node->GetName() << std::endl; | |||||
if (is_unknown) { | |||||
GELOGD("Mark graph %s unknown as contains unknown node %s.", graph->GetName().c_str(), node->GetName().c_str()); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
} | } | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
bool DynamicShapePartitioner::IsUnknowShapeTensor(GeTensorDesc &tensor) { | |||||
bool DynamicShapePartitioner::IsUnknownShapeTensor(const GeTensorDesc &tensor) { | |||||
const static int kUnknowShape = -1; | const static int kUnknowShape = -1; | ||||
const static int kUnknowRank = -2; | const static int kUnknowRank = -2; | ||||
for (auto dim_size : tensor.GetShape().GetDims()) { | for (auto dim_size : tensor.GetShape().GetDims()) { | ||||
@@ -418,7 +435,7 @@ bool DynamicShapePartitioner::IsUnknowShapeTensor(GeTensorDesc &tensor) { | |||||
return false; | return false; | ||||
} | } | ||||
std::string Cluster::DebugString() { | |||||
std::string Cluster::DebugString() const { | |||||
std::stringstream ss; | std::stringstream ss; | ||||
switch (type_) { | switch (type_) { | ||||
case DATA: | case DATA: | ||||
@@ -427,10 +444,10 @@ std::string Cluster::DebugString() { | |||||
case NETOUTPUT: | case NETOUTPUT: | ||||
ss << "NETOUTPUT"; | ss << "NETOUTPUT"; | ||||
break; | break; | ||||
case UNKNOW_SHAPE: | |||||
case UNKNOWN_SHAPE: | |||||
ss << "UNKNOW"; | ss << "UNKNOW"; | ||||
break; | break; | ||||
case KNOW_SHAPE: | |||||
case KNOWN_SHAPE: | |||||
ss << "KNOW"; | ss << "KNOW"; | ||||
break; | break; | ||||
} | } | ||||
@@ -450,18 +467,22 @@ std::string Cluster::DebugString() { | |||||
return ss.str(); | return ss.str(); | ||||
} | } | ||||
size_t Cluster::Id() { return id_; } | |||||
size_t Cluster::Id() const { return id_; } | |||||
void Cluster::UpdateRank(size_t rank) { | void Cluster::UpdateRank(size_t rank) { | ||||
max_ = rank; | max_ = rank; | ||||
min_ = rank; | min_ = rank; | ||||
}; | }; | ||||
bool Cluster::IsData() { return type_ == DATA; }; | |||||
bool Cluster::IsKnowShape() { return type_ == KNOW_SHAPE; }; | |||||
bool Cluster::IsUnknowShape() { return type_ == UNKNOW_SHAPE; }; | |||||
bool Cluster::IsNetOutput() { return type_ == NETOUTPUT; }; | |||||
bool Cluster::IsolatedConstant() { | |||||
return ((nodes_.size() == 1) && (nodes_[0]->GetType() == CONSTANTOP) && (out_clusters_.size() == 1) && | |||||
(*out_clusters_.begin())->IsUnknowShape() && in_clusters_.empty()); | |||||
bool Cluster::IsData() const { return type_ == DATA; }; | |||||
bool Cluster::IsKnownShape() const { return type_ == KNOWN_SHAPE; }; | |||||
bool Cluster::IsUnknownShape() const { return type_ == UNKNOWN_SHAPE; }; | |||||
bool Cluster::IsNetOutput() const { return type_ == NETOUTPUT; }; | |||||
bool Cluster::IsRefVariable() const { | |||||
if ((nodes_.size() == 1) && ((nodes_[0]->GetType() == VARIABLE) || (nodes_[0]->GetType() == VARIABLEV2))) { | |||||
std::string ref_variable_name; | |||||
return (AttrUtils::GetStr(nodes_[0]->GetOpDesc(), REF_VAR_SRC_VAR_NAME, ref_variable_name) && | |||||
!ref_variable_name.empty()); | |||||
} | |||||
return false; | |||||
} | } | ||||
void Cluster::AddInput(ClusterPtr in) { | void Cluster::AddInput(ClusterPtr in) { | ||||
in_clusters_.insert(in); | in_clusters_.insert(in); | ||||
@@ -562,9 +583,9 @@ std::vector<ClusterPtr> Cluster::MergeAllPathFrom(ClusterPtr other) { | |||||
} | } | ||||
return path_clusters; | return path_clusters; | ||||
} | } | ||||
std::unordered_set<ClusterPtr> Cluster::Inputs() { return in_clusters_; }; | |||||
std::unordered_set<ClusterPtr> Cluster::Outputs() { return out_clusters_; }; | |||||
std::vector<NodePtr> Cluster::Nodes() { return nodes_; }; | |||||
std::unordered_set<ClusterPtr> Cluster::Inputs() const { return in_clusters_; }; | |||||
std::unordered_set<ClusterPtr> Cluster::Outputs() const { return out_clusters_; }; | |||||
std::vector<NodePtr> Cluster::Nodes() const { return nodes_; }; | |||||
void Cluster::AddFrameInput(InDataAnchorPtr anchor) { | void Cluster::AddFrameInput(InDataAnchorPtr anchor) { | ||||
inputs_index_[anchor] = inputs_.size(); | inputs_index_[anchor] = inputs_.size(); | ||||
@@ -589,7 +610,7 @@ InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_-> | |||||
OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); }; | OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); }; | ||||
Status Cluster::BuildFrame() { | Status Cluster::BuildFrame() { | ||||
if (IsUnknowShape() || IsKnowShape()) { | |||||
if (IsUnknownShape() || IsKnownShape()) { | |||||
return BuildPartitionFrame(); | return BuildPartitionFrame(); | ||||
} else { | } else { | ||||
auto node = nodes_.front(); | auto node = nodes_.front(); | ||||
@@ -621,7 +642,7 @@ Status Cluster::BuildFrame() { | |||||
Status Cluster::BuildPartitionFrame() { | Status Cluster::BuildPartitionFrame() { | ||||
auto graph = partitioner_->root_graph_; | auto graph = partitioner_->root_graph_; | ||||
bool is_unknown_shape = IsUnknowShape(); | |||||
bool is_unknown_shape = IsUnknownShape(); | |||||
std::string sub_graph_name = | std::string sub_graph_name = | ||||
graph->GetName() + "_sub_" + std::to_string(unique_id_) + (is_unknown_shape ? "_unknow" : "_know"); | graph->GetName() + "_sub_" + std::to_string(unique_id_) + (is_unknown_shape ? "_unknow" : "_know"); | ||||
subgraph_ = MakeShared<ComputeGraph>(sub_graph_name); | subgraph_ = MakeShared<ComputeGraph>(sub_graph_name); | ||||
@@ -727,6 +748,7 @@ Status Cluster::BuildPartitionSubgraph() { | |||||
auto data_op = MakeShared<OpDesc>(std::string("Data_") + std::to_string(parent_node_index), ge::DATA); | auto data_op = MakeShared<OpDesc>(std::string("Data_") + std::to_string(parent_node_index), ge::DATA); | ||||
REQUIRE_NOT_NULL(data_op, "Failed new memory for data op."); | REQUIRE_NOT_NULL(data_op, "Failed new memory for data op."); | ||||
auto input_desc = anchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(anchor->GetIdx()); | auto input_desc = anchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(anchor->GetIdx()); | ||||
REQUIRE_GRAPH_SUCCESS(data_op->AddInputDesc(input_desc), "Failed add input desc."); | |||||
REQUIRE_GRAPH_SUCCESS(data_op->AddOutputDesc(input_desc), "Failed add output desc."); | REQUIRE_GRAPH_SUCCESS(data_op->AddOutputDesc(input_desc), "Failed add output desc."); | ||||
REQUIRE(AttrUtils::SetInt(data_op, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index), | REQUIRE(AttrUtils::SetInt(data_op, ATTR_NAME_PARENT_NODE_INDEX, parent_node_index), | ||||
"Failed set parent_node_index on subgraph data node."); | "Failed set parent_node_index on subgraph data node."); | ||||
@@ -29,27 +29,27 @@ class DynamicShapePartitioner { | |||||
public: | public: | ||||
// An cluster means set of nodes that can be merged in same partition, | // An cluster means set of nodes that can be merged in same partition, | ||||
// Corresponding relationship between cluster type and node: | // Corresponding relationship between cluster type and node: | ||||
// DATA:DATA, UNKNOW_SHAPE:unknowshape, KNOW_SHAPE:knowshape, NETOUTPUT:NETOUTPUT. | |||||
// DATA:DATA, UNKNOWN_SHAPE:unknowshape, KNOWN_SHAPE:knowshape, NETOUTPUT:NETOUTPUT. | |||||
class Cluster : public std::enable_shared_from_this<Cluster> { | class Cluster : public std::enable_shared_from_this<Cluster> { | ||||
public: | public: | ||||
enum Type { DATA, NETOUTPUT, KNOW_SHAPE, UNKNOW_SHAPE }; | |||||
explicit Cluster(size_t rank, Type type, NodePtr node, DynamicShapePartitioner *partitioner) | |||||
enum Type { DATA, NETOUTPUT, KNOWN_SHAPE, UNKNOWN_SHAPE }; | |||||
Cluster(size_t rank, Type type, NodePtr node, DynamicShapePartitioner *partitioner) | |||||
: id_(rank), min_(rank), max_(rank), type_(type), partitioner_(partitioner) { | : id_(rank), min_(rank), max_(rank), type_(type), partitioner_(partitioner) { | ||||
nodes_.push_back(node); | nodes_.push_back(node); | ||||
} | } | ||||
~Cluster() = default; | ~Cluster() = default; | ||||
std::string DebugString(); | |||||
std::string DebugString() const; | |||||
// Basic bean functions | // Basic bean functions | ||||
size_t Id(); | |||||
size_t Id() const; | |||||
void UpdateRank(size_t rank); | void UpdateRank(size_t rank); | ||||
bool IsData(); | |||||
bool IsKnowShape(); | |||||
bool IsUnknowShape(); | |||||
bool IsNetOutput(); | |||||
std::unordered_set<std::shared_ptr<Cluster>> Inputs(); | |||||
std::unordered_set<std::shared_ptr<Cluster>> Outputs(); | |||||
std::vector<NodePtr> Nodes(); | |||||
bool IsolatedConstant(); | |||||
bool IsData() const; | |||||
bool IsKnownShape() const; | |||||
bool IsUnknownShape() const; | |||||
bool IsNetOutput() const; | |||||
std::unordered_set<std::shared_ptr<Cluster>> Inputs() const; | |||||
std::unordered_set<std::shared_ptr<Cluster>> Outputs() const; | |||||
std::vector<NodePtr> Nodes() const; | |||||
bool IsRefVariable() const; | |||||
// Cluster modify functions | // Cluster modify functions | ||||
void AddInput(std::shared_ptr<Cluster> in); | void AddInput(std::shared_ptr<Cluster> in); | ||||
void RemoveInput(std::shared_ptr<Cluster> in); | void RemoveInput(std::shared_ptr<Cluster> in); | ||||
@@ -110,16 +110,16 @@ class DynamicShapePartitioner { | |||||
// Collect nodes that satisfy the unknowshape rules: | // Collect nodes that satisfy the unknowshape rules: | ||||
// 1) The Tensor shape of any input or output is unknow shape(dim_size = -1) or unknow rank(dim_size=-2) | // 1) The Tensor shape of any input or output is unknow shape(dim_size = -1) or unknow rank(dim_size=-2) | ||||
// 2) Subgraphs of the node has an operator that satisfies rule 1) | // 2) Subgraphs of the node has an operator that satisfies rule 1) | ||||
Status MarkUnknowShapeNodes(); | |||||
Status MarkUnknownShapeNodes(); | |||||
// For each node a Cluster structure, and connected according to the connection relationship of the nodes | // For each node a Cluster structure, and connected according to the connection relationship of the nodes | ||||
// An cluster means set of nodes that can be merged in same partition, | // An cluster means set of nodes that can be merged in same partition, | ||||
// Corresponding relationship between cluster type and node: | // Corresponding relationship between cluster type and node: | ||||
// DATA:DATA, UNKNOW_SHAPE:unknowshape, KNOW_SHAPE:knowshape, NETOUTPUT:NETOUTPUT | |||||
// DATA:DATA, UNKNOWN_SHAPE:unknowshape, KNOWN_SHAPE:knowshape, NETOUTPUT:NETOUTPUT | |||||
Status InitClusters(); | Status InitClusters(); | ||||
// Merge clusters according to the following rules: | // Merge clusters according to the following rules: | ||||
// 1) Iterate through the UNKNOW_SHAPE clusters, if the input is UNKNOW_SHAPE, | |||||
// 1) Iterate through the UNKNOWN_SHAPE clusters, if the input is UNKNOWN_SHAPE, | |||||
// merge all the clusters in the path(s) between the two clusters | // merge all the clusters in the path(s) between the two clusters | ||||
// 2) Iterate through the KNOW_SHAPE clusters, if the input is KNOW_SHAPE, and | |||||
// 2) Iterate through the KNOWN_SHAPE clusters, if the input is KNOWN_SHAPE, and | |||||
// and there's only one path between the two clusters , merge the two clusters | // and there's only one path between the two clusters , merge the two clusters | ||||
Status MergeClusters(); | Status MergeClusters(); | ||||
// Topological sort clusters after merge unknow shape clusters. | // Topological sort clusters after merge unknow shape clusters. | ||||
@@ -135,18 +135,18 @@ class DynamicShapePartitioner { | |||||
// Clear resource and break circular dependency | // Clear resource and break circular dependency | ||||
void ClearResource(); | void ClearResource(); | ||||
// Debug functions | // Debug functions | ||||
void DumpGraph(std::string suffix); | |||||
std::string DebugString(); | |||||
void DumpGraph(const std::string &suffix); | |||||
std::string DebugString() const; | |||||
// Util functions | // Util functions | ||||
Status CollectSpreadUnknowShapeNodes(NodePtr node); | |||||
Status IsUnknowShapeGraph(ge::ComputeGraphPtr graph, bool &is_unknow); | |||||
Status IsUnknowShapeNode(ge::NodePtr node, bool &is_unknow); | |||||
bool IsUnknowShapeTensor(ge::GeTensorDesc &tensor); | |||||
Status CollectSpreadUnknownShapeNodes(NodePtr node); | |||||
Status IsUnknownShapeGraph(ge::ComputeGraphPtr graph, bool &is_unknow); | |||||
Status IsUnknownShapeNode(ge::NodePtr node, bool &is_unknow); | |||||
bool IsUnknownShapeTensor(const ge::GeTensorDesc &tensor); | |||||
ge::ComputeGraphPtr root_graph_; // The original graph to partition | ge::ComputeGraphPtr root_graph_; // The original graph to partition | ||||
std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | ||||
// topological sorted clusters, this field will change with the splitting. | // topological sorted clusters, this field will change with the splitting. | ||||
// When partitioning UNKNOW_SHAPE cluster, it is a collection of all topological sorted UNKNOW_SHAPE clusters | |||||
// When partitioning KNOW_SHAPE cluster, it is a collection of all topological sorted KNOW_SHAPE clusters | |||||
// When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters | |||||
// When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters | |||||
std::vector<std::shared_ptr<Cluster>> ordered_cluster_; | std::vector<std::shared_ptr<Cluster>> ordered_cluster_; | ||||
// Unique clusters left after merged clusters | // Unique clusters left after merged clusters | ||||
std::unordered_set<std::shared_ptr<Cluster>> unique_clusters_; | std::unordered_set<std::shared_ptr<Cluster>> unique_clusters_; | ||||