From: @shenwei41 Reviewed-by: @lilongfei15,@ljl0711 Signed-off-by: @ljl0711tags/v1.2.0
| @@ -189,7 +189,6 @@ set(TRAIN_SRC_LIST | |||||
| "graph/passes/atomic_addr_clean_pass.cc" | "graph/passes/atomic_addr_clean_pass.cc" | ||||
| "graph/passes/mark_same_addr_pass.cc" | "graph/passes/mark_same_addr_pass.cc" | ||||
| "graph/passes/mark_graph_unknown_status_pass.cc" | "graph/passes/mark_graph_unknown_status_pass.cc" | ||||
| "graph/passes/dynamic_single_op_reset_shape_pass.cc" | |||||
| "graph/passes/mark_agnostic_pass.cc" | "graph/passes/mark_agnostic_pass.cc" | ||||
| "graph/partition/dynamic_shape_partition.cc" | "graph/partition/dynamic_shape_partition.cc" | ||||
| "graph/partition/stage_partition.cc" | "graph/partition/stage_partition.cc" | ||||
| @@ -351,6 +350,7 @@ set(TRAIN_SRC_LIST | |||||
| "hybrid/executor/node_done_manager.cc" | "hybrid/executor/node_done_manager.cc" | ||||
| "hybrid/executor/hybrid_profiler.cc" | "hybrid/executor/hybrid_profiler.cc" | ||||
| "hybrid/executor/hybrid_model_executor.cc" | "hybrid/executor/hybrid_model_executor.cc" | ||||
| "hybrid/executor/hybrid_model_pipeline_executor.cc" | |||||
| "hybrid/executor/hybrid_model_async_executor.cc" | "hybrid/executor/hybrid_model_async_executor.cc" | ||||
| "hybrid/executor/hybrid_execution_context.cc" | "hybrid/executor/hybrid_execution_context.cc" | ||||
| "hybrid/executor/subgraph_context.cc" | "hybrid/executor/subgraph_context.cc" | ||||
| @@ -388,6 +388,9 @@ set(TRAIN_SRC_LIST | |||||
| "client/ge_api.cc" | "client/ge_api.cc" | ||||
| "analyzer/analyzer.cc" | "analyzer/analyzer.cc" | ||||
| "ir_build/ge_ir_build.cc" | "ir_build/ge_ir_build.cc" | ||||
| "ir_build/attr_options/utils.cc" | |||||
| "ir_build/attr_options/keep_dtype_option.cc" | |||||
| "ir_build/attr_options/weight_compress_option.cc" | |||||
| "ir_build/atc_ir_common.cc" | "ir_build/atc_ir_common.cc" | ||||
| "graph/build/memory/memory_assigner.cc" | "graph/build/memory/memory_assigner.cc" | ||||
| "graph/build/memory/graph_mem_assigner.cc" | "graph/build/memory/graph_mem_assigner.cc" | ||||
| @@ -495,7 +498,6 @@ set(INFER_SRC_LIST | |||||
| "graph/passes/atomic_addr_clean_pass.cc" | "graph/passes/atomic_addr_clean_pass.cc" | ||||
| "graph/passes/mark_same_addr_pass.cc" | "graph/passes/mark_same_addr_pass.cc" | ||||
| "graph/passes/mark_graph_unknown_status_pass.cc" | "graph/passes/mark_graph_unknown_status_pass.cc" | ||||
| "graph/passes/dynamic_single_op_reset_shape_pass.cc" | |||||
| "graph/passes/mark_agnostic_pass.cc" | "graph/passes/mark_agnostic_pass.cc" | ||||
| "graph/common/omg_util.cc" | "graph/common/omg_util.cc" | ||||
| "graph/common/bcast.cc" | "graph/common/bcast.cc" | ||||
| @@ -641,6 +643,9 @@ set(INFER_SRC_LIST | |||||
| "graph/load/model_manager/task_info/super_kernel/super_kernel.cc" | "graph/load/model_manager/task_info/super_kernel/super_kernel.cc" | ||||
| "hybrid/hybrid_davinci_model_stub.cc" | "hybrid/hybrid_davinci_model_stub.cc" | ||||
| "ir_build/ge_ir_build.cc" | "ir_build/ge_ir_build.cc" | ||||
| "ir_build/attr_options/utils.cc" | |||||
| "ir_build/attr_options/keep_dtype_option.cc" | |||||
| "ir_build/attr_options/weight_compress_option.cc" | |||||
| "ir_build/atc_ir_common.cc" | "ir_build/atc_ir_common.cc" | ||||
| "graph/preprocess/insert_op/ge_aipp_op.cc" | "graph/preprocess/insert_op/ge_aipp_op.cc" | ||||
| "graph/preprocess/insert_op/util_insert_aipp_op.cc" | "graph/preprocess/insert_op/util_insert_aipp_op.cc" | ||||
| @@ -512,8 +512,8 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadModel(c | |||||
| FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadRootModel(const ge::ModelData &model_data) { | FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelHelper::LoadRootModel(const ge::ModelData &model_data) { | ||||
| if (model_data.model_data == nullptr || model_data.model_len == 0) { | if (model_data.model_data == nullptr || model_data.model_len == 0) { | ||||
| GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "Model_data is nullptr, or model_data_size is 0"); | |||||
| return GE_EXEC_MODEL_DATA_SIZE_INVALID; | |||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "Model_data is nullptr, or model_data_size is 0"); | |||||
| return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID; | |||||
| } | } | ||||
| if (is_assign_model_) { | if (is_assign_model_) { | ||||
| @@ -207,9 +207,9 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, uint32_t m | |||||
| "ModelFileHeader length :%zu, ModelPartitionTable length :%zu", | "ModelFileHeader length :%zu, ModelPartitionTable length :%zu", | ||||
| index, partition_table->num, sizeof(ModelFileHeader), partition_table_size); | index, partition_table->num, sizeof(ModelFileHeader), partition_table_size); | ||||
| if (model_data_size <= cur_offset) { | if (model_data_size <= cur_offset) { | ||||
| GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "invalid model data, partition_table->num:%u, model data size %u", | |||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, "invalid model data, partition_table->num:%u, model data size %u", | |||||
| partition_table->num, model_data_size); | partition_table->num, model_data_size); | ||||
| return GE_EXEC_MODEL_DATA_SIZE_INVALID; | |||||
| return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID; | |||||
| } | } | ||||
| for (uint32_t i = 0; i < partition_table->num; i++) { | for (uint32_t i = 0; i < partition_table->num; i++) { | ||||
| @@ -231,9 +231,10 @@ Status OmFileLoadHelper::LoadModelPartitionTable(uint8_t *model_data, uint32_t m | |||||
| } | } | ||||
| if (partition.size > model_data_size || cur_offset > model_data_size - partition.size) { | if (partition.size > model_data_size || cur_offset > model_data_size - partition.size) { | ||||
| GELOGE(GE_EXEC_MODEL_DATA_SIZE_INVALID, "The partition size %u is greater than the model data size %u.", | |||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, | |||||
| "The partition size %u is greater than the model data size %u.", | |||||
| partition.size + cur_offset, model_data_size); | partition.size + cur_offset, model_data_size); | ||||
| return GE_EXEC_MODEL_DATA_SIZE_INVALID; | |||||
| return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID; | |||||
| } | } | ||||
| cur_offset += partition.size; | cur_offset += partition.size; | ||||
| GELOGD("Partition, type:%d, size:%u, model_index:%u", static_cast<int>(partition.type), partition.size, index); | GELOGD("Partition, type:%d, size:%u, model_index:%u", static_cast<int>(partition.type), partition.size, index); | ||||
| @@ -81,6 +81,7 @@ set(SRC_LIST | |||||
| "../hybrid/executor/node_done_manager.cc" | "../hybrid/executor/node_done_manager.cc" | ||||
| "../hybrid/executor/hybrid_profiler.cc" | "../hybrid/executor/hybrid_profiler.cc" | ||||
| "../hybrid/executor/hybrid_model_executor.cc" | "../hybrid/executor/hybrid_model_executor.cc" | ||||
| "../hybrid/executor/hybrid_model_pipeline_executor.cc" | |||||
| "../hybrid/executor/hybrid_model_async_executor.cc" | "../hybrid/executor/hybrid_model_async_executor.cc" | ||||
| "../hybrid/executor/hybrid_execution_context.cc" | "../hybrid/executor/hybrid_execution_context.cc" | ||||
| "../hybrid/executor/subgraph_context.cc" | "../hybrid/executor/subgraph_context.cc" | ||||
| @@ -175,14 +175,14 @@ bool IsDynamicImageSizeMatchModel(uint64_t image_height, uint64_t image_width, | |||||
| bool IsDynmaicDimsSizeMatchModel(const vector<uint64_t> cur_dynamic_dims, | bool IsDynmaicDimsSizeMatchModel(const vector<uint64_t> cur_dynamic_dims, | ||||
| const vector<vector<int64_t>> &batch_info) { | const vector<vector<int64_t>> &batch_info) { | ||||
| if (batch_info.empty()) { | if (batch_info.empty()) { | ||||
| GELOGE(ge::FAILED, "Dynamic batch info is empty."); | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Dynamic batch info is empty."); | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool find_match = false; | bool find_match = false; | ||||
| for (auto resolution : batch_info) { | for (auto resolution : batch_info) { | ||||
| if (cur_dynamic_dims.size() != resolution.size()) { | if (cur_dynamic_dims.size() != resolution.size()) { | ||||
| GELOGE(ge::FAILED, "Cur dynamic dims param num is %zu, current resolution size is %zu.", | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Cur dynamic dims param num is %zu, current resolution size is %zu.", | |||||
| cur_dynamic_dims.size(), resolution.size()); | cur_dynamic_dims.size(), resolution.size()); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -199,7 +199,7 @@ bool IsDynmaicDimsSizeMatchModel(const vector<uint64_t> cur_dynamic_dims, | |||||
| } | } | ||||
| } | } | ||||
| if (!find_match) { | if (!find_match) { | ||||
| GELOGE(ge::FAILED, "choose dynamic dims can not match the gear of model."); | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "choose dynamic dims can not match the gear of model."); | |||||
| } | } | ||||
| return find_match; | return find_match; | ||||
| } | } | ||||
| @@ -70,6 +70,9 @@ GRAPH_MANAGER_LOCAL_SRC_FILES := \ | |||||
| BUILER_SRC_FILES := \ | BUILER_SRC_FILES := \ | ||||
| ir_build/ge_ir_build.cc \ | ir_build/ge_ir_build.cc \ | ||||
| ir_build/attr_options/utils.cc \ | |||||
| ir_build/attr_options/keep_dtype_option.cc \ | |||||
| ir_build/attr_options/weight_compress_option.cc \ | |||||
| ir_build/atc_ir_common.cc \ | ir_build/atc_ir_common.cc \ | ||||
| ANALYZER_SRC_FILES:= \ | ANALYZER_SRC_FILES:= \ | ||||
| @@ -111,7 +114,6 @@ OMG_HOST_SRC_FILES := \ | |||||
| graph/passes/atomic_addr_clean_pass.cc \ | graph/passes/atomic_addr_clean_pass.cc \ | ||||
| graph/passes/mark_same_addr_pass.cc \ | graph/passes/mark_same_addr_pass.cc \ | ||||
| graph/passes/mark_graph_unknown_status_pass.cc \ | graph/passes/mark_graph_unknown_status_pass.cc \ | ||||
| graph/passes/dynamic_single_op_reset_shape_pass.cc \ | |||||
| graph/passes/mark_agnostic_pass.cc \ | graph/passes/mark_agnostic_pass.cc \ | ||||
| graph/common/omg_util.cc \ | graph/common/omg_util.cc \ | ||||
| graph/common/bcast.cc \ | graph/common/bcast.cc \ | ||||
| @@ -114,7 +114,6 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| graph/passes/atomic_addr_clean_pass.cc \ | graph/passes/atomic_addr_clean_pass.cc \ | ||||
| graph/passes/mark_same_addr_pass.cc \ | graph/passes/mark_same_addr_pass.cc \ | ||||
| graph/passes/mark_graph_unknown_status_pass.cc \ | graph/passes/mark_graph_unknown_status_pass.cc \ | ||||
| graph/passes/dynamic_single_op_reset_shape_pass.cc \ | |||||
| graph/passes/mark_agnostic_pass.cc \ | graph/passes/mark_agnostic_pass.cc \ | ||||
| graph/partition/dynamic_shape_partition.cc \ | graph/partition/dynamic_shape_partition.cc \ | ||||
| graph/partition/stage_partition.cc \ | graph/partition/stage_partition.cc \ | ||||
| @@ -312,6 +311,9 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| executor/ge_executor.cc \ | executor/ge_executor.cc \ | ||||
| analyzer/analyzer.cc \ | analyzer/analyzer.cc \ | ||||
| ir_build/ge_ir_build.cc \ | ir_build/ge_ir_build.cc \ | ||||
| ir_build/attr_options/utils.cc \ | |||||
| ir_build/attr_options/keep_dtype_option.cc \ | |||||
| ir_build/attr_options/weight_compress_option.cc \ | |||||
| ir_build/atc_ir_common.cc \ | ir_build/atc_ir_common.cc \ | ||||
| LIBCLIENT_LOCAL_SRC_FILES := \ | LIBCLIENT_LOCAL_SRC_FILES := \ | ||||
| @@ -48,7 +48,7 @@ const char *const kVectorEngine = "VectorEngine"; | |||||
| const char *const kAIcoreEngine = "AIcoreEngine"; | const char *const kAIcoreEngine = "AIcoreEngine"; | ||||
| const char *const kFileNameSuffix = "online"; | const char *const kFileNameSuffix = "online"; | ||||
| const char *const kAicpuAllshape = "_AllShape"; | const char *const kAicpuAllshape = "_AllShape"; | ||||
| const size_t kDynamicDimSize = 1; | |||||
| constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; | |||||
| const int64_t kDynamicDimValue = -2; | const int64_t kDynamicDimValue = -2; | ||||
| std::map<ge::OpEngineType, std::string> engine_type_map{ | std::map<ge::OpEngineType, std::string> engine_type_map{ | ||||
| @@ -251,30 +251,6 @@ static void GetOpsProtoPath(string &opsproto_path) { | |||||
| opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); | opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); | ||||
| } | } | ||||
| static Status CheckShapeReset(const OpDescPtr &op_desc, bool &change_shape_flag) { | |||||
| GE_CHECK_NOTNULL_EXEC(op_desc, return PARAM_INVALID); | |||||
| change_shape_flag = false; | |||||
| for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) { | |||||
| auto input_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(i)); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| // pass scalar input desc | |||||
| auto dims = input_desc->GetShape().GetDims(); | |||||
| if (dims.size() == kDynamicDimSize && dims[0] == kDynamicDimValue) { | |||||
| change_shape_flag = true; | |||||
| } | |||||
| } | |||||
| for (size_t i = 0; i < op_desc->GetAllOutputsDesc().size(); i++) { | |||||
| auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(i)); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| // pass scalar output desc | |||||
| auto dims = output_desc->GetShape().GetDims(); | |||||
| if (dims.size() == kDynamicDimSize && dims[0] == kDynamicDimValue) { | |||||
| change_shape_flag = true; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| static Status ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor> &inputs_dynamic) { | static Status ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTensor> &inputs_dynamic) { | ||||
| for (auto input : inputs) { | for (auto input : inputs) { | ||||
| auto input_desc = input.GetTensorDesc(); | auto input_desc = input.GetTensorDesc(); | ||||
| @@ -289,7 +265,7 @@ static Status ResetTensorVecShape(const vector<GeTensor> &inputs, vector<GeTenso | |||||
| bool is_const = false; | bool is_const = false; | ||||
| (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const); | (void)AttrUtils::GetBool(input_desc, CONST_ATTR_NAME_INPUT, is_const); | ||||
| if (!is_const && shape_ori.GetDims().size() > 0) { | |||||
| if (!is_const) { | |||||
| int64_t storage_format = FORMAT_NCHW; | int64_t storage_format = FORMAT_NCHW; | ||||
| if (ge::AttrUtils::GetInt(desc, ge::ATTR_NAME_STORAGE_FORMAT, storage_format) && | if (ge::AttrUtils::GetInt(desc, ge::ATTR_NAME_STORAGE_FORMAT, storage_format) && | ||||
| !ge::AttrUtils::SetListInt(desc, ge::ATTR_NAME_STORAGE_SHAPE, dynamic_shape_dims)) { | !ge::AttrUtils::SetListInt(desc, ge::ATTR_NAME_STORAGE_SHAPE, dynamic_shape_dims)) { | ||||
| @@ -645,6 +621,32 @@ namespace { | |||||
| } | } | ||||
| return is_need; | return is_need; | ||||
| } | } | ||||
| Status CheckDynamicSupport(GeModelPtr &ge_model, const ComputeGraphPtr &graph) { | |||||
| bool support_dynamic = true; | |||||
| bool is_dynamic = false; | |||||
| for (const auto &node : graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| if (node->GetType() == DATA || node->GetType() == CONSTANT || node->GetType() == CONSTANTOP || | |||||
| node->GetType() == NETOUTPUT) { | |||||
| continue; | |||||
| } | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if (AttrUtils::HasAttr(op_desc, kAttrSupportDynamicShape)) { | |||||
| is_dynamic = true; | |||||
| (void) AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic); | |||||
| if (!support_dynamic) { | |||||
| GELOGW("Node[%s] doesn't support dynamic shape.", node->GetName().c_str()); | |||||
| break; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (is_dynamic) { | |||||
| (void) AttrUtils::SetBool(ge_model, kAttrSupportDynamicShape, support_dynamic); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } | } | ||||
| Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, | Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, | ||||
| @@ -719,14 +721,14 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||||
| GELOGE(PARAM_INVALID, "GetSubgraphInstanceNameToModel is empty."); | GELOGE(PARAM_INVALID, "GetSubgraphInstanceNameToModel is empty."); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| const ComputeGraphPtr root_graph = ge_root_model->GetRootGraph(); | |||||
| GeModelPtr &ge_model = name_to_ge_model.begin()->second; | GeModelPtr &ge_model = name_to_ge_model.begin()->second; | ||||
| GE_CHK_STATUS_RET_NOLOG(CheckDynamicSupport(ge_model, root_graph)); | |||||
| GELOGD("The opType in op_desc_tmp is [%s]", op_desc_tmp->GetType().c_str()); | GELOGD("The opType in op_desc_tmp is [%s]", op_desc_tmp->GetType().c_str()); | ||||
| bool all_shape = false; | bool all_shape = false; | ||||
| bool dynamic_flag = false; | |||||
| (void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape); | (void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape); | ||||
| CheckShapeReset(op_desc, dynamic_flag); | |||||
| if (dynamic_flag || all_shape) { | |||||
| if (all_shape) { | |||||
| GELOGD("Get aicpu all_shape kernel!"); | GELOGD("Get aicpu all_shape kernel!"); | ||||
| vector<GeTensor> inputs_dynamic; | vector<GeTensor> inputs_dynamic; | ||||
| vector<GeTensor> outputs_dynamic; | vector<GeTensor> outputs_dynamic; | ||||
| @@ -374,63 +374,43 @@ bool IsContinuousInputConflict(const ge::NodePtr &node, const OpDescPtr &peer_op | |||||
| // If GetBool fail, is_peer_reference is false. | // If GetBool fail, is_peer_reference is false. | ||||
| (void) AttrUtils::GetBool(peer_op_desc, ATTR_NAME_REFERENCE, is_peer_reference); | (void) AttrUtils::GetBool(peer_op_desc, ATTR_NAME_REFERENCE, is_peer_reference); | ||||
| GE_IF_BOOL_EXEC(is_peer_reference, | GE_IF_BOOL_EXEC(is_peer_reference, | ||||
| std::string error = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) + | |||||
| std::string warning = "Current op" + FmtToStr(node->GetOpDesc()->GetName()) + | |||||
| " requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) + | " requires continuous input, while the previous op" + FmtToStr(peer_op_desc->GetName()) + | ||||
| " requires continuous output. There may be conflict between the two." + | |||||
| "This node is not supported now."; | |||||
| GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | |||||
| return true;); | |||||
| " is ref. There may be conflict between the two."; | |||||
| GELOGW("%s", warning.c_str()); | |||||
| return false;); | |||||
| return false; | return false; | ||||
| } | } | ||||
| Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) { | Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) { | ||||
| Status ret; | Status ret; | ||||
| // Stored nodes which need assign continuous input memory in `reverse topo order` | |||||
| std::vector<NodePtr> nodes_stack; | |||||
| std::map<NodePtr, uint32_t> node_2_continuous_type; | |||||
| // Traverse nodes | |||||
| for (auto &node : compute_graph_->GetAllNodes()) { | for (auto &node : compute_graph_->GetAllNodes()) { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| auto continuous_type = GetContinuousMemoryType(node->GetOpDesc()); | |||||
| uint32_t continuous_type; | |||||
| auto iter = node_2_continuous_type.find(node); | |||||
| if (iter == node_2_continuous_type.end()) { | |||||
| continuous_type = GetContinuousMemoryType(node->GetOpDesc()); | |||||
| node_2_continuous_type.emplace(node, continuous_type); | |||||
| } else { | |||||
| continuous_type = iter->second; | |||||
| } | |||||
| // Assign continuous input memory | // Assign continuous input memory | ||||
| bool continuous_input = ((continuous_type & kTypeInput) != 0) || ((continuous_type & kTypeInputNoPadding) != 0); | bool continuous_input = ((continuous_type & kTypeInput) != 0) || ((continuous_type & kTypeInputNoPadding) != 0); | ||||
| int64_t memory_type = RT_MEMORY_HBM; | |||||
| if (continuous_input) { | if (continuous_input) { | ||||
| int64_t mem_clean_start = 0; | |||||
| int64_t mem_clean_size = 0; | |||||
| GE_CHK_STATUS_RET(GetNodeMemoryType(node, memory_type, "input"), "Get node memory type failed."); | |||||
| ret = AssignContinuousInputMemory(node, mem_clean_start, mem_clean_size, memory_type, continuous_type); | |||||
| if (ret != ge::SUCCESS) { | |||||
| GELOGE(ret, "Assign continuous input memory failed!"); | |||||
| return ret; | |||||
| } | |||||
| // Clean up atomic address, eg, hcom node | |||||
| vector<int32_t> input_indexes; | |||||
| // If GetListInt fail, input_indexes is empty. | |||||
| (void) ge::AttrUtils::GetListInt(node->GetOpDesc(), ATOMIC_ATTR_INPUT_INDEX, input_indexes); | |||||
| if (!input_indexes.empty() && input_indexes[0] == kAllInputAddrIsAtomic) { | |||||
| // check whether there is an atomic conflict between the current node and the peer out node | |||||
| if (!CheckInputIsSupportAtomic(node)) { | |||||
| GELOGE(ge::FAILED, | |||||
| "There is an atomic conflict between the current node and the peer out node, not supported!"); | |||||
| return ge::FAILED; | |||||
| } | |||||
| const auto &in_control_anchor = node->GetInControlAnchor(); | |||||
| GE_CHECK_NOTNULL(in_control_anchor); | |||||
| for (const auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { | |||||
| GE_CHECK_NOTNULL(peer_out_control_anchor); | |||||
| auto peer_out_node = peer_out_control_anchor->GetOwnerNode(); | |||||
| if (peer_out_node->GetType() == ATOMICADDRCLEAN) { | |||||
| ret = SetAtomicCleanAttr(peer_out_node, {mem_clean_start}, {mem_clean_size}, memory_type); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Failed to set attr for atomic addr clean node %s.", peer_out_node->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | |||||
| if (AssignContinuousInputMemoryWithAtomicProcessDirectly(node, node_2_continuous_type)) { | |||||
| GE_CHK_STATUS_RET(AssignContinuousInputMemoryWithAtomicProcess(node, continuous_type), | |||||
| "Assign node %s continuous input memory failed.", node->GetName().c_str()) | |||||
| } else { | |||||
| nodes_stack.push_back(node); | |||||
| } | } | ||||
| } | } | ||||
| // Assign continuous output memory | // Assign continuous output memory | ||||
| int64_t memory_type = RT_MEMORY_HBM; | |||||
| bool continuous_output = ((continuous_type & kTypeOutput) != 0) || ((continuous_type & kTypeOutputNoPadding) != 0); | bool continuous_output = ((continuous_type & kTypeOutput) != 0) || ((continuous_type & kTypeOutputNoPadding) != 0); | ||||
| if (continuous_output) { | if (continuous_output) { | ||||
| GE_CHK_STATUS_RET(GetNodeMemoryType(node, memory_type, "output"), "Get node memory type failed."); | GE_CHK_STATUS_RET(GetNodeMemoryType(node, memory_type, "output"), "Get node memory type failed."); | ||||
| @@ -441,6 +421,18 @@ Status GraphMemoryAssigner::ReAssignContinuousMemory(bool is_loop_graph) { | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| // Assign continuous input memory in `reverse topo order` which stored before | |||||
| while (!nodes_stack.empty()){ | |||||
| auto node = nodes_stack.back(); | |||||
| nodes_stack.pop_back(); | |||||
| auto iter = node_2_continuous_type.find(node); | |||||
| if (iter == node_2_continuous_type.end()) { | |||||
| GELOGE(FAILED, "node %s has no continuous type!", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GE_CHK_STATUS_RET(AssignContinuousInputMemoryWithAtomicProcess(node, iter->second), | |||||
| "Assign node %s continuous input memory failed.", node->GetName().c_str()) | |||||
| } | |||||
| for (auto pair : memory_offset_) { | for (auto pair : memory_offset_) { | ||||
| GELOGD("After reassign continuous memory, memory type = %ld, memoffset = %zu.", pair.first, | GELOGD("After reassign continuous memory, memory type = %ld, memoffset = %zu.", pair.first, | ||||
| pair.second.mem_offset_); | pair.second.mem_offset_); | ||||
| @@ -463,7 +455,15 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, | |||||
| int64_t mem_offset = iter->second.mem_offset_; | int64_t mem_offset = iter->second.mem_offset_; | ||||
| int64_t extra_memory_size = 0; | int64_t extra_memory_size = 0; | ||||
| bool is_continuous_input_allocated = false; | bool is_continuous_input_allocated = false; | ||||
| (void) ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CONTINUOUS_INPUT_ALLOC, is_continuous_input_allocated); | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| vector<int64_t> output_list_this = op_desc->GetOutputOffset(); | |||||
| if (output_list_this.empty()) { | |||||
| std::string error = "node:" + FmtToStr(op_desc->GetName()) + "has no output offset"; | |||||
| GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| (void) ge::AttrUtils::GetBool(op_desc, ATTR_NAME_CONTINUOUS_INPUT_ALLOC, is_continuous_input_allocated); | |||||
| for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | for (auto &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
| GE_IF_BOOL_EXEC(in_data_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(in_data_anchor == nullptr, continue); | ||||
| auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | auto peer_out_data_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
| @@ -505,6 +505,17 @@ Status GraphMemoryAssigner::AssignContinuousInputMemory(const ge::NodePtr &node, | |||||
| // when continuous input has been allocated first input is beginning offset | // when continuous input has been allocated first input is beginning offset | ||||
| bool is_allocated_first_input = is_continuous_input_allocated && (in_data_anchor->GetIdx() == 0); | bool is_allocated_first_input = is_continuous_input_allocated && (in_data_anchor->GetIdx() == 0); | ||||
| if (is_allocated_first_input) { | if (is_allocated_first_input) { | ||||
| std::map<int32_t, int32_t> out2ins; | |||||
| GE_CHK_STATUS_RET(GetAllRef(node, out2ins), "Node: %s get all ref failed", node->GetName().c_str()); | |||||
| // output is beginning offset, set offset for input; only support this case now | |||||
| if (out2ins.size() == 1 && out2ins.begin()->second == 0) { | |||||
| output_list.at(peer_out_data_anchor->GetIdx()) = output_list_this.at(out2ins.begin()->first); | |||||
| peer_op_desc->SetOutputOffset(output_list); | |||||
| } else { | |||||
| GELOGW("Node %s out %d ref in %d with total ref numbers %zu", node->GetName().c_str(), out2ins.begin()->first, | |||||
| out2ins.begin()->second, out2ins.size()); | |||||
| } | |||||
| // first input is beginning offset | |||||
| mem_offset = output_list.at(peer_out_data_anchor->GetIdx()); | mem_offset = output_list.at(peer_out_data_anchor->GetIdx()); | ||||
| continuous_mem_start = output_list.at(peer_out_data_anchor->GetIdx()); | continuous_mem_start = output_list.at(peer_out_data_anchor->GetIdx()); | ||||
| } else { | } else { | ||||
| @@ -882,7 +893,7 @@ bool GraphMemoryAssigner::CheckInputIsSupportAtomic(const ge::NodePtr &node) { | |||||
| if ((peer_op_desc->GetType() == CONSTANTOP) || (peer_op_desc->GetType() == AIPP_DATA_TYPE) || | if ((peer_op_desc->GetType() == CONSTANTOP) || (peer_op_desc->GetType() == AIPP_DATA_TYPE) || | ||||
| (peer_op_desc->GetType() == VARIABLE)) { | (peer_op_desc->GetType() == VARIABLE)) { | ||||
| std::string error = "Op" + FmtToStr(node->GetName()) + "'s peer out node" + | std::string error = "Op" + FmtToStr(node->GetName()) + "'s peer out node" + | ||||
| FmtToStr(peer_op_desc->GetName()) + " is invalid, only support Constant/AippData/Variable"; | |||||
| FmtToStr(peer_op_desc->GetName()) + " is invalid, Constant/AippData/Variable is not supported"; | |||||
| GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | GE_ERRORLOG_AND_ERRORMSG(FAILED, error.c_str()); | ||||
| return false; | return false; | ||||
| } | } | ||||
| @@ -948,7 +959,7 @@ Status GraphMemoryAssigner::AssignAtomicOutputMemory(const ge::NodePtr &node, ve | |||||
| output_list[output_index] = iter->second.mem_offset_; | output_list[output_index] = iter->second.mem_offset_; | ||||
| std::string batch_label; | std::string batch_label; | ||||
| (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | ||||
| GELOGI("[IMAS]Atomic output : Set %s name[%s] optype[%s] output[%ld] offset to [%zu] stream_id[%ld] memtype[%ld] " | |||||
| GELOGI("[IMAS]Atomic output : Set %s name[%s] optype[%s] output[%ld] offset to [%zu] stream_id[%ld] memtype[%u] " | |||||
| "size[%ld] real_size[%ld] batch[%s].", compute_graph_->GetName().c_str(), op_desc->GetName().c_str(), | "size[%ld] real_size[%ld] batch[%s].", compute_graph_->GetName().c_str(), op_desc->GetName().c_str(), | ||||
| node->GetType().c_str(), output_index, iter->second.mem_offset_, op_desc->GetStreamId(), RT_MEMORY_HBM, | node->GetType().c_str(), output_index, iter->second.mem_offset_, op_desc->GetStreamId(), RT_MEMORY_HBM, | ||||
| size, size, batch_label.c_str()); | size, size, batch_label.c_str()); | ||||
| @@ -1028,7 +1039,7 @@ Status GraphMemoryAssigner::AssignOrdinaryAtomicWorkspaceMemory(const ge::OpDesc | |||||
| (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | ||||
| GELOGI( | GELOGI( | ||||
| "[IMAS]Atomic ordinary workspace : Set %s name[%s] optype[%s] workspace[%lu] offset to [%zu] stream_id[%ld] " | "[IMAS]Atomic ordinary workspace : Set %s name[%s] optype[%s] workspace[%lu] offset to [%zu] stream_id[%ld] " | ||||
| "memtype[%ld] size[%ld] real_size[%ld] batch[%s].", | |||||
| "memtype[%u] size[%ld] real_size[%ld] batch[%s].", | |||||
| compute_graph_->GetName().c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), workspace_index, | compute_graph_->GetName().c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), workspace_index, | ||||
| mem_type_iter->second.mem_offset_, op_desc->GetStreamId(), RT_MEMORY_HBM, workspace_size, workspace_size, | mem_type_iter->second.mem_offset_, op_desc->GetStreamId(), RT_MEMORY_HBM, workspace_size, workspace_size, | ||||
| batch_label.c_str()); | batch_label.c_str()); | ||||
| @@ -1069,7 +1080,7 @@ Status GraphMemoryAssigner::AssignFusionAtomicWorkspaceMemory(const ge::OpDescPt | |||||
| (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | (void)ge::AttrUtils::GetStr(op_desc, ATTR_NAME_BATCH_LABEL, batch_label); | ||||
| GELOGI( | GELOGI( | ||||
| "[IMAS]Atomic fusion workspace : Set %s name[%s] optype[%s] workspace[%lu] offset to [%zu] stream_id[%ld] " | "[IMAS]Atomic fusion workspace : Set %s name[%s] optype[%s] workspace[%lu] offset to [%zu] stream_id[%ld] " | ||||
| "memtype[%ld] ssize[%ld] real_size[%ld] batch[%s].", compute_graph_->GetName().c_str(), | |||||
| "memtype[%u] ssize[%ld] real_size[%ld] batch[%s].", compute_graph_->GetName().c_str(), | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), workspace_index, mem_type_iter->second.mem_offset_, | op_desc->GetName().c_str(), op_desc->GetType().c_str(), workspace_index, mem_type_iter->second.mem_offset_, | ||||
| op_desc->GetStreamId(), RT_MEMORY_HBM, workspace_size, workspace_size, batch_label.c_str()); | op_desc->GetStreamId(), RT_MEMORY_HBM, workspace_size, workspace_size, batch_label.c_str()); | ||||
| @@ -1502,4 +1513,92 @@ void GraphMemoryAssigner::PrintMemoryOffset() { | |||||
| pair.first, pair.second.mem_offset_); | pair.first, pair.second.mem_offset_); | ||||
| } | } | ||||
| } | } | ||||
| ge::Status GraphMemoryAssigner::GetAllRef(const NodePtr &node, map<int32_t, int32_t> &out2ins) { | |||||
| for (const auto &out_data_anchor : node->GetAllOutDataAnchors()) { | |||||
| int32_t reuse_in_index = -1; | |||||
| bool reuse_input_flag = GraphUtils::IsRefFromInput(out_data_anchor, reuse_in_index); | |||||
| if (reuse_input_flag) { | |||||
| if (node->GetInDataAnchor(reuse_in_index) != nullptr) { | |||||
| out2ins.emplace(out_data_anchor->GetIdx(), reuse_in_index); | |||||
| } else { | |||||
| GELOGE(FAILED, "Invalid reuse_input value %d on output %d of node %s, please check attr reuse_input", | |||||
| reuse_in_index, out_data_anchor->GetIdx(), node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| bool GraphMemoryAssigner::AssignContinuousInputMemoryWithAtomicProcessDirectly( | |||||
| const NodePtr &input_continuous_node, map<NodePtr, uint32_t> &node_2_continuous_type) { | |||||
| for (const auto &in_node : input_continuous_node->GetInDataNodes()) { | |||||
| auto iter = node_2_continuous_type.find(in_node); | |||||
| // In node's topo order in the front, so function can not be exception | |||||
| auto continuous_type = iter->second; | |||||
| bool continuous_input = ((continuous_type & kTypeInput) != 0) || ((continuous_type & kTypeInputNoPadding) != 0); | |||||
| if (continuous_input) { | |||||
| GELOGI("node %s 's precursor node %s need assign continuous input memory, store node firstly.", | |||||
| input_continuous_node->GetName().c_str(), in_node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| for (const auto &out_node : input_continuous_node->GetOutDataNodes()) { | |||||
| auto continuous_type = GetContinuousMemoryType(out_node->GetOpDesc()); | |||||
| node_2_continuous_type.emplace(out_node, continuous_type); | |||||
| bool continuous_input = ((continuous_type & kTypeInput) != 0) || ((continuous_type & kTypeInputNoPadding) != 0); | |||||
| if (continuous_input) { | |||||
| GELOGI("node %s 's succeed node %s need assign continuous input memory, store node firstly.", | |||||
| input_continuous_node->GetName().c_str(), out_node->GetName().c_str()); | |||||
| return false; | |||||
| } | |||||
| } | |||||
| return true; | |||||
| } | |||||
| ge::Status GraphMemoryAssigner::AssignContinuousInputMemoryWithAtomicProcess(const NodePtr &input_continuous_node, | |||||
| uint32_t continuous_type) { | |||||
| int64_t mem_clean_start = 0; | |||||
| int64_t mem_clean_size = 0; | |||||
| int64_t memory_type = RT_MEMORY_HBM; | |||||
| GE_CHK_STATUS_RET(GetNodeMemoryType(input_continuous_node, memory_type, "input"), "Get node memory type failed."); | |||||
| auto ret = AssignContinuousInputMemory(input_continuous_node, mem_clean_start, mem_clean_size, memory_type, continuous_type); | |||||
| if (ret != ge::SUCCESS) { | |||||
| GELOGE(ret, "Assign continuous input memory failed!"); | |||||
| return ret; | |||||
| } | |||||
| // Clean up atomic address, eg, hcom node | |||||
| vector<int32_t> input_indexes; | |||||
| // If GetListInt fail, input_indexes is empty. | |||||
| (void)ge::AttrUtils::GetListInt(input_continuous_node->GetOpDesc(), ATOMIC_ATTR_INPUT_INDEX, input_indexes); | |||||
| if (!input_indexes.empty() && input_indexes[0] == kAllInputAddrIsAtomic) { | |||||
| // check whether there is an atomic conflict between the current node and the peer out node | |||||
| if (!CheckInputIsSupportAtomic(input_continuous_node)) { | |||||
| GELOGE(ge::FAILED, "There is an atomic conflict between the current node and the peer out node, not supported!"); | |||||
| return ge::FAILED; | |||||
| } | |||||
| const auto &in_control_anchor = input_continuous_node->GetInControlAnchor(); | |||||
| GE_CHECK_NOTNULL(in_control_anchor); | |||||
| for (const auto &peer_out_control_anchor : in_control_anchor->GetPeerOutControlAnchors()) { | |||||
| GE_CHECK_NOTNULL(peer_out_control_anchor); | |||||
| auto peer_out_node = peer_out_control_anchor->GetOwnerNode(); | |||||
| if (peer_out_node->GetType() == ATOMICADDRCLEAN) { | |||||
| ret = SetAtomicCleanAttr(peer_out_node, {mem_clean_start}, {mem_clean_size}, memory_type); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Failed to set attr for atomic addr clean node %s.", peer_out_node->GetName().c_str()); | |||||
| return ret; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -125,6 +125,14 @@ class GraphMemoryAssigner { | |||||
| ge::Status ReAssignAtomicMemory(bool is_loop_graph); | ge::Status ReAssignAtomicMemory(bool is_loop_graph); | ||||
| ge::Status GetAllRef(const NodePtr &node, std::map<int32_t, int32_t> &out2ins); | |||||
| bool AssignContinuousInputMemoryWithAtomicProcessDirectly(const NodePtr &input_continuous_node, | |||||
| std::map<NodePtr, uint32_t> &node_2_continuous_type); | |||||
| ge::Status AssignContinuousInputMemoryWithAtomicProcess(const NodePtr &input_continuous_node, | |||||
| uint32_t continuous_type); | |||||
| ge::Status FilterAtomicNodesForMemoryAssign(map<string, map<NodePtr, vector<NodePtr>>> &normal_atomic_nodes_map, | ge::Status FilterAtomicNodesForMemoryAssign(map<string, map<NodePtr, vector<NodePtr>>> &normal_atomic_nodes_map, | ||||
| map<string, vector<NodePtr>> &connecting_output_atomic_nodes); | map<string, vector<NodePtr>> &connecting_output_atomic_nodes); | ||||
| @@ -35,7 +35,7 @@ using std::vector; | |||||
| namespace { | namespace { | ||||
| const int64_t kTaskNumPerNormalNode = 3; | const int64_t kTaskNumPerNormalNode = 3; | ||||
| const int64_t kTaskNumPerHcclNode = 200; | |||||
| const int64_t kTaskNumPerHcclNode = 245; | |||||
| const char *const kTrueStr = "true"; | const char *const kTrueStr = "true"; | ||||
| const char *const kFalseStr = "false"; | const char *const kFalseStr = "false"; | ||||
| @@ -728,6 +728,7 @@ Status DavinciModel::Init(void *dev_ptr, size_t mem_size, void *weight_ptr, size | |||||
| GE_CHK_RT_RET(rtSetCtxINFMode((fp_ceiling_mode != "0"))); | GE_CHK_RT_RET(rtSetCtxINFMode((fp_ceiling_mode != "0"))); | ||||
| } | } | ||||
| SetProfileTime(MODEL_LOAD_END); | |||||
| // collect profiling for ge | // collect profiling for ge | ||||
| GE_CHK_STATUS_RET(InitModelProfile(), "Init model profile failed"); | GE_CHK_STATUS_RET(InitModelProfile(), "Init model profile failed"); | ||||
| auto &profiling_manager = ProfilingManager::Instance(); | auto &profiling_manager = ProfilingManager::Instance(); | ||||
| @@ -2279,8 +2280,12 @@ Status DavinciModel::SinkModelProfile() { | |||||
| } | } | ||||
| // stream id info | // stream id info | ||||
| uint32_t streamId = profile.fusion_info.stream_id; | |||||
| reporter_data.data = (unsigned char *)&streamId; | |||||
| uint32_t stream_id = 0; | |||||
| auto iter = profiler_report_op_info_.find(fusion_op_name); | |||||
| if (iter != profiler_report_op_info_.end()) { | |||||
| stream_id = iter->second.second; | |||||
| } | |||||
| reporter_data.data = (unsigned char *)&stream_id; | |||||
| reporter_data.dataLen = sizeof(int32_t); | reporter_data.dataLen = sizeof(int32_t); | ||||
| GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | GE_CHK_BOOL_EXEC(prof_mgr.CallMsprofReport(reporter_data) == 0, return FAILED, | ||||
| "Reporter data fail, model id:%u.", this->Id()); | "Reporter data fail, model id:%u.", this->Id()); | ||||
| @@ -3278,8 +3283,8 @@ bool DavinciModel::CheckInputAndModelSize(const int64_t &input_size, const int64 | |||||
| } | } | ||||
| // The input and model input size can not be exactly equal because user input is not definite. | // The input and model input size can not be exactly equal because user input is not definite. | ||||
| if ((input_size + kDataMemAlignSizeCompare) < op_size) { | if ((input_size + kDataMemAlignSizeCompare) < op_size) { | ||||
| GELOGE(FAILED, "Input size [%ld] can not be smaller than op size [%ld] after 64-byte alignment", input_size, | |||||
| op_size); | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, | |||||
| "Input size [%ld] can not be smaller than op size [%ld] after 64-byte alignment", input_size, op_size); | |||||
| return false; | return false; | ||||
| } | } | ||||
| return true; | return true; | ||||
| @@ -3329,27 +3334,28 @@ Status DavinciModel::UpdateIoTaskArgs(const std::map<uint32_t, ZeroCopyOffset> & | |||||
| string input_or_output = "input"; | string input_or_output = "input"; | ||||
| is_input ? input_or_output = "input" : input_or_output = "output"; | is_input ? input_or_output = "input" : input_or_output = "output"; | ||||
| if (blobs.size() != data_info.size()) { | if (blobs.size() != data_info.size()) { | ||||
| GELOGE(FAILED, "Verify %s data num failed: model requires %zu, but user actually feeds %zu", | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Verify %s data num failed: model requires %zu, but user actually feeds %zu", | |||||
| input_or_output.c_str(), data_info.size(), blobs.size()); | input_or_output.c_str(), data_info.size(), blobs.size()); | ||||
| return FAILED; | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | } | ||||
| for (const auto &data : data_info) { | for (const auto &data : data_info) { | ||||
| if (data.first >= blobs.size()) { // check data index. | if (data.first >= blobs.size()) { // check data index. | ||||
| GELOGE(FAILED, "Verify %s data num failed: can not find No.%u data, because user only feeds %zu", | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "Verify %s data num failed: can not find No.%u data, because user only feeds %zu", | |||||
| input_or_output.c_str(), data.first, blobs.size()); | input_or_output.c_str(), data.first, blobs.size()); | ||||
| return FAILED; | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | } | ||||
| const DataBuffer &buffer = blobs[data.first]; // index of data. | const DataBuffer &buffer = blobs[data.first]; // index of data. | ||||
| if (buffer.data == nullptr) { | if (buffer.data == nullptr) { | ||||
| GELOGE(FAILED, "data_buf.data is nullptr, index=%u", data.first); | |||||
| return FAILED; | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "data_buf.data is nullptr, index=%u", data.first); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | } | ||||
| if (!CheckInputAndModelSize(buffer.length, data.second.GetDataSize(), is_dynamic)) { | if (!CheckInputAndModelSize(buffer.length, data.second.GetDataSize(), is_dynamic)) { | ||||
| GELOGE(FAILED, "Check input size and model size failed, op[%s]", data.second.GetOpName().c_str()); | |||||
| return FAILED; | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, | |||||
| "Check input size and model size failed, op[%s]", data.second.GetOpName().c_str()); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | } | ||||
| void *basic_addr = data.second.GetBasicAddr(); | void *basic_addr = data.second.GetBasicAddr(); | ||||
| @@ -3357,9 +3363,10 @@ Status DavinciModel::UpdateIoTaskArgs(const std::map<uint32_t, ZeroCopyOffset> & | |||||
| if (copy_only_addrs_.count(basic_addr) > 0) { | if (copy_only_addrs_.count(basic_addr) > 0) { | ||||
| if (is_input) { | if (is_input) { | ||||
| GELOGI("[IMAS] Find addr %p need direct copy from user malloc input %p", basic_addr, buffer.data); | GELOGI("[IMAS] Find addr %p need direct copy from user malloc input %p", basic_addr, buffer.data); | ||||
| if (rtMemcpy(basic_addr, data_size, buffer.data, buffer.length, RT_MEMCPY_DEVICE_TO_DEVICE) != RT_ERROR_NONE) { | |||||
| GELOGE(FAILED, "Non-zero copy data node copy failed"); | |||||
| return FAILED; | |||||
| rtError_t rt_ret = rtMemcpy(basic_addr, data_size, buffer.data, buffer.length, RT_MEMCPY_DEVICE_TO_DEVICE); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(rt_ret, "Non-zero copy data node copy failed"); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | |||||
| } | } | ||||
| } | } | ||||
| GELOGI("No need to exeucte zero copy task because this addr %p need direct copy.", basic_addr); | GELOGI("No need to exeucte zero copy task because this addr %p need direct copy.", basic_addr); | ||||
| @@ -3380,7 +3387,7 @@ Status DavinciModel::UpdateIoTaskArgs(const std::map<uint32_t, ZeroCopyOffset> & | |||||
| } | } | ||||
| uintptr_t addr_val = reinterpret_cast<uintptr_t>(addr); | uintptr_t addr_val = reinterpret_cast<uintptr_t>(addr); | ||||
| if (task.UpdateTaskParam(addr_val, buffer_addr) != SUCCESS) { | if (task.UpdateTaskParam(addr_val, buffer_addr) != SUCCESS) { | ||||
| return FAILED; | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -55,16 +55,18 @@ const char *const kDeleteCustOp = "deleteCustOp"; | |||||
| const int kTimeSpecNano = 1000000000; | const int kTimeSpecNano = 1000000000; | ||||
| const int kTimeSpecMiro = 1000000; | const int kTimeSpecMiro = 1000000; | ||||
| const int kOpNameMaxSize = 100; | const int kOpNameMaxSize = 100; | ||||
| #pragma pack(push, 1) | |||||
| struct CustAicpuSoBuf { | struct CustAicpuSoBuf { | ||||
| uint64_t kernelSoBuf; | uint64_t kernelSoBuf; | ||||
| uint32_t kernelSoBufLen; | uint32_t kernelSoBufLen; | ||||
| uint64_t kernelSoName; | uint64_t kernelSoName; | ||||
| uint32_t kernelSoNameLen; | uint32_t kernelSoNameLen; | ||||
| } __attribute__((packed)); | |||||
| }; | |||||
| struct BatchLoadOpFromBufArgs { | struct BatchLoadOpFromBufArgs { | ||||
| uint32_t soNum; | uint32_t soNum; | ||||
| uint64_t args; | uint64_t args; | ||||
| } __attribute__((packed)); | |||||
| }; | |||||
| #pragma pack(pop) | |||||
| } // namespace | } // namespace | ||||
| DumpProperties ModelManager::dump_properties_; | DumpProperties ModelManager::dump_properties_; | ||||
| @@ -328,7 +330,8 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||||
| GELOGE(FAILED, "davinci_model is nullptr"); | GELOGE(FAILED, "davinci_model is nullptr"); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano + | |||||
| timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond | |||||
| davinci_model->SetId(model_id); | davinci_model->SetId(model_id); | ||||
| davinci_model->SetDeviceId(GetContext().DeviceId()); | davinci_model->SetDeviceId(GetContext().DeviceId()); | ||||
| @@ -355,10 +358,6 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||||
| InsertModel(model_id, davinci_model); | InsertModel(model_id, davinci_model); | ||||
| GELOGI("Parse model %u success.", model_id); | GELOGI("Parse model %u success.", model_id); | ||||
| davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano + | |||||
| timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond | |||||
| davinci_model->SetProfileTime(MODEL_LOAD_END); | |||||
| } while (0); | } while (0); | ||||
| GE_CHK_RT(rtDeviceReset(static_cast<int32_t>(GetContext().DeviceId()))); | GE_CHK_RT(rtDeviceReset(static_cast<int32_t>(GetContext().DeviceId()))); | ||||
| @@ -1085,6 +1084,8 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Make shared failed since other exception raise"); | GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Make shared failed since other exception raise"); | ||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | return ACL_ERROR_GE_MEMORY_ALLOCATION; | ||||
| } | } | ||||
| davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano + | |||||
| timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond | |||||
| ret = davinci_model->Assign(ge_model); | ret = davinci_model->Assign(ge_model); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGW("assign model failed."); | GELOGW("assign model failed."); | ||||
| @@ -1121,11 +1122,7 @@ Status ModelManager::LoadModelOffline(uint32_t &model_id, const ModelData &model | |||||
| InsertModel(model_id, davinci_model); | InsertModel(model_id, davinci_model); | ||||
| GELOGI("Parse model %u success.", model_id); | GELOGI("Parse model %u success.", model_id); | ||||
| davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano + | |||||
| timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond | |||||
| davinci_model->SetProfileTime(MODEL_LOAD_END); | |||||
| GE_IF_BOOL_EXEC(ret == SUCCESS, device_count++); | GE_IF_BOOL_EXEC(ret == SUCCESS, device_count++); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } while (0); | } while (0); | ||||
| @@ -29,6 +29,10 @@ | |||||
| #include "hybrid/node_executor/aicpu/aicpu_ext_info.h" | #include "hybrid/node_executor/aicpu/aicpu_ext_info.h" | ||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| namespace { | |||||
| const char *const kAicpuAllshape = "_AllShape"; | |||||
| } // namespace | |||||
| namespace ge { | namespace ge { | ||||
| Status KernelExTaskInfo::InitTaskExtInfo(const std::string &ext_info, const OpDescPtr &op_desc) { | Status KernelExTaskInfo::InitTaskExtInfo(const std::string &ext_info, const OpDescPtr &op_desc) { | ||||
| if (ext_info.empty()) { | if (ext_info.empty()) { | ||||
| @@ -50,6 +54,25 @@ Status KernelExTaskInfo::InitTaskExtInfo(const std::string &ext_info, const OpDe | |||||
| GE_CHK_STATUS_RET(ext_handle->UpdateExecuteMode(true), "UpdateExecuteMode failed."); | GE_CHK_STATUS_RET(ext_handle->UpdateExecuteMode(true), "UpdateExecuteMode failed."); | ||||
| GELOGD("Update aicpu_task ext_info bit_map execute mode to 1."); | GELOGD("Update aicpu_task ext_info bit_map execute mode to 1."); | ||||
| bool all_shape = false; | |||||
| (void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape); | |||||
| if (all_shape) { | |||||
| GELOGD("Aicpu all_shape kernel need to update io shape."); | |||||
| for (uint32_t i = 0; i < num_inputs; i++) { | |||||
| auto input_desc = op_desc->MutableInputDesc(i); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| GE_CHK_STATUS_RET(ext_handle->UpdateInputShapeAndType(i, *input_desc), | |||||
| "Input[%u] update input shape failed.", i); | |||||
| } | |||||
| if (unknown_type != DEPEND_COMPUTE) { | |||||
| for (uint32_t j = 0; j < num_outputs; j++) { | |||||
| auto output_desc = op_desc->MutableOutputDesc(j); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| GE_CHK_STATUS_RET(ext_handle->UpdateOutputShapeAndType(j, *output_desc), | |||||
| "Output[%u] update output shape failed.", j); | |||||
| } | |||||
| } | |||||
| } | |||||
| auto rt_ret = rtMalloc(&ext_info_addr_, ext_handle->GetExtInfoLen(), RT_MEMORY_HBM); | auto rt_ret = rtMalloc(&ext_info_addr_, ext_handle->GetExtInfoLen(), RT_MEMORY_HBM); | ||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, | ||||
| GELOGE(RT_FAILED, "rtMalloc ext_info error: 0x%X, size=%zu", rt_ret, ext_info.size()); | GELOGE(RT_FAILED, "rtMalloc ext_info error: 0x%X, size=%zu", rt_ret, ext_info.size()); | ||||
| @@ -43,6 +43,7 @@ constexpr int64_t kInvalidGroupKey = -1; | |||||
| constexpr uint32_t kSKTSingleSize = 1; | constexpr uint32_t kSKTSingleSize = 1; | ||||
| const char *kIsLastNode = "is_last_node"; | const char *kIsLastNode = "is_last_node"; | ||||
| const char *kIsFirstNode = "is_first_node"; | const char *kIsFirstNode = "is_first_node"; | ||||
| const char *const kAicpuAllshape = "_AllShape"; | |||||
| const int64_t kCloseSkt = 100; | const int64_t kCloseSkt = 100; | ||||
| const uint32_t kAddrLen = sizeof(void *); | const uint32_t kAddrLen = sizeof(void *); | ||||
| const int kBaseInt = 10; | const int kBaseInt = 10; | ||||
| @@ -985,6 +986,23 @@ Status KernelTaskInfo::InitAicpuTaskExtInfo(const std::string &ext_info) { | |||||
| GE_CHK_STATUS_RET(ext_handle->UpdateExecuteMode(true), "UpdateExecuteMode failed."); | GE_CHK_STATUS_RET(ext_handle->UpdateExecuteMode(true), "UpdateExecuteMode failed."); | ||||
| GELOGD("Update aicpu_task ext_info bit_map execute mode to 1."); | GELOGD("Update aicpu_task ext_info bit_map execute mode to 1."); | ||||
| bool all_shape = false; | |||||
| (void)AttrUtils::GetBool(op_desc_, kAicpuAllshape, all_shape); | |||||
| if (all_shape) { | |||||
| GELOGD("Aicpu all_shape kernel need to update io shape."); | |||||
| for (uint32_t i = 0; i < num_inputs; i++) { | |||||
| auto input_desc = op_desc_->MutableInputDesc(i); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| GE_CHK_STATUS_RET(ext_handle->UpdateInputShapeAndType(i, *input_desc), | |||||
| "Input[%u] update input shape failed.", i); | |||||
| } | |||||
| for (uint32_t j = 0; j < num_outputs; j++) { | |||||
| auto output_desc = op_desc_->MutableOutputDesc(j); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| GE_CHK_STATUS_RET(ext_handle->UpdateOutputShapeAndType(j, *output_desc), | |||||
| "Output[%u] update output shape failed.", j); | |||||
| } | |||||
| } | |||||
| auto rt_ret = rtMalloc(&aicpu_ext_info_addr_, ext_handle->GetExtInfoLen(), RT_MEMORY_HBM); | auto rt_ret = rtMalloc(&aicpu_ext_info_addr_, ext_handle->GetExtInfoLen(), RT_MEMORY_HBM); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "rtMalloc ext_info error: 0x%X, size=%zu", rt_ret, ext_info.size()); | GELOGE(RT_FAILED, "rtMalloc ext_info error: 0x%X, size=%zu", rt_ret, ext_info.size()); | ||||
| @@ -59,7 +59,6 @@ | |||||
| #include "graph/passes/iterator_op_pass.h" | #include "graph/passes/iterator_op_pass.h" | ||||
| #include "graph/passes/link_gen_mask_nodes_pass.h" | #include "graph/passes/link_gen_mask_nodes_pass.h" | ||||
| #include "graph/passes/mark_graph_unknown_status_pass.h" | #include "graph/passes/mark_graph_unknown_status_pass.h" | ||||
| #include "graph/passes/dynamic_single_op_reset_shape_pass.h" | |||||
| #include "graph/passes/merge_pass.h" | #include "graph/passes/merge_pass.h" | ||||
| #include "graph/passes/merge_input_memcpy_pass.h" | #include "graph/passes/merge_input_memcpy_pass.h" | ||||
| #include "graph/passes/merge_to_stream_merge_pass.h" | #include "graph/passes/merge_to_stream_merge_pass.h" | ||||
| @@ -643,22 +642,11 @@ Status GraphManager::ReplaceSubgraphWithOriGraph(const ComputeGraphPtr &compute_ | |||||
| Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph, GraphPartitioner &partitioner) { | Status GraphManager::SetSubgraph(uint64_t session_id, ComputeGraphPtr compute_graph, GraphPartitioner &partitioner) { | ||||
| GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
| PassManager pass_for_dynamic_shape_reset_optimize; | |||||
| GE_CHK_STATUS_RET(pass_for_dynamic_shape_reset_optimize.AddPass( | |||||
| "SetSubgraph::AfterSetSubgraph::DynamicSingleOpResetShapePass", new (std::nothrow) DynamicSingleOpResetShapePass)) | |||||
| GE_TIMESTAMP_START(pass_for_dynamic_shape_reset_optimize); | |||||
| Status ret = pass_for_dynamic_shape_reset_optimize.Run(compute_graph); | |||||
| GE_TIMESTAMP_END(pass_for_dynamic_shape_reset_optimize, "SetSubgraph::AfterSetSubgraph"); | |||||
| if (ret != SUCCESS && ret != NOT_CHANGED) { | |||||
| GELOGE(ret, "Run passes when optimize subgraph failed"); | |||||
| return ret; | |||||
| } | |||||
| auto sub_graph_map = partitioner.GetSubGraphMap(); | auto sub_graph_map = partitioner.GetSubGraphMap(); | ||||
| GELOGD("Directly optimize subgraph with build mode:%s, and step:%s.", | GELOGD("Directly optimize subgraph with build mode:%s, and step:%s.", | ||||
| options_.build_mode.c_str(), | options_.build_mode.c_str(), | ||||
| options_.build_step.c_str()); | options_.build_step.c_str()); | ||||
| ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id); | |||||
| Status ret = OptimizeSubGraphWithMultiThreads(compute_graph, sub_graph_map, session_id); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Multiply optimize subgraph failed"); | GELOGE(ret, "Multiply optimize subgraph failed"); | ||||
| return ret; | return ret; | ||||
| @@ -3032,6 +3020,7 @@ Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGra | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| GE_TIMESTAMP_EVENT_END(GraphPartitionDynamicShape, "OptimizeSubgraph::GraphPartitionDynamicShape"); | GE_TIMESTAMP_EVENT_END(GraphPartitionDynamicShape, "OptimizeSubgraph::GraphPartitionDynamicShape"); | ||||
| GE_DUMP(compute_graph, "AfterDynamicShapePartition"); | |||||
| GE_TIMESTAMP_START(GraphPartition); | GE_TIMESTAMP_START(GraphPartition); | ||||
| GraphPartitioner &partitioner = GetCompilerStages(graph_node->GetGraphId()).partitioner; | GraphPartitioner &partitioner = GetCompilerStages(graph_node->GetGraphId()).partitioner; | ||||
| ret = partitioner.Partition(compute_graph, GraphPartitioner::kPartitioning); | ret = partitioner.Partition(compute_graph, GraphPartitioner::kPartitioning); | ||||
| @@ -84,15 +84,14 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, HcclDataType | |||||
| int32_t size = 0; | int32_t size = 0; | ||||
| GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclTypeSize(data_type, size), "GetHcomCount: GetHcclTypeSize fail!"); | GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclTypeSize(data_type, size), "GetHcomCount: GetHcclTypeSize fail!"); | ||||
| if (op_desc->GetType() == HCOMRECEIVE) { | if (op_desc->GetType() == HCOMRECEIVE) { | ||||
| vector<int64_t> shape_dims; | |||||
| bool ret = ge::AttrUtils::GetListInt(op_desc, HCOM_ATTR_SHAPE, shape_dims); | |||||
| if (ret == false) { | |||||
| GELOGE(PARAM_INVALID, "op:HcomReceive, op desc no attr: shape."); | |||||
| return PARAM_INVALID; | |||||
| for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { | |||||
| int64_t output_size = 0; | |||||
| GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(i)); | |||||
| GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetOutputDescPtr(i), output_size), | |||||
| "Get size from TensorDesc failed, op: %s, output index: %zu.", op_desc->GetName().c_str(), i); | |||||
| output_size = (output_size + align_size - 1) / align_size * align_size; | |||||
| total_size += output_size; | |||||
| } | } | ||||
| ge::GeShape shape = ge::GeShape(shape_dims); | |||||
| int64_t input_size = shape.GetShapeSize() * size; | |||||
| total_size = (input_size + align_size - 1) / align_size * align_size; | |||||
| } else { | } else { | ||||
| for (size_t i = 0; i < op_desc->GetInputsSize(); i++) { | for (size_t i = 0; i < op_desc->GetInputsSize(); i++) { | ||||
| int64_t input_size = 0; | int64_t input_size = 0; | ||||
| @@ -742,6 +742,12 @@ Status GraphOptimize::HandleMemoryRWConflict(ComputeGraphPtr &compute_graph) { | |||||
| if (node->GetType() == NETOUTPUT && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { | if (node->GetType() == NETOUTPUT && AttrUtils::HasAttr(node->GetOpDesc(), ATTR_NAME_PARENT_NODE_INDEX)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| bool identity_reserved = false; | |||||
| AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_CANNOT_BE_DELETED, identity_reserved); | |||||
| if (identity_reserved) { | |||||
| GELOGD("Identity [%s] need to be reserved", node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| if (node->GetType() == IDENTITY || node->GetType() == READVARIABLEOP) { | if (node->GetType() == IDENTITY || node->GetType() == READVARIABLEOP) { | ||||
| // split identity | // split identity | ||||
| ret = SplitIdentity(node); | ret = SplitIdentity(node); | ||||
| @@ -607,6 +607,9 @@ Status ge::GraphPartitioner::AddPartitionsToGraphNode(vector<ge::SubGraphInfoPtr | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| auto &engine_name = graph_info_.partitions_.at(sub_graph); | auto &engine_name = graph_info_.partitions_.at(sub_graph); | ||||
| (void)AttrUtils::SetStr(sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName()); | |||||
| GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(), | |||||
| compute_graph->GetName().c_str()); | |||||
| GE_DUMP(sub_graph, sub_graph->GetName() + "_" + mode_2_str_[graph_info_.mode_]); | GE_DUMP(sub_graph, sub_graph->GetName() + "_" + mode_2_str_[graph_info_.mode_]); | ||||
| if (!session_graph_id.empty()) { | if (!session_graph_id.empty()) { | ||||
| GE_IF_BOOL_EXEC(!AttrUtils::SetStr(sub_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id), | GE_IF_BOOL_EXEC(!AttrUtils::SetStr(sub_graph, ATTR_NAME_SESSION_GRAPH_ID, session_graph_id), | ||||
| @@ -614,9 +617,6 @@ Status ge::GraphPartitioner::AddPartitionsToGraphNode(vector<ge::SubGraphInfoPtr | |||||
| } | } | ||||
| // flush parent node of subgraph | // flush parent node of subgraph | ||||
| sub_graph->SetParentNode(compute_graph->GetParentNode()); | sub_graph->SetParentNode(compute_graph->GetParentNode()); | ||||
| (void)AttrUtils::SetStr(*sub_graph, ATTR_NAME_PARENT_GRAPH_NAME, compute_graph->GetName()); | |||||
| GELOGD("set attr success. subgraph(%s) with parent graph(%s)", sub_graph->GetName().c_str(), | |||||
| compute_graph->GetName().c_str()); | |||||
| auto sgi = MakeShared<SubGraphInfo>(); | auto sgi = MakeShared<SubGraphInfo>(); | ||||
| if (sgi == nullptr) { | if (sgi == nullptr) { | ||||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "[GraphPartitioner]: MakeShared sub graph info failed."); | GELOGE(GE_GRAPH_PARAM_NULLPTR, "[GraphPartitioner]: MakeShared sub graph info failed."); | ||||
| @@ -805,8 +805,19 @@ Status ge::GraphPartitioner::SplitSubGraphs(ge::ComputeGraphPtr compute_graph) { | |||||
| GELOGD("In anchor index is %d", AnchorUtils::GetIdx(in_anchor)); | GELOGD("In anchor index is %d", AnchorUtils::GetIdx(in_anchor)); | ||||
| for (auto &peer_out_anchor : in_anchor->GetPeerAnchors()) { | for (auto &peer_out_anchor : in_anchor->GetPeerAnchors()) { | ||||
| GELOGD("Peer out anchor index is %d", AnchorUtils::GetIdx(peer_out_anchor)); | GELOGD("Peer out anchor index is %d", AnchorUtils::GetIdx(peer_out_anchor)); | ||||
| // All nodes have a copy in corresponding_node_in_partitions_, so function at can not be execption | |||||
| auto parent_node = graph_info_.corresponding_node_in_partitions_.at(peer_out_anchor->GetOwnerNode()); | |||||
| // Normally, all nodes have a copy in corresponding_node_in_partitions_, so function at can not be exception | |||||
| auto iter = graph_info_.corresponding_node_in_partitions_.find(peer_out_anchor->GetOwnerNode()); | |||||
| if (iter == graph_info_.corresponding_node_in_partitions_.end()) { | |||||
| GELOGE(GRAPH_FAILED, | |||||
| "[SpiltSubGraphs]: node[%s]id[%ld]'s parent_node[%s]id[%ld]" | |||||
| "should make corresponding in advance", | |||||
| node->GetOpDesc()->GetName().c_str(), node->GetOpDesc()->GetId(), | |||||
| peer_out_anchor->GetOwnerNode()->GetOpDesc()->GetName().c_str(), | |||||
| peer_out_anchor->GetOwnerNode()->GetOpDesc()->GetId()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| auto parent_node = iter->second; | |||||
| GE_CHECK_NOTNULL(parent_node); | |||||
| GELOGD("Parent node name is %s", parent_node->GetName().c_str()); | GELOGD("Parent node name is %s", parent_node->GetName().c_str()); | ||||
| // add edge | // add edge | ||||
| auto src_anchor = parent_node->GetOutAnchor(AnchorUtils::GetIdx(peer_out_anchor)); | auto src_anchor = parent_node->GetOutAnchor(AnchorUtils::GetIdx(peer_out_anchor)); | ||||
| @@ -52,6 +52,7 @@ Status StagePartitioner::Partition() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| GE_DUMP(root_graph_, "BeforeStagePartition"); | |||||
| if (SplitStageLevel() != SUCCESS) { | if (SplitStageLevel() != SUCCESS) { | ||||
| GELOGE(FAILED, "Split graph-stage for graph %s failed.", root_graph_->GetName().c_str()); | GELOGE(FAILED, "Split graph-stage for graph %s failed.", root_graph_->GetName().c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -74,6 +75,7 @@ Status StagePartitioner::Partition() { | |||||
| "maybe stage_level was not set correctly.", root_graph_->GetName().c_str()); | "maybe stage_level was not set correctly.", root_graph_->GetName().c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| GE_DUMP(root_graph_, "AfterStagePartition"); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -26,9 +26,9 @@ namespace { | |||||
| namespace ge { | namespace ge { | ||||
| Status CondPass::Run(NodePtr &node) { | Status CondPass::Run(NodePtr &node) { | ||||
| ComputeGraphPtr graph = nullptr; | ComputeGraphPtr graph = nullptr; | ||||
| OutDataAnchorPtr cond_out_anchor = nullptr; | |||||
| OutDataAnchorPtr peer_out_anchor = nullptr; | |||||
| InDataAnchorPtr cond_in_anchor = nullptr; | InDataAnchorPtr cond_in_anchor = nullptr; | ||||
| Status ret = GetCondInfo(node, graph, cond_out_anchor, cond_in_anchor); | |||||
| Status ret = GetCondInfo(node, graph, peer_out_anchor, cond_in_anchor); | |||||
| if (ret == NOT_CHANGED) { | if (ret == NOT_CHANGED) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } else if (ret != SUCCESS) { | } else if (ret != SUCCESS) { | ||||
| @@ -48,18 +48,18 @@ Status CondPass::Run(NodePtr &node) { | |||||
| if (cond_tensor.MutableShape().GetDim(0) == UNKNOWN_DIM_NUM) { | if (cond_tensor.MutableShape().GetDim(0) == UNKNOWN_DIM_NUM) { | ||||
| GELOGI("Output tensor rank of Cond is unknown."); | GELOGI("Output tensor rank of Cond is unknown."); | ||||
| if (cond_tensor.GetDataType() == DT_STRING) { | if (cond_tensor.GetDataType() == DT_STRING) { | ||||
| GE_CHK_STATUS_RET(HandleStringCond(graph, cond_out_anchor, cond_in_anchor), "HandleStringCond for %s failed.", | |||||
| GE_CHK_STATUS_RET(HandleStringCond(graph, peer_out_anchor, cond_in_anchor), "HandleStringCond for %s failed.", | |||||
| op_desc->GetName().c_str()) | op_desc->GetName().c_str()) | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| if (!cond_tensor.GetShape().IsScalar()) { | if (!cond_tensor.GetShape().IsScalar()) { | ||||
| GE_CHK_STATUS_RET(HandleNonScalarCond(graph, cond_out_anchor, cond_in_anchor), "HandleNonScalarCond for %s failed.", | |||||
| GE_CHK_STATUS_RET(HandleNonScalarCond(graph, peer_out_anchor, cond_in_anchor), "HandleNonScalarCond for %s failed.", | |||||
| op_desc->GetName().c_str()) | op_desc->GetName().c_str()) | ||||
| } else { | } else { | ||||
| switch (cond_tensor.GetDataType()) { | switch (cond_tensor.GetDataType()) { | ||||
| case DT_STRING: | case DT_STRING: | ||||
| GE_CHK_STATUS_RET(HandleStringCond(graph, cond_out_anchor, cond_in_anchor), "HandleStringCond for %s failed.", | |||||
| GE_CHK_STATUS_RET(HandleStringCond(graph, peer_out_anchor, cond_in_anchor), "HandleStringCond for %s failed.", | |||||
| op_desc->GetName().c_str()) | op_desc->GetName().c_str()) | ||||
| break; | break; | ||||
| case DT_BOOL: | case DT_BOOL: | ||||
| @@ -69,7 +69,7 @@ Status CondPass::Run(NodePtr &node) { | |||||
| case DT_INT16: | case DT_INT16: | ||||
| case DT_INT8: | case DT_INT8: | ||||
| case DT_INT64: | case DT_INT64: | ||||
| GE_CHK_STATUS_RET(HandleScalarCond(graph, cond_out_anchor, cond_in_anchor, cond_tensor.GetDataType()), | |||||
| GE_CHK_STATUS_RET(HandleScalarCond(graph, peer_out_anchor, cond_in_anchor, cond_tensor.GetDataType()), | |||||
| "HandleScalarCond for %s failed.", op_desc->GetName().c_str()) | "HandleScalarCond for %s failed.", op_desc->GetName().c_str()) | ||||
| break; | break; | ||||
| case DT_INT32: | case DT_INT32: | ||||
| @@ -96,21 +96,21 @@ Status CondPass::Run(NodePtr &node) { | |||||
| /// @brief Get cond info for if / while | /// @brief Get cond info for if / while | ||||
| /// @param [in] node: If / While op | /// @param [in] node: If / While op | ||||
| /// @param [out] graph: owner_graph of if node / while_cond subgraph | /// @param [out] graph: owner_graph of if node / while_cond subgraph | ||||
| /// @param [out] cond_out_anchor: peer_cond_anchor | |||||
| /// @param [out] peer_out_anchor: peer_cond_anchor | |||||
| /// @param [out] cond_in_anchor: cond_input | /// @param [out] cond_in_anchor: cond_input | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status CondPass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, | |||||
| Status CondPass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor, | |||||
| InDataAnchorPtr &cond_in_anchor) { | InDataAnchorPtr &cond_in_anchor) { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| std::string type = node->GetType(); | std::string type = node->GetType(); | ||||
| if (kIfOpTypes.count(type) != 0) { | if (kIfOpTypes.count(type) != 0) { | ||||
| if (GetCondInfoForIf(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { | |||||
| if (GetCondInfoForIf(node, graph, peer_out_anchor, cond_in_anchor) != SUCCESS) { | |||||
| GELOGE(FAILED, "Get cond_info for if node failed."); | GELOGE(FAILED, "Get cond_info for if node failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| } else if (kWhileOpTypes.count(type) != 0) { | } else if (kWhileOpTypes.count(type) != 0) { | ||||
| if (GetCondInfoForWhile(node, graph, cond_out_anchor, cond_in_anchor) != SUCCESS) { | |||||
| if (GetCondInfoForWhile(node, graph, peer_out_anchor, cond_in_anchor) != SUCCESS) { | |||||
| GELOGE(FAILED, "Get cond_info for while node failed."); | GELOGE(FAILED, "Get cond_info for while node failed."); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -126,19 +126,19 @@ Status CondPass::GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDat | |||||
| /// @brief Get cond info for if node | /// @brief Get cond info for if node | ||||
| /// @param [in] node: If op | /// @param [in] node: If op | ||||
| /// @param [out] graph: owner_graph of if node | /// @param [out] graph: owner_graph of if node | ||||
| /// @param [out] cond_out_anchor: peer_cond_anchor | |||||
| /// @param [out] peer_out_anchor: peer_cond_anchor | |||||
| /// @param [out] cond_in_anchor: cond_input of if | /// @param [out] cond_in_anchor: cond_input of if | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status CondPass::GetCondInfoForIf(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, | |||||
| Status CondPass::GetCondInfoForIf(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor, | |||||
| InDataAnchorPtr &cond_in_anchor) { | InDataAnchorPtr &cond_in_anchor) { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| graph = node->GetOwnerComputeGraph(); | graph = node->GetOwnerComputeGraph(); | ||||
| GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
| cond_in_anchor = node->GetInDataAnchor(IF_COND_INPUT); | cond_in_anchor = node->GetInDataAnchor(IF_COND_INPUT); | ||||
| GE_CHECK_NOTNULL(cond_in_anchor); | GE_CHECK_NOTNULL(cond_in_anchor); | ||||
| cond_out_anchor = cond_in_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(cond_out_anchor); | |||||
| peer_out_anchor = cond_in_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -146,11 +146,11 @@ Status CondPass::GetCondInfoForIf(const NodePtr &node, ComputeGraphPtr &graph, O | |||||
| /// @brief Get cond info for while node | /// @brief Get cond info for while node | ||||
| /// @param [in] node: While op | /// @param [in] node: While op | ||||
| /// @param [out] graph: while_cond subgraph | /// @param [out] graph: while_cond subgraph | ||||
| /// @param [out] cond_out_anchor: peer_cond_anchor | |||||
| /// @param [out] peer_out_anchor: peer_cond_anchor | |||||
| /// @param [out] cond_in_anchor: input of NetOutput in cond_graph | /// @param [out] cond_in_anchor: input of NetOutput in cond_graph | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status CondPass::GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, | |||||
| Status CondPass::GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor, | |||||
| InDataAnchorPtr &cond_in_anchor) { | InDataAnchorPtr &cond_in_anchor) { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| OpDescPtr op_desc = node->GetOpDesc(); | OpDescPtr op_desc = node->GetOpDesc(); | ||||
| @@ -177,8 +177,8 @@ Status CondPass::GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph | |||||
| cond_in_anchor = net_output_node->GetInDataAnchor(0); | cond_in_anchor = net_output_node->GetInDataAnchor(0); | ||||
| GE_CHECK_NOTNULL(cond_in_anchor); | GE_CHECK_NOTNULL(cond_in_anchor); | ||||
| cond_out_anchor = cond_in_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(cond_out_anchor); | |||||
| peer_out_anchor = cond_in_anchor->GetPeerOutAnchor(); | |||||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -186,56 +186,56 @@ Status CondPass::GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph | |||||
| /// | /// | ||||
| /// @brief Process Cond Op with non-scalar cond_input: cond->Size->If / NetOutput(while) | /// @brief Process Cond Op with non-scalar cond_input: cond->Size->If / NetOutput(while) | ||||
| /// @param [in] graph | /// @param [in] graph | ||||
| /// @param [in] out_anchor: peer_cond_anchor | |||||
| /// @param [in] in_anchor: cond_input | |||||
| /// @param [in] peer_out_anchor: peer_cond_anchor | |||||
| /// @param [in] cond_in_anchor: cond_input | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status CondPass::HandleNonScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, | |||||
| const InDataAnchorPtr &in_anchor) { | |||||
| Status CondPass::HandleNonScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, | |||||
| const InDataAnchorPtr &cond_in_anchor) { | |||||
| GELOGI("Handle cond with non-scalar cond-input."); | GELOGI("Handle cond with non-scalar cond-input."); | ||||
| return InsertNode(graph, out_anchor, in_anchor, SIZE); | |||||
| return InsertNode(graph, peer_out_anchor, cond_in_anchor, SIZE); | |||||
| } | } | ||||
| /// | /// | ||||
| /// @brief Process Cond Op with scalar-string cond_input: cond->StringLength(int32)->If / NetOutput(while) | /// @brief Process Cond Op with scalar-string cond_input: cond->StringLength(int32)->If / NetOutput(while) | ||||
| /// @param [in] graph | /// @param [in] graph | ||||
| /// @param [in] out_anchor: peer_cond_anchor | |||||
| /// @param [in] in_anchor: cond_input | |||||
| /// @param [in] peer_out_anchor: peer_cond_anchor | |||||
| /// @param [in] cond_in_anchor: cond_input | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status CondPass::HandleStringCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, | |||||
| const InDataAnchorPtr &in_anchor) { | |||||
| Status CondPass::HandleStringCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, | |||||
| const InDataAnchorPtr &cond_in_anchor) { | |||||
| GELOGI("Handle cond with scalar-string cond-input."); | GELOGI("Handle cond with scalar-string cond-input."); | ||||
| return InsertNode(graph, out_anchor, in_anchor, kStringLength); | |||||
| return InsertNode(graph, peer_out_anchor, cond_in_anchor, kStringLength); | |||||
| } | } | ||||
| /// | /// | ||||
| /// @brief Process Cond Op with scalar cond_input: cond->Cast(2int32)->If / NetOutput(while) | /// @brief Process Cond Op with scalar cond_input: cond->Cast(2int32)->If / NetOutput(while) | ||||
| /// @param [in] graph | /// @param [in] graph | ||||
| /// @param [in] out_anchor: peer_cond_anchor | |||||
| /// @param [in] in_anchor: cond_input | |||||
| /// @param [in] peer_out_anchor: peer_cond_anchor | |||||
| /// @param [in] cond_in_anchor: cond_input | |||||
| /// @param [in] src_type | /// @param [in] src_type | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status CondPass::HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, | |||||
| const InDataAnchorPtr &in_anchor, DataType src_type) { | |||||
| GE_CHECK_NOTNULL(in_anchor); | |||||
| GE_CHECK_NOTNULL(out_anchor); | |||||
| GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc()); | |||||
| Status CondPass::HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, | |||||
| const InDataAnchorPtr &cond_in_anchor, DataType src_type) { | |||||
| GE_CHECK_NOTNULL(cond_in_anchor); | |||||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||||
| GE_CHECK_NOTNULL(peer_out_anchor->GetOwnerNode()->GetOpDesc()); | |||||
| GELOGI("Handle cond with scalar cond-input."); | GELOGI("Handle cond with scalar cond-input."); | ||||
| GeTensorDesc tensor = out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()); | |||||
| std::string cast_name = in_anchor->GetOwnerNode()->GetName() + "_Cast"; | |||||
| GeTensorDesc tensor = peer_out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(peer_out_anchor->GetIdx()); | |||||
| std::string cast_name = cond_in_anchor->GetOwnerNode()->GetName() + "_Cast"; | |||||
| NodePtr cast_node = AddCastNode(graph, cast_name, tensor, src_type, DT_INT32); | NodePtr cast_node = AddCastNode(graph, cast_name, tensor, src_type, DT_INT32); | ||||
| if (cast_node == nullptr) { | if (cast_node == nullptr) { | ||||
| GELOGE(FAILED, "Add Cast node failed, name:%s.", cast_name.c_str()); | GELOGE(FAILED, "Add Cast node failed, name:%s.", cast_name.c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| if (GraphUtils::InsertNodeAfter(out_anchor, { in_anchor }, cast_node) != GRAPH_SUCCESS) { | |||||
| if (GraphUtils::InsertNodeAfter(peer_out_anchor, { cond_in_anchor }, cast_node) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Insert Cast node %s between %s->%s failed.", | GELOGE(FAILED, "Insert Cast node %s between %s->%s failed.", | ||||
| cast_node->GetName().c_str(), out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
| cast_node->GetName().c_str(), peer_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| cond_in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -245,27 +245,27 @@ Status CondPass::HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnc | |||||
| /// | /// | ||||
| /// @brief Insert node | /// @brief Insert node | ||||
| /// @param [in] graph | /// @param [in] graph | ||||
| /// @param [in] out_anchor | |||||
| /// @param [in] in_anchor | |||||
| /// @param [in] peer_out_anchor | |||||
| /// @param [in] in_data_anchor | |||||
| /// @param [in] type | /// @param [in] type | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status CondPass::InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, | |||||
| const InDataAnchorPtr &in_anchor, const std::string &type) { | |||||
| GE_CHECK_NOTNULL(out_anchor); | |||||
| GE_CHECK_NOTNULL(in_anchor); | |||||
| Status CondPass::InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, | |||||
| const InDataAnchorPtr &in_data_anchor, const std::string &type) { | |||||
| GE_CHECK_NOTNULL(peer_out_anchor); | |||||
| GE_CHECK_NOTNULL(in_data_anchor); | |||||
| GELOGD("Begin to insert %s node.", type.c_str()); | GELOGD("Begin to insert %s node.", type.c_str()); | ||||
| GE_CHECK_NOTNULL(out_anchor->GetOwnerNode()->GetOpDesc()); | |||||
| GE_CHECK_NOTNULL(in_anchor->GetOwnerNode()->GetOpDesc()); | |||||
| GeTensorDesc in_tensor = out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx()); | |||||
| GeTensorDesc out_tensor = in_anchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(out_anchor->GetIdx()); | |||||
| GE_CHECK_NOTNULL(peer_out_anchor->GetOwnerNode()->GetOpDesc()); | |||||
| GE_CHECK_NOTNULL(in_data_anchor->GetOwnerNode()->GetOpDesc()); | |||||
| GeTensorDesc in_tensor = peer_out_anchor->GetOwnerNode()->GetOpDesc()->GetOutputDesc(peer_out_anchor->GetIdx()); | |||||
| GeTensorDesc out_tensor = in_data_anchor->GetOwnerNode()->GetOpDesc()->GetInputDesc(in_data_anchor->GetIdx()); | |||||
| out_tensor.SetDataType(DT_INT32); | out_tensor.SetDataType(DT_INT32); | ||||
| out_tensor.SetOriginDataType(DT_INT32); | out_tensor.SetOriginDataType(DT_INT32); | ||||
| out_tensor.SetShape(in_tensor.GetShape()); | out_tensor.SetShape(in_tensor.GetShape()); | ||||
| out_tensor.SetOriginShape(in_tensor.GetOriginShape()); | out_tensor.SetOriginShape(in_tensor.GetOriginShape()); | ||||
| OpDescBuilder op_desc_builder(in_anchor->GetOwnerNode()->GetName() + "_" + type, type); | |||||
| OpDescBuilder op_desc_builder(in_data_anchor->GetOwnerNode()->GetName() + "_" + type, type); | |||||
| OpDescPtr op_desc = op_desc_builder.AddInput("x", in_tensor).AddOutput("y", out_tensor).Build(); | OpDescPtr op_desc = op_desc_builder.AddInput("x", in_tensor).AddOutput("y", out_tensor).Build(); | ||||
| if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
| GELOGE(FAILED, "Create op_desc failed."); | GELOGE(FAILED, "Create op_desc failed."); | ||||
| @@ -278,10 +278,10 @@ Status CondPass::InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr | |||||
| } | } | ||||
| AddRePassNode(new_node); | AddRePassNode(new_node); | ||||
| if (GraphUtils::InsertNodeAfter(out_anchor, { in_anchor }, new_node) != GRAPH_SUCCESS) { | |||||
| if (GraphUtils::InsertNodeAfter(peer_out_anchor, { in_data_anchor }, new_node) != GRAPH_SUCCESS) { | |||||
| GELOGE(FAILED, "Insert %s node %s between %s->%s failed.", type.c_str(), | GELOGE(FAILED, "Insert %s node %s between %s->%s failed.", type.c_str(), | ||||
| new_node->GetName().c_str(), out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| in_anchor->GetOwnerNode()->GetName().c_str()); | |||||
| new_node->GetName().c_str(), peer_out_anchor->GetOwnerNode()->GetName().c_str(), | |||||
| in_data_anchor->GetOwnerNode()->GetName().c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -28,76 +28,76 @@ class CondPass : public BaseNodePass { | |||||
| /// @brief Get cond info for if / while | /// @brief Get cond info for if / while | ||||
| /// @param [in] node: If / While op | /// @param [in] node: If / While op | ||||
| /// @param [out] graph: owner_graph of if node / while_cond subgraph | /// @param [out] graph: owner_graph of if node / while_cond subgraph | ||||
| /// @param [out] cond_out_anchor: peer_cond_anchor | |||||
| /// @param [out] peer_out_anchor: peer_cond_anchor | |||||
| /// @param [out] cond_in_anchor: cond_input | /// @param [out] cond_in_anchor: cond_input | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| static Status GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, | |||||
| InDataAnchorPtr &cond_in_anchor); | |||||
| static Status GetCondInfo(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor, | |||||
| InDataAnchorPtr &cond_in_anchor); | |||||
| /// | /// | ||||
| /// @brief Get cond info for if node | /// @brief Get cond info for if node | ||||
| /// @param [in] node: If op | /// @param [in] node: If op | ||||
| /// @param [out] graph: owner_graph of if node | /// @param [out] graph: owner_graph of if node | ||||
| /// @param [out] cond_out_anchor: peer_cond_anchor | |||||
| /// @param [out] peer_out_anchor: peer_cond_anchor | |||||
| /// @param [out] cond_in_anchor: cond_input of if | /// @param [out] cond_in_anchor: cond_input of if | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| static Status GetCondInfoForIf(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, | |||||
| InDataAnchorPtr &cond_in_anchor); | |||||
| static Status GetCondInfoForIf(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor, | |||||
| InDataAnchorPtr &cond_in_anchor); | |||||
| /// | /// | ||||
| /// @brief Get cond info for while node | /// @brief Get cond info for while node | ||||
| /// @param [in] node: While op | /// @param [in] node: While op | ||||
| /// @param [out] graph: while_cond subgraph | /// @param [out] graph: while_cond subgraph | ||||
| /// @param [out] cond_out_anchor: peer_cond_anchor | |||||
| /// @param [out] peer_out_anchor: peer_cond_anchor | |||||
| /// @param [out] cond_in_anchor: input of NetOutput in cond_graph | /// @param [out] cond_in_anchor: input of NetOutput in cond_graph | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| static Status GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &cond_out_anchor, | |||||
| InDataAnchorPtr &cond_in_anchor); | |||||
| static Status GetCondInfoForWhile(const NodePtr &node, ComputeGraphPtr &graph, OutDataAnchorPtr &peer_out_anchor, | |||||
| InDataAnchorPtr &cond_in_anchor); | |||||
| /// | /// | ||||
| /// @brief Process Cond Op with non-scalar cond_input | /// @brief Process Cond Op with non-scalar cond_input | ||||
| /// @param [in] graph | /// @param [in] graph | ||||
| /// @param [in] out_anchor: peer_cond_anchor | |||||
| /// @param [in] in_anchor: cond_input | |||||
| /// @param [in] peer_out_anchor: peer_cond_anchor | |||||
| /// @param [in] cond_in_anchor: cond_input | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status HandleNonScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, | |||||
| const InDataAnchorPtr &in_anchor); | |||||
| Status HandleNonScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, | |||||
| const InDataAnchorPtr &cond_in_anchor); | |||||
| /// | /// | ||||
| /// @brief Process Cond Op with scalar-string cond_input | /// @brief Process Cond Op with scalar-string cond_input | ||||
| /// @param [in] graph | /// @param [in] graph | ||||
| /// @param [in] out_anchor: peer_cond_anchor | |||||
| /// @param [in] in_anchor: cond_input | |||||
| /// @param [in] peer_out_anchor: peer_cond_anchor | |||||
| /// @param [in] cond_in_anchor: cond_input | |||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status HandleStringCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, | |||||
| const InDataAnchorPtr &in_anchor); | |||||
| Status HandleStringCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, | |||||
| const InDataAnchorPtr &cond_in_anchor); | |||||
| /// | /// | ||||
| /// @brief Process Cond Op with scalar cond_input | /// @brief Process Cond Op with scalar cond_input | ||||
| /// @param [in] graph | /// @param [in] graph | ||||
| /// @param [in] out_anchor: peer_cond_anchor | |||||
| /// @param [in] in_anchor: cond_input | |||||
| /// @param [in] peer_out_anchor: peer_cond_anchor | |||||
| /// @param [in] cond_in_anchor: cond_input | |||||
| /// @param [in] src_type | /// @param [in] src_type | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, | |||||
| const InDataAnchorPtr &in_anchor, DataType src_type); | |||||
| Status HandleScalarCond(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, | |||||
| const InDataAnchorPtr &cond_in_anchor, DataType src_type); | |||||
| /// | /// | ||||
| /// @brief Insert node | /// @brief Insert node | ||||
| /// @param [in] graph | /// @param [in] graph | ||||
| /// @param [in] out_anchor | |||||
| /// @param [in] in_anchor | |||||
| /// @param [in] peer_out_anchor | |||||
| /// @param [in] in_data_anchor | |||||
| /// @param [in] type | /// @param [in] type | ||||
| /// @return Status | /// @return Status | ||||
| /// | /// | ||||
| Status InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &out_anchor, | |||||
| const InDataAnchorPtr &in_anchor, const std::string &type); | |||||
| Status InsertNode(const ComputeGraphPtr &graph, const OutDataAnchorPtr &peer_out_anchor, | |||||
| const InDataAnchorPtr &in_data_anchor, const std::string &type); | |||||
| /// | /// | ||||
| /// @brief Add cast node | /// @brief Add cast node | ||||
| @@ -1,155 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/passes/dynamic_single_op_reset_shape_pass.h" | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/tensor_utils.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| const int64_t kDynamicShapeDim = -2; | |||||
| const char *const kEngineNameAiCpu = "DNN_VM_AICPU_ASCEND"; | |||||
| const char *const kEngineNameAiCpuTf = "DNN_VM_AICPU"; | |||||
| } // namespace | |||||
| Status DynamicSingleOpResetShapePass::Run(ComputeGraphPtr graph) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| std::shared_ptr<GELib> instance = ge::GELib::GetInstance(); | |||||
| if (instance == nullptr || !instance->InitFlag()) { | |||||
| GELOGE(ge::GE_CLI_GE_NOT_INITIALIZED, "Run CompileNodesPass failed."); | |||||
| return ge::GE_CLI_GE_NOT_INITIALIZED; | |||||
| } | |||||
| // pass if graph has not aicpu node. | |||||
| bool is_not_aicpu = false; | |||||
| if (CheckAllAicpuNodes(graph, is_not_aicpu) != SUCCESS) { | |||||
| GELOGE(ge::GE_CLI_GE_NOT_INITIALIZED, "Check if graph has not aicpu node failed."); | |||||
| return ge::GE_CLI_GE_NOT_INITIALIZED; | |||||
| } | |||||
| if (is_not_aicpu) { | |||||
| GELOGI("The graph [%s] has not aicpu node, whose aicpu nodes would not be reset dynamic shape", | |||||
| graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| for (const auto &node : graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| // pass input and output node | |||||
| if (node->GetType() == DATA || node->GetType() == CONSTANT || node->GetType() == CONSTANTOP || | |||||
| node->GetType() == NETOUTPUT) { | |||||
| continue; | |||||
| } | |||||
| // pass node without attr: ATTR_SINGLE_OP_SCENE | |||||
| bool single_aicpu_unknown = false; | |||||
| if (!AttrUtils::GetBool(node->GetOpDesc(), ATTR_SINGLE_OP_SCENE, single_aicpu_unknown) || | |||||
| !single_aicpu_unknown) { | |||||
| continue; | |||||
| } | |||||
| // reset aicpu shape to unknown shape | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| if (ResetOpShape(op_desc) != SUCCESS) { | |||||
| GELOGE(ge::GE_CLI_GE_NOT_INITIALIZED, "Reset node[%s] dynamic shapr failed.", node->GetName().c_str()); | |||||
| return ge::GE_CLI_GE_NOT_INITIALIZED; | |||||
| } | |||||
| GELOGD("Reset dynamic aicpu node [%s] shape success!", node->GetName().c_str()); | |||||
| } | |||||
| GELOGD("Reset dynamic aicpu nodes shape of graph [%s] success!", graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DynamicSingleOpResetShapePass::CheckAllAicpuNodes(const ComputeGraphPtr &graph, bool &is_not_aicpu) { | |||||
| is_not_aicpu = false; | |||||
| for (const auto &node : graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| // pass input and output node | |||||
| if (node->GetType() == DATA || node->GetType() == CONSTANT || node->GetType() == CONSTANTOP || | |||||
| node->GetType() == NETOUTPUT) { | |||||
| continue; | |||||
| } | |||||
| // find if there are aicpu nodes. | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| string engine_name = op_desc->GetOpEngineName(); | |||||
| if (engine_name.empty()) { | |||||
| GELOGE(GRAPH_FAILED, "Get engine failed of node[%s].", node->GetName().c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| if (engine_name != kEngineNameAiCpu && engine_name != kEngineNameAiCpuTf) { | |||||
| is_not_aicpu = true; | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| bool DynamicSingleOpResetShapePass::CheckIfConstInput(const GeTensorDescPtr &input_tensor_desc) { | |||||
| bool is_const = false; | |||||
| (void)AttrUtils::GetBool(input_tensor_desc, CONST_ATTR_NAME_INPUT, is_const); | |||||
| return is_const; | |||||
| } | |||||
| Status DynamicSingleOpResetShapePass::ResetOpShape(OpDescPtr &op_desc) { | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| std::vector<int64_t> dynamic_shape_dims = {kDynamicShapeDim}; | |||||
| GeShape dynamic_shape(dynamic_shape_dims); | |||||
| (void)ResetInputTensorShape(op_desc, dynamic_shape); | |||||
| (void)ResetOutputTensorShape(op_desc, dynamic_shape); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DynamicSingleOpResetShapePass::ResetInputTensorShape(OpDescPtr &op_desc, | |||||
| const GeShape &dynamic_shape) { | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| for (size_t i = 0; i < op_desc->GetAllInputsDesc().size(); i++) { | |||||
| auto input_desc = op_desc->MutableInputDesc(static_cast<uint32_t>(i)); | |||||
| GE_CHECK_NOTNULL(input_desc); | |||||
| // pass scalar input desc | |||||
| auto dims_ori = input_desc->GetShape().GetDims(); | |||||
| if (dims_ori.size() == 0) { | |||||
| continue; | |||||
| } | |||||
| // pass const input | |||||
| if (CheckIfConstInput(input_desc)) { | |||||
| continue; | |||||
| } | |||||
| input_desc->SetShape(dynamic_shape); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DynamicSingleOpResetShapePass::ResetOutputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape) { | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| for (size_t i = 0; i < op_desc->GetAllOutputsDesc().size(); i++) { | |||||
| auto output_desc = op_desc->MutableOutputDesc(static_cast<uint32_t>(i)); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| // pass scalar input desc | |||||
| auto output_dims_ori = output_desc->GetShape().GetDims(); | |||||
| if (output_dims_ori.size() == 0) { | |||||
| continue; | |||||
| } | |||||
| output_desc->SetShape(dynamic_shape); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,36 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_PASSES_DYNAMIC_SINGLE_OP_RESET_SHAPE_PASS_H_ | |||||
| #define GE_GRAPH_PASSES_DYNAMIC_SINGLE_OP_RESET_SHAPE_PASS_H_ | |||||
| #include "graph/graph.h" | |||||
| #include "inc/graph_pass.h" | |||||
| #include "init/gelib.h" | |||||
| namespace ge { | |||||
| class DynamicSingleOpResetShapePass : public GraphPass { | |||||
| public: | |||||
| Status Run(ComputeGraphPtr graph) override; | |||||
| private: | |||||
| Status ResetOpShape(OpDescPtr &op_desc); | |||||
| Status ResetInputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape); | |||||
| Status ResetOutputTensorShape(OpDescPtr &op_desc, const GeShape &dynamic_shape); | |||||
| Status CheckAllAicpuNodes(const ComputeGraphPtr &graph, bool &is_not_aicpu); | |||||
| bool CheckIfConstInput(const GeTensorDescPtr &input_tensor_desc); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_PASSES_DYNAMIC_SINGLE_OP_RESET_SHAPE_PASS_H_ | |||||
| @@ -17,6 +17,7 @@ | |||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| namespace ge { | namespace ge { | ||||
| const size_t kTwoInputNodesSize = 2; | const size_t kTwoInputNodesSize = 2; | ||||
| @@ -32,53 +33,110 @@ Status MarkAgnosticPass::Run(ComputeGraphPtr graph) { | |||||
| GELOGD("Op: %s, Index:0,has no input", node->GetName().c_str()); | GELOGD("Op: %s, Index:0,has no input", node->GetName().c_str()); | ||||
| continue; | continue; | ||||
| } | } | ||||
| AttrUtils::SetInt(op_tensor, "_format_continuous", 1); | |||||
| AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | |||||
| AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1})); | |||||
| AttrUtils::SetInt(op_tensor, ATTR_NAME_FORMAT_CONTINUOUS, 1); | |||||
| AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_FORMAT_AGNOSTIC, 1); | |||||
| AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_FORMAT_AGNOSTIC_EXCEPT_INPUT, std::vector<int64_t>({1})); | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (node_type == IDENTITY) { | if (node_type == IDENTITY) { | ||||
| GELOGD("Mark format agnostic for identity node %s", node->GetName().c_str()); | GELOGD("Mark format agnostic for identity node %s", node->GetName().c_str()); | ||||
| AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | |||||
| AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_FORMAT_AGNOSTIC, 1); | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (node_type == REFMERGE || node_type == REFSWITCH) { | if (node_type == REFMERGE || node_type == REFSWITCH) { | ||||
| GELOGD("Mark format agnostic for regmerge and refswitch node %s", node->GetName().c_str()); | GELOGD("Mark format agnostic for regmerge and refswitch node %s", node->GetName().c_str()); | ||||
| AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | |||||
| AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_input", std::vector<int64_t>({1})); | |||||
| AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_FORMAT_AGNOSTIC, 1); | |||||
| AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_FORMAT_AGNOSTIC_EXCEPT_INPUT, std::vector<int64_t>({1})); | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (node_type == MERGE) { | if (node_type == MERGE) { | ||||
| GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); | GELOGD("Mark format agnostic and continuous for merge node %s", node->GetName().c_str()); | ||||
| const auto &input_nodes = node->GetInAllNodes(); | |||||
| /// Enter-----------+ | |||||
| /// +-> Merge | |||||
| /// NextIteration---+ | |||||
| if (input_nodes.size() == kTwoInputNodesSize) { | |||||
| if (input_nodes.at(0)->GetType() == ENTER && input_nodes.at(1)->GetType() == NEXTITERATION) { | |||||
| continue; | |||||
| } | |||||
| } | |||||
| const OpDescPtr op_desc = node->GetOpDesc(); | |||||
| const GeTensorDescPtr op_tensor = op_desc->MutableOutputDesc(0); | |||||
| if (op_tensor == nullptr) { | |||||
| GELOGD("Op: %s, Index:0,has no output", node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| AttrUtils::SetInt(op_tensor, "_format_continuous", 1); | |||||
| // Merge----------->NetOutput only set format_cofntinuous attr | |||||
| const auto &output_nodes = node->GetOutAllNodes(); | |||||
| if (output_nodes.size() > 0) { | |||||
| // Always set continuous attr for merge output 0 | |||||
| GE_CHK_STATUS_RET(SetContinuousAttr(node, {0})); | |||||
| // Merge-->NetOutput only set merge output 0's continuous attr | |||||
| const auto &output_nodes = node->GetOutDataNodes(); | |||||
| if (!output_nodes.empty()) { | |||||
| if (output_nodes.at(0)->GetType() == NETOUTPUT) { | if (output_nodes.at(0)->GetType() == NETOUTPUT) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| } | } | ||||
| AttrUtils::SetInt(node->GetOpDesc(), "_format_agnostic", 1); | |||||
| AttrUtils::SetListInt(node->GetOpDesc(), "_format_agnostic_except_output", std::vector<int64_t>({1})); | |||||
| // Set format agnostic attr for merge in and out tensordesc | |||||
| AttrUtils::SetInt(node->GetOpDesc(), ATTR_NAME_FORMAT_AGNOSTIC, 1); | |||||
| AttrUtils::SetListInt(node->GetOpDesc(), ATTR_NAME_FORMAT_AGNOSTIC_EXCEPT_OUTPUT, std::vector<int64_t>({1})); | |||||
| // Set attr for enter and nextiteration | |||||
| if (HandWhileLoop(node) != SUCCESS) { | |||||
| GELOGE(FAILED, "Node: %s type merge handle while loop failed", node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| continue; | continue; | ||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | |||||
| } | |||||
| bool MarkAgnosticPass::IsWhileLoop(const NodePtr &merge_node, NodePtr &enter, NodePtr &next) { | |||||
| auto node_type = NodeUtils::GetNodeType(*merge_node); | |||||
| if (node_type != MERGE) { | |||||
| GELOGW("Node %s type %s is not merge op.", merge_node->GetName().c_str(), node_type.c_str()); | |||||
| return false; | |||||
| } | |||||
| /// Enter-----------+ | |||||
| /// +-> Merge | |||||
| /// NextIteration---+ | |||||
| auto input_nodes = merge_node->GetInDataNodes(); | |||||
| if (input_nodes.size() != kTwoInputNodesSize) { | |||||
| GELOGD("Node %s type %s with [data input size[%zu]] is not enter-merge-nextiteration target.", | |||||
| merge_node->GetName().c_str(), node_type.c_str(), input_nodes.size()); | |||||
| return false; | |||||
| } | |||||
| auto in_node0 = input_nodes.at(0); | |||||
| auto in_node1 = input_nodes.at(1); | |||||
| auto in_type0 = NodeUtils::GetNodeType(in_node0); | |||||
| auto in_type1 = NodeUtils::GetNodeType(in_node1); | |||||
| if ((in_type0 != ENTER || in_type1 != NEXTITERATION) && (in_type0 != NEXTITERATION || in_type1 != ENTER)) { | |||||
| GELOGD("Node %s type %s with [data input0's type %s input1's type %s] is not enter-merge-nextiteration target.", | |||||
| merge_node->GetName().c_str(), node_type.c_str(), in_type0.c_str(), in_type1.c_str()); | |||||
| return false; | |||||
| } | |||||
| enter = in_node0; | |||||
| next = in_node1; | |||||
| return true; | |||||
| } | |||||
| Status MarkAgnosticPass::HandWhileLoop(const NodePtr &node) { | |||||
| NodePtr enter = nullptr; | |||||
| NodePtr next = nullptr; | |||||
| if (!IsWhileLoop(node, enter, next)) { | |||||
| return SUCCESS; | |||||
| } | |||||
| GE_CHECK_NOTNULL(enter); | |||||
| GE_CHECK_NOTNULL(next); | |||||
| // Set continuous attr | |||||
| GE_CHK_STATUS_RET(SetContinuousAttr(enter, {0})); | |||||
| GE_CHK_STATUS_RET(SetContinuousAttr(next, {0})); | |||||
| // Set format agnostic attr | |||||
| (void)AttrUtils::SetInt(enter->GetOpDesc(), ATTR_NAME_FORMAT_AGNOSTIC, 1); | |||||
| (void)AttrUtils::SetInt(next->GetOpDesc(), ATTR_NAME_FORMAT_AGNOSTIC, 1); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MarkAgnosticPass::SetContinuousAttr(const NodePtr &node, const std::vector<uint32_t> &indexes) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| // This flag is for fe performance optimization | |||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NAME_REFRESH_CONTINUOUS_FLAG, true); | |||||
| for (auto index : indexes) { | |||||
| auto out = op_desc->MutableOutputDesc(index); | |||||
| GE_CHECK_NOTNULL(out); | |||||
| // This attr is for out's dtype and format continuous with it's peer input | |||||
| (void)AttrUtils::SetInt(out, ATTR_NAME_FORMAT_CONTINUOUS, 1); | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -22,6 +22,11 @@ namespace ge { | |||||
| class MarkAgnosticPass : public GraphPass { | class MarkAgnosticPass : public GraphPass { | ||||
| public: | public: | ||||
| Status Run(ComputeGraphPtr graph) override; | Status Run(ComputeGraphPtr graph) override; | ||||
| private: | |||||
| bool IsWhileLoop(const NodePtr& node, NodePtr& enter, NodePtr& next); | |||||
| Status HandWhileLoop(const NodePtr& node); | |||||
| Status SetContinuousAttr(const NodePtr& node, const std::vector<uint32_t>& index); | |||||
| }; | }; | ||||
| } | } | ||||
| @@ -109,6 +109,7 @@ Status MultiBatchClonePass::Run(ComputeGraphPtr graph) { | |||||
| GE_CHK_STATUS_RET(CreateSubgraphs(graph, branch), "Construct subgraph failed."); | GE_CHK_STATUS_RET(CreateSubgraphs(graph, branch), "Construct subgraph failed."); | ||||
| GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed"); | GE_CHK_STATUS_RET(PruneDirectOutput(graph), "Prune direct output failed"); | ||||
| GE_CHK_STATUS_RET(UpdateSubgraphOutput(), "Update subgraph output failed"); | |||||
| GELOGD("MultiBatchClonePass Leave"); | GELOGD("MultiBatchClonePass Leave"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1057,8 +1058,6 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const | |||||
| subgraph->SetParentGraph(graph); | subgraph->SetParentGraph(graph); | ||||
| graph->AddSubgraph(subgraph->GetName(), subgraph); | graph->AddSubgraph(subgraph->GetName(), subgraph); | ||||
| all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); | all_branch_output_[subgraph] = subgraph->FindFirstNodeMatchType(NETOUTPUT); | ||||
| GE_CHK_STATUS_RET(UpdateSubgraphOutput(all_branch_output_[subgraph]), | |||||
| "Update %s failed", all_branch_output_[subgraph]->GetName().c_str()); | |||||
| const string key_name = "branches" + std::to_string(i); | const string key_name = "branches" + std::to_string(i); | ||||
| op_desc->AddSubgraphName(key_name); | op_desc->AddSubgraphName(key_name); | ||||
| @@ -1085,21 +1084,22 @@ Status MultiBatchClonePass::CreateSubgraphs(const ComputeGraphPtr &graph, const | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief Update output_node in Subgraph. | /// @brief Update output_node in Subgraph. | ||||
| /// @param [in] const NodePtr &output_node: output_node in Subgraph. | |||||
| /// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
| /// | /// | ||||
| Status MultiBatchClonePass::UpdateSubgraphOutput(const NodePtr &output_node) { | |||||
| const auto &op_desc = output_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) { | |||||
| GeTensorDescPtr tensor = op_desc->MutableInputDesc(index); | |||||
| GE_CHECK_NOTNULL(tensor); | |||||
| if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { | |||||
| GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| Status MultiBatchClonePass::UpdateSubgraphOutput() { | |||||
| for (const auto &item : all_branch_output_) { | |||||
| const auto &output_node = item.second; | |||||
| const auto &op_desc = output_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| for (size_t index = 0; index < op_desc->GetInputsSize(); ++index) { | |||||
| GeTensorDescPtr tensor = op_desc->MutableInputDesc(index); | |||||
| GE_CHECK_NOTNULL(tensor); | |||||
| if (!AttrUtils::SetInt(tensor, ATTR_NAME_PARENT_NODE_INDEX, index)) { | |||||
| GELOGE(FAILED, "Failed to set parent index for node %s", output_node->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -136,10 +136,9 @@ class MultiBatchClonePass : public GraphPass { | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief Update output_node in Subgraph. | /// @brief Update output_node in Subgraph. | ||||
| /// @param [in] const NodePtr &output_node: output_node in Subgraph. | |||||
| /// @return 0: SUCCESS / others: FAILED | /// @return 0: SUCCESS / others: FAILED | ||||
| /// | /// | ||||
| Status UpdateSubgraphOutput(const NodePtr &output_node); | |||||
| Status UpdateSubgraphOutput(); | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| @@ -15,29 +15,50 @@ | |||||
| */ | */ | ||||
| #include "graph/passes/reshape_remove_pass.h" | #include "graph/passes/reshape_remove_pass.h" | ||||
| #include <map> | |||||
| #include <string> | |||||
| #include "framework/common/util.h" | #include "framework/common/util.h" | ||||
| #include "framework/common/types.h" | |||||
| #include "graph/passes/pass_utils.h" | #include "graph/passes/pass_utils.h" | ||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const int kReshapeDataIndex = 0; | const int kReshapeDataIndex = 0; | ||||
| enum OpHashValue { | |||||
| kReshapeType = 0, | |||||
| kReformatType = 1, | |||||
| kOpNoDelete = -1 | |||||
| }; | |||||
| std::map<std::string, OpHashValue> kToBeDeleteOp = { | |||||
| {RESHAPE, kReshapeType}, | |||||
| {REFORMAT, kReformatType} | |||||
| }; | |||||
| } | } | ||||
| Status ReshapeRemovePass::Run(NodePtr &node) { | Status ReshapeRemovePass::Run(NodePtr &node) { | ||||
| GE_CHECK_NOTNULL(node); | GE_CHECK_NOTNULL(node); | ||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | GE_CHECK_NOTNULL(node->GetOpDesc()); | ||||
| if (node->GetType() != RESHAPE && node->GetType() != REFORMAT) { | |||||
| return SUCCESS; | |||||
| } | |||||
| bool is_shape_unknown = false; | |||||
| if (NodeUtils::GetNodeUnknownShapeStatus(*node, is_shape_unknown) == GRAPH_SUCCESS) { | |||||
| if (is_shape_unknown) { | |||||
| GELOGI("op:%s is unknown shape, can not be deleted.", | |||||
| node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| int key = kToBeDeleteOp.find(node->GetType()) == kToBeDeleteOp.end() ? kOpNoDelete : kToBeDeleteOp[node->GetType()]; | |||||
| switch(key) { | |||||
| case kReshapeType: { | |||||
| bool is_shape_unknown = false; | |||||
| if (NodeUtils::GetNodeUnknownShapeStatus(*node, is_shape_unknown) == GRAPH_SUCCESS) { | |||||
| if (is_shape_unknown) { | |||||
| GELOGI("op:%s is unknown shape, can not be deleted.", | |||||
| node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| break; | |||||
| } | } | ||||
| case kReformatType: | |||||
| break; | |||||
| default: | |||||
| return SUCCESS; | |||||
| } | } | ||||
| GELOGI("Remove %s node %s", node->GetType().c_str(), node->GetName().c_str()); | GELOGI("Remove %s node %s", node->GetType().c_str(), node->GetName().c_str()); | ||||
| @@ -460,6 +460,7 @@ Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDat | |||||
| .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(0)) | .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(0)) | ||||
| .Build(); | .Build(); | ||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | ||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NAME_CANNOT_BE_DELETED, true); | |||||
| if (GraphUtils::InsertNodeAfter(out_anchor, in_anchors, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { | if (GraphUtils::InsertNodeAfter(out_anchor, in_anchors, graph->AddNode(op_desc)) != GRAPH_SUCCESS) { | ||||
| GELOGE(FAILED, "Insert IDENTITY node %s after %s failed.", name.c_str(), in_node->GetName().c_str()); | GELOGE(FAILED, "Insert IDENTITY node %s after %s failed.", name.c_str(), in_node->GetName().c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| @@ -967,6 +967,13 @@ Status ParseDynamicInputShapeRange(const std::string &shape_range, | |||||
| // unknown dim, should get range. | // unknown dim, should get range. | ||||
| auto range_left = StringToLongNoThrow(range_pair_set.at(0).c_str()); | auto range_left = StringToLongNoThrow(range_pair_set.at(0).c_str()); | ||||
| auto range_right = StringToLongNoThrow(range_pair_set.at(1).c_str()); | auto range_right = StringToLongNoThrow(range_pair_set.at(1).c_str()); | ||||
| if (range_left < 0 || range_right < 0) { | |||||
| GELOGE(PARAM_INVALID, | |||||
| "Shape range of input is invalid. Given range pair [%ld,%ld], while correct example: " | |||||
| "\"[1~20,3,3~6,-1],[1~20,3,3~6,-1]\"", | |||||
| range_left, range_right); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| range_pair = std::make_pair(range_left, range_right); | range_pair = std::make_pair(range_left, range_right); | ||||
| } else { | } else { | ||||
| GELOGE(PARAM_INVALID, | GELOGE(PARAM_INVALID, | ||||
| @@ -983,22 +990,31 @@ Status ParseDynamicInputShapeRange(const std::string &shape_range, | |||||
| Status GetDynamicInputShapeRange(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option, | Status GetDynamicInputShapeRange(const std::vector<GeTensor> &user_input, const std::map<string, string> &graph_option, | ||||
| vector<vector<std::pair<int64_t, int64_t>>> &range_vec) { | vector<vector<std::pair<int64_t, int64_t>>> &range_vec) { | ||||
| // check both mode and shape_range option are all enabled | |||||
| auto mode_iter = graph_option.find(OPTION_EXEC_DYNAMIC_EXECUTE_MODE); | auto mode_iter = graph_option.find(OPTION_EXEC_DYNAMIC_EXECUTE_MODE); | ||||
| if (mode_iter == graph_option.end()) { | |||||
| GELOGD("Graph Option: Can not find %s option in graph options.", OPTION_EXEC_DYNAMIC_EXECUTE_MODE); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGD("Graph Option: dynamic_input_mode value is %s.", mode_iter->second.c_str()); | |||||
| if (mode_iter->second != "dynamic_execute") { | |||||
| return SUCCESS; | |||||
| bool enable_dynamic_execute_mode = (mode_iter != graph_option.end()) && (mode_iter->second == "dynamic_execute"); | |||||
| if (!enable_dynamic_execute_mode) { | |||||
| GELOGD("Graph Option: Can not find %s option in graph options or option value is empty", | |||||
| OPTION_EXEC_DYNAMIC_EXECUTE_MODE); | |||||
| } | } | ||||
| auto iter = graph_option.find(OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE); | auto iter = graph_option.find(OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE); | ||||
| if (iter == graph_option.end()) { | |||||
| GELOGE(PARAM_INVALID, "Graph option %s is required when %s is dynamic_execute", OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE, | |||||
| OPTION_EXEC_DYNAMIC_EXECUTE_MODE); | |||||
| bool enable_input_shape_range = (iter != graph_option.end()) && (!iter->second.empty()); | |||||
| if (!enable_input_shape_range) { | |||||
| GELOGD("Graph Option: Can not find %s option in graph options or option value is empty", | |||||
| OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE); | |||||
| } | |||||
| if (enable_dynamic_execute_mode && enable_input_shape_range) { | |||||
| GELOGD("GraphOption: %s value is dynamic_execute, %s value is %s.", OPTION_EXEC_DYNAMIC_EXECUTE_MODE, | |||||
| OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE, iter->second.c_str()); | |||||
| } else if (!enable_dynamic_execute_mode && !enable_input_shape_range) { | |||||
| return SUCCESS; | |||||
| } else { | |||||
| GELOGE(PARAM_INVALID, "Graph option: %s and %s should be enabled at the same time.", | |||||
| OPTION_EXEC_DYNAMIC_EXECUTE_MODE, OPTION_EXEC_DATA_INPUTS_SHAPE_RANGE); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| GELOGD("GraphOption: dynamic_inputs_shape_range value is %s.", iter->second.c_str()); | |||||
| auto ret = ParseDynamicInputShapeRange(iter->second, range_vec); | auto ret = ParseDynamicInputShapeRange(iter->second, range_vec); | ||||
| GE_CHK_STATUS_RET(ret, "Parse dynamic input shape range failed."); | GE_CHK_STATUS_RET(ret, "Parse dynamic input shape range failed."); | ||||
| if (range_vec.size() != user_input.size()) { | if (range_vec.size() != user_input.size()) { | ||||
| @@ -15,6 +15,7 @@ | |||||
| */ | */ | ||||
| #include "hybrid_execution_context.h" | #include "hybrid_execution_context.h" | ||||
| #include <atomic> | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| @@ -23,7 +24,14 @@ const uint32_t kEndOfSequence = 0x0704000a; | |||||
| const uint32_t kEndOfSequenceNew = 507005; | const uint32_t kEndOfSequenceNew = 507005; | ||||
| const int32_t kModelAbortNormal = 0x0704000e; | const int32_t kModelAbortNormal = 0x0704000e; | ||||
| const int32_t kModelAbortNormalNew = 507024; | const int32_t kModelAbortNormalNew = 507024; | ||||
| std::atomic_ulong context_id_gen {}; | |||||
| } // namespace | } // namespace | ||||
| GraphExecutionContext::GraphExecutionContext() { | |||||
| context_id = context_id_gen++; | |||||
| } | |||||
| void GraphExecutionContext::SetErrorCode(Status error_code) { | void GraphExecutionContext::SetErrorCode(Status error_code) { | ||||
| std::lock_guard<std::mutex> lk(mu); | std::lock_guard<std::mutex> lk(mu); | ||||
| this->status = error_code; | this->status = error_code; | ||||
| @@ -48,11 +48,15 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| struct GraphExecutionContext { | struct GraphExecutionContext { | ||||
| GraphExecutionContext(); | |||||
| ~GraphExecutionContext() = default; | |||||
| void SetErrorCode(Status error_code); | void SetErrorCode(Status error_code); | ||||
| Status GetStatus() const; | Status GetStatus() const; | ||||
| Status Synchronize(rtStream_t rt_stream); | Status Synchronize(rtStream_t rt_stream); | ||||
| uint64_t session_id = 0; | uint64_t session_id = 0; | ||||
| uint64_t context_id = 0; | |||||
| const HybridModel *model = nullptr; | const HybridModel *model = nullptr; | ||||
| const GEThreadLocalContext *ge_context = nullptr; | const GEThreadLocalContext *ge_context = nullptr; | ||||
| rtStream_t stream = nullptr; | rtStream_t stream = nullptr; | ||||
| @@ -67,6 +71,8 @@ struct GraphExecutionContext { | |||||
| std::atomic_bool is_eos_; | std::atomic_bool is_eos_; | ||||
| long profiling_level = 0; | long profiling_level = 0; | ||||
| long iteration = 0; | long iteration = 0; | ||||
| private: | |||||
| Status status = SUCCESS; | Status status = SUCCESS; | ||||
| mutable std::mutex mu; | mutable std::mutex mu; | ||||
| }; | }; | ||||
| @@ -75,7 +81,8 @@ struct GraphExecutionContext { | |||||
| do { \ | do { \ | ||||
| if ((context != nullptr) && (context)->profiler != nullptr) { \ | if ((context != nullptr) && (context)->profiler != nullptr) { \ | ||||
| if (node_name != nullptr) { \ | if (node_name != nullptr) { \ | ||||
| context->profiler->RecordEvent(evt_type, "tid:%lu [%s] [%s] " fmt, GeLog::GetTid(), node_name, category, \ | |||||
| context->profiler->RecordEvent(evt_type, "tid:%lu [%s@%ld] [%s] " fmt, \ | |||||
| GeLog::GetTid(), node_name, context->iteration, category, \ | |||||
| ##__VA_ARGS__); \ | ##__VA_ARGS__); \ | ||||
| } else { \ | } else { \ | ||||
| context->profiler->RecordEvent(evt_type, "tid:%lu [%s] " fmt, GeLog::GetTid(), category, ##__VA_ARGS__); \ | context->profiler->RecordEvent(evt_type, "tid:%lu [%s] " fmt, GeLog::GetTid(), category, ##__VA_ARGS__); \ | ||||
| @@ -25,6 +25,7 @@ namespace ge { | |||||
| namespace hybrid { | namespace hybrid { | ||||
| namespace { | namespace { | ||||
| const int kDataOutputIndex = 0; | const int kDataOutputIndex = 0; | ||||
| const size_t kMinimumPiplineStages = 2; | |||||
| } | } | ||||
| HybridModelAsyncExecutor::HybridModelAsyncExecutor(HybridModel *model) | HybridModelAsyncExecutor::HybridModelAsyncExecutor(HybridModel *model) | ||||
| : model_(model), run_flag_(false) { | : model_(model), run_flag_(false) { | ||||
| @@ -95,7 +96,17 @@ Status HybridModelAsyncExecutor::Init() { | |||||
| executor_ = std::unique_ptr<HybridModelExecutor>(new(std::nothrow) HybridModelExecutor(model_, device_id_, stream_)); | executor_ = std::unique_ptr<HybridModelExecutor>(new(std::nothrow) HybridModelExecutor(model_, device_id_, stream_)); | ||||
| GE_CHECK_NOTNULL(executor_); | GE_CHECK_NOTNULL(executor_); | ||||
| GE_CHK_STATUS_RET(executor_->Init(), "Failed to init hybrid engine"); | GE_CHK_STATUS_RET(executor_->Init(), "Failed to init hybrid engine"); | ||||
| GELOGI("HybridModel stage nums:%zu", model_->GetRootGraphItem()->NumGroups()); | |||||
| if (model_->GetRootGraphItem()->NumGroups() >= kMinimumPiplineStages) { | |||||
| pipe_executor_ = | |||||
| std::unique_ptr<HybridModelPipelineExecutor>(new(std::nothrow) HybridModelPipelineExecutor(model_, device_id_)); | |||||
| GE_CHECK_NOTNULL(pipe_executor_); | |||||
| GE_CHK_STATUS_RET(pipe_executor_->Init(), "Failed to init hybrid engine"); | |||||
| } | |||||
| GE_CHK_STATUS_RET(InitInputDesc(), "Failed to init input tensors"); | GE_CHK_STATUS_RET(InitInputDesc(), "Failed to init input tensors"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -135,7 +146,18 @@ Status HybridModelAsyncExecutor::RunInternal() { | |||||
| CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); | CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_FMK, JOBSUBSTATE_GRAPH_EXEC); | ||||
| continue, "PreRun failed."); // [No need to check value] | continue, "PreRun failed."); // [No need to check value] | ||||
| ret = executor_->Execute(args); | |||||
| if (pipe_executor_ != nullptr) { | |||||
| GELOGI("HybridModel will execute in pipeline mode"); | |||||
| auto iter_per_run = std::getenv("ITER_NUM"); | |||||
| if (iter_per_run) { | |||||
| args.num_loops = static_cast<int>(strtol(iter_per_run, nullptr, 10)); | |||||
| } | |||||
| ret = pipe_executor_->Execute(args); | |||||
| } else { | |||||
| GELOGI("HybridModel will execute in singleline mode"); | |||||
| ge::GetContext().SetSessionId(executor_->GetContext()->session_id); | |||||
| ret = executor_->Execute(args); | |||||
| } | |||||
| ret = HandleResult(ret, current_data.index, args, data_wrapper->GetOutput()); | ret = HandleResult(ret, current_data.index, args, data_wrapper->GetOutput()); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); | CsaInteract::GetInstance().StoreInternalErrorCode(ret, ERROR_MODULE_RUNTIME, JOBSUBSTATE_GRAPH_EXEC); | ||||
| @@ -219,7 +241,22 @@ Status HybridModelAsyncExecutor::PrepareInputs(const InputData ¤t_data, Hy | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| auto &tensor_desc = input_tensor_desc_[input_index]; | auto &tensor_desc = input_tensor_desc_[input_index]; | ||||
| tensor_desc->SetShape(GeShape(current_data.shapes[input_index])); | |||||
| GeShape shape(current_data.shapes[input_index]); | |||||
| std::vector<std::pair<int64_t, int64_t>> range; | |||||
| auto range_ret = tensor_desc->GetShapeRange(range); | |||||
| GE_CHK_BOOL_RET_STATUS(range_ret == GRAPH_SUCCESS, INTERNAL_ERROR, | |||||
| "Get shape range failed, ret=%u.", range_ret); | |||||
| for (size_t k = 0; k < range.size(); ++k) { | |||||
| if (k >= shape.GetDimNum()) { | |||||
| break; | |||||
| } | |||||
| if (shape.GetDim(k) < range[k].first || shape.GetDim(k) > range[k].second) { | |||||
| GELOGE(PARAM_INVALID, "Dim out of range, shape idx = %zu, dim idx = %zu, dim = %ld, range = [%ld, %ld]", | |||||
| input_index, k, shape.GetDim(k), range[k].first, range[k].second); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| } | |||||
| tensor_desc->SetShape(shape); | |||||
| args.input_desc[input_index] = tensor_desc; | args.input_desc[input_index] = tensor_desc; | ||||
| GELOGD("Update shape of input[%zu] to [%s]", input_index, tensor_desc->MutableShape().ToString().c_str()); | GELOGD("Update shape of input[%zu] to [%s]", input_index, tensor_desc->MutableShape().ToString().c_str()); | ||||
| GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size), | GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorMemorySizeInBytes(*tensor_desc, tensor_size), | ||||
| @@ -23,6 +23,7 @@ | |||||
| #include "external/ge/ge_api_types.h" | #include "external/ge/ge_api_types.h" | ||||
| #include "graph/load/model_manager/data_inputer.h" | #include "graph/load/model_manager/data_inputer.h" | ||||
| #include "hybrid/executor/hybrid_model_executor.h" | #include "hybrid/executor/hybrid_model_executor.h" | ||||
| #include "hybrid/executor/hybrid_model_pipeline_executor.h" | |||||
| #include "runtime/stream.h" | #include "runtime/stream.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -81,6 +82,7 @@ class HybridModelAsyncExecutor { | |||||
| std::atomic_bool run_flag_; | std::atomic_bool run_flag_; | ||||
| std::unique_ptr<DataInputer> data_inputer_; | std::unique_ptr<DataInputer> data_inputer_; | ||||
| std::unique_ptr<HybridModelExecutor> executor_; | std::unique_ptr<HybridModelExecutor> executor_; | ||||
| std::unique_ptr<HybridModelPipelineExecutor> pipe_executor_; | |||||
| std::future<Status> future_; | std::future<Status> future_; | ||||
| uint64_t iterator_count_ = 0; | uint64_t iterator_count_ = 0; | ||||
| @@ -87,7 +87,7 @@ Status HybridModelExecutor::ExecuteGraphInternal(SubgraphExecutor &executor, | |||||
| Status HybridModelExecutor::Cleanup() { | Status HybridModelExecutor::Cleanup() { | ||||
| GELOGD("Start to cleanup."); | GELOGD("Start to cleanup."); | ||||
| context_.callback_manager->Destroy(); | context_.callback_manager->Destroy(); | ||||
| RuntimeInferenceContext::DestroyContext(std::to_string(context_.session_id)); | |||||
| RuntimeInferenceContext::DestroyContext(std::to_string(context_.context_id)); | |||||
| GELOGD("Cleanup successfully."); | GELOGD("Cleanup successfully."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -105,7 +105,7 @@ Status HybridModelExecutor::InitExecutionContext() { | |||||
| GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id); | GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id); | ||||
| context_.allocator = NpuMemoryAllocator::GetAllocator(device_id_); | context_.allocator = NpuMemoryAllocator::GetAllocator(device_id_); | ||||
| GE_CHECK_NOTNULL(context_.allocator); | GE_CHECK_NOTNULL(context_.allocator); | ||||
| context_.callback_manager = std::unique_ptr<CallbackManager>(new(std::nothrow)CallbackManager(stream_)); | |||||
| context_.callback_manager = std::unique_ptr<CallbackManager>(new(std::nothrow)CallbackManager()); | |||||
| GE_CHECK_NOTNULL(context_.callback_manager); | GE_CHECK_NOTNULL(context_.callback_manager); | ||||
| context_.dump_properties = PropertiesManager::Instance().GetDumpProperties(context_.session_id); | context_.dump_properties = PropertiesManager::Instance().GetDumpProperties(context_.session_id); | ||||
| const char *profiling_level = std::getenv(kEnvProfilingLevel); | const char *profiling_level = std::getenv(kEnvProfilingLevel); | ||||
| @@ -126,7 +126,7 @@ Status HybridModelExecutor::InitExecutionContext() { | |||||
| Status HybridModelExecutor::ResetExecutionContext(GraphExecutionContext &context) { | Status HybridModelExecutor::ResetExecutionContext(GraphExecutionContext &context) { | ||||
| GE_CHK_STATUS_RET_NOLOG(context.callback_manager->Init()); | GE_CHK_STATUS_RET_NOLOG(context.callback_manager->Init()); | ||||
| string ctx_id = std::to_string(context.session_id); | |||||
| string ctx_id = std::to_string(context.context_id); | |||||
| RuntimeInferenceContext::DestroyContext(ctx_id); | RuntimeInferenceContext::DestroyContext(ctx_id); | ||||
| GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::CreateContext(ctx_id), "Failed to Destroy RuntimeInferenceContext"); | GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::CreateContext(ctx_id), "Failed to Destroy RuntimeInferenceContext"); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -32,6 +32,7 @@ class HybridModelExecutor { | |||||
| std::vector<TensorValue> outputs; | std::vector<TensorValue> outputs; | ||||
| std::vector<ConstGeTensorDescPtr> output_desc; | std::vector<ConstGeTensorDescPtr> output_desc; | ||||
| bool is_eos = false; | bool is_eos = false; | ||||
| int num_loops = 10; | |||||
| }; | }; | ||||
| HybridModelExecutor(HybridModel *model, uint32_t device_id, rtStream_t stream); | HybridModelExecutor(HybridModel *model, uint32_t device_id, rtStream_t stream); | ||||
| @@ -0,0 +1,284 @@ | |||||
| #include "hybrid_model_pipeline_executor.h" | |||||
| #include "common/math/math_util.h" | |||||
| #include "graph/ge_context.h" | |||||
| #include "graph/runtime_inference_context.h" | |||||
| namespace ge { | |||||
| namespace hybrid { | |||||
| namespace { | |||||
| constexpr int kNumExecutors = 2; | |||||
| const int kIntBase = 10; | |||||
| const char *const kEnvProfilingLevel = "HYBRID_PROFILING_LEVEL"; | |||||
| } | |||||
| StageExecutor::StageExecutor(int id, HybridModel *model, PipeExecutionConfig *config) | |||||
| : id_(id), model_(model), pipe_config_(config) {} | |||||
| StageExecutor::~StageExecutor() { GELOGD("~StageExecutor(), id = %d", id_); } | |||||
| Status StageExecutor::Init() { | |||||
| GELOGD("[Executor: %d] Start to init StateExecutor", id_); | |||||
| context_.rt_context = pipe_config_->rt_context; | |||||
| GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | |||||
| GE_CHK_RT_RET(rtStreamCreate(&stream_, RT_STREAM_PRIORITY_DEFAULT)); | |||||
| context_.stream = stream_; | |||||
| root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_)); | |||||
| GE_CHECK_NOTNULL(root_graph_executor_); | |||||
| GELOGD("[Executor: %d] Init stage executor successfully", id_); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StageExecutor::ResetExecutionContext(GraphExecutionContext &context) { | |||||
| GE_CHK_STATUS_RET_NOLOG(context.callback_manager->Init()); | |||||
| string ctx_id = std::to_string(context.context_id); | |||||
| RuntimeInferenceContext::DestroyContext(ctx_id); | |||||
| GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::CreateContext(ctx_id), "Failed to Destroy RuntimeInferenceContext"); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StageExecutor::Start(const std::vector<TensorValue> &inputs, const std::vector<ConstGeTensorDescPtr> &input_desc, | |||||
| int iteration_count) { | |||||
| GELOGD("Start"); | |||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | |||||
| int num_loops = iteration_count / pipe_config_->num_executors; | |||||
| if (id_ < iteration_count % iteration_count) { | |||||
| num_loops += 1; | |||||
| } | |||||
| FMK_INT32_MULCHECK(num_loops, pipe_config_->num_stages); | |||||
| num_loops *= pipe_config_->num_stages; | |||||
| GELOGD("[Executor: %d] loop count = %d", id_, num_loops); | |||||
| for (int loop_idx = 0; loop_idx < num_loops; ++loop_idx) { | |||||
| GELOGD("[Executor: %d] Start to wait for task.", id_); | |||||
| StageTask task_info; | |||||
| task_queue_.Pop(task_info); | |||||
| GELOGD("[Executor: %d] Got task, stage = %d, iteration = %ld", id_, task_info.stage, task_info.iteration); | |||||
| if (task_info.iteration >= pipe_config_->iteration_end) { | |||||
| GELOGE(INTERNAL_ERROR, "[Executor: %d] Unexpected iteration: %d", id_, task_info.iteration); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| if (task_info.event != nullptr) { | |||||
| GELOGD("[%d] Add StreamWaitEvent", id_); | |||||
| GE_CHK_RT_RET(rtStreamWaitEvent(stream_, task_info.event)); | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %d] [Stage = %d] End", task_info.iteration - 1, | |||||
| task_info.stage); | |||||
| } | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %d] [Stage = %d] Start", task_info.iteration, | |||||
| task_info.stage); | |||||
| if (task_info.stage == 0) { | |||||
| GELOGD("[Executor: %d] To ResetExecutionContext", id_); | |||||
| GE_CHK_STATUS_RET(ResetExecutionContext(context_), "[Executor: %d] Failed to reset context", id_); | |||||
| context_.iteration = task_info.iteration; | |||||
| GE_CHK_STATUS_RET_NOLOG(SetInputs(inputs, input_desc)); | |||||
| } | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[Stage = %d] PartialExecuteAsync Start", task_info.stage); | |||||
| GE_CHK_STATUS_RET(root_graph_executor_->PartialExecuteAsync(task_info.stage)); | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[Stage = %d] PartialExecuteAsync End", task_info.stage); | |||||
| GELOGD("[Executor: %d] PartialExecuteAsync successfully.", id_); | |||||
| // notify next execution unit | |||||
| StageTask next_task; | |||||
| next_task.stage = task_info.stage; | |||||
| next_task.iteration = task_info.iteration + 1; | |||||
| auto sync_result = Synchronize(); | |||||
| if (sync_result != SUCCESS) { | |||||
| GELOGE(sync_result, "[Executor: %d] Failed to sync result. iteration = %d", id_, task_info.iteration); | |||||
| context_.profiler->Dump(std::cout); | |||||
| context_.callback_manager->Destroy(); | |||||
| RuntimeInferenceContext::DestroyContext(std::to_string(context_.context_id)); | |||||
| return sync_result; | |||||
| } | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %d] [Stage = %d] End", task_info.iteration, task_info.stage); | |||||
| // if not end stage | |||||
| if (task_info.stage >= pipe_config_->num_stages - 1) { | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %d] Schedule End", task_info.iteration); | |||||
| GELOGD("[Executor: %d] End of iteration [%ld]", id_, task_info.iteration); | |||||
| context_.callback_manager->Destroy(); | |||||
| RuntimeInferenceContext::DestroyContext(std::to_string(context_.context_id)); | |||||
| } | |||||
| next_executor_->ExecuteAsync(next_task); | |||||
| GELOGD("[Executor: %d] Push item successfully.", id_); | |||||
| } | |||||
| GELOGD("[Executor: %d] Process task ended.", id_); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StageExecutor::ExecuteAsync(const StageTask &args) { | |||||
| (void)task_queue_.Push(args); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StageExecutor::Synchronize() { | |||||
| auto ret = root_graph_executor_->Synchronize(); | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[Synchronize] End, ret = %u", ret); | |||||
| return ret; | |||||
| } | |||||
| HybridModelPipelineExecutor::HybridModelPipelineExecutor(HybridModel *model, uint32_t device_id) | |||||
| : model_(model), device_id_(device_id) { | |||||
| config_.num_executors = kNumExecutors; | |||||
| config_.num_stages = model_->GetRootGraphItem()->NumGroups(); | |||||
| config_.device_id = device_id_; | |||||
| } | |||||
| Status StageExecutor::InitExecutionContext() { | |||||
| GE_CHK_RT_RET(rtCtxCreate(&context_.rt_gen_context, RT_CTX_GEN_MODE, 0)); | |||||
| GE_CHK_RT_RET(rtCtxSetCurrent(context_.rt_context)); | |||||
| context_.model = model_; | |||||
| context_.session_id = ::ge::GetContext().SessionId(); | |||||
| GELOGD("session id from model = %lu, from context = %lu", model_->GetSessionId(), context_.session_id); | |||||
| context_.allocator = NpuMemoryAllocator::GetAllocator(pipe_config_->device_id); | |||||
| GE_CHECK_NOTNULL(context_.allocator); | |||||
| context_.callback_manager = std::unique_ptr<CallbackManager>(new (std::nothrow) CallbackManager()); | |||||
| GE_CHECK_NOTNULL(context_.callback_manager); | |||||
| context_.dump_properties = PropertiesManager::Instance().GetDumpProperties(context_.session_id); | |||||
| if (IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) { | |||||
| context_.trace_enabled = true; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StageExecutor::SetInputs(const vector<TensorValue> &inputs, const vector<ConstGeTensorDescPtr> &input_desc) { | |||||
| root_graph_executor_->InitForPartialExecution(inputs, input_desc); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status StageExecutor::GetOutputs(vector<TensorValue> &outputs, vector<ConstGeTensorDescPtr> &output_desc) { | |||||
| return root_graph_executor_->GetOutputs(outputs, output_desc); | |||||
| } | |||||
| void StageExecutor::Reset() { | |||||
| task_queue_.Stop(); | |||||
| task_queue_.Clear(); | |||||
| task_queue_.Restart(); | |||||
| } | |||||
| Status HybridModelPipelineExecutor::Init() { | |||||
| const char *profiling_level = std::getenv(kEnvProfilingLevel); | |||||
| if (profiling_level != nullptr) { | |||||
| context_.profiling_level = std::strtol(profiling_level, nullptr, kIntBase); | |||||
| GELOGD("Got profiling level = %ld", context_.profiling_level); | |||||
| if (context_.profiling_level > 0) { | |||||
| context_.profiler.reset(new (std::nothrow) HybridProfiler()); | |||||
| GE_CHECK_NOTNULL(context_.profiler); | |||||
| } | |||||
| } | |||||
| GELOGD("Number of stages = %d, number of executors = %d", config_.num_stages, config_.num_executors); | |||||
| GE_CHK_RT_RET(rtCtxGetCurrent(&config_.rt_context)); | |||||
| GE_CHK_STATUS_RET_NOLOG(InitStageExecutors()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelPipelineExecutor::InitStageExecutors() { | |||||
| for (int i = 0; i < config_.num_executors; ++i) { | |||||
| auto stage_executor = std::unique_ptr<StageExecutor>(new (std::nothrow) StageExecutor(i, model_, &config_)); | |||||
| GE_CHECK_NOTNULL(stage_executor); | |||||
| GE_CHK_STATUS_RET_NOLOG(stage_executor->Init()); | |||||
| if (context_.profiler != nullptr) { | |||||
| // will call unique_ptr::release later | |||||
| stage_executor->context_.profiler.reset(context_.profiler.get()); | |||||
| stage_executor->context_.profiling_level = context_.profiling_level; | |||||
| } | |||||
| stage_executors_.emplace_back(std::move(stage_executor)); | |||||
| } | |||||
| // build propagation loop | |||||
| for (int i = 0; i < config_.num_executors - 1; ++i) { | |||||
| stage_executors_[i]->SetNext(stage_executors_[i + 1].get()); | |||||
| } | |||||
| stage_executors_[config_.num_executors - 1]->SetNext(stage_executors_[0].get()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelPipelineExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
| int loop_count = args.num_loops; | |||||
| GE_CHECK_GE(loop_count, 2); | |||||
| auto &inputs = args.inputs; | |||||
| auto &input_desc = args.input_desc; | |||||
| // Start schedulers | |||||
| std::vector<std::future<Status>> futures; | |||||
| for (size_t i = 0; i < stage_executors_.size(); ++i) { | |||||
| GELOGD("Starting executor %zu", i); | |||||
| auto executor = stage_executors_[i].get(); | |||||
| executor->Reset(); | |||||
| auto future = std::async( | |||||
| [loop_count, executor, inputs, input_desc]() { return executor->Start(inputs, input_desc, loop_count); }); | |||||
| futures.emplace_back(std::move(future)); | |||||
| } | |||||
| // Push initial tasks | |||||
| GELOGD("Start to execute with loops, loop count = %d", loop_count); | |||||
| config_.iteration_end = iteration_ + loop_count; | |||||
| for (int i = 0; i < config_.num_stages; ++i) { | |||||
| StageExecutor::StageTask task_info; | |||||
| task_info.stage = i; | |||||
| task_info.iteration = iteration_; | |||||
| stage_executors_[0]->ExecuteAsync(task_info); | |||||
| } | |||||
| // Wait for end of iterations | |||||
| bool has_error = false; | |||||
| for (size_t i = 0; i < stage_executors_.size(); ++i) { | |||||
| GELOGD("Start to sync result of executor[%zu]", i); | |||||
| auto ret = futures[i].get(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Executor: %zu] Failed to schedule tasks.", i); | |||||
| has_error = true; | |||||
| continue; | |||||
| } | |||||
| ret = stage_executors_[i]->Synchronize(); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "[Executor: %zu] Failed to synchronize result.", i); | |||||
| has_error = true; | |||||
| continue; | |||||
| } | |||||
| } | |||||
| // record for profiling analyzer | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[Cleanup] End"); | |||||
| if (context_.profiler != nullptr) { | |||||
| context_.profiler->Dump(std::cout); | |||||
| } | |||||
| iteration_ = config_.iteration_end; | |||||
| if (has_error) { | |||||
| GELOGE(FAILED, "Error occurred while execution"); | |||||
| return FAILED; | |||||
| } | |||||
| auto last_iter_executor_idx = loop_count % stage_executors_.size(); | |||||
| GE_CHK_STATUS_RET(stage_executors_[last_iter_executor_idx]->GetOutputs(args.outputs, args.output_desc), | |||||
| "Failed to get output from executor[%zu]", last_iter_executor_idx); | |||||
| return SUCCESS; | |||||
| } | |||||
| HybridModelPipelineExecutor::~HybridModelPipelineExecutor() { | |||||
| GELOGD("~HybridModelPipelineExecutor()"); | |||||
| for (auto &executor : stage_executors_) { | |||||
| (void)executor->context_.profiler.release(); | |||||
| } | |||||
| } | |||||
| } // namespace hybrid | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,88 @@ | |||||
| #ifndef GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_ | |||||
| #define GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_ | |||||
| #include "common/blocking_queue.h" | |||||
| #include "common/thread_pool.h" | |||||
| #include "hybrid/executor/hybrid_execution_context.h" | |||||
| #include "hybrid/executor/rt_callback_manager.h" | |||||
| #include "hybrid/executor/subgraph_executor.h" | |||||
| #include "hybrid_model_executor.h" | |||||
| namespace ge { | |||||
| namespace hybrid { | |||||
| struct PipeExecutionConfig { | |||||
| uint32_t device_id; | |||||
| rtContext_t rt_context; | |||||
| int num_executors; | |||||
| int num_stages; | |||||
| long iteration_end; | |||||
| }; | |||||
| class StageExecutor { | |||||
| public: | |||||
| struct StageTask { | |||||
| rtEvent_t event = nullptr; | |||||
| int stage = 0; | |||||
| long iteration = 0; | |||||
| }; | |||||
| StageExecutor(int id, HybridModel *model, PipeExecutionConfig *config); | |||||
| ~StageExecutor(); | |||||
| Status Init(); | |||||
| void Reset(); | |||||
| Status Start(const std::vector<TensorValue> &inputs, const std::vector<ConstGeTensorDescPtr> &input_desc, | |||||
| int loop_count); | |||||
| Status SetInputs(const std::vector<TensorValue> &inputs, const std::vector<ConstGeTensorDescPtr> &input_desc); | |||||
| Status ExecuteAsync(const StageTask &args); | |||||
| Status GetOutputs(std::vector<TensorValue> &outputs, std::vector<ConstGeTensorDescPtr> &output_desc); | |||||
| Status Synchronize(); | |||||
| void SetNext(StageExecutor *next_executor) { next_executor_ = next_executor; } | |||||
| private: | |||||
| friend class HybridModelPipelineExecutor; | |||||
| static Status ResetExecutionContext(GraphExecutionContext &context); | |||||
| Status InitExecutionContext(); | |||||
| int id_; | |||||
| HybridModel *model_; | |||||
| PipeExecutionConfig *pipe_config_; | |||||
| BlockingQueue<StageTask> task_queue_; | |||||
| std::unique_ptr<SubgraphExecutor> root_graph_executor_; | |||||
| GraphExecutionContext context_; | |||||
| StageExecutor *next_executor_; | |||||
| rtStream_t stream_ = nullptr; | |||||
| }; | |||||
| class HybridModelPipelineExecutor { | |||||
| public: | |||||
| HybridModelPipelineExecutor(HybridModel *model, uint32_t device_id); | |||||
| ~HybridModelPipelineExecutor(); | |||||
| Status Init(); | |||||
| Status InitStageExecutors(); | |||||
| Status Execute(HybridModelExecutor::ExecuteArgs &args); | |||||
| private: | |||||
| HybridModel *model_; | |||||
| uint32_t device_id_; | |||||
| std::vector<std::unique_ptr<StageExecutor>> stage_executors_; | |||||
| PipeExecutionConfig config_; | |||||
| GraphExecutionContext context_; | |||||
| long iteration_ = 0; | |||||
| }; | |||||
| } // namespace hybrid | |||||
| } // namespace ge | |||||
| #endif // GE_HYBRID_EXECUTOR_HYBRID_MODEL_PIPELINE_EXECUTOR_H_ | |||||
| @@ -24,7 +24,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| namespace { | namespace { | ||||
| const int kMaxEvents = 10000; | |||||
| const int kMaxEvents = 1024 * 500; | |||||
| const int kEventDescMax = 512; | const int kEventDescMax = 512; | ||||
| const int kMaxEventTypes = 8; | const int kMaxEventTypes = 8; | ||||
| const int kIndent = 8; | const int kIndent = 8; | ||||
| @@ -46,11 +46,14 @@ void HybridProfiler::RecordEvent(EventType event_type, const char *fmt, ...) { | |||||
| } | } | ||||
| va_end(args); | va_end(args); | ||||
| std::string event = buf; | |||||
| auto index = counter_++; | auto index = counter_++; | ||||
| if (index >= static_cast<int>(events_.size())) { | |||||
| GELOGE(INTERNAL_ERROR, "index out of range. index = %d, max event size = %zu", index, events_.size()); | |||||
| return; | |||||
| } | |||||
| auto &evt = events_[index]; | auto &evt = events_[index]; | ||||
| evt.timestamp = std::chrono::system_clock::now(); | evt.timestamp = std::chrono::system_clock::now(); | ||||
| evt.desc = std::move(event); | |||||
| evt.desc = std::string(buf); | |||||
| evt.event_type = event_type; | evt.event_type = event_type; | ||||
| } | } | ||||
| @@ -78,7 +81,7 @@ void HybridProfiler::Dump(std::ostream &output_stream) { | |||||
| auto cost_dump = std::chrono::duration_cast<std::chrono::microseconds>(end_dump - start_dump).count(); | auto cost_dump = std::chrono::duration_cast<std::chrono::microseconds>(end_dump - start_dump).count(); | ||||
| output_stream << std::setw(kIndent) << elapsed_dump << "\t\t" << cost_dump | output_stream << std::setw(kIndent) << elapsed_dump << "\t\t" << cost_dump | ||||
| << "\t\t" << "[Dump profiling]" << std::endl; | << "\t\t" << "[Dump profiling]" << std::endl; | ||||
| events_.clear(); | |||||
| Reset(); | |||||
| } | } | ||||
| void HybridProfiler::Reset() { | void HybridProfiler::Reset() { | ||||
| @@ -34,6 +34,14 @@ ShapeInferenceState::ShapeInferenceState(const NodeItem &node_item) : node_item( | |||||
| GELOGD("[%s] ShapeInferenceState created, pending shape count = %d", | GELOGD("[%s] ShapeInferenceState created, pending shape count = %d", | ||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| this->num_pending_shapes_); | this->num_pending_shapes_); | ||||
| for (int i = 0; i < node_item.num_inputs; ++i){ | |||||
| input_tensor_desc.emplace_back(*node_item.MutableInputDesc(i)); | |||||
| } | |||||
| for (int i = 0; i < node_item.num_outputs; ++i){ | |||||
| output_tensor_desc.emplace_back(*node_item.MutableOutputDesc(i)); | |||||
| } | |||||
| } | } | ||||
| Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) { | Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target) { | ||||
| @@ -56,11 +64,10 @@ Status ShapeInferenceState::UpdateInputShape(int idx, const GeTensorDesc &target | |||||
| tensor_size); | tensor_size); | ||||
| std::lock_guard<std::mutex> lk(mu_); | std::lock_guard<std::mutex> lk(mu_); | ||||
| auto tensor_desc = node_item.MutableInputDesc(idx); | |||||
| GE_CHECK_NOTNULL(tensor_desc); | |||||
| tensor_desc->SetShape(target.GetShape()); | |||||
| tensor_desc->SetOriginShape(target.GetOriginShape()); | |||||
| (void) TensorUtils::SetSize(*tensor_desc, tensor_size); | |||||
| auto &input_desc = input_tensor_desc[idx]; | |||||
| input_desc.SetShape(target.GetShape()); | |||||
| input_desc.SetOriginShape(target.GetOriginShape()); | |||||
| (void) TensorUtils::SetSize(input_desc, tensor_size); | |||||
| if (--num_pending_shapes_ <= 0) { | if (--num_pending_shapes_ <= 0) { | ||||
| ready_cv_.notify_all(); | ready_cv_.notify_all(); | ||||
| } | } | ||||
| @@ -115,12 +122,27 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex | |||||
| } | } | ||||
| } | } | ||||
| for (size_t i = 0; i < input_tensor_desc.size(); ++i) { | |||||
| auto dst_tensor_desc = node_item.op_desc->MutableInputDesc(i); | |||||
| if (dst_tensor_desc == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto &tensor_desc = input_tensor_desc[i]; | |||||
| int64_t tensor_size = -1; | |||||
| (void) TensorUtils::GetSize(tensor_desc, tensor_size); | |||||
| dst_tensor_desc->SetShape(tensor_desc.MutableShape()); | |||||
| dst_tensor_desc->SetOriginShape(tensor_desc.GetOriginShape()); | |||||
| (void) TensorUtils::SetSize(*dst_tensor_desc, tensor_size); | |||||
| } | |||||
| for (auto &p : shape_futures) { | for (auto &p : shape_futures) { | ||||
| auto idx = p.first; | auto idx = p.first; | ||||
| auto &future = p.second; | auto &future = p.second; | ||||
| RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx); | RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] Start", idx); | ||||
| GeTensorDescPtr src_tensor_desc; | |||||
| GE_CHK_STATUS_RET_NOLOG(future.GetTensorDesc(src_tensor_desc)); | |||||
| const GeTensorDesc* src_tensor_desc = nullptr; | |||||
| GE_CHK_STATUS_RET_NOLOG(future.GetTensorDesc(&src_tensor_desc)); | |||||
| GE_CHECK_NOTNULL(src_tensor_desc); | GE_CHECK_NOTNULL(src_tensor_desc); | ||||
| RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); | RECORD_SHAPE_INFERENCE_EVENT(&context, node_item.NodeName().c_str(), "[AwaitShape] [idx = %u] End", idx); | ||||
| @@ -142,10 +164,28 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| ShapeFuture::ShapeFuture(NodePtr src_node, | |||||
| const vector<GeTensorDesc> &ShapeInferenceState::GetOutputTensorDesc() const { | |||||
| return output_tensor_desc; | |||||
| } | |||||
| Status ShapeInferenceState::UpdateOutputDesc() { | |||||
| for (size_t i = 0; i < output_tensor_desc.size(); ++i) { | |||||
| auto src_tensor_desc = node_item.MutableOutputDesc(i); | |||||
| GE_CHECK_NOTNULL(src_tensor_desc); | |||||
| auto &dst_tensor_desc = output_tensor_desc[i]; | |||||
| dst_tensor_desc.SetShape(src_tensor_desc->MutableShape()); | |||||
| dst_tensor_desc.SetOriginShape(src_tensor_desc->GetOriginShape()); | |||||
| int64_t tensor_size = -1; | |||||
| (void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); | |||||
| (void) TensorUtils::SetSize(dst_tensor_desc, tensor_size); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| ShapeFuture::ShapeFuture(NodeState *src_node, | |||||
| uint32_t src_index, | uint32_t src_index, | ||||
| SubgraphContext *subgraph_context) | SubgraphContext *subgraph_context) | ||||
| : src_node_(std::move(src_node)), src_index_(src_index), subgraph_context_(subgraph_context) { | |||||
| : src_node_(src_node), src_index_(src_index), subgraph_context_(subgraph_context) { | |||||
| } | } | ||||
| NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context) | NodeState::NodeState(const NodeItem &node_item, SubgraphContext *subgraph_context) | ||||
| @@ -187,6 +227,13 @@ Status NodeState::WaitForPrepareDone() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status NodeState::UpdateOutputShapes(int index, const GeShape &shape, const GeShape &ori_shape) { | |||||
| auto self_tensor_desc = op_desc_->MutableOutputDesc(index); | |||||
| GE_CHECK_NOTNULL(self_tensor_desc); | |||||
| self_tensor_desc->SetShape(shape); | |||||
| self_tensor_desc->SetOriginShape(ori_shape); | |||||
| return SUCCESS; | |||||
| } | |||||
| void NodeState::SetTaskContext(std::shared_ptr<TaskContext> &task_context) { | void NodeState::SetTaskContext(std::shared_ptr<TaskContext> &task_context) { | ||||
| task_context_ = task_context; | task_context_ = task_context; | ||||
| @@ -198,17 +245,19 @@ std::shared_ptr<TaskContext> NodeState::GetTaskContext() { | |||||
| Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) { | Status ShapeFuture::Get(GeShape &ori_shape, GeShape &shape) { | ||||
| GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); | GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); | ||||
| HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_), "cancelled"); | |||||
| shape = src_node_->GetOpDesc()->MutableOutputDesc(src_index_)->MutableShape(); | |||||
| ori_shape = src_node_->GetOpDesc()->MutableOutputDesc(src_index_)->GetOriginShape(); | |||||
| HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_->GetNodeItem()->node), "cancelled"); | |||||
| auto &output_desc = src_node_->GetShapeInferenceState().GetOutputTensorDesc().at(src_index_); | |||||
| shape = output_desc.GetShape(); | |||||
| ori_shape = output_desc.GetOriginShape(); | |||||
| GELOGD("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str()); | GELOGD("Get shape from %s:%u. shape = [%s]", src_node_->GetName().c_str(), src_index_, shape.ToString().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ShapeFuture::GetTensorDesc(GeTensorDescPtr &tensor_desc) { | |||||
| Status ShapeFuture::GetTensorDesc(const GeTensorDesc **tensor_desc) { | |||||
| GE_CHECK_NOTNULL(tensor_desc); | |||||
| GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); | GELOGD("Start to wait node: %s for getting shape", src_node_->GetName().c_str()); | ||||
| HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_), "cancelled"); | |||||
| tensor_desc = src_node_->GetOpDesc()->MutableOutputDesc(src_index_); | |||||
| HYBRID_CHK_STATUS_RET(subgraph_context_->Await(src_node_->GetNodeItem()->node), "cancelled"); | |||||
| *tensor_desc = &src_node_->GetShapeInferenceState().GetOutputTensorDesc().at(src_index_); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| @@ -30,16 +30,17 @@ class NodeTask; | |||||
| struct GraphExecutionContext; | struct GraphExecutionContext; | ||||
| class SubgraphContext; | class SubgraphContext; | ||||
| class TaskContext; | class TaskContext; | ||||
| class NodeState; | |||||
| class ShapeFuture { | class ShapeFuture { | ||||
| public: | public: | ||||
| ShapeFuture(NodePtr src_node, uint32_t src_index, SubgraphContext *subgraph_context); | |||||
| ShapeFuture(NodeState *src_node, uint32_t src_index, SubgraphContext *subgraph_context); | |||||
| ~ShapeFuture() = default; | ~ShapeFuture() = default; | ||||
| Status Get(GeShape &ori_shape, GeShape &shape); | Status Get(GeShape &ori_shape, GeShape &shape); | ||||
| Status GetTensorDesc(GeTensorDescPtr &tensor_desc); | |||||
| Status GetTensorDesc(const GeTensorDesc **tensor_desc); | |||||
| private: | private: | ||||
| NodePtr src_node_; | |||||
| NodeState *src_node_; | |||||
| uint32_t src_index_; | uint32_t src_index_; | ||||
| SubgraphContext *subgraph_context_; | SubgraphContext *subgraph_context_; | ||||
| }; | }; | ||||
| @@ -53,10 +54,19 @@ struct ShapeInferenceState { | |||||
| Status AwaitShapesReady(const GraphExecutionContext &context); | Status AwaitShapesReady(const GraphExecutionContext &context); | ||||
| Status UpdateOutputDesc(); | |||||
| const vector<GeTensorDesc> &GetOutputTensorDesc() const; | |||||
| const NodeItem &node_item; | const NodeItem &node_item; | ||||
| private: | private: | ||||
| friend struct NodeState; | |||||
| std::vector<std::pair<int, ShapeFuture>> shape_futures; | std::vector<std::pair<int, ShapeFuture>> shape_futures; | ||||
| // do not directly update op_desc, in case race condition across pipelines | |||||
| std::vector<GeTensorDesc> input_tensor_desc; | |||||
| std::vector<GeTensorDesc> output_tensor_desc; | |||||
| int num_pending_shapes_ = 0; | int num_pending_shapes_ = 0; | ||||
| std::condition_variable ready_cv_; | std::condition_variable ready_cv_; | ||||
| std::mutex mu_; | std::mutex mu_; | ||||
| @@ -88,6 +98,8 @@ struct NodeState { | |||||
| return shape_inference_state_; | return shape_inference_state_; | ||||
| } | } | ||||
| Status UpdateOutputShapes(int index, const GeShape &shape, const GeShape &ori_shape); | |||||
| const shared_ptr<NodeTask> &GetKernelTask() const { | const shared_ptr<NodeTask> &GetKernelTask() const { | ||||
| return kernel_task_; | return kernel_task_; | ||||
| } | } | ||||
| @@ -21,14 +21,11 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| CallbackManager::CallbackManager(rtStream_t stream) : stream_(stream) { | |||||
| } | |||||
| Status CallbackManager::RegisterCallback(rtCallback_t callback, void *user_data) { | |||||
| Status CallbackManager::RegisterCallback(rtStream_t stream, rtCallback_t callback, void *user_data) { | |||||
| GELOGD("To register callback"); | GELOGD("To register callback"); | ||||
| rtEvent_t event = nullptr; | rtEvent_t event = nullptr; | ||||
| GE_CHK_RT_RET(rtEventCreate(&event)); | GE_CHK_RT_RET(rtEventCreate(&event)); | ||||
| auto rt_ret = rtEventRecord(event, stream_); | |||||
| auto rt_ret = rtEventRecord(event, stream); | |||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Failed to invoke rtEventRecord, error code = %d", rt_ret); | GELOGE(RT_FAILED, "Failed to invoke rtEventRecord, error code = %d", rt_ret); | ||||
| (void) rtEventDestroy(event); | (void) rtEventDestroy(event); | ||||
| @@ -112,11 +109,11 @@ void CallbackManager::RtCallbackFunc(void *data) { | |||||
| delete callback_func; | delete callback_func; | ||||
| } | } | ||||
| Status CallbackManager::RegisterCallback(const std::function<void()> &callback) { | |||||
| Status CallbackManager::RegisterCallback(rtStream_t stream, const std::function<void()> &callback) { | |||||
| auto func = std::unique_ptr<std::function<void()>>(new(std::nothrow) std::function<void()>(callback)); | auto func = std::unique_ptr<std::function<void()>>(new(std::nothrow) std::function<void()>(callback)); | ||||
| GE_CHECK_NOTNULL(func); | GE_CHECK_NOTNULL(func); | ||||
| GELOGD("Callback registered"); | GELOGD("Callback registered"); | ||||
| return RegisterCallback(RtCallbackFunc, func.release()); | |||||
| return RegisterCallback(stream, RtCallbackFunc, func.release()); | |||||
| } | } | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -30,23 +30,21 @@ namespace ge { | |||||
| namespace hybrid { | namespace hybrid { | ||||
| class CallbackManager { | class CallbackManager { | ||||
| public: | public: | ||||
| explicit CallbackManager(rtStream_t stream); | |||||
| CallbackManager() = default; | |||||
| ~CallbackManager() = default; | ~CallbackManager() = default; | ||||
| Status Init(); | Status Init(); | ||||
| Status Destroy(); | Status Destroy(); | ||||
| Status RegisterCallback(rtCallback_t callback, void *user_data); | |||||
| Status RegisterCallback(const std::function<void()> &callback); | |||||
| Status RegisterCallback(rtStream_t stream, rtCallback_t callback, void *user_data); | |||||
| Status RegisterCallback(rtStream_t stream, const std::function<void()> &callback); | |||||
| private: | private: | ||||
| Status CallbackProcess(rtContext_t context); | Status CallbackProcess(rtContext_t context); | ||||
| static void RtCallbackFunc(void *data); | static void RtCallbackFunc(void *data); | ||||
| BlockingQueue<std::pair<rtEvent_t, std::pair<rtCallback_t, void *>>> callback_queue_; | BlockingQueue<std::pair<rtEvent_t, std::pair<rtCallback_t, void *>>> callback_queue_; | ||||
| rtStream_t stream_; | |||||
| std::future<Status> ret_future_; | std::future<Status> ret_future_; | ||||
| }; | }; | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| @@ -24,6 +24,7 @@ namespace ge { | |||||
| namespace hybrid { | namespace hybrid { | ||||
| namespace { | namespace { | ||||
| constexpr int kDefaultThreadNum = 4; | constexpr int kDefaultThreadNum = 4; | ||||
| constexpr int kDefaultQueueSize = 16; | |||||
| constexpr int kDataInputIndex = 0; | constexpr int kDataInputIndex = 0; | ||||
| } | } | ||||
| @@ -31,7 +32,8 @@ SubgraphExecutor::SubgraphExecutor(const GraphItem *graph_item, GraphExecutionCo | |||||
| : graph_item_(graph_item), | : graph_item_(graph_item), | ||||
| context_(context), | context_(context), | ||||
| force_infer_shape_(force_infer_shape), | force_infer_shape_(force_infer_shape), | ||||
| pre_run_pool_(kDefaultThreadNum) { | |||||
| pre_run_pool_(kDefaultThreadNum), | |||||
| ready_queue_(kDefaultQueueSize) { | |||||
| } | } | ||||
| SubgraphExecutor::~SubgraphExecutor() { | SubgraphExecutor::~SubgraphExecutor() { | ||||
| @@ -169,7 +171,7 @@ Status SubgraphExecutor::ExecuteAsyncForKnownShape(const std::vector<TensorValue | |||||
| GE_CHECK_NOTNULL(node_state); | GE_CHECK_NOTNULL(node_state); | ||||
| node_state->SetKernelTask(node_item->kernel_task); | node_state->SetKernelTask(node_item->kernel_task); | ||||
| known_shape_task_context_ = TaskContext::Create(*node_item, context_, subgraph_context_.get()); | |||||
| known_shape_task_context_ = TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | |||||
| GE_CHECK_NOTNULL(known_shape_task_context_); | GE_CHECK_NOTNULL(known_shape_task_context_); | ||||
| HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_), | HYBRID_CHK_STATUS_RET(ExecutionEngine::ExecuteAsync(*node_state, known_shape_task_context_, *context_), | ||||
| @@ -201,11 +203,11 @@ Status SubgraphExecutor::ExecuteAsync(TaskContext &task_context) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status SubgraphExecutor::PrepareNodes() { | |||||
| GELOGD("[%s] Start to prepare nodes. force infer shape = %s.", | |||||
| Status SubgraphExecutor::PrepareNodes(int group) { | |||||
| GELOGD("[%s] Start to prepare nodes. group = %d", | |||||
| graph_item_->GetName().c_str(), | graph_item_->GetName().c_str(), | ||||
| force_infer_shape_ ? "true" : "false"); | |||||
| auto &all_nodes = graph_item_->GetAllNodes(); | |||||
| group); | |||||
| auto &all_nodes = graph_item_->GetAllNodes(group); | |||||
| for (auto all_node : all_nodes) { | for (auto all_node : all_nodes) { | ||||
| auto &node_item = *all_node; | auto &node_item = *all_node; | ||||
| // for while op | // for while op | ||||
| @@ -240,7 +242,8 @@ Status SubgraphExecutor::PrepareNodes() { | |||||
| } else { | } else { | ||||
| node_state->SetKernelTask(node_item.kernel_task); | node_state->SetKernelTask(node_item.kernel_task); | ||||
| } | } | ||||
| auto unique_task_context = TaskContext::Create(*node_state->GetNodeItem(), context_, subgraph_context_.get()); | |||||
| auto unique_task_context = | |||||
| TaskContext::Create(node_state.get(), context_, subgraph_context_.get()); | |||||
| GE_CHECK_NOTNULL(unique_task_context); | GE_CHECK_NOTNULL(unique_task_context); | ||||
| const auto &task = node_state->GetKernelTask(); | const auto &task = node_state->GetKernelTask(); | ||||
| if (task == nullptr) { | if (task == nullptr) { | ||||
| @@ -265,15 +268,17 @@ Status SubgraphExecutor::PrepareNodes() { | |||||
| GELOGD("[%s] Push node [%s] to queue.", graph_item_->GetName().c_str(), node_item.NodeName().c_str()); | GELOGD("[%s] Push node [%s] to queue.", graph_item_->GetName().c_str(), node_item.NodeName().c_str()); | ||||
| } | } | ||||
| GELOGD("[%s] Done preparing nodes successfully.", graph_item_->GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status SubgraphExecutor::InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state) { | |||||
| const auto &node_item = *node_state.GetNodeItem(); | |||||
| Status SubgraphExecutor::InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state) const { | |||||
| GetContext().SetSessionId(context_->context_id); | |||||
| HYBRID_CHK_STATUS_RET(shape_inference_engine->InferShape(node_state), | HYBRID_CHK_STATUS_RET(shape_inference_engine->InferShape(node_state), | ||||
| "[%s] Failed to InferShape.", node_state.GetName().c_str()); | |||||
| HYBRID_CHK_STATUS_RET(shape_inference_engine->PropagateOutputShapes(node_item), | |||||
| "[%s] Failed to PropagateOutputShapes.", node_state.GetName().c_str()); | |||||
| "[%s] Failed to InferShape.", node_state.GetName().c_str()); | |||||
| GetContext().SetSessionId(context_->session_id); | |||||
| HYBRID_CHK_STATUS_RET(shape_inference_engine->PropagateOutputShapes(node_state), | |||||
| "[%s] Failed to PropagateOutputShapes.", node_state.GetName().c_str()); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -285,7 +290,7 @@ Status SubgraphExecutor::PrepareForExecution(GraphExecutionContext *ctx, NodeSta | |||||
| } else { | } else { | ||||
| node_state.SetKernelTask(node_item.kernel_task); | node_state.SetKernelTask(node_item.kernel_task); | ||||
| } | } | ||||
| auto unique_task_context = TaskContext::Create(*node_state.GetNodeItem(), context_, subgraph_context_.get()); | |||||
| auto unique_task_context = TaskContext::Create(&node_state, context_, subgraph_context_.get()); | |||||
| GE_CHECK_NOTNULL(unique_task_context); | GE_CHECK_NOTNULL(unique_task_context); | ||||
| const auto &task = node_state.GetKernelTask(); | const auto &task = node_state.GetKernelTask(); | ||||
| if (task == nullptr) { | if (task == nullptr) { | ||||
| @@ -336,11 +341,11 @@ Status SubgraphExecutor::LaunchTasks() { | |||||
| } | } | ||||
| } | } | ||||
| Status SubgraphExecutor::ScheduleTasks() { | |||||
| Status SubgraphExecutor::ScheduleTasks(int group) { | |||||
| GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); | GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); | ||||
| auto prepare_future = std::async(std::launch::async, [&]() -> Status { | auto prepare_future = std::async(std::launch::async, [&]() -> Status { | ||||
| GetContext().SetSessionId(context_->session_id); | GetContext().SetSessionId(context_->session_id); | ||||
| auto ret = PrepareNodes(); | |||||
| auto ret = PrepareNodes(group); | |||||
| ready_queue_.Push(nullptr); | ready_queue_.Push(nullptr); | ||||
| return ret; | return ret; | ||||
| }); | }); | ||||
| @@ -481,5 +486,14 @@ Status SubgraphExecutor::EnableOutputZeroCopy(const vector<TensorValue> &outputs | |||||
| GELOGD("Done enabling zero copy for outputs successfully."); | GELOGD("Done enabling zero copy for outputs successfully."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status SubgraphExecutor::PartialExecuteAsync(int task_group) { | |||||
| return ScheduleTasks(task_group); | |||||
| } | |||||
| Status SubgraphExecutor::InitForPartialExecution(const vector<TensorValue> &inputs, | |||||
| const vector<ConstGeTensorDescPtr> &input_desc) { | |||||
| return Init(inputs, input_desc); | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -36,6 +36,11 @@ class SubgraphExecutor { | |||||
| SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape = false); | SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape = false); | ||||
| ~SubgraphExecutor(); | ~SubgraphExecutor(); | ||||
| Status InitForPartialExecution(const std::vector<TensorValue> &inputs, | |||||
| const std::vector<ConstGeTensorDescPtr> &input_desc); | |||||
| Status PartialExecuteAsync(int task_group); | |||||
| /** | /** | ||||
| * Execute subgraph async, output tensor address(not data) and output tensor descriptions are | * Execute subgraph async, output tensor address(not data) and output tensor descriptions are | ||||
| * valid after this method returned | * valid after this method returned | ||||
| @@ -89,15 +94,15 @@ class SubgraphExecutor { | |||||
| private: | private: | ||||
| Status PrepareForExecution(GraphExecutionContext *ctx, NodeState &node_state); | Status PrepareForExecution(GraphExecutionContext *ctx, NodeState &node_state); | ||||
| Status EnableOutputZeroCopy(const std::vector<TensorValue> &outputs); | Status EnableOutputZeroCopy(const std::vector<TensorValue> &outputs); | ||||
| static Status InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state); | |||||
| Status InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state) const; | |||||
| Status Init(const std::vector<TensorValue> &inputs, | Status Init(const std::vector<TensorValue> &inputs, | ||||
| const std::vector<ConstGeTensorDescPtr> &input_desc); | const std::vector<ConstGeTensorDescPtr> &input_desc); | ||||
| Status InitInputsForUnknownShape(const std::vector<TensorValue> &inputs, | Status InitInputsForUnknownShape(const std::vector<TensorValue> &inputs, | ||||
| const std::vector<ConstGeTensorDescPtr> &input_desc); | const std::vector<ConstGeTensorDescPtr> &input_desc); | ||||
| Status InitInputsForKnownShape(const std::vector<TensorValue> &inputs); | Status InitInputsForKnownShape(const std::vector<TensorValue> &inputs); | ||||
| Status ExecuteAsyncForKnownShape(const std::vector<TensorValue> &inputs); | Status ExecuteAsyncForKnownShape(const std::vector<TensorValue> &inputs); | ||||
| Status ScheduleTasks(); | |||||
| Status PrepareNodes(); | |||||
| Status ScheduleTasks(int group = -1); | |||||
| Status PrepareNodes(int group = -1); | |||||
| Status LaunchTasks(); | Status LaunchTasks(); | ||||
| Status SetOutputsToParentNode(TaskContext &task_context); | Status SetOutputsToParentNode(TaskContext &task_context); | ||||
| @@ -125,16 +125,16 @@ Status NodeDoneCallback::PrepareConstInputs(const NodeItem &node_item) { | |||||
| RT_MEMCPY_DEVICE_TO_HOST)); | RT_MEMCPY_DEVICE_TO_HOST)); | ||||
| } | } | ||||
| tensor.SetData(std::move(host_buffer)); | tensor.SetData(std::move(host_buffer)); | ||||
| string session_id = std::to_string(context_->GetSessionId()); | |||||
| string context_id = std::to_string(graph_context_->context_id); | |||||
| RuntimeInferenceContext *runtime_infer_ctx = nullptr; | RuntimeInferenceContext *runtime_infer_ctx = nullptr; | ||||
| GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::GetContext(session_id, &runtime_infer_ctx), | |||||
| "Failed to get RuntimeInferenceContext, session_id = %s", session_id.c_str()); | |||||
| GE_CHK_GRAPH_STATUS_RET(RuntimeInferenceContext::GetContext(context_id, &runtime_infer_ctx), | |||||
| "Failed to get RuntimeInferenceContext, context_id = %s", context_id.c_str()); | |||||
| GE_CHK_STATUS_RET(runtime_infer_ctx->SetTensor(node_item.node_id, output_idx, std::move(tensor)), | GE_CHK_STATUS_RET(runtime_infer_ctx->SetTensor(node_item.node_id, output_idx, std::move(tensor)), | ||||
| "Failed to SetTensor, node = %s, output_index = %d", node_item.NodeName().c_str(), output_idx); | "Failed to SetTensor, node = %s, output_index = %d", node_item.NodeName().c_str(), output_idx); | ||||
| GELOGD("[%s] Output[%d] cached successfully in session: %s. node_id = %d, shape = [%s]", | |||||
| GELOGD("[%s] Output[%d] cached successfully in context: %s. node_id = %d, shape = [%s]", | |||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| output_idx, | output_idx, | ||||
| session_id.c_str(), | |||||
| context_id.c_str(), | |||||
| node_item.node_id, | node_item.node_id, | ||||
| ge_tensor_desc->GetShape().ToString().c_str()); | ge_tensor_desc->GetShape().ToString().c_str()); | ||||
| @@ -332,6 +332,7 @@ Status NodeDoneCallback::OnNodeDone() { | |||||
| if (node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE) { | if (node_item.shape_inference_type == DEPEND_SHAPE_RANGE || node_item.shape_inference_type == DEPEND_COMPUTE) { | ||||
| // update output tensor sizes | // update output tensor sizes | ||||
| GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(node_item)); | GE_CHK_STATUS_RET_NOLOG(ShapeInferenceEngine::CalcOutputTensorSizes(node_item)); | ||||
| GE_CHK_STATUS_RET_NOLOG(context_->GetNodeState()->GetShapeInferenceState().UpdateOutputDesc()); | |||||
| } | } | ||||
| // PropagateOutputs for type == DEPEND_COMPUTE | // PropagateOutputs for type == DEPEND_COMPUTE | ||||
| if (node_item.shape_inference_type == DEPEND_COMPUTE) { | if (node_item.shape_inference_type == DEPEND_COMPUTE) { | ||||
| @@ -363,7 +364,7 @@ Status ExecutionEngine::ExecuteAsync(NodeState &node_state, | |||||
| RECORD_EXECUTION_EVENT(&execution_context, task_context->GetNodeName(), "Start"); | RECORD_EXECUTION_EVENT(&execution_context, task_context->GetNodeName(), "Start"); | ||||
| auto cb = std::shared_ptr<NodeDoneCallback>(new(std::nothrow) NodeDoneCallback(&execution_context, task_context)); | auto cb = std::shared_ptr<NodeDoneCallback>(new(std::nothrow) NodeDoneCallback(&execution_context, task_context)); | ||||
| GE_CHECK_NOTNULL(cb); | GE_CHECK_NOTNULL(cb); | ||||
| auto callback = [&, cb]() { | |||||
| auto callback = [task_context, cb]() { | |||||
| auto ret = cb->OnNodeDone(); | auto ret = cb->OnNodeDone(); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| task_context->OnError(ret); | task_context->OnError(ret); | ||||
| @@ -109,7 +109,8 @@ Status ShapeInferenceEngine::AwaitDependentNodes(NodeState &node_state) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) { | |||||
| Status ShapeInferenceEngine::PropagateOutputShapes(NodeState &node_state) { | |||||
| auto &node_item = *node_state.GetNodeItem(); | |||||
| if (node_item.is_output_shape_static) { | if (node_item.is_output_shape_static) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -140,9 +141,8 @@ Status ShapeInferenceEngine::PropagateOutputShapes(const NodeItem &node_item) { | |||||
| // in case type 3 and 4, shape will be valid after computing is done | // in case type 3 and 4, shape will be valid after computing is done | ||||
| auto &infer_state = dst_node_state->GetShapeInferenceState(); | auto &infer_state = dst_node_state->GetShapeInferenceState(); | ||||
| if (shape_is_future) { | if (shape_is_future) { | ||||
| ShapeFuture future(node_item.node, i, subgraph_context_); | |||||
| infer_state.UpdateInputShapeFuture(dst_input_index_and_node.first, | |||||
| std::move(future)); | |||||
| ShapeFuture future(&node_state, i, subgraph_context_); | |||||
| infer_state.UpdateInputShapeFuture(dst_input_index_and_node.first, std::move(future)); | |||||
| } else { | } else { | ||||
| GE_CHK_STATUS_RET_NOLOG(infer_state.UpdateInputShape(dst_input_index_and_node.first, *output_desc)); | GE_CHK_STATUS_RET_NOLOG(infer_state.UpdateInputShape(dst_input_index_and_node.first, *output_desc)); | ||||
| } | } | ||||
| @@ -32,7 +32,7 @@ class ShapeInferenceEngine { | |||||
| Status InferShapeForSubgraph(const NodeItem &node_item, const FusedSubgraph &fused_subgraph); | Status InferShapeForSubgraph(const NodeItem &node_item, const FusedSubgraph &fused_subgraph); | ||||
| Status PropagateOutputShapes(const NodeItem &node_item); | |||||
| Status PropagateOutputShapes(NodeState &node_state); | |||||
| static Status CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range = false); | static Status CalcOutputTensorSizes(const NodeItem &node_item, bool fallback_with_range = false); | ||||
| @@ -30,6 +30,19 @@ const vector<NodeItem *> &hybrid::GraphItem::GetAllNodes() const { | |||||
| return node_items_; | return node_items_; | ||||
| } | } | ||||
| const vector<NodeItem *> &GraphItem::GetAllNodes(int group) const { | |||||
| if (group == -1) { | |||||
| return GetAllNodes(); | |||||
| } | |||||
| if (group >= static_cast<int>(grouped_node_items_.size())) { | |||||
| static vector<NodeItem *> empty_nodes; | |||||
| return empty_nodes; | |||||
| } | |||||
| return grouped_node_items_[group]; | |||||
| } | |||||
| const vector<const NodeItem *> &GraphItem::GetInputNodes() const { | const vector<const NodeItem *> &GraphItem::GetInputNodes() const { | ||||
| return input_nodes_; | return input_nodes_; | ||||
| } | } | ||||
| @@ -74,5 +87,28 @@ const NodeItem *GraphItem::GetOutputNode() const { | |||||
| const vector<std::pair<const NodeItem *, int>> &GraphItem::GetOutputEdges() const { | const vector<std::pair<const NodeItem *, int>> &GraphItem::GetOutputEdges() const { | ||||
| return output_edges_; | return output_edges_; | ||||
| } | } | ||||
| Status GraphItem::GroupNodes() { | |||||
| int last_group = INT32_MIN; | |||||
| std::set<int> seen_groups; | |||||
| for (auto node : node_items_) { | |||||
| int group = node->group; | |||||
| if (group != last_group) { | |||||
| if (seen_groups.find(group) != seen_groups.end()) { | |||||
| GELOGE(INTERNAL_ERROR, "Unordered node group found. node = %s, group = %d", node->NodeName().c_str(), group); | |||||
| return INTERNAL_ERROR; | |||||
| } else { | |||||
| last_group = group; | |||||
| seen_groups.insert(group); | |||||
| grouped_node_items_.emplace_back(std::vector<NodeItem *>()); | |||||
| } | |||||
| } | |||||
| GELOGD("Adding node [%s] to group %d", node->NodeName().c_str(), group); | |||||
| grouped_node_items_.back().emplace_back(node); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -26,7 +26,9 @@ class GraphItem { | |||||
| public: | public: | ||||
| GraphItem() = default; | GraphItem() = default; | ||||
| ~GraphItem(); | ~GraphItem(); | ||||
| Status GroupNodes(); | |||||
| const vector<NodeItem *> &GetAllNodes() const; | const vector<NodeItem *> &GetAllNodes() const; | ||||
| const vector<NodeItem *> &GetAllNodes(int group) const; | |||||
| const vector<const NodeItem *> &GetInputNodes() const; | const vector<const NodeItem *> &GetInputNodes() const; | ||||
| Status GetOutputDescList(std::vector<ConstGeTensorDescPtr> &output_desc_list) const; | Status GetOutputDescList(std::vector<ConstGeTensorDescPtr> &output_desc_list) const; | ||||
| const vector<std::pair<const NodeItem *, int>> &GetOutputEdges() const; | const vector<std::pair<const NodeItem *, int>> &GetOutputEdges() const; | ||||
| @@ -46,6 +48,10 @@ class GraphItem { | |||||
| name_ = name; | name_ = name; | ||||
| } | } | ||||
| size_t NumGroups() const { | |||||
| return grouped_node_items_.size(); | |||||
| } | |||||
| const NodeItem *GetOutputNode() const; | const NodeItem *GetOutputNode() const; | ||||
| bool IsDynamic() const; | bool IsDynamic() const; | ||||
| @@ -56,6 +62,7 @@ class GraphItem { | |||||
| friend class HybridModelBuilder; | friend class HybridModelBuilder; | ||||
| std::string name_; | std::string name_; | ||||
| std::vector<NodeItem *> node_items_; | std::vector<NodeItem *> node_items_; | ||||
| std::vector<std::vector<NodeItem *>> grouped_node_items_; | |||||
| std::vector<const NodeItem *> input_nodes_; | std::vector<const NodeItem *> input_nodes_; | ||||
| const NodeItem *output_node_ = nullptr; | const NodeItem *output_node_ = nullptr; | ||||
| // <src_node, out_index> | // <src_node, out_index> | ||||
| @@ -52,7 +52,7 @@ Status HybridModel::Init(bool is_single_op) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| TensorValue* HybridModel::GetVariable(const string &name) const { | |||||
| TensorValue *HybridModel::GetVariable(const string &name) const { | |||||
| auto it = variable_tensors_.find(name); | auto it = variable_tensors_.find(name); | ||||
| if (it == variable_tensors_.end()) { | if (it == variable_tensors_.end()) { | ||||
| GELOGD("Failed to get variable tensor. var name = [%s]", name.c_str()); | GELOGD("Failed to get variable tensor. var name = [%s]", name.c_str()); | ||||
| @@ -113,7 +113,7 @@ GeModelPtr HybridModel::GetGeModel(const NodePtr &node) const { | |||||
| return it->second; | return it->second; | ||||
| } | } | ||||
| const GraphItem* HybridModel::GetRootGraphItem() const { | |||||
| const GraphItem *HybridModel::GetRootGraphItem() const { | |||||
| return root_graph_item_.get(); | return root_graph_item_.get(); | ||||
| } | } | ||||
| @@ -287,6 +287,16 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s | |||||
| src_node_item->NodeName().c_str()); | src_node_item->NodeName().c_str()); | ||||
| src_node_item->has_observer = true; | src_node_item->has_observer = true; | ||||
| node_item.dependents_for_execution.emplace_back(src_node); | node_item.dependents_for_execution.emplace_back(src_node); | ||||
| node_item.has_observer = true; | |||||
| for (auto &dst_node : ge_node->GetOutNodes()) { | |||||
| if (dst_node == nullptr) { | |||||
| continue; | |||||
| } | |||||
| NodeItem *dst_node_item = nullptr; | |||||
| GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(dst_node, &dst_node_item)); | |||||
| dst_node_item->dependents_for_execution.emplace_back(ge_node); | |||||
| } | |||||
| } else if (src_node_item->shape_inference_type == DEPEND_COMPUTE) { | } else if (src_node_item->shape_inference_type == DEPEND_COMPUTE) { | ||||
| GELOGD("[%s] Add input data dependent node [%s] due to inference type = DEPEND_COMPUTE", | GELOGD("[%s] Add input data dependent node [%s] due to inference type = DEPEND_COMPUTE", | ||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| @@ -614,6 +624,15 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap | |||||
| continue; | continue; | ||||
| } | } | ||||
| if (op_desc->HasAttr(ATTR_STAGE_LEVEL)) { | |||||
| uint32_t stage_level = UINT32_MAX; | |||||
| if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_STAGE_LEVEL, stage_level)) { | |||||
| for (const auto &stage_node : subgraph->GetAllNodes()) { | |||||
| GELOGD("Set ATTR_STAGE_LEVEL on node %s, stage_level=%u", stage_node->GetName().c_str(), stage_level); | |||||
| (void)AttrUtils::SetInt(stage_node->GetOpDesc(), ATTR_STAGE_LEVEL, stage_level); | |||||
| } | |||||
| } | |||||
| } | |||||
| GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), | GE_CHK_GRAPH_STATUS_RET(UnfoldSubgraph(root_graph, *merged_graph, *subgraph), | ||||
| "[%s] Failed to merge subgraph.", | "[%s] Failed to merge subgraph.", | ||||
| subgraph->GetName().c_str()); | subgraph->GetName().c_str()); | ||||
| @@ -621,6 +640,14 @@ Status HybridModelBuilder::UnfoldSubgraphs(ComputeGraph &root_graph, ComputeGrap | |||||
| // invoke before adding subgraphs. in case modify node id in known-shaped subgraphs. | // invoke before adding subgraphs. in case modify node id in known-shaped subgraphs. | ||||
| GE_CHK_GRAPH_STATUS_RET(merged_graph->TopologicalSorting(), "Failed to invoke TopologicalSorting on merged graph."); | GE_CHK_GRAPH_STATUS_RET(merged_graph->TopologicalSorting(), "Failed to invoke TopologicalSorting on merged graph."); | ||||
| GE_DUMP(merged_graph, "hybrid_merged_graph_BeforeStageSort"); | |||||
| merged_graph->TopologicalSorting([](const NodePtr &a, const NodePtr &b) -> bool { | |||||
| uint32_t a_level = UINT32_MAX; | |||||
| (void)AttrUtils::GetInt(a->GetOpDesc(), ATTR_STAGE_LEVEL, a_level); | |||||
| uint32_t b_level = UINT32_MAX; | |||||
| (void)AttrUtils::GetInt(b->GetOpDesc(), ATTR_STAGE_LEVEL, b_level); | |||||
| return a_level < b_level; | |||||
| }); | |||||
| for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { | for (auto &remained_subgraph : root_graph.GetAllSubgraphs()) { | ||||
| GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); | GELOGD("Adding subgraph [%s] to merged-graph.", remained_subgraph->GetName().c_str()); | ||||
| @@ -732,6 +759,7 @@ Status HybridModelBuilder::LoadGraph() { | |||||
| GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph."); | GE_CHK_STATUS_RET(LoadDynamicSubgraph(*root_graph, true), "Failed to load root graph."); | ||||
| GELOGD("Done loading root graph successfully."); | GELOGD("Done loading root graph successfully."); | ||||
| GE_CHK_STATUS_RET(hybrid_model_.root_graph_item_->GroupNodes(), "Failed to group nodes for root graph"); | |||||
| for (auto &sub_graph : root_graph->GetAllSubgraphs()) { | for (auto &sub_graph : root_graph->GetAllSubgraphs()) { | ||||
| GE_CHECK_NOTNULL(sub_graph); | GE_CHECK_NOTNULL(sub_graph); | ||||
| @@ -805,6 +833,7 @@ Status HybridModelBuilder::VarNodeToTensor(const NodePtr &var_node, std::unique_ | |||||
| // var size is only for checking, will not allocate any memory by it | // var size is only for checking, will not allocate any memory by it | ||||
| tensor.reset(new(std::nothrow)TensorValue(dev_mem, static_cast<size_t>(var_size))); | tensor.reset(new(std::nothrow)TensorValue(dev_mem, static_cast<size_t>(var_size))); | ||||
| GE_CHECK_NOTNULL(tensor); | GE_CHECK_NOTNULL(tensor); | ||||
| GELOGI("Get var memory addr %p for node %s, size = %ld, mem_type=%u", dev_mem, var_name.c_str(), var_size, mem_type); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -1737,8 +1766,14 @@ Status HybridModelBuilder::CreateProfilingNodeBefore(GraphItem &graph_item, cons | |||||
| for (const auto &task_def : task_def_lists) { | for (const auto &task_def : task_def_lists) { | ||||
| hybrid_model_.task_defs_[profiling_node].emplace_back(task_def); | hybrid_model_.task_defs_[profiling_node].emplace_back(task_def); | ||||
| } | } | ||||
| if (op_desc->HasAttr(ATTR_STAGE_LEVEL)) { | |||||
| uint32_t stage_level = UINT32_MAX; | |||||
| (void)ge::AttrUtils::GetInt(op_desc, ATTR_STAGE_LEVEL, stage_level); | |||||
| (void)ge::AttrUtils::SetInt(node_ptr->GetOpDesc(), ATTR_STAGE_LEVEL, stage_level); | |||||
| } | |||||
| NodeItem *node_item = nullptr; | NodeItem *node_item = nullptr; | ||||
| GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(profiling_node, &node_item)); | GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(profiling_node, &node_item)); | ||||
| GE_CHECK_NOTNULL(node_item); | |||||
| node_item->input_start = 0; | node_item->input_start = 0; | ||||
| node_item->output_start = 0; | node_item->output_start = 0; | ||||
| graph_item.node_items_.emplace_back(node_item); | graph_item.node_items_.emplace_back(node_item); | ||||
| @@ -1812,8 +1847,14 @@ Status HybridModelBuilder::CreateProfilingNodeAfter(GraphItem &graph_item, const | |||||
| for (const auto &task_def : task_def_lists) { | for (const auto &task_def : task_def_lists) { | ||||
| hybrid_model_.task_defs_[profiling_node].emplace_back(task_def); | hybrid_model_.task_defs_[profiling_node].emplace_back(task_def); | ||||
| } | } | ||||
| if (op_desc->HasAttr(ATTR_STAGE_LEVEL)) { | |||||
| uint32_t stage_level = UINT32_MAX; | |||||
| (void)ge::AttrUtils::GetInt(op_desc, ATTR_STAGE_LEVEL, stage_level); | |||||
| (void)ge::AttrUtils::SetInt(profiling_node->GetOpDesc(), ATTR_STAGE_LEVEL, stage_level); | |||||
| } | |||||
| NodeItem *node_item = nullptr; | NodeItem *node_item = nullptr; | ||||
| GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(profiling_node, &node_item)); | GE_CHK_STATUS_RET_NOLOG(GetOrCreateNodeItem(profiling_node, &node_item)); | ||||
| GE_CHECK_NOTNULL(node_item); | |||||
| node_item->input_start = 0; | node_item->input_start = 0; | ||||
| node_item->output_start = 0; | node_item->output_start = 0; | ||||
| graph_item.node_items_.emplace_back(node_item); | graph_item.node_items_.emplace_back(node_item); | ||||
| @@ -21,8 +21,8 @@ | |||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/utils/node_utils.h" | #include "graph/utils/node_utils.h" | ||||
| #include "hybrid/node_executor/node_executor.h" | |||||
| #include "hybrid/executor/worker/shape_inference_engine.h" | #include "hybrid/executor/worker/shape_inference_engine.h" | ||||
| #include "hybrid/node_executor/node_executor.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| @@ -146,6 +146,20 @@ Status NodeItem::InitInputsAndOutputs() { | |||||
| GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX); | GE_CHECK_LE(op_desc->GetOutputsSize(), INT32_MAX); | ||||
| num_inputs = static_cast<int>(op_desc->GetInputsSize()); | num_inputs = static_cast<int>(op_desc->GetInputsSize()); | ||||
| num_outputs = static_cast<int>(op_desc->GetOutputsSize()); | num_outputs = static_cast<int>(op_desc->GetOutputsSize()); | ||||
| if (AttrUtils::GetInt(op_desc, ::ge::ATTR_STAGE_LEVEL, group)) { | |||||
| GELOGD("[%s] Got stage level from op_desc = %d", op_desc->GetName().c_str(), group); | |||||
| } else { | |||||
| if (AttrUtils::GetInt(node->GetOwnerComputeGraph(), ::ge::ATTR_STAGE_LEVEL, group)) { | |||||
| GELOGD("[%s] Got stage level from parent graph = %d", op_desc->GetName().c_str(), group); | |||||
| } else { | |||||
| auto parent_node = node->GetOwnerComputeGraph()->GetParentNode(); | |||||
| if ((parent_node != nullptr) && (AttrUtils::GetInt(parent_node->GetOpDesc(), ::ge::ATTR_STAGE_LEVEL, group))) { | |||||
| GELOGD("[%s] Got stage level from parent node = %d", op_desc->GetName().c_str(), group); | |||||
| } else { | |||||
| GELOGD("[%s] Node do not set stage level", op_desc->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| } | |||||
| ResolveOptionalInputs(); | ResolveOptionalInputs(); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -222,8 +236,8 @@ void NodeItem::ResolveUnknownShapeType() { | |||||
| Status NodeItem::Init() { | Status NodeItem::Init() { | ||||
| GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs()); | GE_CHK_STATUS_RET_NOLOG(InitInputsAndOutputs()); | ||||
| GE_CHK_STATUS_RET_NOLOG(ResolveDynamicState()); | GE_CHK_STATUS_RET_NOLOG(ResolveDynamicState()); | ||||
| ResolveUnknownShapeType(); | |||||
| if (is_dynamic) { | if (is_dynamic) { | ||||
| ResolveUnknownShapeType(); | |||||
| GE_CHK_STATUS_RET_NOLOG(ResolveStaticInputsAndOutputs()); | GE_CHK_STATUS_RET_NOLOG(ResolveStaticInputsAndOutputs()); | ||||
| GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), "[%s] Failed to parse fused subgraph", node_name.c_str()); | GE_CHK_STATUS_RET(ParseFusedSubgraph(*this), "[%s] Failed to parse fused subgraph", node_name.c_str()); | ||||
| } | } | ||||
| @@ -244,6 +258,7 @@ std::string NodeItem::DebugString() const { | |||||
| ss << ", is_dynamic = " << (is_dynamic ? "True" : "False"); | ss << ", is_dynamic = " << (is_dynamic ? "True" : "False"); | ||||
| ss << ", is_output_static = " << (is_output_shape_static ? "True" : "False"); | ss << ", is_output_static = " << (is_output_shape_static ? "True" : "False"); | ||||
| ss << ", unknown_shape_op_type = " << shape_inference_type; | ss << ", unknown_shape_op_type = " << shape_inference_type; | ||||
| ss << ", stage = " << group; | |||||
| ss << ", input_start = " << input_start; | ss << ", input_start = " << input_start; | ||||
| ss << ", num_inputs = " << num_inputs; | ss << ", num_inputs = " << num_inputs; | ||||
| ss << ", output_start = " << output_start; | ss << ", output_start = " << output_start; | ||||
| @@ -74,6 +74,7 @@ struct NodeItem { | |||||
| NodePtr node; | NodePtr node; | ||||
| OpDesc *op_desc; | OpDesc *op_desc; | ||||
| int node_id = -1; | int node_id = -1; | ||||
| int group = -1; | |||||
| int num_inputs = 0; | int num_inputs = 0; | ||||
| int num_outputs = 0; | int num_outputs = 0; | ||||
| int input_start = -1; | int input_start = -1; | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "hybrid/node_executor/aicore/aicore_op_task.h" | #include "hybrid/node_executor/aicore/aicore_op_task.h" | ||||
| #include "framework/common/taskdown_common.h" | #include "framework/common/taskdown_common.h" | ||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "graph/ge_context.h" | |||||
| #include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
| #include "hybrid/node_executor/aicore/aicore_task_builder.h" | #include "hybrid/node_executor/aicore/aicore_task_builder.h" | ||||
| #include "graph/load/model_manager/tbe_handle_store.h" | #include "graph/load/model_manager/tbe_handle_store.h" | ||||
| @@ -198,9 +199,12 @@ Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) { | |||||
| tiling_info.clear_atomic = true; | tiling_info.clear_atomic = true; | ||||
| auto execution_context = context.GetExecutionContext(); | auto execution_context = context.GetExecutionContext(); | ||||
| GetContext().SetSessionId(execution_context->context_id); | |||||
| RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] Start"); | RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] Start"); | ||||
| GE_CHK_STATUS_RET(CalcTilingInfo(node, tiling_info)); | GE_CHK_STATUS_RET(CalcTilingInfo(node, tiling_info)); | ||||
| RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] End"); | RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] End"); | ||||
| GetContext().SetSessionId(execution_context->session_id); | |||||
| // update op args by tiling info | // update op args by tiling info | ||||
| block_dim_ = static_cast<uint32_t>(tiling_info.block_dim); | block_dim_ = static_cast<uint32_t>(tiling_info.block_dim); | ||||
| @@ -70,7 +70,6 @@ Status AiCoreTaskBuilder::BuildTask(std::unique_ptr<NodeTask> &node_task, | |||||
| auto atomic_task = | auto atomic_task = | ||||
| std::unique_ptr<AtomicAddrCleanOpTask>(new(std::nothrow)AtomicAddrCleanOpTask()); | std::unique_ptr<AtomicAddrCleanOpTask>(new(std::nothrow)AtomicAddrCleanOpTask()); | ||||
| GE_CHECK_NOTNULL(atomic_task); | GE_CHECK_NOTNULL(atomic_task); | ||||
| atomic_task->SetSingleOp(is_single_op); | |||||
| GE_CHK_STATUS_RET(atomic_task->Init(*op_desc_, task_defs_.front()), | GE_CHK_STATUS_RET(atomic_task->Init(*op_desc_, task_defs_.front()), | ||||
| "[%s] Failed to init task for AtomicAddrClean", | "[%s] Failed to init task for AtomicAddrClean", | ||||
| op_desc_->GetName().c_str()); | op_desc_->GetName().c_str()); | ||||
| @@ -28,6 +28,7 @@ namespace hybrid { | |||||
| namespace { | namespace { | ||||
| // mem need release | // mem need release | ||||
| constexpr uint64_t kReleaseFlag = 1; | constexpr uint64_t kReleaseFlag = 1; | ||||
| const char *const kAicpuAllshape = "_AllShape"; | |||||
| } | } | ||||
| REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICPU_TF, AiCpuNodeExecutor); | REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICPU_TF, AiCpuNodeExecutor); | ||||
| REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICPU_CUSTOM, AiCpuNodeExecutor); | REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICPU_CUSTOM, AiCpuNodeExecutor); | ||||
| @@ -60,6 +61,7 @@ Status AicpuNodeTaskBase::InitExtInfo(const std::string &kernel_ext_info, int64_ | |||||
| GELOGD("To update aicpu_task ext_info session_info session_id to %lu", session_id); | GELOGD("To update aicpu_task ext_info session_info session_id to %lu", session_id); | ||||
| GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateSessionInfoSessionId(session_id), | GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateSessionInfoSessionId(session_id), | ||||
| "UpdateSessionInfoSessionId failed."); | "UpdateSessionInfoSessionId failed."); | ||||
| GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateExecuteMode(!node_item_->is_dynamic), "UpdateExecuteMode failed."); | |||||
| // copy task args buf | // copy task args buf | ||||
| GE_CHK_STATUS_RET(AllocTensorBuffer(aicpu_ext_handle_.GetExtInfoLen(), ext_info_addr_dev_), | GE_CHK_STATUS_RET(AllocTensorBuffer(aicpu_ext_handle_.GetExtInfoLen(), ext_info_addr_dev_), | ||||
| @@ -74,7 +76,7 @@ Status AicpuNodeTaskBase::InitExtInfo(const std::string &kernel_ext_info, int64_ | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status AicpuNodeTaskBase::UpdateOutputShapeFromExtInfo() { | |||||
| Status AicpuNodeTaskBase::UpdateOutputShapeFromExtInfo(TaskContext &task_context) { | |||||
| if (node_item_->num_outputs == 0) { | if (node_item_->num_outputs == 0) { | ||||
| GELOGD("Task [%s] output_num is 0, no need update output shape.", node_name_.c_str()); | GELOGD("Task [%s] output_num is 0, no need update output shape.", node_name_.c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -91,19 +93,19 @@ Status AicpuNodeTaskBase::UpdateOutputShapeFromExtInfo() { | |||||
| // not support update data type now, just for param | // not support update data type now, just for param | ||||
| DataType data_type; | DataType data_type; | ||||
| aicpu_ext_handle_.GetOutputShapeAndType(i, shape, data_type); | aicpu_ext_handle_.GetOutputShapeAndType(i, shape, data_type); | ||||
| auto output_desc = node_item_->MutableOutputDesc(i); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| GE_CHK_STATUS_RET(UpdateShapeToOutputDesc(shape, i, output_desc), | |||||
| GE_CHK_STATUS_RET(UpdateShapeToOutputDesc(task_context, shape, i), | |||||
| "Update node %s [%d]th output shape failed.", | "Update node %s [%d]th output shape failed.", | ||||
| node_name_.c_str(), i); | node_name_.c_str(), i); | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status AicpuNodeTaskBase::UpdateShapeToOutputDesc(const GeShape &shape_new, | |||||
| int32_t output_index, GeTensorDescPtr &output_desc) { | |||||
| Status AicpuNodeTaskBase::UpdateShapeToOutputDesc(TaskContext &task_context, | |||||
| const GeShape &shape_new, | |||||
| int32_t output_index) { | |||||
| auto output_desc = task_context.MutableOutputDesc(output_index); | |||||
| GE_CHECK_NOTNULL(output_desc); | |||||
| auto shape_old = output_desc->GetShape(); | auto shape_old = output_desc->GetShape(); | ||||
| output_desc->SetShape(shape_new); | |||||
| GELOGD("Update node[%s] out[%d] shape from %s to %s.", node_name_.c_str(), output_index, | GELOGD("Update node[%s] out[%d] shape from %s to %s.", node_name_.c_str(), output_index, | ||||
| shape_old.ToString().c_str(), shape_new.ToString().c_str()); | shape_old.ToString().c_str(), shape_new.ToString().c_str()); | ||||
| @@ -111,9 +113,9 @@ Status AicpuNodeTaskBase::UpdateShapeToOutputDesc(const GeShape &shape_new, | |||||
| auto origin_format = output_desc->GetOriginFormat(); | auto origin_format = output_desc->GetOriginFormat(); | ||||
| auto format = output_desc->GetFormat(); | auto format = output_desc->GetFormat(); | ||||
| if (origin_format == format) { | if (origin_format == format) { | ||||
| output_desc->SetOriginShape(shape_new); | |||||
| return SUCCESS; | |||||
| return task_context.GetNodeState()->UpdateOutputShapes(output_index, shape_new, shape_new); | |||||
| } | } | ||||
| // if format is not same need convert shape | // if format is not same need convert shape | ||||
| std::vector<int64_t> origin_dims_new; | std::vector<int64_t> origin_dims_new; | ||||
| auto trans_ret = formats::TransShape(format, shape_new.GetDims(), | auto trans_ret = formats::TransShape(format, shape_new.GetDims(), | ||||
| @@ -122,7 +124,8 @@ Status AicpuNodeTaskBase::UpdateShapeToOutputDesc(const GeShape &shape_new, | |||||
| "Node[%s] out[%d] originFormat[%d] is not same as format[%d], but TransShape failed, shape=%s.", | "Node[%s] out[%d] originFormat[%d] is not same as format[%d], but TransShape failed, shape=%s.", | ||||
| node_name_.c_str(), output_index, origin_format, format, shape_new.ToString().c_str()); | node_name_.c_str(), output_index, origin_format, format, shape_new.ToString().c_str()); | ||||
| auto origin_shape_new = GeShape(origin_dims_new); | auto origin_shape_new = GeShape(origin_dims_new); | ||||
| output_desc->SetOriginShape(origin_shape_new); | |||||
| GE_CHK_STATUS_RET(task_context.GetNodeState()->UpdateOutputShapes(output_index, shape_new, origin_shape_new), | |||||
| "Node[%s] failed to update update shape, index = %d", node_name_.c_str(), output_index); | |||||
| GELOGD("Node[%s] out[%d] originFormat[%d] is not same as format[%d], need update from %s ro %s.", | GELOGD("Node[%s] out[%d] originFormat[%d] is not same as format[%d], need update from %s ro %s.", | ||||
| node_name_.c_str(), output_index, origin_format, format, | node_name_.c_str(), output_index, origin_format, format, | ||||
| origin_shape_old.ToString().c_str(), origin_shape_new.ToString().c_str()); | origin_shape_old.ToString().c_str(), origin_shape_new.ToString().c_str()); | ||||
| @@ -136,7 +139,6 @@ Status AicpuNodeTaskBase::UpdateExtInfo() { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateExecuteMode(false), "UpdateExecuteMode failed."); | |||||
| for (auto i = 0; i < node_item_->num_inputs; ++i) { | for (auto i = 0; i < node_item_->num_inputs; ++i) { | ||||
| auto input_desc = node_item_->MutableInputDesc(i); | auto input_desc = node_item_->MutableInputDesc(i); | ||||
| GE_CHECK_NOTNULL(input_desc); | GE_CHECK_NOTNULL(input_desc); | ||||
| @@ -176,10 +178,14 @@ Status AicpuNodeTaskBase::UpdateArgs(TaskContext &context) { | |||||
| } | } | ||||
| GE_CHK_STATUS_RET(UpdateIoAddr(context), "Node[%s] update io addr failed.", node_name_.c_str()); | GE_CHK_STATUS_RET(UpdateIoAddr(context), "Node[%s] update io addr failed.", node_name_.c_str()); | ||||
| if (node_item_->is_dynamic) { | |||||
| // dynamic node need update ext info. | |||||
| bool all_shape = false; | |||||
| const OpDescPtr op_desc = node_item_->GetOpDesc(); | |||||
| (void)AttrUtils::GetBool(op_desc, kAicpuAllshape, all_shape); | |||||
| if (node_item_->is_dynamic || all_shape) { | |||||
| // dynamic node and all_shape kernel need update ext info. | |||||
| GE_CHK_STATUS_RET(UpdateExtInfo(), "Node[%s] update ext info failed.", node_name_.c_str()); | GE_CHK_STATUS_RET(UpdateExtInfo(), "Node[%s] update ext info failed.", node_name_.c_str()); | ||||
| } | } | ||||
| GELOGD("Node[%s] update args end.", node_name_.c_str()); | GELOGD("Node[%s] update args end.", node_name_.c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -513,7 +519,6 @@ Status AicpuTfNodeTask::UpdateShapeByHbmBuffer(TaskContext &context, | |||||
| node_name_.c_str(), node_item_->num_outputs, out_shape_hbm.size()); | node_name_.c_str(), node_item_->num_outputs, out_shape_hbm.size()); | ||||
| for (auto i = 0; i < node_item_->num_outputs; ++i) { | for (auto i = 0; i < node_item_->num_outputs; ++i) { | ||||
| const auto &result_summary = output_summary_host_[i]; | const auto &result_summary = output_summary_host_[i]; | ||||
| auto output_desc = node_item_->MutableOutputDesc(i); | |||||
| std::vector<int64_t> shape_dims; | std::vector<int64_t> shape_dims; | ||||
| if (result_summary.shape_data_size > 0) { | if (result_summary.shape_data_size > 0) { | ||||
| const auto &shape_hbm = out_shape_hbm[i]; | const auto &shape_hbm = out_shape_hbm[i]; | ||||
| @@ -531,7 +536,7 @@ Status AicpuTfNodeTask::UpdateShapeByHbmBuffer(TaskContext &context, | |||||
| GELOGD("Node[%s] [%d]th output dim[%u]=%ld.", node_name_.c_str(), i, dim_idx, shape_addr[dim_idx]); | GELOGD("Node[%s] [%d]th output dim[%u]=%ld.", node_name_.c_str(), i, dim_idx, shape_addr[dim_idx]); | ||||
| } | } | ||||
| } | } | ||||
| GE_CHK_STATUS_RET(UpdateShapeToOutputDesc(GeShape(shape_dims), i, output_desc), | |||||
| GE_CHK_STATUS_RET(UpdateShapeToOutputDesc(context, GeShape(shape_dims), i), | |||||
| "Node[%s] update [%d]th output shape failed.", | "Node[%s] update [%d]th output shape failed.", | ||||
| node_name_.c_str(), i); | node_name_.c_str(), i); | ||||
| } | } | ||||
| @@ -634,7 +639,7 @@ Status AicpuTfNodeTask::TaskCallback(TaskContext &context) { | |||||
| // check need update shape, call update shape. | // check need update shape, call update shape. | ||||
| if (unknown_type_ == DEPEND_SHAPE_RANGE) { | if (unknown_type_ == DEPEND_SHAPE_RANGE) { | ||||
| // check result | // check result | ||||
| callback_ret = UpdateOutputShapeFromExtInfo(); | |||||
| callback_ret = UpdateOutputShapeFromExtInfo(context); | |||||
| } else if (unknown_type_ == DEPEND_COMPUTE) { | } else if (unknown_type_ == DEPEND_COMPUTE) { | ||||
| callback_ret = UpdateShapeAndDataByResultSummary(context); | callback_ret = UpdateShapeAndDataByResultSummary(context); | ||||
| } | } | ||||
| @@ -781,7 +786,7 @@ Status AicpuNodeTask::TaskCallback(TaskContext &context) { | |||||
| // check need update shape, call update shape. | // check need update shape, call update shape. | ||||
| if (node_item_->is_dynamic && unknown_type_ == DEPEND_SHAPE_RANGE) { | if (node_item_->is_dynamic && unknown_type_ == DEPEND_SHAPE_RANGE) { | ||||
| // check result | // check result | ||||
| callback_ret = UpdateOutputShapeFromExtInfo(); | |||||
| callback_ret = UpdateOutputShapeFromExtInfo(context); | |||||
| } else { | } else { | ||||
| GELOGD("Node[%s] unknown shape type is %d no need update output shape.", | GELOGD("Node[%s] unknown shape type is %d no need update output shape.", | ||||
| node_name_.c_str(), unknown_type_); | node_name_.c_str(), unknown_type_); | ||||
| @@ -49,9 +49,9 @@ class AicpuNodeTaskBase : public NodeTask { | |||||
| virtual Status UpdateExtInfo(); | virtual Status UpdateExtInfo(); | ||||
| virtual Status UpdateOutputShapeFromExtInfo(); | |||||
| virtual Status UpdateOutputShapeFromExtInfo(TaskContext &task_context); | |||||
| Status UpdateShapeToOutputDesc(const GeShape &shape_new, int32_t output_index, GeTensorDescPtr &output_desc); | |||||
| Status UpdateShapeToOutputDesc(TaskContext &task_context, const GeShape &shape_new, int32_t output_index); | |||||
| virtual Status LaunchTask(TaskContext &context) = 0; | virtual Status LaunchTask(TaskContext &context) = 0; | ||||
| @@ -36,7 +36,7 @@ const std::map<std::string, std::vector<uint32_t>> | |||||
| {BROADCASTGRADIENTARGS, {}} | {BROADCASTGRADIENTARGS, {}} | ||||
| }; | }; | ||||
| const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE}; | |||||
| const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE, NOOP}; | |||||
| Status RefInputTask::UpdateArgs(TaskContext &) { | Status RefInputTask::UpdateArgs(TaskContext &) { | ||||
| // no need update args | // no need update args | ||||
| @@ -22,6 +22,8 @@ | |||||
| #include "graph/manager/util/hcom_util.h" | #include "graph/manager/util/hcom_util.h" | ||||
| #include "graph/runtime_inference_context.h" | #include "graph/runtime_inference_context.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "graph/types.h" | |||||
| #include "hccl/hcom.h" | |||||
| #include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -96,13 +98,13 @@ Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
| GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclRootId(op_desc, root_id), "GetHcclRootId failed"); | GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclRootId(op_desc, root_id), "GetHcclRootId failed"); | ||||
| } | } | ||||
| op_info.root = root_id; | op_info.root = root_id; | ||||
| auto callback = [this, op_desc](HcclResult status) { | |||||
| auto callback = [op_desc, done_callback](HcclResult status) { | |||||
| if (status != HCCL_SUCCESS) { | if (status != HCCL_SUCCESS) { | ||||
| GELOGE(HCCL_E_INTERNAL, "node %s call HcomExecEnqueueOperation failed, ret: 0x%X", | GELOGE(HCCL_E_INTERNAL, "node %s call HcomExecEnqueueOperation failed, ret: 0x%X", | ||||
| op_desc->GetName().c_str(), status); | op_desc->GetName().c_str(), status); | ||||
| } | } | ||||
| std::lock_guard<std::mutex> lock(this->hccl_mutex_); | |||||
| this->cond_.notify_all(); | |||||
| done_callback(); | |||||
| GELOGI("node %s hccl callback success.", op_desc->GetName().c_str()); | GELOGI("node %s hccl callback success.", op_desc->GetName().c_str()); | ||||
| }; | }; | ||||
| int32_t count = 0; | int32_t count = 0; | ||||
| @@ -119,11 +121,6 @@ Status HcclNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
| return HCCL_E_INTERNAL; | return HCCL_E_INTERNAL; | ||||
| } | } | ||||
| // pending until hccl finished | |||||
| std::unique_lock<std::mutex> ulock(hccl_mutex_); | |||||
| cond_.wait(ulock); | |||||
| GE_CHK_STATUS_RET_NOLOG(context.RegisterCallback(done_callback)); | |||||
| GELOGI("[%s] HcclNodeTask::ExecuteAsync success.", context.GetNodeName()); | GELOGI("[%s] HcclNodeTask::ExecuteAsync success.", context.GetNodeName()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -165,7 +162,8 @@ Status RdmaNodeTask::Init(TaskContext &context) { | |||||
| Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { | Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccessAddrInfo> &addr_infos) { | ||||
| RuntimeInferenceContext *ctx = nullptr; | RuntimeInferenceContext *ctx = nullptr; | ||||
| GE_CHK_STATUS_RET(RuntimeInferenceContext::GetContext(std::to_string(context.GetSessionId()), &ctx)); | |||||
| GE_CHK_STATUS_RET( | |||||
| RuntimeInferenceContext::GetContext(std::to_string(context.GetExecutionContext()->context_id), &ctx)); | |||||
| ge::Tensor remote_tensor; | ge::Tensor remote_tensor; | ||||
| GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); | GE_CHK_STATUS_RET(ctx->GetTensor(remote_index_.first, remote_index_.second, remote_tensor)); | ||||
| @@ -224,7 +222,7 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess | |||||
| Tensor offset_tensor; | Tensor offset_tensor; | ||||
| GE_CHK_STATUS_RET(ctx->GetTensor(offset_index_.first, offset_index_.second, offset_tensor)) | GE_CHK_STATUS_RET(ctx->GetTensor(offset_index_.first, offset_index_.second, offset_tensor)) | ||||
| if (static_cast<int64_t>(offset_tensor.GetSize() / GetSizeByDataType(data_type)) != row_num) { | if (static_cast<int64_t>(offset_tensor.GetSize() / GetSizeByDataType(data_type)) != row_num) { | ||||
| GELOGE(PARAM_INVALID, "num of offset and remote addr mismatch, offset size=%zu, remote_addr size=%lld, dtype=%s", | |||||
| GELOGE(PARAM_INVALID, "num of offset and remote addr mismatch, offset size=%zu, remote_addr size=%ld, dtype=%s", | |||||
| offset_tensor.GetSize(), row_num, TypeUtils::DataTypeToSerialString(data_type).c_str()); | offset_tensor.GetSize(), row_num, TypeUtils::DataTypeToSerialString(data_type).c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -246,7 +244,7 @@ Status RdmaNodeTask::ExtractTensor(TaskContext &context, vector<HcomRemoteAccess | |||||
| auto local_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tv->MutableData())); | auto local_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(tv->MutableData())); | ||||
| auto device_len = tv->GetSize() / row_num; | auto device_len = tv->GetSize() / row_num; | ||||
| if (device_len <= 0 || device_len > data[kVarTableIdxLen]) { | if (device_len <= 0 || device_len > data[kVarTableIdxLen]) { | ||||
| GELOGE(FAILED, "Local embedding length is out of range, expect %lld, but %lld exactly.", | |||||
| GELOGE(FAILED, "Local embedding length is out of range, expect %ld, but %ld exactly.", | |||||
| data[kVarTableIdxLen], device_len); | data[kVarTableIdxLen], device_len); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -282,12 +280,13 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| auto callback = [this](HcclResult status) { | |||||
| TaskContext *p_ctx = &context; | |||||
| auto callback = [p_ctx, done_callback](HcclResult status) { | |||||
| if (status != HCCL_SUCCESS) { | if (status != HCCL_SUCCESS) { | ||||
| GELOGE(HCCL_E_INTERNAL, "Call HcomExecInitialize failed, ret: 0x%X", status); | |||||
| GELOGE(HCCL_E_INTERNAL, "Call HcomExcutorInitialize failed, ret: 0x%X", status); | |||||
| p_ctx->SetStatus(FAILED); | |||||
| } | } | ||||
| std::lock_guard<std::mutex> lock(this->hccl_mutex_); | |||||
| this->cond_.notify_all(); | |||||
| done_callback(); | |||||
| GELOGI("rdma callback success."); | GELOGI("rdma callback success."); | ||||
| }; | }; | ||||
| @@ -297,15 +296,10 @@ Status RdmaNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do | |||||
| } | } | ||||
| HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); | HcclResult hccl_ret = HcomExecEnqueueRemoteAccess(context.GetNodeItem().NodeType(), addr_infos, callback); | ||||
| if (hccl_ret != HCCL_SUCCESS) { | if (hccl_ret != HCCL_SUCCESS) { | ||||
| GELOGE(HCCL_E_INTERNAL, "Call HcomExecInitialize failed, ret: 0x%X", hccl_ret); | |||||
| GELOGE(HCCL_E_INTERNAL, "Call HcomExcutorInitialize failed, ret: 0x%X", hccl_ret); | |||||
| return HCCL_E_INTERNAL; | return HCCL_E_INTERNAL; | ||||
| } | } | ||||
| // pending until hccl finished | |||||
| std::unique_lock<std::mutex> ulock(hccl_mutex_); | |||||
| cond_.wait(ulock); | |||||
| (void)context.RegisterCallback(done_callback); | |||||
| GELOGI("[%s] RdmaNodeTask::ExecuteAsync success.", context.GetNodeName()); | GELOGI("[%s] RdmaNodeTask::ExecuteAsync success.", context.GetNodeName()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -17,6 +17,7 @@ | |||||
| #include "rts_node_executor.h" | #include "rts_node_executor.h" | ||||
| #include "common/debug/log.h" | #include "common/debug/log.h" | ||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "common/types.h" | |||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "hybrid/model/hybrid_model.h" | #include "hybrid/model/hybrid_model.h" | ||||
| #include "runtime/rt.h" | #include "runtime/rt.h" | ||||
| @@ -50,6 +51,20 @@ Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status ReadVariableOpNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | |||||
| GELOGD("[%s] Start to execute.", context.GetNodeName()); | |||||
| for (int i = 0; i < context.NumInputs(); ++i) { | |||||
| GE_CHK_STATUS_RET(DoCopyTensor(context, i)); | |||||
| } | |||||
| if (done_callback) { | |||||
| GE_CHK_STATUS_RET(context.RegisterCallback(done_callback)); | |||||
| } | |||||
| GELOGD("[%s] Done executing successfully.", context.GetNodeName()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status IdentityNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | Status IdentityNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { | ||||
| GELOGD("[%s] Start to execute.", context.GetNodeName()); | GELOGD("[%s] Start to execute.", context.GetNodeName()); | ||||
| GE_CHK_STATUS_RET(DoCopyTensor(context, 0)); | GE_CHK_STATUS_RET(DoCopyTensor(context, 0)); | ||||
| @@ -111,6 +126,8 @@ Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, | |||||
| task = MakeShared<IdentityNodeTask>(); | task = MakeShared<IdentityNodeTask>(); | ||||
| } else if (op_type == IDENTITYN) { | } else if (op_type == IDENTITYN) { | ||||
| task = MakeShared<IdentityNNodeTask>(); | task = MakeShared<IdentityNNodeTask>(); | ||||
| } else if (op_type == READVARIABLEOP) { | |||||
| task = MakeShared<ReadVariableOpNodeTask>(); | |||||
| } else if (op_type == PROFILINGTRAININGTRACE) { | } else if (op_type == PROFILINGTRAININGTRACE) { | ||||
| auto *task_defs = model.GetTaskDefs(node); | auto *task_defs = model.GetTaskDefs(node); | ||||
| if (task_defs == nullptr || task_defs->empty()) { | if (task_defs == nullptr || task_defs->empty()) { | ||||
| @@ -36,6 +36,11 @@ class IdentityNNodeTask : public IdentityNodeTask { | |||||
| Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | ||||
| }; | }; | ||||
| class ReadVariableOpNodeTask : public IdentityNodeTask { | |||||
| public: | |||||
| Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; | |||||
| }; | |||||
| class ProfilingTraceNodeTask : public NodeTask { | class ProfilingTraceNodeTask : public NodeTask { | ||||
| public: | public: | ||||
| explicit ProfilingTraceNodeTask(const std::vector<domi::TaskDef> &task_defs) : task_defs_(task_defs) {} | explicit ProfilingTraceNodeTask(const std::vector<domi::TaskDef> &task_defs) : task_defs_(task_defs) {} | ||||
| @@ -27,10 +27,12 @@ | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| TaskContext::TaskContext(GraphExecutionContext *execution_context, | TaskContext::TaskContext(GraphExecutionContext *execution_context, | ||||
| const NodeItem *node_item, | |||||
| NodeState *node_state, | |||||
| SubgraphContext *subgraph_context) | SubgraphContext *subgraph_context) | ||||
| : node_item_(node_item), execution_context_(execution_context), subgraph_context_(subgraph_context) { | |||||
| } | |||||
| : node_state_(node_state), | |||||
| node_item_(node_state->GetNodeItem()), | |||||
| execution_context_(execution_context), | |||||
| subgraph_context_(subgraph_context) {} | |||||
| TaskContext::~TaskContext() { | TaskContext::~TaskContext() { | ||||
| GELOGD("[%s] TaskContext destroyed.", node_item_->NodeName().c_str()); | GELOGD("[%s] TaskContext destroyed.", node_item_->NodeName().c_str()); | ||||
| @@ -47,9 +49,10 @@ TaskContext::~TaskContext() { | |||||
| } | } | ||||
| } | } | ||||
| std::unique_ptr<TaskContext> TaskContext::Create(const NodeItem &node_item, | |||||
| std::unique_ptr<TaskContext> TaskContext::Create(NodeState *node_state, | |||||
| GraphExecutionContext *execution_context, | GraphExecutionContext *execution_context, | ||||
| SubgraphContext *subgraph_context) { | SubgraphContext *subgraph_context) { | ||||
| const NodeItem &node_item = *node_state->GetNodeItem(); | |||||
| GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.", | GELOGI("[%s] To create task context, input start = %d, num_inputs = %d, output start = %d, num_outputs = %d.", | ||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| node_item.input_start, | node_item.input_start, | ||||
| @@ -65,7 +68,7 @@ std::unique_ptr<TaskContext> TaskContext::Create(const NodeItem &node_item, | |||||
| } | } | ||||
| auto task_context = std::unique_ptr<TaskContext>( | auto task_context = std::unique_ptr<TaskContext>( | ||||
| new(std::nothrow)TaskContext(execution_context, &node_item, subgraph_context)); | |||||
| new(std::nothrow)TaskContext(execution_context, node_state, subgraph_context)); | |||||
| if (task_context == nullptr) { | if (task_context == nullptr) { | ||||
| GELOGE(MEMALLOC_FAILED, "[%s] Failed to create instance of TaskContext.", node_item.NodeName().c_str()); | GELOGE(MEMALLOC_FAILED, "[%s] Failed to create instance of TaskContext.", node_item.NodeName().c_str()); | ||||
| return nullptr; | return nullptr; | ||||
| @@ -154,7 +157,7 @@ Status TaskContext::RegisterCallback(const std::function<void()> &callback_fun) | |||||
| GELOGW("[%s] Callback is NULL", GetNodeName()); | GELOGW("[%s] Callback is NULL", GetNodeName()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| auto ret = execution_context_->callback_manager->RegisterCallback(callback_fun); | |||||
| auto ret = execution_context_->callback_manager->RegisterCallback(GetStream(), callback_fun); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "[%s] Failed to register callback", GetNodeName()); | GELOGE(ret, "[%s] Failed to register callback", GetNodeName()); | ||||
| execution_context_->callback_manager->Destroy(); | execution_context_->callback_manager->Destroy(); | ||||
| @@ -309,7 +312,7 @@ Status TaskContext::SetOutput(int index, const TensorValue &tensor) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| rtStream_t TaskContext::GetStream() { | |||||
| rtStream_t TaskContext::GetStream() const { | |||||
| return execution_context_->stream; | return execution_context_->stream; | ||||
| } | } | ||||
| @@ -536,6 +539,10 @@ Status TaskContext::SaveProfilingTaskDescInfo(uint32_t task_id, uint32_t stream | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| NodeState *TaskContext::GetNodeState() const { | |||||
| return node_state_; | |||||
| } | |||||
| Status TaskContext::SaveProfilingGraphDescInfo(uint32_t task_id, uint32_t stream_id) { | Status TaskContext::SaveProfilingGraphDescInfo(uint32_t task_id, uint32_t stream_id) { | ||||
| if (ProfilingManager::Instance().ProfilingModelExecuteOn()) { | if (ProfilingManager::Instance().ProfilingModelExecuteOn()) { | ||||
| const NodeItem &node_item = GetNodeItem(); | const NodeItem &node_item = GetNodeItem(); | ||||
| @@ -25,6 +25,7 @@ | |||||
| #include "framework/common/ge_types.h" | #include "framework/common/ge_types.h" | ||||
| #include "hybrid/common/tensor_value.h" | #include "hybrid/common/tensor_value.h" | ||||
| #include "hybrid/common/npu_memory_allocator.h" | #include "hybrid/common/npu_memory_allocator.h" | ||||
| #include "hybrid/executor/node_state.h" | |||||
| #include "hybrid/executor/rt_callback_manager.h" | #include "hybrid/executor/rt_callback_manager.h" | ||||
| #include "hybrid/model/node_item.h" | #include "hybrid/model/node_item.h" | ||||
| @@ -35,7 +36,7 @@ class SubgraphContext; | |||||
| class TaskContext { | class TaskContext { | ||||
| public: | public: | ||||
| static std::unique_ptr<TaskContext> Create(const NodeItem &node_item, | |||||
| static std::unique_ptr<TaskContext> Create(NodeState *node_state, | |||||
| GraphExecutionContext *execution_context, | GraphExecutionContext *execution_context, | ||||
| SubgraphContext *subgraph_context); | SubgraphContext *subgraph_context); | ||||
| @@ -45,6 +46,7 @@ class TaskContext { | |||||
| int NumOutputs() const; | int NumOutputs() const; | ||||
| size_t NumWorkspaces() const; | size_t NumWorkspaces() const; | ||||
| const NodeItem &GetNodeItem() const; | const NodeItem &GetNodeItem() const; | ||||
| NodeState *GetNodeState() const; | |||||
| const char *GetNodeName() const; | const char *GetNodeName() const; | ||||
| TensorValue *MutableInput(int index); | TensorValue *MutableInput(int index); | ||||
| ConstGeTensorDescPtr GetInputDesc(int index) const; | ConstGeTensorDescPtr GetInputDesc(int index) const; | ||||
| @@ -58,7 +60,7 @@ class TaskContext { | |||||
| const TensorValue *GetOutput(int index) const; | const TensorValue *GetOutput(int index) const; | ||||
| TensorValue *MutableOutput(int index); | TensorValue *MutableOutput(int index); | ||||
| TensorValue *GetVariable(const std::string &name); | TensorValue *GetVariable(const std::string &name); | ||||
| rtStream_t GetStream(); | |||||
| rtStream_t GetStream() const; | |||||
| int64_t GetSessionId() const; | int64_t GetSessionId() const; | ||||
| uint64_t GetIterationNumber() const; | uint64_t GetIterationNumber() const; | ||||
| @@ -119,12 +121,13 @@ class TaskContext { | |||||
| private: | private: | ||||
| TaskContext(GraphExecutionContext *execution_context, | TaskContext(GraphExecutionContext *execution_context, | ||||
| const NodeItem *node_item, | |||||
| NodeState *node_state, | |||||
| SubgraphContext *subgraph_context); | SubgraphContext *subgraph_context); | ||||
| static string TensorDesc2String(const GeTensorDesc &desc); | static string TensorDesc2String(const GeTensorDesc &desc); | ||||
| Status AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue &tensor, AllocationAttr *attr); | Status AllocateTensor(const GeTensorDesc &tensor_desc, TensorValue &tensor, AllocationAttr *attr); | ||||
| NodeState *node_state_ = nullptr; | |||||
| const NodeItem *node_item_ = nullptr; | const NodeItem *node_item_ = nullptr; | ||||
| bool force_infer_shape_ = false; | bool force_infer_shape_ = false; | ||||
| GraphExecutionContext *execution_context_; | GraphExecutionContext *execution_context_; | ||||
| @@ -44,6 +44,7 @@ | |||||
| #include "omm/csa_interact.h" | #include "omm/csa_interact.h" | ||||
| #include "runtime/kernel.h" | #include "runtime/kernel.h" | ||||
| #include "opskernel_manager/ops_kernel_builder_manager.h" | #include "opskernel_manager/ops_kernel_builder_manager.h" | ||||
| #include "external/runtime/rt_error_codes.h" | |||||
| using Json = nlohmann::json; | using Json = nlohmann::json; | ||||
| @@ -76,6 +77,13 @@ Status GELib::Initialize(const map<string, string> &options) { | |||||
| GELOGE(ret, "GeLib initial failed."); | GELOGE(ret, "GeLib initial failed."); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| ret = instancePtr_->SetAiCoreNum(new_options); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "GeLib initial: SetAiCoreNum failed."); | |||||
| return ret; | |||||
| } | |||||
| instancePtr_->SetDefaultPrecisionMode(new_options); | instancePtr_->SetDefaultPrecisionMode(new_options); | ||||
| if (new_options.find("ge.fpCeilingMode") == new_options.end()) { | if (new_options.find("ge.fpCeilingMode") == new_options.end()) { | ||||
| @@ -251,6 +259,24 @@ Status GELib::SetRTSocVersion(const map<string, string> &options, map<string, st | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GELib::SetAiCoreNum(map<string, string> &options) { | |||||
| // Already set or get AICORE_NUM from options in offline mode | |||||
| if (options.find(AICORE_NUM) != options.end()) { | |||||
| return SUCCESS; | |||||
| } | |||||
| uint32_t aicore_num = 0; | |||||
| rtError_t ret = rtGetAiCoreCount(&aicore_num); | |||||
| if (ret == ACL_ERROR_RT_FEATURE_NOT_SUPPORT) { // offline without ATC Input of AiCoreNum | |||||
| return SUCCESS; | |||||
| } else if (ret == RT_ERROR_NONE) { // online-mode | |||||
| options.emplace(std::make_pair(AICORE_NUM, std::to_string(aicore_num))); | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGE(FAILED, "rtGetAiCoreCount failed."); | |||||
| return FAILED; | |||||
| } | |||||
| void GELib::InitOptions(const map<string, string> &options) { | void GELib::InitOptions(const map<string, string> &options) { | ||||
| this->options_.session_id = 0; | this->options_.session_id = 0; | ||||
| auto iter = options.find(OPTION_EXEC_SESSION_ID); | auto iter = options.find(OPTION_EXEC_SESSION_ID); | ||||
| @@ -81,6 +81,7 @@ class GE_FUNC_VISIBILITY GELib { | |||||
| Status InnerInitialize(const map<string, string> &options); | Status InnerInitialize(const map<string, string> &options); | ||||
| Status SystemInitialize(const map<string, string> &options); | Status SystemInitialize(const map<string, string> &options); | ||||
| Status SetRTSocVersion(const map<string, string> &options, map<string, string> &new_options); | Status SetRTSocVersion(const map<string, string> &options, map<string, string> &new_options); | ||||
| Status SetAiCoreNum(map<string, string> &options); | |||||
| void SetDefaultPrecisionMode(map<string, string> &new_options); | void SetDefaultPrecisionMode(map<string, string> &new_options); | ||||
| void RollbackInit(); | void RollbackInit(); | ||||
| void InitOptions(const map<string, string> &options); | void InitOptions(const map<string, string> &options); | ||||
| @@ -13,14 +13,17 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #ifndef KEEP_DTYPE_OPTION_H_ | |||||
| #define KEEP_DTYPE_OPTION_H_ | |||||
| #ifndef ATTR_OPTIONS_H_ | |||||
| #define ATTR_OPTIONS_H_ | |||||
| #include <string> | #include <string> | ||||
| #include "graph/compute_graph.h" | #include "graph/compute_graph.h" | ||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| #include "graph/ge_error_codes.h" | |||||
| namespace ge { | namespace ge { | ||||
| Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype); | |||||
| bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name); | |||||
| graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | |||||
| graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const std::string &cfg_path); | |||||
| } // namespace | } // namespace | ||||
| #endif // KEEP_DTYPE_OPTION_H_ | |||||
| #endif // ATTR_OPTIONS_H_ | |||||
| @@ -13,7 +13,7 @@ | |||||
| * See the License for the specific language governing permissions and | * See the License for the specific language governing permissions and | ||||
| * limitations under the License. | * limitations under the License. | ||||
| */ | */ | ||||
| #include "keep_dtype_option.h" | |||||
| #include "attr_options.h" | |||||
| #include <fstream> | #include <fstream> | ||||
| #include <iostream> | #include <iostream> | ||||
| #include <sstream> | #include <sstream> | ||||
| @@ -26,20 +26,6 @@ namespace ge { | |||||
| namespace { | namespace { | ||||
| const size_t kMaxOpsNum = 10; | const size_t kMaxOpsNum = 10; | ||||
| } // namespace | } // namespace | ||||
| bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | |||||
| std::vector<std::string> original_op_names; | |||||
| if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { | |||||
| return false; | |||||
| } | |||||
| for (auto &origin_name : original_op_names) { | |||||
| if (origin_name == op_name) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| void KeepDtypeReportError(const std::vector<std::string> &invalid_list) { | void KeepDtypeReportError(const std::vector<std::string> &invalid_list) { | ||||
| std::stringstream err_msg; | std::stringstream err_msg; | ||||
| @@ -67,20 +53,20 @@ void KeepDtypeReportError(const std::vector<std::string> &invalid_list) { | |||||
| GELOGE(FAILED, "%s", err_msg.str().c_str()); | GELOGE(FAILED, "%s", err_msg.str().c_str()); | ||||
| } | } | ||||
| Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep_dtype) { | |||||
| graphStatus KeepDtypeFunc(ComputeGraphPtr &graph, const std::string &cfg_path) { | |||||
| GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
| if (keep_dtype.empty()) { | |||||
| return SUCCESS; | |||||
| if (cfg_path.empty()) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | } | ||||
| std::string real_path = RealPath(keep_dtype.c_str()); | |||||
| std::string real_path = RealPath(cfg_path.c_str()); | |||||
| if (real_path.empty()) { | if (real_path.empty()) { | ||||
| GELOGE(PARAM_INVALID, "Can not get real path for %s.", keep_dtype.c_str()); | |||||
| return PARAM_INVALID; | |||||
| GELOGE(GRAPH_PARAM_INVALID, "Can not get real path for %s.", cfg_path.c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | } | ||||
| std::ifstream ifs(real_path); | std::ifstream ifs(real_path); | ||||
| if (!ifs.is_open()) { | if (!ifs.is_open()) { | ||||
| GELOGE(FAILED, "Open file %s failed", keep_dtype.c_str()); | |||||
| return FAILED; | |||||
| GELOGE(GRAPH_FAILED, "Open file %s failed", cfg_path.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | } | ||||
| std::string op_name; | std::string op_name; | ||||
| @@ -108,9 +94,9 @@ Status DealKeepDtypeOption(const ComputeGraphPtr &graph, const std::string &keep | |||||
| if (!invalid_list.empty()) { | if (!invalid_list.empty()) { | ||||
| KeepDtypeReportError(invalid_list); | KeepDtypeReportError(invalid_list); | ||||
| return PARAM_INVALID; | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | } | ||||
| return SUCCESS; | |||||
| return GRAPH_SUCCESS; | |||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -0,0 +1,36 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "attr_options.h" | |||||
| #include <vector> | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| namespace ge { | |||||
| bool IsOriginalOpFind(OpDescPtr &op_desc, const std::string &op_name) { | |||||
| std::vector<std::string> original_op_names; | |||||
| if (!AttrUtils::GetListStr(op_desc, ATTR_NAME_DATA_DUMP_ORIGIN_OP_NAMES, original_op_names)) { | |||||
| return false; | |||||
| } | |||||
| for (auto &origin_name : original_op_names) { | |||||
| if (origin_name == op_name) { | |||||
| return true; | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,64 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "attr_options.h" | |||||
| #include <fstream> | |||||
| #include <iostream> | |||||
| #include <sstream> | |||||
| #include <vector> | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "framework/common/util.h" | |||||
| #include "common/util/error_manager/error_manager.h" | |||||
| namespace ge { | |||||
| graphStatus WeightCompressFunc(ComputeGraphPtr &graph, const string &cfg_path) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| if (cfg_path.empty()) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| std::string real_path = RealPath(cfg_path.c_str()); | |||||
| if (real_path.empty()) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "Can not get real path for %s.", cfg_path.c_str()); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| std::ifstream ifs(real_path); | |||||
| if (!ifs.is_open()) { | |||||
| GELOGE(GRAPH_FAILED, "Open file %s failed", cfg_path.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::string compress_nodes; | |||||
| ifs >> compress_nodes; | |||||
| ifs.close(); | |||||
| GELOGI("Compress weight of nodes: %s", compress_nodes.c_str()); | |||||
| vector<string> compress_node_vec = StringUtils::Split(compress_nodes, ';'); | |||||
| for (size_t i = 0; i < compress_node_vec.size(); ++i) { | |||||
| for (auto &node_ptr : graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node_ptr); | |||||
| auto op_desc = node_ptr->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if ((op_desc->GetName() == compress_node_vec[i]) || IsOriginalOpFind(op_desc, compress_node_vec[i])) { | |||||
| if (!ge::AttrUtils::SetBool(op_desc, ge::ATTR_NAME_COMPRESS_WEIGHT, true)) { | |||||
| GELOGE(GRAPH_FAILED, "node %s SetBool failed.", compress_node_vec[i].c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -39,6 +39,7 @@ | |||||
| #include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
| #include "graph/passes/net_output_pass.h" | #include "graph/passes/net_output_pass.h" | ||||
| #include "graph/passes/data_pass.h" | #include "graph/passes/data_pass.h" | ||||
| #include "ir_build/attr_options/attr_options.h" | |||||
| using std::string; | using std::string; | ||||
| using namespace std; | using namespace std; | ||||
| @@ -52,8 +53,28 @@ const std::string IR_OPTION_LOG_LEVEL_DEFAULT = "default"; | |||||
| const std::string IR_OPTION_BUFFER_OPTIMIZE_DEFAULT = "l2_optimize"; | const std::string IR_OPTION_BUFFER_OPTIMIZE_DEFAULT = "l2_optimize"; | ||||
| const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0"; | const std::string IR_OPTION_DISABLE_REUSE_MEMORY_DEFAULT = "0"; | ||||
| const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false"; | const std::string IR_OPTION_ENABLE_COMPRESS_WEIGHT_DEFAULT = "false"; | ||||
| const std::string KEEP_DTYPE_OPTION = "keep_dtype"; | |||||
| const std::string kInputShape = "input_shape"; | const std::string kInputShape = "input_shape"; | ||||
| const std::string kInputFormat = "input_format"; | const std::string kInputFormat = "input_format"; | ||||
| /** | |||||
| * @name SetOpAttrFun | |||||
| * @brief set attribute for operators in the configuration file | |||||
| * @param graph [IN/OUT] compute graph | |||||
| * @param cfg_path [IN] the config file path | |||||
| * @return graphStatus | |||||
| */ | |||||
| typedef graphStatus (*SetOpAttrFun)(ComputeGraphPtr &graph, const std::string &cfg_path); | |||||
| const std::map<aclgrphAttrType, SetOpAttrFun> kAttrTypeFuncMap = { | |||||
| {ATTR_TYPE_KEEP_DTYPE, KeepDtypeFunc}, | |||||
| {ATTR_TYPE_WEIGHT_COMPRESS, WeightCompressFunc} | |||||
| }; | |||||
| const std::map<aclgrphAttrType, std::string> kAttrTypeToStringMap = { | |||||
| {ATTR_TYPE_KEEP_DTYPE, KEEP_DTYPE_OPTION}, | |||||
| {ATTR_TYPE_WEIGHT_COMPRESS, ge::ir_option::COMPRESS_WEIGHT_CONF} | |||||
| }; | |||||
| } // namespace | } // namespace | ||||
| static graphStatus CheckGlobalOptions(std::map<std::string, std::string> &global_options) { | static graphStatus CheckGlobalOptions(std::map<std::string, std::string> &global_options) { | ||||
| @@ -703,4 +724,33 @@ graphStatus aclgrphGenerateForOp(const AscendString &op_type, const vector<Tenso | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| static std::string AttrTypeToSerialString(aclgrphAttrType attr_type) { | |||||
| auto it = kAttrTypeToStringMap.find(attr_type); | |||||
| if (it != kAttrTypeToStringMap.end()) { | |||||
| return it->second; | |||||
| } else { | |||||
| ErrorManager::GetInstance().ATCReportErrMessage("E19012", {"function", "reason"}, | |||||
| {"AttrTypeToSerialString", "attr_type[" + std::to_string(attr_type) + "] is not support"}); | |||||
| GELOGE(GRAPH_FAILED, "AttrTypeToSerialString: attr_type not support %u", attr_type); | |||||
| return "UNDEFINED"; | |||||
| } | |||||
| } | |||||
| graphStatus aclgrphSetOpAttr(Graph &graph, aclgrphAttrType attr_type, const char *cfg_path) { | |||||
| auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| if (cfg_path == nullptr) { | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| auto iter = kAttrTypeFuncMap.find(attr_type); | |||||
| if (iter == kAttrTypeFuncMap.end()) { | |||||
| GELOGE(GRAPH_FAILED, "attr type: %s is not support", AttrTypeToSerialString(attr_type).c_str()); | |||||
| return GRAPH_FAILED; | |||||
| } | |||||
| std::string path = cfg_path; | |||||
| return iter->second(compute_graph, path); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -10,7 +10,6 @@ protobuf_generate(ge PROTO_SRCS PROTO_HDRS ${PROTO_LIST}) | |||||
| set(SRC_LIST | set(SRC_LIST | ||||
| "main.cc" | "main.cc" | ||||
| "single_op_parser.cc" | "single_op_parser.cc" | ||||
| "keep_dtype_option.cc" | |||||
| "../session/omg.cc" | "../session/omg.cc" | ||||
| "../ir_build/atc_ir_common.cc" | "../ir_build/atc_ir_common.cc" | ||||
| ) | ) | ||||
| @@ -43,7 +43,7 @@ | |||||
| #include "parser/common/register_tbe.h" | #include "parser/common/register_tbe.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| #include "single_op_parser.h" | #include "single_op_parser.h" | ||||
| #include "keep_dtype_option.h" | |||||
| #include "external/ge/ge_ir_build.h" | |||||
| using domi::BuildMode; | using domi::BuildMode; | ||||
| using domi::OpRegistrationData; | using domi::OpRegistrationData; | ||||
| @@ -913,6 +913,22 @@ static Status ConvertModelToJson(int fwk_type, const string &model_file, const s | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| static Status SetAttrOptions(ge::Graph &graph) { | |||||
| if (!FLAGS_keep_dtype.empty()) { | |||||
| if (ge::aclgrphSetOpAttr(graph, ge::ATTR_TYPE_KEEP_DTYPE, FLAGS_keep_dtype.c_str()) != ge::GRAPH_SUCCESS) { | |||||
| return ge::FAILED; | |||||
| } | |||||
| } | |||||
| if (!FLAGS_compress_weight_conf.empty()) { | |||||
| if (ge::aclgrphSetOpAttr(graph, ge::ATTR_TYPE_WEIGHT_COMPRESS, FLAGS_compress_weight_conf.c_str()) | |||||
| != ge::GRAPH_SUCCESS) { | |||||
| return ge::FAILED; | |||||
| } | |||||
| } | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| domi::Status GenerateModel(std::map<string, string> &options, std::string output) { | domi::Status GenerateModel(std::map<string, string> &options, std::string output) { | ||||
| ge::GeGenerator ge_generator; | ge::GeGenerator ge_generator; | ||||
| ge::Status geRet = ge::SUCCESS; | ge::Status geRet = ge::SUCCESS; | ||||
| @@ -969,7 +985,6 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output | |||||
| atc_params.insert(std::pair<string, string>("input_fp16_nodes", FLAGS_input_fp16_nodes)); | atc_params.insert(std::pair<string, string>("input_fp16_nodes", FLAGS_input_fp16_nodes)); | ||||
| atc_params.insert(std::pair<string, string>("is_input_adjust_hw_layout", FLAGS_is_input_adjust_hw_layout)); | atc_params.insert(std::pair<string, string>("is_input_adjust_hw_layout", FLAGS_is_input_adjust_hw_layout)); | ||||
| atc_params.insert(std::pair<string, string>("is_output_adjust_hw_layout", FLAGS_is_output_adjust_hw_layout)); | atc_params.insert(std::pair<string, string>("is_output_adjust_hw_layout", FLAGS_is_output_adjust_hw_layout)); | ||||
| atc_params.insert(std::pair<string, string>("compress_weight_conf", FLAGS_compress_weight_conf)); | |||||
| atc_params.insert(std::pair<string, string>(string(ge::OUTPUT_DATATYPE), FLAGS_output_type)); | atc_params.insert(std::pair<string, string>(string(ge::OUTPUT_DATATYPE), FLAGS_output_type)); | ||||
| atc_params.insert(std::pair<string, string>("output", output)); | atc_params.insert(std::pair<string, string>("output", output)); | ||||
| @@ -1003,11 +1018,10 @@ domi::Status GenerateModel(std::map<string, string> &options, std::string output | |||||
| } | } | ||||
| } | } | ||||
| Status ret = ge::DealKeepDtypeOption(ge::GraphUtils::GetComputeGraph(graph), FLAGS_keep_dtype); | |||||
| if (ret != SUCCESS) { | |||||
| if (SetAttrOptions(graph) != ge::SUCCESS) { | |||||
| (void)ge_generator.Finalize(); | (void)ge_generator.Finalize(); | ||||
| (void)ge::GELib::GetInstance()->Finalize(); | (void)ge::GELib::GetInstance()->Finalize(); | ||||
| return ret; | |||||
| return domi::FAILED; | |||||
| } | } | ||||
| geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); | geRet = ge_generator.GenerateOfflineModel(graph, output, inputs); | ||||
| @@ -1347,10 +1361,10 @@ bool CheckMemInfo() { | |||||
| } | } | ||||
| // only check current available mem when auto_tune_mode is set. | // only check current available mem when auto_tune_mode is set. | ||||
| long current_mem_available = GetMemInfo("MemAvailable"); | long current_mem_available = GetMemInfo("MemAvailable"); | ||||
| GELOGI("Get mem available [%lu].", current_mem_available); | |||||
| GELOGI("Get mem available [%lu kB].", current_mem_available); | |||||
| std::cout << "Current available mem is " << current_mem_available << "kB." << std::endl; | std::cout << "Current available mem is " << current_mem_available << "kB." << std::endl; | ||||
| if ((current_mem_available > 0) && (current_mem_available < kMinAvailableMem)) { | if ((current_mem_available > 0) && (current_mem_available < kMinAvailableMem)) { | ||||
| GELOGE(ge::PARAM_INVALID, "Current available mem [%lu] can not be smaller than [%lu] .", | |||||
| GELOGE(ge::PARAM_INVALID, "Current available mem [%lu kB] can not be smaller than [%lu kB] .", | |||||
| current_mem_available, kMinAvailableMem); | current_mem_available, kMinAvailableMem); | ||||
| ErrorManager::GetInstance().ATCReportErrMessage("E10044", {"value", "min_value"}, | ErrorManager::GetInstance().ATCReportErrMessage("E10044", {"value", "min_value"}, | ||||
| {to_string(current_mem_available), to_string(kMinAvailableMem)}); | {to_string(current_mem_available), to_string(kMinAvailableMem)}); | ||||
| @@ -1406,7 +1420,7 @@ int main(int argc, char* argv[]) { | |||||
| if (result != 0) { | if (result != 0) { | ||||
| DOMI_LOGE("ErrorManager outputErrMessage fail !"); | DOMI_LOGE("ErrorManager outputErrMessage fail !"); | ||||
| } | } | ||||
| GELOGI("Current mem available mem is [%lu]", GetMemInfo("MemAvailable")); | |||||
| GELOGI("Current mem available mem is [%lu kB]", GetMemInfo("MemAvailable")); | |||||
| return ret; | return ret; | ||||
| } else { | } else { | ||||
| std::cout << "ATC run success, welcome to the next use." << std::endl; | std::cout << "ATC run success, welcome to the next use." << std::endl; | ||||
| @@ -10,7 +10,6 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||||
| LOCAL_SRC_FILES := \ | LOCAL_SRC_FILES := \ | ||||
| main.cc \ | main.cc \ | ||||
| keep_dtype_option.cc \ | |||||
| single_op_parser.cc \ | single_op_parser.cc \ | ||||
| ../session/omg.cc \ | ../session/omg.cc \ | ||||
| ../ir_build/atc_ir_common.cc \ | ../ir_build/atc_ir_common.cc \ | ||||
| @@ -64,7 +63,6 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||||
| LOCAL_SRC_FILES := \ | LOCAL_SRC_FILES := \ | ||||
| main.cc \ | main.cc \ | ||||
| keep_dtype_option.cc \ | |||||
| single_op_parser.cc \ | single_op_parser.cc \ | ||||
| ../session/omg.cc \ | ../session/omg.cc \ | ||||
| ../ir_build/atc_ir_common.cc \ | ../ir_build/atc_ir_common.cc \ | ||||
| @@ -118,7 +116,6 @@ LOCAL_CFLAGS += -DPROTOBUF_INLINE_NOT_IN_HEADERS=0 -DCOMPILE_OMG_PACKAGE -O2 -Dg | |||||
| LOCAL_SRC_FILES := \ | LOCAL_SRC_FILES := \ | ||||
| main.cc \ | main.cc \ | ||||
| keep_dtype_option.cc \ | |||||
| single_op_parser.cc \ | single_op_parser.cc \ | ||||
| ../session/omg.cc \ | ../session/omg.cc \ | ||||
| ../ir_build/atc_ir_common.cc \ | ../ir_build/atc_ir_common.cc \ | ||||
| @@ -193,44 +193,6 @@ static Status CheckInputFp16Nodes(const ComputeGraphPtr &graph, const string &in | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| static Status SetWeightCompressNodes(const ComputeGraphPtr &graph, const string &compress_weight_conf) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| if (compress_weight_conf.empty()) { | |||||
| return SUCCESS; | |||||
| } | |||||
| std::string real_path = RealPath(compress_weight_conf.c_str()); | |||||
| if (real_path.empty()) { | |||||
| GELOGE(PARAM_INVALID, "Can not get real path for %s.", compress_weight_conf.c_str()); | |||||
| return PARAM_INVALID; | |||||
| } | |||||
| std::ifstream ifs(real_path); | |||||
| if (!ifs.is_open()) { | |||||
| GELOGE(domi::FAILED, "Open file %s failed", compress_weight_conf.c_str()); | |||||
| return domi::FAILED; | |||||
| } | |||||
| std::string compress_nodes; | |||||
| ifs >> compress_nodes; | |||||
| ifs.close(); | |||||
| GELOGI("Compress weight of nodes: %s", compress_nodes.c_str()); | |||||
| vector<string> compress_node_vec = StringUtils::Split(compress_nodes, ';'); | |||||
| for (size_t i = 0; i < compress_node_vec.size(); ++i) { | |||||
| ge::NodePtr node = graph->FindNode(compress_node_vec[i]); | |||||
| if (node == nullptr) { | |||||
| GELOGW("node %s is not in graph", compress_node_vec[i].c_str()); | |||||
| continue; | |||||
| } | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if (!ge::AttrUtils::SetBool(op_desc, ge::ATTR_NAME_COMPRESS_WEIGHT, true)) { | |||||
| GELOGE(domi::FAILED, "node %s SetBool failed.", compress_node_vec[i].c_str()); | |||||
| return domi::FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| static Status ParseOutputFp16NodesFormat(const string &is_output_fp16) { | static Status ParseOutputFp16NodesFormat(const string &is_output_fp16) { | ||||
| if (is_output_fp16.empty()) { | if (is_output_fp16.empty()) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -800,10 +762,6 @@ FMK_FUNC_HOST_VISIBILITY Status ParseGraph(ge::Graph &graph, const std::map<stri | |||||
| GE_RETURN_IF_ERROR(CheckInputShapeNode(compute_graph, is_dynamic_input, run_mode)); | GE_RETURN_IF_ERROR(CheckInputShapeNode(compute_graph, is_dynamic_input, run_mode)); | ||||
| std::string compress_weight_conf; | |||||
| ParseAtcParms(atc_params, "compress_weight_conf", compress_weight_conf); | |||||
| GE_RETURN_IF_ERROR(SetWeightCompressNodes(compute_graph, compress_weight_conf)); | |||||
| // Verify the contents of the op_name_map | // Verify the contents of the op_name_map | ||||
| if (op_conf != nullptr && *op_conf != '\0') { | if (op_conf != nullptr && *op_conf != '\0') { | ||||
| GE_RETURN_WITH_LOG_IF_ERROR(CheckOpNameMap(compute_graph, op_conf), | GE_RETURN_WITH_LOG_IF_ERROR(CheckOpNameMap(compute_graph, op_conf), | ||||
| @@ -43,20 +43,21 @@ using std::vector; | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const size_t kDataOutputNum = 1; | const size_t kDataOutputNum = 1; | ||||
| } // namespace | |||||
| static Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { | |||||
| auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | |||||
| for (const auto &node : comp_graph->GetAllNodes()) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| const auto &depends = op_desc->GetOpInferDepends(); | |||||
| if (!depends.empty()) { | |||||
| flag = true; | |||||
| return SUCCESS; | |||||
| bool NeedHybridModel(GeModelPtr &ge_model) { | |||||
| auto tasks = ge_model->GetModelTaskDefPtr()->task(); | |||||
| int32_t kernel_task_num = 0; | |||||
| for (int i = 0; i < tasks.size(); ++i) { | |||||
| if (static_cast<rtModelTaskType_t>(tasks[i].type()) == RT_MODEL_TASK_KERNEL) { | |||||
| kernel_task_num++; | |||||
| if (kernel_task_num > 1) { | |||||
| return true; | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | |||||
| return false; | |||||
| } | } | ||||
| } // namespace | |||||
| SingleOpModel::SingleOpModel(const std::string &model_name, const void *model_data, uint32_t model_size) | SingleOpModel::SingleOpModel(const std::string &model_name, const void *model_data, uint32_t model_size) | ||||
| : model_name_(model_name), ori_model_data_(model_data), ori_model_size_(model_size) {} | : model_name_(model_name), ori_model_data_(model_data), ori_model_size_(model_size) {} | ||||
| @@ -497,9 +498,7 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp & | |||||
| auto ge_model = model_helper_.GetGeModel(); | auto ge_model = model_helper_.GetGeModel(); | ||||
| GE_CHECK_NOTNULL(ge_model); | GE_CHECK_NOTNULL(ge_model); | ||||
| bool infer_depend_flag = false; | |||||
| GE_CHK_STATUS_RET_NOLOG(IfInferDepend(ge_model, infer_depend_flag)); | |||||
| if (ge_model->GetModelTaskDefPtr()->task_size() > 1 || infer_depend_flag) { | |||||
| if (NeedHybridModel(ge_model)) { | |||||
| GELOGD("Build single op HybridModel."); | GELOGD("Build single op HybridModel."); | ||||
| GE_CHK_STATUS_RET_NOLOG(hybrid::NodeExecutorManager::GetInstance().EnsureInitialized()); | GE_CHK_STATUS_RET_NOLOG(hybrid::NodeExecutorManager::GetInstance().EnsureInitialized()); | ||||
| auto root_model = model_helper_.GetGeRootModel(); | auto root_model = model_helper_.GetGeRootModel(); | ||||
| @@ -1,4 +1,4 @@ | |||||
| #!/usr/bin/python3.7 | |||||
| #!/usr/bin/python3 | |||||
| # -*- coding: UTF-8 -*- | # -*- coding: UTF-8 -*- | ||||
| #------------------------------------------------------------------- | #------------------------------------------------------------------- | ||||
| # Purpose: | # Purpose: | ||||
| @@ -219,6 +219,9 @@ const std::string HCOM_PARALLEL = "ge.hcomParallel"; | |||||
| // configure whether to use dynamic batch size | // configure whether to use dynamic batch size | ||||
| const char *const kDynamicBatchSize = "ge.dynamicBatchSize"; | const char *const kDynamicBatchSize = "ge.dynamicBatchSize"; | ||||
| // configure threshold of fusion data size for communication op | |||||
| const std::string FUSION_TENSOR_SIZE = "ge.fusionTensorSize"; | |||||
| const std::string INPUT_SHAPE = "ge.inputShape"; | const std::string INPUT_SHAPE = "ge.inputShape"; | ||||
| const std::string DYNAMIC_NODE_TYPE = "ge.dynamicNodeType"; | const std::string DYNAMIC_NODE_TYPE = "ge.dynamicNodeType"; | ||||
| @@ -50,6 +50,8 @@ struct ModelBufferData { | |||||
| uint64_t length; | uint64_t length; | ||||
| }; | }; | ||||
| enum aclgrphAttrType { ATTR_TYPE_KEEP_DTYPE = 0, ATTR_TYPE_WEIGHT_COMPRESS }; | |||||
| /** | /** | ||||
| * @ingroup AscendCL | * @ingroup AscendCL | ||||
| * @brief build model.Notice the model is stored in buffer | * @brief build model.Notice the model is stored in buffer | ||||
| @@ -80,13 +82,16 @@ GE_FUNC_VISIBILITY void aclgrphBuildFinalize(); | |||||
| * @retval GRAPH_SUCCESS The function is successfully executed. | * @retval GRAPH_SUCCESS The function is successfully executed. | ||||
| * @retval OtherValues Failure | * @retval OtherValues Failure | ||||
| */ | */ | ||||
| ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &, const std::map<AscendString, AscendString> &, | |||||
| ModelBufferData &)) | |||||
| GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<std::string, std::string> &build_options, | |||||
| ModelBufferData &model); | |||||
| ATTRIBUTED_DEPRECATED(GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &, | |||||
| const std::map<AscendString, AscendString> &, | |||||
| ModelBufferData &)) | |||||
| GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, | |||||
| const std::map<std::string, std::string> &build_options, | |||||
| ModelBufferData &model); | |||||
| GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, const std::map<AscendString, AscendString> &build_options, | |||||
| ModelBufferData &model); | |||||
| GE_FUNC_VISIBILITY graphStatus aclgrphBuildModel(const ge::Graph &graph, | |||||
| const std::map<AscendString, AscendString> &build_options, | |||||
| ModelBufferData &model); | |||||
| /** | /** | ||||
| * @ingroup AscendCL | * @ingroup AscendCL | ||||
| @@ -138,7 +143,17 @@ GE_FUNC_VISIBILITY graphStatus aclgrphDumpGraph(const ge::Graph &graph, const ch | |||||
| * @retval OtherValues Failure | * @retval OtherValues Failure | ||||
| */ | */ | ||||
| GE_FUNC_VISIBILITY graphStatus aclgrphGenerateForOp(const AscendString &op_type, const std::vector<TensorDesc> &inputs, | GE_FUNC_VISIBILITY graphStatus aclgrphGenerateForOp(const AscendString &op_type, const std::vector<TensorDesc> &inputs, | ||||
| const std::vector<TensorDesc> &outputs, Graph &graph); | |||||
| const std::vector<TensorDesc> &outputs, Graph &graph); | |||||
| /** | |||||
| * @name aclgrphSetOpAttr | |||||
| * @brief set attribute for operators in the configuration file | |||||
| * @param graph [IN/OUT] compute graph | |||||
| * @param attr_type [In] attribute type | |||||
| * @param cfg_path [IN] the config file path | |||||
| * @return graphStatus | |||||
| */ | |||||
| GE_FUNC_VISIBILITY graphStatus aclgrphSetOpAttr(Graph &graph, aclgrphAttrType attr_type, const char *cfg_path); | |||||
| }; // namespace ge | }; // namespace ge | ||||
| #endif // INC_EXTERNAL_GE_IR_BUILD_H_ | #endif // INC_EXTERNAL_GE_IR_BUILD_H_ | ||||
| @@ -1 +1 @@ | |||||
| Subproject commit 2596725889c19c60a03440ab9e4e313070326ec0 | |||||
| Subproject commit 40e2d5c974eda1d1f5716b18fc776dede7da4370 | |||||
| @@ -1 +1 @@ | |||||
| Subproject commit 6516132e2eaeea2bf51cc790d52c83709588f5d8 | |||||
| Subproject commit 3c534dc831eeedd13ad86d9c2b52879f345403e0 | |||||
| @@ -354,6 +354,11 @@ rtError_t rtGetSocVersion(char *version, const uint32_t maxLen) | |||||
| return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
| } | } | ||||
| rtError_t rtGetAiCoreCount(uint32_t *aiCoreCnt) | |||||
| { | |||||
| return RT_ERROR_NONE; | |||||
| } | |||||
| rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback) | rtError_t rtSetTaskFailCallback(rtTaskFailCallback callback) | ||||
| { | { | ||||
| return RT_ERROR_NONE; | return RT_ERROR_NONE; | ||||
| @@ -49,13 +49,13 @@ include_directories(${GE_CODE_DIR}/metadef) | |||||
| include_directories(${GE_CODE_DIR}/metadef/graph) | include_directories(${GE_CODE_DIR}/metadef/graph) | ||||
| include_directories(${GE_CODE_DIR}/inc/external) | include_directories(${GE_CODE_DIR}/inc/external) | ||||
| include_directories(${GE_CODE_DIR}/metadef/inc/external) | include_directories(${GE_CODE_DIR}/metadef/inc/external) | ||||
| include_directories(${GE_CODE_DIR}/parser) | |||||
| include_directories(${GE_CODE_DIR}/parser/parser) | |||||
| include_directories(${GE_CODE_DIR}/metadef/inc/external/graph) | include_directories(${GE_CODE_DIR}/metadef/inc/external/graph) | ||||
| include_directories(${GE_CODE_DIR}/metadef/inc/graph) | include_directories(${GE_CODE_DIR}/metadef/inc/graph) | ||||
| include_directories(${GE_CODE_DIR}/inc/framework) | include_directories(${GE_CODE_DIR}/inc/framework) | ||||
| include_directories(${GE_CODE_DIR}/metadef/inc/common) | include_directories(${GE_CODE_DIR}/metadef/inc/common) | ||||
| include_directories(${GE_CODE_DIR}/metadef/third_party) | include_directories(${GE_CODE_DIR}/metadef/third_party) | ||||
| include_directories(${GE_CODE_DIR}/parser) | |||||
| include_directories(${GE_CODE_DIR}/parser/parser) | |||||
| include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc) | ||||
| include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/cce) | include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/cce) | ||||
| include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops) | include_directories(${GE_CODE_DIR}/third_party/fwkacllib/inc/ops) | ||||
| @@ -65,25 +65,9 @@ include_directories(${CMAKE_BINARY_DIR}) | |||||
| include_directories(${CMAKE_BINARY_DIR}/proto/ge) | include_directories(${CMAKE_BINARY_DIR}/proto/ge) | ||||
| include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) | include_directories(${CMAKE_BINARY_DIR}/proto/ge/proto) | ||||
| set(COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/common/properties_manager.cc" | |||||
| "${GE_CODE_DIR}/ge/common/ge/plugin_manager.cc" | |||||
| "${GE_CODE_DIR}/ge/common/ge/tbe_plugin_manager.cc" | |||||
| set(GRAPH_SRC_FILES | |||||
| "${GE_CODE_DIR}/metadef/graph/option/ge_local_context.cc" | "${GE_CODE_DIR}/metadef/graph/option/ge_local_context.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/option/ge_context.cc" | "${GE_CODE_DIR}/metadef/graph/option/ge_context.cc" | ||||
| "${GE_CODE_DIR}/ge/common/types.cc" | |||||
| "${GE_CODE_DIR}/parser/parser/common/op_map.cc" | |||||
| "${GE_CODE_DIR}/ge/common/fmk_error_codes.cc" | |||||
| "${GE_CODE_DIR}/ge/common/op/ge_op_utils.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/manager/util/variable_accelerate_ctrl.cc" | |||||
| "${GE_CODE_DIR}/ge/opskernel_manager/ops_kernel_manager.cc" | |||||
| "${GE_CODE_DIR}/ge/generator/ge_generator.cc" | |||||
| "${GE_CODE_DIR}/ge/generator/generator_api.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/common/omg_util.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/common/bcast.cc" | |||||
| "${GE_CODE_DIR}/ge/common/util.cc" | |||||
| "${GE_CODE_DIR}/ge/common/ge/op_tiling_manager.cc" | |||||
| "${GE_CODE_DIR}/ge/init/gelib.cc" | |||||
| "${GE_CODE_DIR}/metadef/graph/ge_attr_define.cc" | "${GE_CODE_DIR}/metadef/graph/ge_attr_define.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/anchor.cc" | "${GE_CODE_DIR}/metadef/graph/anchor.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/ge_attr_value.cc" | "${GE_CODE_DIR}/metadef/graph/ge_attr_value.cc" | ||||
| @@ -128,6 +112,38 @@ set(COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/metadef/register/tensor_assign.cpp" | "${GE_CODE_DIR}/metadef/register/tensor_assign.cpp" | ||||
| "${GE_CODE_DIR}/metadef/register/register_format_transfer.cc" | "${GE_CODE_DIR}/metadef/register/register_format_transfer.cc" | ||||
| "${GE_CODE_DIR}/metadef/graph/format_refiner.cc" | "${GE_CODE_DIR}/metadef/graph/format_refiner.cc" | ||||
| "${GE_CODE_DIR}/metadef/register/ops_kernel_builder_registry.cc" | |||||
| "${GE_CODE_DIR}/metadef/register/op_tiling.cpp" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/tuning_utils.cc" | |||||
| "${GE_CODE_DIR}/metadef/register/op_tiling_registry.cpp" | |||||
| ) | |||||
| set(PARSER_SRC_FILES | |||||
| "${GE_CODE_DIR}/parser/parser/common/op_map.cc" | |||||
| "${GE_CODE_DIR}/parser/parser/common/pre_checker.cc" | |||||
| "${GE_CODE_DIR}/parser/parser/common/convert/pb2json.cc" | |||||
| "${GE_CODE_DIR}/parser/parser/common/parser_factory.cc" | |||||
| "${GE_CODE_DIR}/parser/parser/common/model_saver.cc" | |||||
| "${GE_CODE_DIR}/parser/parser/common/parser_types.cc" | |||||
| "${GE_CODE_DIR}/parser/parser/common/parser_inner_ctx.cc" | |||||
| ) | |||||
| set(COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/common/properties_manager.cc" | |||||
| "${GE_CODE_DIR}/ge/common/ge/plugin_manager.cc" | |||||
| "${GE_CODE_DIR}/ge/common/ge/tbe_plugin_manager.cc" | |||||
| "${GE_CODE_DIR}/ge/common/types.cc" | |||||
| "${GE_CODE_DIR}/ge/common/fmk_error_codes.cc" | |||||
| "${GE_CODE_DIR}/ge/common/op/ge_op_utils.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/manager/util/variable_accelerate_ctrl.cc" | |||||
| "${GE_CODE_DIR}/ge/opskernel_manager/ops_kernel_manager.cc" | |||||
| "${GE_CODE_DIR}/ge/generator/ge_generator.cc" | |||||
| "${GE_CODE_DIR}/ge/generator/generator_api.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/common/omg_util.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/common/bcast.cc" | |||||
| "${GE_CODE_DIR}/ge/common/util.cc" | |||||
| "${GE_CODE_DIR}/ge/common/ge/op_tiling_manager.cc" | |||||
| "${GE_CODE_DIR}/ge/init/gelib.cc" | |||||
| "${GE_CODE_DIR}/ge/engine_manager/dnnengine_manager.cc" | "${GE_CODE_DIR}/ge/engine_manager/dnnengine_manager.cc" | ||||
| "${GE_CODE_DIR}/ge/opskernel_manager/ops_kernel_manager.cc" | "${GE_CODE_DIR}/ge/opskernel_manager/ops_kernel_manager.cc" | ||||
| "${GE_CODE_DIR}/ge/session/session_manager.cc" | "${GE_CODE_DIR}/ge/session/session_manager.cc" | ||||
| @@ -186,7 +202,6 @@ set(COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/passes/atomic_addr_clean_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/atomic_addr_clean_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/mark_same_addr_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/mark_same_addr_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/mark_graph_unknown_status_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/mark_graph_unknown_status_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/dynamic_single_op_reset_shape_pass.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/passes/mark_agnostic_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/mark_agnostic_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/dimension_compute_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/dimension_compute_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/dimension_adjust_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/dimension_adjust_pass.cc" | ||||
| @@ -274,6 +289,9 @@ set(COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/graph/partition/graph_partition.cc" | "${GE_CODE_DIR}/ge/graph/partition/graph_partition.cc" | ||||
| "${GE_CODE_DIR}/ge/common/helper/model_cache_helper.cc" | "${GE_CODE_DIR}/ge/common/helper/model_cache_helper.cc" | ||||
| "${GE_CODE_DIR}/ge/ir_build/ge_ir_build.cc" | "${GE_CODE_DIR}/ge/ir_build/ge_ir_build.cc" | ||||
| "${GE_CODE_DIR}/ge/ir_build/attr_options/utils.cc" | |||||
| "${GE_CODE_DIR}/ge/ir_build/attr_options/keep_dtype_option.cc" | |||||
| "${GE_CODE_DIR}/ge/ir_build/attr_options/weight_compress_option.cc" | |||||
| "${GE_CODE_DIR}/ge/graph/build/label_allocator.cc" | "${GE_CODE_DIR}/ge/graph/build/label_allocator.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" | "${GE_CODE_DIR}/ge/graph/passes/memcpy_addr_async_pass.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc" | "${GE_CODE_DIR}/ge/graph/partition/stage_partition.cc" | ||||
| @@ -312,17 +330,7 @@ set(COMMON_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/common/model_saver.cc" | "${GE_CODE_DIR}/ge/common/model_saver.cc" | ||||
| "${GE_CODE_DIR}/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc" | "${GE_CODE_DIR}/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc" | ||||
| "${GE_CODE_DIR}/ge/common/ge/datatype_util.cc" | "${GE_CODE_DIR}/ge/common/ge/datatype_util.cc" | ||||
| "${GE_CODE_DIR}/metadef/register/ops_kernel_builder_registry.cc" | |||||
| "${GE_CODE_DIR}/metadef/register/op_tiling.cpp" | |||||
| "${GE_CODE_DIR}/metadef/graph/utils/tuning_utils.cc" | |||||
| "${GE_CODE_DIR}/metadef/register/op_tiling_registry.cpp" | |||||
| "${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" | "${GE_CODE_DIR}/ge/ge_local_engine/engine/host_cpu_engine.cc" | ||||
| "${GE_CODE_DIR}/parser/parser/common/pre_checker.cc" | |||||
| "${GE_CODE_DIR}/parser/parser/common/convert/pb2json.cc" | |||||
| "${GE_CODE_DIR}/parser/parser/common/parser_factory.cc" | |||||
| "${GE_CODE_DIR}/parser/parser/common/model_saver.cc" | |||||
| "${GE_CODE_DIR}/parser/parser/common/parser_types.cc" | |||||
| "${GE_CODE_DIR}/parser/parser/common/parser_inner_ctx.cc" | |||||
| "${GE_CODE_DIR}/ge/session/omg.cc" | "${GE_CODE_DIR}/ge/session/omg.cc" | ||||
| ) | ) | ||||
| @@ -345,6 +353,7 @@ set(COMMON_FORMAT_SRC_FILES | |||||
| "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc" | "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_nhwc.cc" | ||||
| "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc" | "${GE_CODE_DIR}/ge/common/formats/format_transfers/format_transfer_fracz_hwcn.cc" | ||||
| "${GE_CODE_DIR}/ge/common/formats/utils/formats_trans_utils.cc" | "${GE_CODE_DIR}/ge/common/formats/utils/formats_trans_utils.cc" | ||||
| "${GE_CODE_DIR}/ge/graph/manager/util/hcom_util.cc" | |||||
| ) | ) | ||||
| set(GRAPH_OPTIMIZE_COMMON_SRC_FILES | set(GRAPH_OPTIMIZE_COMMON_SRC_FILES | ||||
| @@ -742,6 +751,7 @@ set(MULTI_PARTS_TEST_FILES | |||||
| "graph/build/logical_stream_allocator_unittest.cc" | "graph/build/logical_stream_allocator_unittest.cc" | ||||
| "graph/build/mem_assigner_unittest.cc" | "graph/build/mem_assigner_unittest.cc" | ||||
| "graph/preprocess/graph_preprocess_unittest.cc" | "graph/preprocess/graph_preprocess_unittest.cc" | ||||
| "graph/manager/hcom_util_unittest.cc" | |||||
| "session/omg_omg_unittest.cc" | "session/omg_omg_unittest.cc" | ||||
| ) | ) | ||||
| @@ -770,27 +780,59 @@ list(APPEND COMMON_SHARED_LIBRARIES | |||||
| hccl_stub | hccl_stub | ||||
| error_manager_stub | error_manager_stub | ||||
| ) | ) | ||||
| # build graph | |||||
| add_library(ge_ut_graph STATIC | |||||
| ${GRAPH_SRC_FILES} ${PARSER_SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS} | |||||
| ) | |||||
| target_compile_definitions(ge_ut_graph PRIVATE | |||||
| google=ascend_private | |||||
| ) | |||||
| target_compile_options(ge_ut_graph PRIVATE | |||||
| -g | |||||
| ) | |||||
| target_link_libraries(ge_ut_graph PRIVATE | |||||
| $<BUILD_INTERFACE:intf_pub> | |||||
| c_sec | |||||
| ascend_protobuf | |||||
| json | |||||
| ) | |||||
| # build common | # build common | ||||
| add_library(ge_ut_common STATIC ${COMMON_SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | |||||
| add_library(ge_ut_common STATIC ${COMMON_SRC_FILES} ${PROTO_HDRS}) | |||||
| target_compile_definitions(ge_ut_common PRIVATE | target_compile_definitions(ge_ut_common PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_compile_options(ge_ut_common PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(ge_ut_common PRIVATE | target_link_libraries(ge_ut_common PRIVATE | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| c_sec | c_sec | ||||
| ascend_protobuf | ascend_protobuf | ||||
| json | json | ||||
| ge_ut_graph | |||||
| ) | ) | ||||
| # build common format | # build common format | ||||
| add_library(ge_ut_common_format STATIC ${COMMON_SRC_FILES} ${COMMON_FORMAT_SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | |||||
| add_library(ge_ut_common_format STATIC ${COMMON_SRC_FILES} ${COMMON_FORMAT_SRC_FILES} ${PROTO_HDRS}) | |||||
| target_compile_definitions(ge_ut_common_format PRIVATE | target_compile_definitions(ge_ut_common_format PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_compile_options(ge_ut_common_format PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(ge_ut_common_format PRIVATE | target_link_libraries(ge_ut_common_format PRIVATE | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| c_sec | c_sec | ||||
| @@ -799,12 +841,17 @@ target_link_libraries(ge_ut_common_format PRIVATE | |||||
| ) | ) | ||||
| # build graph prepare common | # build graph prepare common | ||||
| add_library(ge_prepare_common STATIC ${GRAPH_PREPARE_COMMON_SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | |||||
| add_library(ge_prepare_common STATIC ${GRAPH_PREPARE_COMMON_SRC_FILES} ${PROTO_HDRS}) | |||||
| target_compile_definitions(ge_prepare_common PRIVATE | target_compile_definitions(ge_prepare_common PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_compile_options(ge_prepare_common PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(ge_prepare_common PRIVATE | target_link_libraries(ge_prepare_common PRIVATE | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| c_sec | c_sec | ||||
| @@ -813,12 +860,17 @@ target_link_libraries(ge_prepare_common PRIVATE | |||||
| ) | ) | ||||
| # build graph optimize common | # build graph optimize common | ||||
| add_library(ge_optimize_common STATIC ${GRAPH_OPTIMIZE_COMMON_SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | |||||
| add_library(ge_optimize_common STATIC ${GRAPH_OPTIMIZE_COMMON_SRC_FILES} ${PROTO_HDRS}) | |||||
| target_compile_definitions(ge_optimize_common PRIVATE | target_compile_definitions(ge_optimize_common PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_compile_options(ge_optimize_common PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(ge_optimize_common PRIVATE | target_link_libraries(ge_optimize_common PRIVATE | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ascend_protobuf | ascend_protobuf | ||||
| @@ -827,12 +879,17 @@ target_link_libraries(ge_optimize_common PRIVATE | |||||
| ) | ) | ||||
| # build graph partition common | # build graph partition common | ||||
| add_library(ge_partition_common STATIC ${GRAPH_PARTITION_COMMON_SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | |||||
| add_library(ge_partition_common STATIC ${GRAPH_PARTITION_COMMON_SRC_FILES} ${PROTO_HDRS}) | |||||
| target_compile_definitions(ge_partition_common PRIVATE | target_compile_definitions(ge_partition_common PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_compile_options(ge_partition_common PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(ge_partition_common PRIVATE | target_link_libraries(ge_partition_common PRIVATE | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ascend_protobuf | ascend_protobuf | ||||
| @@ -841,12 +898,17 @@ target_link_libraries(ge_partition_common PRIVATE | |||||
| ) | ) | ||||
| # build build graph load common | # build build graph load common | ||||
| add_library(ge_load_common STATIC ${GRAPH_LOAD_COMMON_SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | |||||
| add_library(ge_load_common STATIC ${GRAPH_LOAD_COMMON_SRC_FILES} ${PROTO_HDRS}) | |||||
| target_compile_definitions(ge_load_common PRIVATE | target_compile_definitions(ge_load_common PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_compile_options(ge_load_common PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(ge_load_common PRIVATE | target_link_libraries(ge_load_common PRIVATE | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| c_sec | c_sec | ||||
| @@ -855,12 +917,17 @@ target_link_libraries(ge_load_common PRIVATE | |||||
| ) | ) | ||||
| # build graph execute common | # build graph execute common | ||||
| add_library(ge_execute_common STATIC ${GRAPH_EXECUTE_COMMON_SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | |||||
| add_library(ge_execute_common STATIC ${GRAPH_EXECUTE_COMMON_SRC_FILES} ${PROTO_HDRS}) | |||||
| target_compile_definitions(ge_execute_common PRIVATE | target_compile_definitions(ge_execute_common PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_compile_options(ge_execute_common PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(ge_execute_common PRIVATE | target_link_libraries(ge_execute_common PRIVATE | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| c_sec | c_sec | ||||
| @@ -869,12 +936,17 @@ target_link_libraries(ge_execute_common PRIVATE | |||||
| ) | ) | ||||
| # build graph build common | # build graph build common | ||||
| add_library(ge_build_common STATIC ${GRAPH_BUILD_COMMON_SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | |||||
| add_library(ge_build_common STATIC ${GRAPH_BUILD_COMMON_SRC_FILES} ${PROTO_HDRS}) | |||||
| target_compile_definitions(ge_build_common PRIVATE | target_compile_definitions(ge_build_common PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_compile_options(ge_build_common PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(ge_build_common PRIVATE | target_link_libraries(ge_build_common PRIVATE | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| c_sec | c_sec | ||||
| @@ -883,12 +955,17 @@ target_link_libraries(ge_build_common PRIVATE | |||||
| ) | ) | ||||
| # build graph pass common | # build graph pass common | ||||
| add_library(ge_pass_common STATIC ${GRAPH_PASS_COMMON_SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | |||||
| add_library(ge_pass_common STATIC ${GRAPH_PASS_COMMON_SRC_FILES} ${PROTO_HDRS}) | |||||
| target_compile_definitions(ge_pass_common PRIVATE | target_compile_definitions(ge_pass_common PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_compile_options(ge_pass_common PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(ge_pass_common PRIVATE | target_link_libraries(ge_pass_common PRIVATE | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ascend_protobuf | ascend_protobuf | ||||
| @@ -897,12 +974,17 @@ target_link_libraries(ge_pass_common PRIVATE | |||||
| ) | ) | ||||
| # build single_op common | # build single_op common | ||||
| add_library(ge_single_op STATIC ${SINGLE_OP_SRC_FILES} ${PROTO_SRCS} ${PROTO_HDRS}) | |||||
| add_library(ge_single_op STATIC ${SINGLE_OP_SRC_FILES} ${PROTO_HDRS}) | |||||
| target_compile_definitions(ge_single_op PRIVATE | target_compile_definitions(ge_single_op PRIVATE | ||||
| google=ascend_private | google=ascend_private | ||||
| ) | ) | ||||
| target_compile_options(ge_single_op PRIVATE | |||||
| -g --coverage -fprofile-arcs -ftest-coverage | |||||
| -Werror=format | |||||
| ) | |||||
| target_link_libraries(ge_single_op PRIVATE | target_link_libraries(ge_single_op PRIVATE | ||||
| $<BUILD_INTERFACE:intf_pub> | $<BUILD_INTERFACE:intf_pub> | ||||
| ascend_protobuf | ascend_protobuf | ||||
| @@ -921,6 +1003,7 @@ add_executable(ut_libge_multiparts_utest | |||||
| target_compile_options(ut_libge_multiparts_utest PRIVATE | target_compile_options(ut_libge_multiparts_utest PRIVATE | ||||
| -g --coverage -fprofile-arcs -ftest-coverage | -g --coverage -fprofile-arcs -ftest-coverage | ||||
| -Werror=format | |||||
| ) | ) | ||||
| target_compile_definitions(ut_libge_multiparts_utest PRIVATE | target_compile_definitions(ut_libge_multiparts_utest PRIVATE | ||||
| @@ -943,6 +1026,7 @@ add_executable(ut_libge_others_utest | |||||
| target_compile_options(ut_libge_others_utest PRIVATE | target_compile_options(ut_libge_others_utest PRIVATE | ||||
| -g --coverage -fprofile-arcs -ftest-coverage | -g --coverage -fprofile-arcs -ftest-coverage | ||||
| -Werror=format | |||||
| ) | ) | ||||
| target_link_libraries(ut_libge_others_utest | target_link_libraries(ut_libge_others_utest | ||||
| @@ -960,6 +1044,7 @@ add_executable(ut_libge_kernel_utest | |||||
| target_compile_options(ut_libge_kernel_utest PRIVATE | target_compile_options(ut_libge_kernel_utest PRIVATE | ||||
| -g --coverage -fprofile-arcs -ftest-coverage | -g --coverage -fprofile-arcs -ftest-coverage | ||||
| -Werror=format | |||||
| ) | ) | ||||
| target_link_libraries(ut_libge_kernel_utest | target_link_libraries(ut_libge_kernel_utest | ||||
| @@ -978,6 +1063,7 @@ add_executable(ut_libge_distinct_load_utest | |||||
| target_compile_options(ut_libge_distinct_load_utest PRIVATE | target_compile_options(ut_libge_distinct_load_utest PRIVATE | ||||
| -g --coverage -fprofile-arcs -ftest-coverage | -g --coverage -fprofile-arcs -ftest-coverage | ||||
| -Werror=format | |||||
| ) | ) | ||||
| target_compile_definitions(ut_libge_distinct_load_utest PRIVATE | target_compile_definitions(ut_libge_distinct_load_utest PRIVATE | ||||
| @@ -34,6 +34,10 @@ class UtestDavinciModel : public testing::Test { | |||||
| void TearDown() {} | void TearDown() {} | ||||
| }; | }; | ||||
| int32_t MsprofReport(uint32_t moduleId, uint32_t type, void *data, uint32_t len) { | |||||
| return 0; | |||||
| } | |||||
| /* | /* | ||||
| TEST_F(UtestDavinciModel, init_success) { | TEST_F(UtestDavinciModel, init_success) { | ||||
| DavinciModel model(0, nullptr); | DavinciModel model(0, nullptr); | ||||
| @@ -853,4 +857,18 @@ TEST_F(UtestDavinciModel, LoadWithQueue_fail_with_diff_args) { | |||||
| EXPECT_EQ(model.LoadWithQueue(), INTERNAL_ERROR); | EXPECT_EQ(model.LoadWithQueue(), INTERNAL_ERROR); | ||||
| EXPECT_EQ(model.active_stream_list_.size(), 0); | EXPECT_EQ(model.active_stream_list_.size(), 0); | ||||
| } | } | ||||
| TEST_F(UtestDavinciModel, Sink_model_profile) { | |||||
| ProfilingManager::Instance().prof_cb_.msprofReporterCallback = MsprofReport; | |||||
| ProfileInfo profile; | |||||
| profile.fusion_info.op_name = "relu"; | |||||
| DavinciModel model(0, nullptr); | |||||
| model.profile_list_.emplace_back(profile); | |||||
| std::map<std::string, std::pair<uint32_t, uint32_t>> op_info; | |||||
| op_info["relu"] = std::pair<uint32_t, uint32_t>(1, 1); | |||||
| model.profiler_report_op_info_ = op_info; | |||||
| model.SinkModelProfile(); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -140,6 +140,7 @@ TEST_F(UtestKernelExTaskInfo, kernel_ex_task_info_calculate_args) { | |||||
| TEST_F(UtestKernelExTaskInfo, kernel_ex_task_ext_info) { | TEST_F(UtestKernelExTaskInfo, kernel_ex_task_ext_info) { | ||||
| const string ext_info = {1, 1, 1, 1, 0, 0, 0, 0}; | const string ext_info = {1, 1, 1, 1, 0, 0, 0, 0}; | ||||
| const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); | const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); | ||||
| AttrUtils::SetBool(op_desc, "_AllShape", true); | |||||
| KernelExTaskInfo kernel_ex_task_info; | KernelExTaskInfo kernel_ex_task_info; | ||||
| EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); | EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); | ||||
| @@ -390,6 +390,7 @@ TEST_F(UtestKernelTaskInfo, init_kernel_taskInfo_with_aicpu_kernel_type_fail) { | |||||
| rtStreamCreate(&stream, 0); | rtStreamCreate(&stream, 0); | ||||
| model.stream_list_ = { stream }; | model.stream_list_ = { stream }; | ||||
| model.op_list_[0] = CreateOpDesc("FrameworkOp", "FrameworkOp"); | model.op_list_[0] = CreateOpDesc("FrameworkOp", "FrameworkOp"); | ||||
| AttrUtils::SetBool(model.op_list_[0], "_AllShape", true); | |||||
| domi::TaskDef task_def; | domi::TaskDef task_def; | ||||
| KernelTaskInfo kernel_task_info; | KernelTaskInfo kernel_task_info; | ||||
| @@ -0,0 +1,97 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include <gtest/gtest.h> | |||||
| #include <memory> | |||||
| #include "common/ge_inner_error_codes.h" | |||||
| #include "common/types.h" | |||||
| #include "common/util.h" | |||||
| #include "graph/utils/attr_utils.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/passes/addn_pass.h" | |||||
| #define private public | |||||
| #define protected public | |||||
| #include "graph/manager/util/hcom_util.h" | |||||
| #include "ge/ge_api.h" | |||||
| #undef private | |||||
| #undef protected | |||||
| using namespace std; | |||||
| namespace ge { | |||||
| namespace { | |||||
| GeTensorDescPtr CreateTensorDesc(std::initializer_list<int64_t> shape, Format format = FORMAT_NCHW, | |||||
| DataType data_type = DT_FLOAT) { | |||||
| GeShape ge_shape{vector<int64_t>(shape)}; | |||||
| GeTensorDescPtr tensor_desc = std::make_shared<GeTensorDesc>(); | |||||
| tensor_desc->SetShape(ge_shape); | |||||
| tensor_desc->SetFormat(format); | |||||
| tensor_desc->SetDataType(data_type); | |||||
| return tensor_desc; | |||||
| } | |||||
| class NodeBuilder { | |||||
| public: | |||||
| NodeBuilder(const std::string &name, const std::string &type) { op_desc_ = std::make_shared<OpDesc>(name, type); } | |||||
| NodeBuilder &AddInputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, | |||||
| DataType data_type = DT_FLOAT) { | |||||
| op_desc_->AddInputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); | |||||
| return *this; | |||||
| } | |||||
| NodeBuilder &AddOutputDesc(std::initializer_list<int64_t> shape = {1, 1, 224, 224}, Format format = FORMAT_NCHW, | |||||
| DataType data_type = DT_FLOAT) { | |||||
| op_desc_->AddOutputDesc(CreateTensorDesc(shape, format, data_type)->Clone()); | |||||
| return *this; | |||||
| } | |||||
| NodeBuilder &AddOutputDesc(GeTensorDescPtr tensor_desc) { | |||||
| op_desc_->AddOutputDesc(tensor_desc->Clone()); | |||||
| return *this; | |||||
| } | |||||
| NodePtr Build(const ComputeGraphPtr &graph) { | |||||
| NodePtr node = graph->AddNode(op_desc_); | |||||
| return node; | |||||
| } | |||||
| private: | |||||
| OpDescPtr op_desc_; | |||||
| }; | |||||
| } // namespace | |||||
| class UtestHcomUtil : public testing::Test { | |||||
| protected: | |||||
| void SetUp() { | |||||
| } | |||||
| void TearDown() { | |||||
| } | |||||
| }; | |||||
| TEST_F(UtestHcomUtil, test_GetHcomCount_succ) { | |||||
| ComputeGraphPtr graph = std::make_shared<ComputeGraph>("test"); | |||||
| NodePtr node = NodeBuilder("node", HCOMRECEIVE).AddInputDesc({1, 1, 224, 224}).AddOutputDesc({1, 1, 224, 224}).Build(graph); | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| HcomOmeUtil hcom_ome_util; | |||||
| int count = 0; | |||||
| auto ret = hcom_ome_util.GetHcomCount(op_desc, HCCL_DATA_TYPE_FP32, true, count); | |||||
| EXPECT_EQ(ret, 0); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,96 @@ | |||||
| /** | |||||
| * @file rt_error_codes.h | |||||
| * | |||||
| * Copyright (C) Huawei Technologies Co., Ltd. 2019-2020. All Rights Reserved. | |||||
| * | |||||
| * This program is distributed in the hope that it will be useful, | |||||
| * but WITHOUT ANY WARRANTY; without even the implied warranty of | |||||
| * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. | |||||
| */ | |||||
| #ifndef __INC_EXTERNEL_RT_ERROR_CODES_H__ | |||||
| #define __INC_EXTERNEL_RT_ERROR_CODES_H__ | |||||
| #include <stddef.h> | |||||
| #ifdef __cplusplus | |||||
| extern "C" { | |||||
| #endif | |||||
| static const int32_t ACL_RT_SUCCESS = 0; // success | |||||
| static const int32_t ACL_ERROR_RT_PARAM_INVALID = 107000; // param invalid | |||||
| static const int32_t ACL_ERROR_RT_INVALID_DEVICEID = 107001; // invalid device id | |||||
| static const int32_t ACL_ERROR_RT_CONTEXT_NULL = 107002; // current context null | |||||
| static const int32_t ACL_ERROR_RT_STREAM_CONTEXT = 107003; // stream not in current context | |||||
| static const int32_t ACL_ERROR_RT_MODEL_CONTEXT = 107004; // model not in current context | |||||
| static const int32_t ACL_ERROR_RT_STREAM_MODEL = 107005; // stream not in model | |||||
| static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_INVALID = 107006; // event timestamp invalid | |||||
| static const int32_t ACL_ERROR_RT_EVENT_TIMESTAMP_REVERSAL = 107007; // event timestamp reversal | |||||
| static const int32_t ACL_ERROR_RT_ADDR_UNALIGNED = 107008; // memory address unaligned | |||||
| static const int32_t ACL_ERROR_RT_FILE_OPEN = 107009; // open file failed | |||||
| static const int32_t ACL_ERROR_RT_FILE_WRITE = 107010; // write file failed | |||||
| static const int32_t ACL_ERROR_RT_STREAM_SUBSCRIBE = 107011; // error subscribe stream | |||||
| static const int32_t ACL_ERROR_RT_THREAD_SUBSCRIBE = 107012; // error subscribe thread | |||||
| static const int32_t ACL_ERROR_RT_GROUP_NOT_SET = 107013; // group not set | |||||
| static const int32_t ACL_ERROR_RT_GROUP_NOT_CREATE = 107014; // group not create | |||||
| static const int32_t ACL_ERROR_RT_STREAM_NO_CB_REG = 107015; // callback not register to stream | |||||
| static const int32_t ACL_ERROR_RT_INVALID_MEMORY_TYPE = 107016; // invalid memory type | |||||
| static const int32_t ACL_ERROR_RT_INVALID_HANDLE = 107017; // invalid handle | |||||
| static const int32_t ACL_ERROR_RT_INVALID_MALLOC_TYPE = 107018; // invalid malloc type | |||||
| static const int32_t ACL_ERROR_RT_FEATURE_NOT_SUPPORT = 207000; // feature not support | |||||
| static const int32_t ACL_ERROR_RT_MEMORY_ALLOCATION = 207001; // memory allocation error | |||||
| static const int32_t ACL_ERROR_RT_MEMORY_FREE = 207002; // memory free error | |||||
| static const int32_t ACL_ERROR_RT_AICORE_OVER_FLOW = 207003; // aicore over flow | |||||
| static const int32_t ACL_ERROR_RT_NO_DEVICE = 207004; // no device | |||||
| static const int32_t ACL_ERROR_RT_RESOURCE_ALLOC_FAIL = 207005; // resource alloc fail | |||||
| static const int32_t ACL_ERROR_RT_NO_PERMISSION = 207006; // no permission | |||||
| static const int32_t ACL_ERROR_RT_NO_EVENT_RESOURCE = 207007; // no event resource | |||||
| static const int32_t ACL_ERROR_RT_NO_STREAM_RESOURCE = 207008; // no stream resource | |||||
| static const int32_t ACL_ERROR_RT_NO_NOTIFY_RESOURCE = 207009; // no notify resource | |||||
| static const int32_t ACL_ERROR_RT_NO_MODEL_RESOURCE = 207010; // no model resource | |||||
| static const int32_t ACL_ERROR_RT_INTERNAL_ERROR = 507000; // runtime internal error | |||||
| static const int32_t ACL_ERROR_RT_TS_ERROR = 507001; // ts internel error | |||||
| static const int32_t ACL_ERROR_RT_STREAM_TASK_FULL = 507002; // task full in stream | |||||
| static const int32_t ACL_ERROR_RT_STREAM_TASK_EMPTY = 507003; // task empty in stream | |||||
| static const int32_t ACL_ERROR_RT_STREAM_NOT_COMPLETE = 507004; // stream not complete | |||||
| static const int32_t ACL_ERROR_RT_END_OF_SEQUENCE = 507005; // end of sequence | |||||
| static const int32_t ACL_ERROR_RT_EVENT_NOT_COMPLETE = 507006; // event not complete | |||||
| static const int32_t ACL_ERROR_RT_CONTEXT_RELEASE_ERROR = 507007; // context release error | |||||
| static const int32_t ACL_ERROR_RT_SOC_VERSION = 507008; // soc version error | |||||
| static const int32_t ACL_ERROR_RT_TASK_TYPE_NOT_SUPPORT = 507009; // task type not support | |||||
| static const int32_t ACL_ERROR_RT_LOST_HEARTBEAT = 507010; // ts lost heartbeat | |||||
| static const int32_t ACL_ERROR_RT_MODEL_EXECUTE = 507011; // model execute failed | |||||
| static const int32_t ACL_ERROR_RT_REPORT_TIMEOUT = 507012; // report timeout | |||||
| static const int32_t ACL_ERROR_RT_SYS_DMA = 507013; // sys dma error | |||||
| static const int32_t ACL_ERROR_RT_AICORE_TIMEOUT = 507014; // aicore timeout | |||||
| static const int32_t ACL_ERROR_RT_AICORE_EXCEPTION = 507015; // aicore exception | |||||
| static const int32_t ACL_ERROR_RT_AICORE_TRAP_EXCEPTION = 507016; // aicore trap exception | |||||
| static const int32_t ACL_ERROR_RT_AICPU_TIMEOUT = 507017; // aicpu timeout | |||||
| static const int32_t ACL_ERROR_RT_AICPU_EXCEPTION = 507018; // aicpu exception | |||||
| static const int32_t ACL_ERROR_RT_AICPU_DATADUMP_RSP_ERR = 507019; // aicpu datadump response error | |||||
| static const int32_t ACL_ERROR_RT_AICPU_MODEL_RSP_ERR = 507020; // aicpu model operate response error | |||||
| static const int32_t ACL_ERROR_RT_PROFILING_ERROR = 507021; // profiling error | |||||
| static const int32_t ACL_ERROR_RT_IPC_ERROR = 507022; // ipc error | |||||
| static const int32_t ACL_ERROR_RT_MODEL_ABORT_NORMAL = 507023; // model abort normal | |||||
| static const int32_t ACL_ERROR_RT_KERNEL_UNREGISTERING = 507024; // kernel unregistering | |||||
| static const int32_t ACL_ERROR_RT_RINGBUFFER_NOT_INIT = 507025; // ringbuffer not init | |||||
| static const int32_t ACL_ERROR_RT_RINGBUFFER_NO_DATA = 507026; // ringbuffer no data | |||||
| static const int32_t ACL_ERROR_RT_KERNEL_LOOKUP = 507027; // kernel lookup error | |||||
| static const int32_t ACL_ERROR_RT_KERNEL_DUPLICATE = 507028; // kernel register duplicate | |||||
| static const int32_t ACL_ERROR_RT_DEBUG_REGISTER_FAIL = 507029; // debug register failed | |||||
| static const int32_t ACL_ERROR_RT_DEBUG_UNREGISTER_FAIL = 507030; // debug unregister failed | |||||
| static const int32_t ACL_ERROR_RT_LABEL_CONTEXT = 507031; // label not in current context | |||||
| static const int32_t ACL_ERROR_RT_PROGRAM_USE_OUT = 507032; // program register num use out | |||||
| static const int32_t ACL_ERROR_RT_DEV_SETUP_ERROR = 507033; // device setup error | |||||
| static const int32_t ACL_ERROR_RT_DRV_INTERNAL_ERROR = 507899; // drv internal error | |||||
| static const int32_t ACL_ERROR_RT_AICPU_INTERNAL_ERROR = 507900; // aicpu internal error | |||||
| #ifdef __cplusplus | |||||
| } | |||||
| #endif | |||||
| #endif // __INC_EXTERNEL_RT_ERROR_CODES_H__ | |||||