| @@ -173,10 +173,12 @@ set(TRAIN_SRC_LIST | |||||
| "graph/manager/graph_manager_utils.cc" | "graph/manager/graph_manager_utils.cc" | ||||
| "graph/manager/graph_mem_allocator.cc" | "graph/manager/graph_mem_allocator.cc" | ||||
| "graph/manager/graph_caching_allocator.cc" | "graph/manager/graph_caching_allocator.cc" | ||||
| "graph/manager/session_scope_mem_allocator.cc" | |||||
| "graph/manager/graph_var_manager.cc" | "graph/manager/graph_var_manager.cc" | ||||
| "graph/manager/host_mem_manager.cc" | "graph/manager/host_mem_manager.cc" | ||||
| "graph/manager/rdma_pool_allocator.cc" | "graph/manager/rdma_pool_allocator.cc" | ||||
| "graph/manager/host_mem_allocator.cc" | "graph/manager/host_mem_allocator.cc" | ||||
| "graph/manager/graph_mem_manager.cc" | |||||
| "graph/manager/memory_api.cc" | "graph/manager/memory_api.cc" | ||||
| "graph/manager/model_manager/event_manager.cc" | "graph/manager/model_manager/event_manager.cc" | ||||
| "graph/manager/trans_var_data_utils.cc" | "graph/manager/trans_var_data_utils.cc" | ||||
| @@ -270,7 +272,6 @@ set(TRAIN_SRC_LIST | |||||
| "graph/passes/identity_pass.cc" | "graph/passes/identity_pass.cc" | ||||
| "graph/passes/ref_identity_delete_op_pass.cc" | "graph/passes/ref_identity_delete_op_pass.cc" | ||||
| "graph/passes/infershape_pass.cc" | "graph/passes/infershape_pass.cc" | ||||
| "graph/passes/isolated_op_remove_pass.cc" | |||||
| "graph/passes/iterator_op_pass.cc" | "graph/passes/iterator_op_pass.cc" | ||||
| "graph/passes/link_gen_mask_nodes_pass.cc" | "graph/passes/link_gen_mask_nodes_pass.cc" | ||||
| "graph/passes/merge_pass.cc" | "graph/passes/merge_pass.cc" | ||||
| @@ -317,13 +318,11 @@ set(TRAIN_SRC_LIST | |||||
| "graph/passes/transop_without_reshape_fusion_pass.cc" | "graph/passes/transop_without_reshape_fusion_pass.cc" | ||||
| "graph/passes/transpose_transdata_pass.cc" | "graph/passes/transpose_transdata_pass.cc" | ||||
| "graph/passes/unused_const_pass.cc" | "graph/passes/unused_const_pass.cc" | ||||
| "graph/passes/unused_op_remove_pass.cc" | |||||
| "graph/passes/var_is_initialized_op_pass.cc" | "graph/passes/var_is_initialized_op_pass.cc" | ||||
| "graph/passes/parallel_concat_start_op_pass.cc" | "graph/passes/parallel_concat_start_op_pass.cc" | ||||
| "graph/passes/cond_pass.cc" | "graph/passes/cond_pass.cc" | ||||
| "graph/passes/cond_remove_pass.cc" | "graph/passes/cond_remove_pass.cc" | ||||
| "graph/passes/for_pass.cc" | "graph/passes/for_pass.cc" | ||||
| "graph/passes/variable_format_pass.cc" | |||||
| "graph/passes/variable_op_pass.cc" | "graph/passes/variable_op_pass.cc" | ||||
| "graph/passes/variable_prepare_op_pass.cc" | "graph/passes/variable_prepare_op_pass.cc" | ||||
| "graph/passes/variable_ref_delete_op_pass.cc" | "graph/passes/variable_ref_delete_op_pass.cc" | ||||
| @@ -478,6 +477,8 @@ set(INFER_SRC_LIST | |||||
| "graph/manager/host_mem_allocator.cc" | "graph/manager/host_mem_allocator.cc" | ||||
| "graph/manager/graph_mem_allocator.cc" | "graph/manager/graph_mem_allocator.cc" | ||||
| "graph/manager/graph_caching_allocator.cc" | "graph/manager/graph_caching_allocator.cc" | ||||
| "graph/manager/session_scope_mem_allocator.cc" | |||||
| "graph/manager/graph_mem_manager.cc" | |||||
| "model/ge_model.cc" | "model/ge_model.cc" | ||||
| "model/ge_root_model.cc" | "model/ge_root_model.cc" | ||||
| "graph/common/transop_util.cc" | "graph/common/transop_util.cc" | ||||
| @@ -522,12 +523,10 @@ set(INFER_SRC_LIST | |||||
| "graph/passes/dimension_adjust_pass.cc" | "graph/passes/dimension_adjust_pass.cc" | ||||
| "graph/passes/get_original_format_pass.cc" | "graph/passes/get_original_format_pass.cc" | ||||
| "graph/passes/shape_operate_op_remove_pass.cc" | "graph/passes/shape_operate_op_remove_pass.cc" | ||||
| "graph/passes/unused_op_remove_pass.cc" | |||||
| "graph/passes/assert_pass.cc" | "graph/passes/assert_pass.cc" | ||||
| "graph/passes/dropout_pass.cc" | "graph/passes/dropout_pass.cc" | ||||
| "graph/passes/infershape_pass.cc" | "graph/passes/infershape_pass.cc" | ||||
| "graph/passes/unused_const_pass.cc" | "graph/passes/unused_const_pass.cc" | ||||
| "graph/passes/isolated_op_remove_pass.cc" | |||||
| "graph/passes/permute_pass.cc" | "graph/passes/permute_pass.cc" | ||||
| "graph/passes/ctrl_edge_transfer_pass.cc" | "graph/passes/ctrl_edge_transfer_pass.cc" | ||||
| "graph/passes/end_of_sequence_add_control_pass.cc" | "graph/passes/end_of_sequence_add_control_pass.cc" | ||||
| @@ -610,7 +609,6 @@ set(INFER_SRC_LIST | |||||
| "graph/passes/switch_logic_remove_pass.cc" | "graph/passes/switch_logic_remove_pass.cc" | ||||
| "graph/passes/switch_data_edges_bypass.cc" | "graph/passes/switch_data_edges_bypass.cc" | ||||
| "graph/passes/merge_pass.cc" | "graph/passes/merge_pass.cc" | ||||
| "graph/passes/variable_format_pass.cc" | |||||
| "graph/passes/variable_op_pass.cc" | "graph/passes/variable_op_pass.cc" | ||||
| "graph/passes/cast_remove_pass.cc" | "graph/passes/cast_remove_pass.cc" | ||||
| "graph/passes/transpose_transdata_pass.cc" | "graph/passes/transpose_transdata_pass.cc" | ||||
| @@ -62,7 +62,6 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::LoadFro | |||||
| char *data = new (std::nothrow) char[len]; | char *data = new (std::nothrow) char[len]; | ||||
| if (data == nullptr) { | if (data == nullptr) { | ||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Load model From file failed, bad memory allocation occur. (need:%u)", len); | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Load][ModelFromFile]Failed, " | GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Load][ModelFromFile]Failed, " | ||||
| "bad memory allocation occur(need %u), file %s", len, model_path); | "bad memory allocation occur(need %u), file %s", len, model_path); | ||||
| REPORT_CALL_ERROR("E19999", "Load model from file %s failed, " | REPORT_CALL_ERROR("E19999", "Load model from file %s failed, " | ||||
| @@ -90,33 +89,45 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY Status ModelParserBase::ParseMo | |||||
| GE_CHECK_NOTNULL(model.model_data); | GE_CHECK_NOTNULL(model.model_data); | ||||
| // Model length too small | // Model length too small | ||||
| GE_CHK_BOOL_RET_STATUS(model.model_len >= sizeof(ModelFileHeader), ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, | |||||
| "Invalid model. Model data size %u must be greater than or equal to %zu.", model.model_len, | |||||
| sizeof(ModelFileHeader)); | |||||
| GE_CHK_BOOL_EXEC(model.model_len >= sizeof(ModelFileHeader), | |||||
| REPORT_INPUT_ERROR("E10003", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
| std::vector<std::string>({"om", model.om_name.c_str(), "invalid om file"})); | |||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, | |||||
| "[Check][Param] Invalid model. Model data size %u must be greater than or equal to %zu.", | |||||
| model.model_len, sizeof(ModelFileHeader)); | |||||
| return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID;); | |||||
| // Get file header | // Get file header | ||||
| auto file_header = reinterpret_cast<ModelFileHeader *>(model.model_data); | auto file_header = reinterpret_cast<ModelFileHeader *>(model.model_data); | ||||
| // Determine whether the file length and magic number match | // Determine whether the file length and magic number match | ||||
| GE_CHK_BOOL_RET_STATUS( | |||||
| file_header->length == model.model_len - sizeof(ModelFileHeader) && file_header->magic == MODEL_FILE_MAGIC_NUM, | |||||
| ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, | |||||
| "Invalid model. file_header->length[%u] + sizeof(ModelFileHeader)[%zu] != model->model_len[%u] || " | |||||
| "MODEL_FILE_MAGIC_NUM[%u] != file_header->magic[%u]", | |||||
| file_header->length, sizeof(ModelFileHeader), model.model_len, MODEL_FILE_MAGIC_NUM, file_header->magic); | |||||
| GE_CHK_BOOL_EXEC(file_header->length == model.model_len - sizeof(ModelFileHeader) && | |||||
| file_header->magic == MODEL_FILE_MAGIC_NUM, | |||||
| REPORT_INPUT_ERROR("E10003", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
| std::vector<std::string>({"om", model.om_name.c_str(), "invalid om file"})); | |||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID, | |||||
| "[Check][Param] Invalid model, file_header->length[%u] + sizeof(ModelFileHeader)[%zu] != " | |||||
| "model->model_len[%u] || MODEL_FILE_MAGIC_NUM[%u] != file_header->magic[%u]", | |||||
| file_header->length, sizeof(ModelFileHeader), model.model_len, | |||||
| MODEL_FILE_MAGIC_NUM, file_header->magic); | |||||
| return ACL_ERROR_GE_EXEC_MODEL_DATA_SIZE_INVALID;); | |||||
| Status res = SUCCESS; | Status res = SUCCESS; | ||||
| // Get data address | // Get data address | ||||
| uint8_t *data = reinterpret_cast<uint8_t *>(model.model_data) + sizeof(ModelFileHeader); | uint8_t *data = reinterpret_cast<uint8_t *>(model.model_data) + sizeof(ModelFileHeader); | ||||
| if (file_header->is_encrypt == ModelEncryptType::UNENCRYPTED) { // Unencrypted model | if (file_header->is_encrypt == ModelEncryptType::UNENCRYPTED) { // Unencrypted model | ||||
| GE_CHK_BOOL_RET_STATUS(model.key.empty(), ACL_ERROR_GE_PARAM_INVALID, | |||||
| "Invalid param. model is unencrypted, but key is not empty."); | |||||
| if (!model.key.empty()) { | |||||
| REPORT_INPUT_ERROR("E10003", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
| std::vector<std::string>({"om", model.om_name.c_str(), "invalid om file"})); | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, | |||||
| "[Check][Param] Invalid param, model is unencrypted, but key is not empty."); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | |||||
| } | |||||
| model_data = data; | model_data = data; | ||||
| model_len = file_header->length; | model_len = file_header->length; | ||||
| GELOGD("Model_len is %u, model_file_head_len is %zu.", model_len, sizeof(ModelFileHeader)); | GELOGD("Model_len is %u, model_file_head_len is %zu.", model_len, sizeof(ModelFileHeader)); | ||||
| } else { | } else { | ||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Param]Invalid, model encrypt type not supported"); | GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Param]Invalid, model encrypt type not supported"); | ||||
| REPORT_CALL_ERROR("E19999","Invalid model, encrypt type not supported"); | |||||
| REPORT_INPUT_ERROR("E10003", std::vector<std::string>({"parameter", "value", "reason"}), | |||||
| std::vector<std::string>({"om", model.om_name.c_str(), "invalid om file"})); | |||||
| res = ACL_ERROR_GE_PARAM_INVALID; | res = ACL_ERROR_GE_PARAM_INVALID; | ||||
| } | } | ||||
| @@ -184,7 +184,10 @@ ge::Status ProfilingManager::ParseOptions(const std::string &options) { | |||||
| if (options.find(kTrainingTrace) == std::string::npos) { | if (options.find(kTrainingTrace) == std::string::npos) { | ||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| const std::string training_trace = prof_options[kTrainingTrace]; | |||||
| std::string training_trace; | |||||
| if (prof_options.contains(kTrainingTrace)) { | |||||
| training_trace = prof_options[kTrainingTrace]; | |||||
| } | |||||
| if (training_trace.empty()) { | if (training_trace.empty()) { | ||||
| GELOGI("Training trace will not take effect."); | GELOGI("Training trace will not take effect."); | ||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| @@ -196,8 +199,12 @@ ge::Status ProfilingManager::ParseOptions(const std::string &options) { | |||||
| REPORT_INNER_ERROR("E19999", "Training trace param:%s is invalid.", training_trace.c_str()); | REPORT_INNER_ERROR("E19999", "Training trace param:%s is invalid.", training_trace.c_str()); | ||||
| return ge::PARAM_INVALID; | return ge::PARAM_INVALID; | ||||
| } | } | ||||
| fp_point_ = prof_options[kFpPoint]; | |||||
| bp_point_ = prof_options[kBpPoint]; | |||||
| if (prof_options.contains(kFpPoint)) { | |||||
| fp_point_ = prof_options[kFpPoint]; | |||||
| } | |||||
| if (prof_options.contains(kBpPoint)) { | |||||
| bp_point_ = prof_options[kBpPoint]; | |||||
| } | |||||
| if (!fp_point_.empty() && !bp_point_.empty()) { | if (!fp_point_.empty() && !bp_point_.empty()) { | ||||
| GELOGI("Training trace bp fp is set, bp_point:%s, fp_point:%s.", bp_point_.c_str(), fp_point_.c_str()); | GELOGI("Training trace bp fp is set, bp_point:%s, fp_point:%s.", bp_point_.c_str(), fp_point_.c_str()); | ||||
| } | } | ||||
| @@ -1014,10 +1021,12 @@ FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY void ProfilingManager::GetFpBpP | |||||
| if (is_profiling_valid) { | if (is_profiling_valid) { | ||||
| try { | try { | ||||
| Json prof_options = Json::parse(profiling_options); | Json prof_options = Json::parse(profiling_options); | ||||
| fp_point_ = prof_options[kFpPoint]; | |||||
| bp_point_ = prof_options[kBpPoint]; | |||||
| if (prof_options.contains(kFpPoint)) { | |||||
| fp_point_ = prof_options[kFpPoint]; | |||||
| } | |||||
| if (prof_options.contains(kBpPoint)) { | |||||
| bp_point_ = prof_options[kBpPoint]; | |||||
| } | |||||
| fp_point = fp_point_; | fp_point = fp_point_; | ||||
| bp_point = bp_point_; | bp_point = bp_point_; | ||||
| if (!fp_point_.empty() && !bp_point_.empty()) { | if (!fp_point_.empty() && !bp_point_.empty()) { | ||||
| @@ -81,7 +81,9 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ProfilingManager { | |||||
| Status ProfModelUnsubscribe(void *model); | Status ProfModelUnsubscribe(void *model); | ||||
| void StopProfiling(); | void StopProfiling(); | ||||
| bool ProfilingTrainingTraceOn() const { return is_training_trace_; } | bool ProfilingTrainingTraceOn() const { return is_training_trace_; } | ||||
| // report model load profiling data flag, data contain task desc info, step info, model load fusion op info | |||||
| bool ProfilingModelLoadOn() const { return is_load_profiling_; } | bool ProfilingModelLoadOn() const { return is_load_profiling_; } | ||||
| // report model execute profiling data flag, data contain model execute time info | |||||
| bool ProfilingModelExecuteOn() const; | bool ProfilingModelExecuteOn() const; | ||||
| // is_execute_profiling_ only used by ge option and env | // is_execute_profiling_ only used by ge option and env | ||||
| bool ProfilingOn() const { return is_load_profiling_ && is_execute_profiling_; } | bool ProfilingOn() const { return is_load_profiling_ && is_execute_profiling_; } | ||||
| @@ -28,6 +28,8 @@ set(SRC_LIST | |||||
| "../graph/manager/graph_var_manager.cc" | "../graph/manager/graph_var_manager.cc" | ||||
| "../graph/manager/graph_mem_allocator.cc" | "../graph/manager/graph_mem_allocator.cc" | ||||
| "../graph/manager/graph_caching_allocator.cc" | "../graph/manager/graph_caching_allocator.cc" | ||||
| "../graph/manager/session_scope_mem_allocator.cc" | |||||
| "../graph/manager/graph_mem_manager.cc" | |||||
| "../graph/manager/trans_var_data_utils.cc" | "../graph/manager/trans_var_data_utils.cc" | ||||
| "../graph/manager/util/debug.cc" | "../graph/manager/util/debug.cc" | ||||
| "../graph/manager/rdma_pool_allocator.cc" | "../graph/manager/rdma_pool_allocator.cc" | ||||
| @@ -26,7 +26,7 @@ | |||||
| #include "graph/execute/graph_execute.h" | #include "graph/execute/graph_execute.h" | ||||
| #include "graph/load/graph_loader.h" | #include "graph/load/graph_loader.h" | ||||
| #include "graph/load/model_manager/model_manager.h" | #include "graph/load/model_manager/model_manager.h" | ||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| #include "graph/manager/graph_mem_manager.h" | |||||
| #include "single_op/single_op_manager.h" | #include "single_op/single_op_manager.h" | ||||
| #include "graph/load/model_manager/davinci_model.h" | #include "graph/load/model_manager/davinci_model.h" | ||||
| #include "opskernel_manager/ops_kernel_builder_manager.h" | #include "opskernel_manager/ops_kernel_builder_manager.h" | ||||
| @@ -122,12 +122,10 @@ OMG_HOST_SRC_FILES := \ | |||||
| graph/passes/dimension_adjust_pass.cc \ | graph/passes/dimension_adjust_pass.cc \ | ||||
| graph/passes/get_original_format_pass.cc \ | graph/passes/get_original_format_pass.cc \ | ||||
| graph/passes/shape_operate_op_remove_pass.cc \ | graph/passes/shape_operate_op_remove_pass.cc \ | ||||
| graph/passes/unused_op_remove_pass.cc \ | |||||
| graph/passes/assert_pass.cc \ | graph/passes/assert_pass.cc \ | ||||
| graph/passes/dropout_pass.cc \ | graph/passes/dropout_pass.cc \ | ||||
| graph/passes/infershape_pass.cc \ | graph/passes/infershape_pass.cc \ | ||||
| graph/passes/unused_const_pass.cc \ | graph/passes/unused_const_pass.cc \ | ||||
| graph/passes/isolated_op_remove_pass.cc \ | |||||
| graph/passes/permute_pass.cc \ | graph/passes/permute_pass.cc \ | ||||
| graph/passes/ctrl_edge_transfer_pass.cc \ | graph/passes/ctrl_edge_transfer_pass.cc \ | ||||
| graph/passes/end_of_sequence_add_control_pass.cc \ | graph/passes/end_of_sequence_add_control_pass.cc \ | ||||
| @@ -209,7 +207,6 @@ OMG_HOST_SRC_FILES := \ | |||||
| graph/passes/switch_logic_remove_pass.cc \ | graph/passes/switch_logic_remove_pass.cc \ | ||||
| graph/passes/switch_data_edges_bypass.cc \ | graph/passes/switch_data_edges_bypass.cc \ | ||||
| graph/passes/merge_pass.cc \ | graph/passes/merge_pass.cc \ | ||||
| graph/passes/variable_format_pass.cc \ | |||||
| graph/passes/variable_op_pass.cc \ | graph/passes/variable_op_pass.cc \ | ||||
| graph/passes/cast_remove_pass.cc \ | graph/passes/cast_remove_pass.cc \ | ||||
| graph/passes/transpose_transdata_pass.cc \ | graph/passes/transpose_transdata_pass.cc \ | ||||
| @@ -187,7 +187,6 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| graph/passes/identity_pass.cc \ | graph/passes/identity_pass.cc \ | ||||
| graph/passes/ref_identity_delete_op_pass.cc \ | graph/passes/ref_identity_delete_op_pass.cc \ | ||||
| graph/passes/infershape_pass.cc \ | graph/passes/infershape_pass.cc \ | ||||
| graph/passes/isolated_op_remove_pass.cc \ | |||||
| graph/passes/iterator_op_pass.cc \ | graph/passes/iterator_op_pass.cc \ | ||||
| graph/passes/link_gen_mask_nodes_pass.cc \ | graph/passes/link_gen_mask_nodes_pass.cc \ | ||||
| graph/passes/merge_pass.cc \ | graph/passes/merge_pass.cc \ | ||||
| @@ -233,13 +232,11 @@ LIBGE_LOCAL_SRC_FILES := \ | |||||
| graph/passes/transop_without_reshape_fusion_pass.cc \ | graph/passes/transop_without_reshape_fusion_pass.cc \ | ||||
| graph/passes/transpose_transdata_pass.cc \ | graph/passes/transpose_transdata_pass.cc \ | ||||
| graph/passes/unused_const_pass.cc \ | graph/passes/unused_const_pass.cc \ | ||||
| graph/passes/unused_op_remove_pass.cc \ | |||||
| graph/passes/var_is_initialized_op_pass.cc \ | graph/passes/var_is_initialized_op_pass.cc \ | ||||
| graph/passes/parallel_concat_start_op_pass.cc \ | graph/passes/parallel_concat_start_op_pass.cc \ | ||||
| graph/passes/cond_pass.cc \ | graph/passes/cond_pass.cc \ | ||||
| graph/passes/cond_remove_pass.cc \ | graph/passes/cond_remove_pass.cc \ | ||||
| graph/passes/for_pass.cc \ | graph/passes/for_pass.cc \ | ||||
| graph/passes/variable_format_pass.cc \ | |||||
| graph/passes/variable_op_pass.cc \ | graph/passes/variable_op_pass.cc \ | ||||
| graph/passes/variable_prepare_op_pass.cc \ | graph/passes/variable_prepare_op_pass.cc \ | ||||
| graph/passes/variable_ref_delete_op_pass.cc \ | graph/passes/variable_ref_delete_op_pass.cc \ | ||||
| @@ -31,6 +31,7 @@ | |||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "graph/manager/graph_manager.h" | #include "graph/manager/graph_manager.h" | ||||
| #include "graph/manager/util/rt_context_util.h" | #include "graph/manager/util/rt_context_util.h" | ||||
| #include "graph/operator_factory_impl.h" | |||||
| #include "graph/opsproto_manager.h" | #include "graph/opsproto_manager.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| @@ -803,6 +804,41 @@ Status GeGenerator::CheckForSingleOp(OpDescPtr &op_desc, const vector<GeTensor> | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GeGenerator::InferFormatForSingleOp(OpDescPtr &op_desc) { | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| if (OperatorFactoryImpl::GetInferFormatFunc(op_desc->GetType()) != nullptr) { | |||||
| auto node_op = ge::OperatorFactoryImpl::CreateOperator("node_op", op_desc->GetType()); | |||||
| if (node_op.IsEmpty()) { | |||||
| GELOGW("get op from OperatorFactory fail. op type: %s", op_desc->GetType().c_str()); | |||||
| } else { | |||||
| GELOGD("get op from OperatorFactory success. op type: %s", op_desc->GetType().c_str()); | |||||
| auto temp_op_desc = ge::OpDescUtils::GetOpDescFromOperator(node_op); | |||||
| if (temp_op_desc == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "GetOpDescFromOperator failed, as return nullptr, type:%s", | |||||
| op_desc->GetType().c_str()); | |||||
| GELOGE(FAILED, "[Get][OpDesc] temp op desc is null, type:%s", op_desc->GetType().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (!op_desc->UpdateInputName(temp_op_desc->GetAllInputName())) { | |||||
| GELOGW("InferFormatForSingleOp UpdateInputName failed"); | |||||
| } | |||||
| if (!op_desc->UpdateOutputName(temp_op_desc->GetAllOutputName())) { | |||||
| GELOGW("InferFormatForSingleOp UpdateOutputName failed"); | |||||
| } | |||||
| } | |||||
| node_op.BreakConnect(); | |||||
| } | |||||
| auto op = OpDescUtils::CreateOperatorFromOpDesc(op_desc); | |||||
| auto ret = op_desc->CallInferFormatFunc(op); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| REPORT_INNER_ERROR("E19999", "call InferFormatFunc for single op:%s fail", | |||||
| op_desc->GetName().c_str()); | |||||
| GELOGE(FAILED, "[Call][InferFormatFunc] for single op:%s fail.", op_desc->GetName().c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs, | Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &inputs, const vector<GeTensor> &outputs, | ||||
| const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, | const string &model_file_name, OpEngineType engine_type, ModelBufferData &model_buff, | ||||
| bool is_offline, int32_t compile_flag) { | bool is_offline, int32_t compile_flag) { | ||||
| @@ -843,6 +879,7 @@ Status GeGenerator::BuildSingleOp(OpDescPtr &op_desc, const vector<GeTensor> &in | |||||
| Graph graph; | Graph graph; | ||||
| GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), | GE_CHK_STATUS(BuildSingleOpGraph(op_desc, inputs, outputs, name, graph), | ||||
| "[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str()); | "[Build][Graph] for single op:%s fail.", op_desc->GetName().c_str()); | ||||
| GE_CHK_STATUS_RET_NOLOG(InferFormatForSingleOp(op_desc)); | |||||
| // 2. check engine type when compile online | // 2. check engine type when compile online | ||||
| if (model_file_name == kFileNameSuffix) { | if (model_file_name == kFileNameSuffix) { | ||||
| @@ -500,6 +500,7 @@ string MemoryBlock::String() { | |||||
| ss << "Block size: " << Size() << " from " << HeadOffset() << " to " << TailOffset() << " "; | ss << "Block size: " << Size() << " from " << HeadOffset() << " to " << TailOffset() << " "; | ||||
| ss << "real_size_list: " << ToString(real_size_list_) << " "; | ss << "real_size_list: " << ToString(real_size_list_) << " "; | ||||
| ss << "ref_count: " << ref_count_ << " "; | ss << "ref_count: " << ref_count_ << " "; | ||||
| ss << "reuse_mem_: " << reuse_mem_ << " "; | |||||
| ss << "members: "; | ss << "members: "; | ||||
| for (auto x : NodeTypeIndexList()) { | for (auto x : NodeTypeIndexList()) { | ||||
| ss << "__node: " << ToString(x) << " "; | ss << "__node: " << ToString(x) << " "; | ||||
| @@ -513,8 +514,8 @@ string MemoryBlock::String() { | |||||
| BlockMemAssigner::BlockMemAssigner(ComputeGraphPtr compute_graph, const map<string, string> &anchor_to_symbol, | BlockMemAssigner::BlockMemAssigner(ComputeGraphPtr compute_graph, const map<string, string> &anchor_to_symbol, | ||||
| const map<string, list<NodeIndexIO>> &symbol_to_anchors) | const map<string, list<NodeIndexIO>> &symbol_to_anchors) | ||||
| : mem_offset_(0), p2p_mem_offset_(0), compute_graph_(std::move(compute_graph)), | |||||
| symbol_to_anchors_(symbol_to_anchors), anchor_to_symbol_(anchor_to_symbol), life_time_(0) {} | |||||
| : compute_graph_(std::move(compute_graph)), symbol_to_anchors_(symbol_to_anchors), | |||||
| anchor_to_symbol_(anchor_to_symbol), life_time_(0) {} | |||||
| BlockMemAssigner::~BlockMemAssigner() { | BlockMemAssigner::~BlockMemAssigner() { | ||||
| GELOGD("[Destruct][BlockMemAssigner]blocks_store_ size : %lu", blocks_store_.size()); | GELOGD("[Destruct][BlockMemAssigner]blocks_store_ size : %lu", blocks_store_.size()); | ||||
| @@ -1123,7 +1124,7 @@ bool BlockMemAssigner::IsZeroCopyBlock(const NodePtr &node, bool continuous) { | |||||
| MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, size_t no_align_size, | MemoryBlock *BlockMemAssigner::ApplyMemory(size_t block_size, size_t real_size, size_t no_align_size, | ||||
| OpMemoryType mem_type, const NodePtr &n, uint32_t out_index, | OpMemoryType mem_type, const NodePtr &n, uint32_t out_index, | ||||
| const vector<bool> &workspace_reuse_flag, const bool is_op_reuse_mem, | const vector<bool> &workspace_reuse_flag, const bool is_op_reuse_mem, | ||||
| const bool continuous, int64_t memory_type) { | |||||
| const bool continuous, uint64_t memory_type) { | |||||
| GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | GE_CHK_BOOL_TRUE_EXEC_WITH_LOG( | ||||
| n == nullptr, | n == nullptr, | ||||
| REPORT_INNER_ERROR("E19999", "Input parameter n(type:node_ptr) is null, apply memory failed"); | REPORT_INNER_ERROR("E19999", "Input parameter n(type:node_ptr) is null, apply memory failed"); | ||||
| @@ -1824,8 +1825,8 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) { | |||||
| zero_memory_list_.emplace_back(n, kWorkspace, static_cast<uint32_t>(i), false); | zero_memory_list_.emplace_back(n, kWorkspace, static_cast<uint32_t>(i), false); | ||||
| continue; | continue; | ||||
| } | } | ||||
| int64_t memory_type = RT_MEMORY_HBM; | |||||
| if (!GetWorkSpaceMemoryType(n, i, memory_type)) { | |||||
| uint64_t memory_type = RT_MEMORY_HBM; | |||||
| if (!GetWorkSpaceMemoryType(n, i, memory_type, workspace_reuse_flag)) { | |||||
| GELOGW("Get workspace memory type failed."); | GELOGW("Get workspace memory type failed."); | ||||
| return; | return; | ||||
| } | } | ||||
| @@ -1860,7 +1861,7 @@ void BlockMemAssigner::AssignMemoryWithReuse(vector<int64_t> &ranges) { | |||||
| } | } | ||||
| void BlockMemAssigner::CheckWorkspaceReuse(const vector<bool> &workspace_reuse_flag, uint32_t index, int64_t stream_id, | void BlockMemAssigner::CheckWorkspaceReuse(const vector<bool> &workspace_reuse_flag, uint32_t index, int64_t stream_id, | ||||
| MemoryBlock *mem_block, int64_t memory_type) { | |||||
| MemoryBlock *mem_block, uint64_t memory_type) { | |||||
| bool reuse_mem_flag = | bool reuse_mem_flag = | ||||
| ((workspace_reuse_flag.size() > index) && (workspace_reuse_flag[index] == false)) ? false : true; | ((workspace_reuse_flag.size() > index) && (workspace_reuse_flag[index] == false)) ? false : true; | ||||
| if (reuse_mem_flag) { | if (reuse_mem_flag) { | ||||
| @@ -1992,24 +1993,29 @@ void BlockMemAssigner::ReuseBlocksByLifeTime(size_t range_size) { | |||||
| } | } | ||||
| } | } | ||||
| void AddBlockMemOffset(size_t &mem_offset, size_t &p2p_mem_offset, MemoryBlock &block) { | |||||
| if (block.memory_type_ == RT_MEMORY_HBM) { | |||||
| if (block.first_continuous_block_) { | |||||
| mem_offset += MEM_ALIGN_SIZE; | |||||
| } | |||||
| block.Resize(); | |||||
| block.SetHeadOffset(mem_offset); | |||||
| mem_offset += block.Size(); | |||||
| block.SetTailOffset(mem_offset - 1); | |||||
| } else if (block.memory_type_ == RT_MEMORY_P2P_DDR) { | |||||
| if (block.first_continuous_block_) { | |||||
| p2p_mem_offset += MEM_ALIGN_SIZE; | |||||
| void AddBlockMemOffset(std::map<uint64_t, size_t> &mem_offsets, MemoryBlock &block) { | |||||
| auto it = mem_offsets.find(block.memory_type_); | |||||
| if (it == mem_offsets.end()) { | |||||
| auto result = mem_offsets.insert(std::pair<int64_t, size_t>(block.memory_type_, 0)); | |||||
| // Insert failure is unlikely | |||||
| if (!result.second) { | |||||
| return; | |||||
| } | } | ||||
| block.Resize(); | |||||
| block.SetHeadOffset(p2p_mem_offset); | |||||
| p2p_mem_offset += block.Size(); | |||||
| block.SetTailOffset(p2p_mem_offset - 1); | |||||
| it = result.first; | |||||
| } | |||||
| if (it == mem_offsets.end()) { | |||||
| return; | |||||
| } | |||||
| auto &mem_offset = it->second; | |||||
| if (block.first_continuous_block_) { | |||||
| mem_offset += MEM_ALIGN_SIZE; | |||||
| } | } | ||||
| block.Resize(); | |||||
| block.SetHeadOffset(mem_offset); | |||||
| mem_offset += block.Size(); | |||||
| block.SetTailOffset(mem_offset - 1); | |||||
| } | } | ||||
| bool DynamicBatchBlockReuse(MemoryBlock &block) { | bool DynamicBatchBlockReuse(MemoryBlock &block) { | ||||
| @@ -2036,27 +2042,27 @@ void BlockMemAssigner::ResizeDynamicBatchBlocks() { | |||||
| } | } | ||||
| } | } | ||||
| size_t max_mem_offset = mem_offset_; | |||||
| size_t max_p2p_mem_offset = p2p_mem_offset_; | |||||
| std::map<uint64_t, size_t> max_mem_offsets = mem_offsets_; | |||||
| for (auto &batch_blocks : dynamic_batch_blocks) { | for (auto &batch_blocks : dynamic_batch_blocks) { | ||||
| size_t mem_offset = mem_offset_; | |||||
| size_t p2p_mem_offset = p2p_mem_offset_; | |||||
| std::map<uint64_t, size_t> mem_offsets = mem_offsets_; | |||||
| for (auto block : batch_blocks.second) { | for (auto block : batch_blocks.second) { | ||||
| if (block == nullptr || block->deleted_block_ || block->is_zero_copy_) { | if (block == nullptr || block->deleted_block_ || block->is_zero_copy_) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| AddBlockMemOffset(mem_offset, p2p_mem_offset, *block); | |||||
| AddBlockMemOffset(mem_offsets, *block); | |||||
| } | } | ||||
| if (mem_offset > max_mem_offset) { | |||||
| max_mem_offset = mem_offset; | |||||
| } | |||||
| if (p2p_mem_offset > max_p2p_mem_offset) { | |||||
| max_p2p_mem_offset = p2p_mem_offset; | |||||
| for (auto &it : mem_offsets) { | |||||
| auto itmax = max_mem_offsets.find(it.first); | |||||
| if (itmax == max_mem_offsets.end()) { | |||||
| max_mem_offsets[it.first] = it.second; | |||||
| } else if (it.second > itmax->second) { | |||||
| itmax->second = it.second; | |||||
| } | |||||
| GELOGI("Batch:%s memory type:%ld offset:%zu", batch_blocks.first.c_str(), it.first, it.second); | |||||
| } | } | ||||
| GELOGI("Batch[%s] offset[%zu] p2p_offset[%zu]", batch_blocks.first.c_str(), mem_offset, p2p_mem_offset); | |||||
| } | } | ||||
| mem_offset_ = max_mem_offset; | |||||
| p2p_mem_offset_ = max_p2p_mem_offset; | |||||
| mem_offsets_ = max_mem_offsets; | |||||
| } | } | ||||
| /// | /// | ||||
| @@ -2074,11 +2080,13 @@ void BlockMemAssigner::ResizeMemoryBlocks() { | |||||
| continue; | continue; | ||||
| } | } | ||||
| AddBlockMemOffset(mem_offset_, p2p_mem_offset_, *memory_block); | |||||
| AddBlockMemOffset(mem_offsets_, *memory_block); | |||||
| } | } | ||||
| ResizeDynamicBatchBlocks(); | ResizeDynamicBatchBlocks(); | ||||
| GELOGI("mem_offset_ exclude zero_copy_memory is %zu, p2p_mem_offset_ exclude zero_copy_memory is %zu," | |||||
| "theory_min_memory_size %zu", mem_offset_, p2p_mem_offset_, theory_min_memory_size_); | |||||
| for (auto it : mem_offsets_) { | |||||
| GELOGI("Memory type:%ld mem_offset exclude zero_copy_memory:%zu, theory_min_memory_size:%zu", it.first, it.second, | |||||
| theory_min_memory_size_); | |||||
| } | |||||
| } | } | ||||
| /// | /// | ||||
| @@ -2217,7 +2225,8 @@ bool BlockMemAssigner::CheckIsZeroMemNodeType(const string &node_type) const { | |||||
| (node_type == CONSTANTOP) || (node_type == HVDWAIT); | (node_type == CONSTANTOP) || (node_type == HVDWAIT); | ||||
| } | } | ||||
| bool BlockMemAssigner::GetWorkSpaceMemoryType(const NodePtr &node, size_t index, int64_t &memory_type) { | |||||
| bool BlockMemAssigner::GetWorkSpaceMemoryType(const NodePtr &node, size_t index, uint64_t &memory_type, | |||||
| vector<bool> &workspace_reuse_flag) { | |||||
| memory_type = RT_MEMORY_HBM; | memory_type = RT_MEMORY_HBM; | ||||
| vector<int64_t> workspace_memory_type; | vector<int64_t> workspace_memory_type; | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| @@ -2233,6 +2242,20 @@ bool BlockMemAssigner::GetWorkSpaceMemoryType(const NodePtr &node, size_t index, | |||||
| return false; | return false; | ||||
| } | } | ||||
| memory_type = has_workspace_mem_type_attr ? workspace_memory_type[index] : RT_MEMORY_HBM; | memory_type = has_workspace_mem_type_attr ? workspace_memory_type[index] : RT_MEMORY_HBM; | ||||
| vector<int32_t> workspace_no_reuse_scope; | |||||
| bool has_workspace_no_reuse_scope = | |||||
| ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_WORKSPACE_MEMORY_NO_REUSE_SCOPE, workspace_no_reuse_scope); | |||||
| if (has_workspace_no_reuse_scope && (index < workspace_no_reuse_scope.size()) | |||||
| && (workspace_no_reuse_scope[index] == kSessionNoReuse)) { | |||||
| memory_type |= kSessionScopeMemory; | |||||
| if (workspace_reuse_flag.empty()) { | |||||
| workspace_reuse_flag.assign(workspace_no_reuse_scope.size(), true); | |||||
| } | |||||
| // set to no reuse | |||||
| workspace_reuse_flag[index] = false; | |||||
| GELOGI("%s's workspace is session scope no reuse, memory type:%lu.", node->GetName().c_str(), memory_type); | |||||
| } | |||||
| return true; | return true; | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -34,6 +34,10 @@ | |||||
| namespace ge { | namespace ge { | ||||
| const size_t kMaxLifeTime = 0xffffffff; | const size_t kMaxLifeTime = 0xffffffff; | ||||
| const int32_t kInvalidThreadScopeId = -1; | const int32_t kInvalidThreadScopeId = -1; | ||||
| const uint64_t kSessionScopeMemory = 0x100000000; | |||||
| const uint64_t kMemoryTypeMask = 0xffffffff; | |||||
| enum MemoryNoReuseScope { kReuse, kSessionNoReuse, kGraphNoReuse }; | |||||
| using DependStreamLife = std::map<int64_t, std::map<int64_t, size_t>>; | using DependStreamLife = std::map<int64_t, std::map<int64_t, size_t>>; | ||||
| @@ -224,9 +228,7 @@ class BlockMemAssigner : public MemAssigner { | |||||
| Status Assign() override; | Status Assign() override; | ||||
| size_t GetMemOffset() const { return mem_offset_; } | |||||
| size_t GetP2PMemOffset() const { return p2p_mem_offset_; } | |||||
| const std::map<uint64_t, size_t> &GetMemOffsets() const { return mem_offsets_; } | |||||
| int64_t GetAtomicAddrCleanId() const { return atomic_addr_clean_id_; } | int64_t GetAtomicAddrCleanId() const { return atomic_addr_clean_id_; } | ||||
| @@ -329,14 +331,10 @@ class BlockMemAssigner : public MemAssigner { | |||||
| /// | /// | ||||
| void UpdateOpTensorMemType(std::list<NodeIndexIO> node_index_io_list, int64_t memory_type); | void UpdateOpTensorMemType(std::list<NodeIndexIO> node_index_io_list, int64_t memory_type); | ||||
| size_t mem_offset_; | |||||
| size_t p2p_mem_offset_; | |||||
| std::map<uint64_t, size_t> mem_offsets_; | |||||
| ge::ComputeGraphPtr compute_graph_; | ge::ComputeGraphPtr compute_graph_; | ||||
| std::vector<MemoryBlock *> memory_blocks_; | std::vector<MemoryBlock *> memory_blocks_; | ||||
| std::vector<MemoryBlock *> blocks_store_; | std::vector<MemoryBlock *> blocks_store_; | ||||
| std::vector<NodeTypeIndex> zero_memory_list_; | std::vector<NodeTypeIndex> zero_memory_list_; | ||||
| // ref mapping | // ref mapping | ||||
| @@ -380,7 +378,7 @@ class BlockMemAssigner : public MemAssigner { | |||||
| /// | /// | ||||
| MemoryBlock *ApplyMemory(size_t block_size, size_t real_size, size_t no_align_size, OpMemoryType mem_type, | MemoryBlock *ApplyMemory(size_t block_size, size_t real_size, size_t no_align_size, OpMemoryType mem_type, | ||||
| const ge::NodePtr &n, uint32_t out_index, const std::vector<bool> &workspace_reuse_flag, | const ge::NodePtr &n, uint32_t out_index, const std::vector<bool> &workspace_reuse_flag, | ||||
| const bool is_op_reuse_mem, const bool continuous, int64_t memory_type); | |||||
| const bool is_op_reuse_mem, const bool continuous, uint64_t memory_type); | |||||
| /// | /// | ||||
| /// @ingroup GE | /// @ingroup GE | ||||
| @@ -394,7 +392,7 @@ class BlockMemAssigner : public MemAssigner { | |||||
| /// @author | /// @author | ||||
| /// | /// | ||||
| void CheckWorkspaceReuse(const vector<bool> &workspace_reuse_flag, uint32_t index, int64_t stream_id, | void CheckWorkspaceReuse(const vector<bool> &workspace_reuse_flag, uint32_t index, int64_t stream_id, | ||||
| MemoryBlock *mem_block, int64_t memory_type); | |||||
| MemoryBlock *mem_block, uint64_t memory_type); | |||||
| /// | /// | ||||
| /// @ingroup GE | /// @ingroup GE | ||||
| @@ -457,7 +455,8 @@ class BlockMemAssigner : public MemAssigner { | |||||
| bool IsContinuousOutput(const NodePtr &n); | bool IsContinuousOutput(const NodePtr &n); | ||||
| bool GetWorkSpaceMemoryType(const NodePtr &node, size_t index, int64_t &memory_type); | |||||
| bool GetWorkSpaceMemoryType(const NodePtr &node, size_t index, uint64_t &memory_type, | |||||
| vector<bool> &workspace_reuse_flag); | |||||
| void ContinuousOutRefCheck(bool &isAllOutputRef, bool &isOutputHasRef, const NodePtr &n); | void ContinuousOutRefCheck(bool &isAllOutputRef, bool &isOutputHasRef, const NodePtr &n); | ||||
| @@ -69,6 +69,10 @@ int64_t GetSymbolOutputOffset(const std::map<std::string, std::string> &anchor_t | |||||
| } | } | ||||
| return ge::kInvalidOffset; | return ge::kInvalidOffset; | ||||
| } | } | ||||
| bool isVariableMemoryNode(const ge::NodePtr &node) { | |||||
| return (node->GetType() == ge::VARIABLE) || (node->GetType() == ge::CONSTANTOP); | |||||
| } | |||||
| } // namespace | } // namespace | ||||
| namespace ge { | namespace ge { | ||||
| Status VariableMemoryAssigner::Assign() { | Status VariableMemoryAssigner::Assign() { | ||||
| @@ -107,11 +111,22 @@ Status GraphMemoryAssigner::AssignMemory() { | |||||
| compute_graph_->GetGraphID(), compute_graph_->GetName().c_str()); | compute_graph_->GetGraphID(), compute_graph_->GetName().c_str()); | ||||
| return ge::FAILED; | return ge::FAILED; | ||||
| } | } | ||||
| MemoryOffset memory_offset(RT_MEMORY_HBM, mem_assigner->GetMemOffset()); | |||||
| memory_offset_.emplace(RT_MEMORY_HBM, memory_offset); | |||||
| if (mem_assigner->GetP2PMemOffset() >= 0) { | |||||
| MemoryOffset p2p_memory_offset(RT_MEMORY_P2P_DDR, mem_assigner->GetP2PMemOffset()); | |||||
| for (auto pair : mem_assigner->GetMemOffsets()) { | |||||
| MemoryOffset offset(pair.first, pair.second); | |||||
| memory_offset_.emplace(pair.first, offset); | |||||
| } | |||||
| // base memtype offset must be exist | |||||
| auto it = mem_assigner->GetMemOffsets().find(RT_MEMORY_HBM); | |||||
| if (it == mem_assigner->GetMemOffsets().end()) { | |||||
| MemoryOffset memory_offset(RT_MEMORY_HBM, 0); | |||||
| memory_offset_.emplace(RT_MEMORY_HBM, memory_offset); | |||||
| } | |||||
| it = mem_assigner->GetMemOffsets().find(RT_MEMORY_P2P_DDR); | |||||
| if (it == mem_assigner->GetMemOffsets().end()) { | |||||
| MemoryOffset p2p_memory_offset(RT_MEMORY_P2P_DDR, 0); | |||||
| memory_offset_.emplace(RT_MEMORY_P2P_DDR, p2p_memory_offset); | memory_offset_.emplace(RT_MEMORY_P2P_DDR, p2p_memory_offset); | ||||
| } | } | ||||
| @@ -224,7 +239,7 @@ ge::Status CalculateTensorRealSizeAndOutSize(const ge::ConstGeTensorDescPtr &out | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphMemoryAssigner::ReAssignMemory(bool is_loop_graph, map<int64_t, size_t> &mem_type_to_offset) { | |||||
| Status GraphMemoryAssigner::ReAssignMemory(bool is_loop_graph, map<uint64_t, size_t> &mem_type_to_offset) { | |||||
| if (memory_offset_.empty()) { | if (memory_offset_.empty()) { | ||||
| REPORT_INNER_ERROR("E19999", "InnerData memory_offset_ empty, not expected, graph_id:%u, graph_name:%s", | REPORT_INNER_ERROR("E19999", "InnerData memory_offset_ empty, not expected, graph_id:%u, graph_name:%s", | ||||
| compute_graph_->GetGraphID(), compute_graph_->GetName().c_str()); | compute_graph_->GetGraphID(), compute_graph_->GetName().c_str()); | ||||
| @@ -264,7 +279,7 @@ Status GraphMemoryAssigner::ReAssignMemory(bool is_loop_graph, map<int64_t, size | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphMemoryAssigner::AssignZeroCopyMemory(map<int64_t, size_t> &mem_offset, size_t &zero_mem_copy_size) { | |||||
| Status GraphMemoryAssigner::AssignZeroCopyMemory(map<uint64_t, size_t> &mem_offset, size_t &zero_mem_copy_size) { | |||||
| BlockMemAssignerPtr priority_assigner = std::move(mem_assigner_->GetPriorityAssinger()); | BlockMemAssignerPtr priority_assigner = std::move(mem_assigner_->GetPriorityAssinger()); | ||||
| if (priority_assigner == nullptr) { | if (priority_assigner == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "InnerData priority_assigner nullptr, not expected, graph_id:%u, graph_name:%s", | REPORT_INNER_ERROR("E19999", "InnerData priority_assigner nullptr, not expected, graph_id:%u, graph_name:%s", | ||||
| @@ -436,22 +451,31 @@ bool IsContinuousInputConflict(const ge::NodePtr &node, const OpDescPtr &peer_op | |||||
| /// op1 -> node -> op2 | /// op1 -> node -> op2 | ||||
| /// return true when node is ref from input, and op1 or op2 is reuse input from output | /// return true when node is ref from input, and op1 or op2 is reuse input from output | ||||
| bool GraphMemoryAssigner::IsRefFromInputOpCascade(const NodePtr &node) { | bool GraphMemoryAssigner::IsRefFromInputOpCascade(const NodePtr &node) { | ||||
| bool ref_from_input = false; | |||||
| std::unordered_set<int32_t> ref_input_index; | |||||
| int32_t reuse_in_index = -1; | int32_t reuse_in_index = -1; | ||||
| for (const auto &out_anchor : node->GetAllOutDataAnchors()) { | for (const auto &out_anchor : node->GetAllOutDataAnchors()) { | ||||
| ref_from_input = GraphUtils::IsRefFromInput(out_anchor, reuse_in_index); | |||||
| if (ref_from_input) { | |||||
| bool reuse_input = GraphUtils::IsRefFromInput(out_anchor, reuse_in_index); | |||||
| if (reuse_input) { | |||||
| GELOGD("IsRefFromInputOpCascade: cur node:%s:%d is ref", node->GetName().c_str(), reuse_in_index); | GELOGD("IsRefFromInputOpCascade: cur node:%s:%d is ref", node->GetName().c_str(), reuse_in_index); | ||||
| break; | |||||
| ref_input_index.insert(reuse_in_index); | |||||
| } | } | ||||
| } | } | ||||
| bool ref_from_input = !ref_input_index.empty(); | |||||
| if (!ref_from_input) { | |||||
| return false; | |||||
| } | |||||
| for (const auto &in_anchor : node->GetAllInDataAnchors()) { | for (const auto &in_anchor : node->GetAllInDataAnchors()) { | ||||
| const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); | const auto &peer_out_anchor = in_anchor->GetPeerOutAnchor(); | ||||
| GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | ||||
| auto in_node = peer_out_anchor->GetOwnerNode(); | |||||
| if (isVariableMemoryNode(in_node) && (ref_input_index.count(in_anchor->GetIdx()) > 0)) { | |||||
| GELOGD("Reuse variable memory, input node:%s, type:%s.", in_node->GetName().c_str(), in_node->GetType().c_str()); | |||||
| return false; | |||||
| } | |||||
| if (ref_from_input && GraphUtils::IsRefFromInput(peer_out_anchor, reuse_in_index)) { | if (ref_from_input && GraphUtils::IsRefFromInput(peer_out_anchor, reuse_in_index)) { | ||||
| GELOGD("IsRefFromInputOpCascade: in node[%s] is ref, reuse index is:%d", | GELOGD("IsRefFromInputOpCascade: in node[%s] is ref, reuse index is:%d", | ||||
| peer_out_anchor->GetOwnerNode()->GetName().c_str(), reuse_in_index); | |||||
| in_node->GetName().c_str(), reuse_in_index); | |||||
| return true; | return true; | ||||
| } | } | ||||
| } | } | ||||
| @@ -489,6 +513,11 @@ Status GraphMemoryAssigner::UpdateRefOpOffsetReverse(const NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(peer_out_anchor); | GE_CHECK_NOTNULL(peer_out_anchor); | ||||
| auto peer_node = peer_out_anchor->GetOwnerNode(); | auto peer_node = peer_out_anchor->GetOwnerNode(); | ||||
| GE_CHECK_NOTNULL(peer_node); | GE_CHECK_NOTNULL(peer_node); | ||||
| if (isVariableMemoryNode(peer_node)) { | |||||
| GELOGW("Peer node to update is %s, skip it. Node name:%s.", | |||||
| peer_node->GetType().c_str(), peer_node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| auto peer_op_desc = peer_node->GetOpDesc(); | auto peer_op_desc = peer_node->GetOpDesc(); | ||||
| GE_CHECK_NOTNULL(peer_op_desc); | GE_CHECK_NOTNULL(peer_op_desc); | ||||
| vector<int64_t> peer_output_list = peer_op_desc->GetOutputOffset(); | vector<int64_t> peer_output_list = peer_op_desc->GetOutputOffset(); | ||||
| @@ -1398,6 +1427,9 @@ ge::Status GraphMemoryAssigner::SetInputOffset() { | |||||
| "graph_id:%u, graph_name:%s", compute_graph_->GetGraphID(), compute_graph_->GetName().c_str()); | "graph_id:%u, graph_name:%s", compute_graph_->GetGraphID(), compute_graph_->GetName().c_str()); | ||||
| } | } | ||||
| for (auto pair : memory_offset_) { | for (auto pair : memory_offset_) { | ||||
| if ((pair.first != RT_MEMORY_HBM) && (pair.second.mem_offset_ == 0)) { | |||||
| continue; | |||||
| } | |||||
| GEEVENT("[IMAS]AfterAssignMemory : %s memoffset[%zu], memtype[%ld]", compute_graph_->GetName().c_str(), | GEEVENT("[IMAS]AfterAssignMemory : %s memoffset[%zu], memtype[%ld]", compute_graph_->GetName().c_str(), | ||||
| pair.second.mem_offset_, pair.first); | pair.second.mem_offset_, pair.first); | ||||
| } | } | ||||
| @@ -103,9 +103,9 @@ class GraphMemoryAssigner { | |||||
| ge::Status AssignMemory2HasRefAttrNode(); | ge::Status AssignMemory2HasRefAttrNode(); | ||||
| ge::Status ReAssignMemory(bool is_loop_graph, map<int64_t, size_t> &mem_type_to_offset); | |||||
| ge::Status ReAssignMemory(bool is_loop_graph, map<uint64_t, size_t> &mem_type_to_offset); | |||||
| ge::Status AssignZeroCopyMemory(map<int64_t, size_t> &mem_offset, size_t &zero_mem_copy_size); | |||||
| ge::Status AssignZeroCopyMemory(map<uint64_t, size_t> &mem_offset, size_t &zero_mem_copy_size); | |||||
| ge::Status SetInputOffset(); | ge::Status SetInputOffset(); | ||||
| @@ -23,7 +23,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| HybridMemAssigner::HybridMemAssigner(ge::ComputeGraphPtr compute_graph) | HybridMemAssigner::HybridMemAssigner(ge::ComputeGraphPtr compute_graph) | ||||
| : mem_offset_(0), p2p_mem_offset_(0), compute_graph_(std::move(compute_graph)), priority_assigner_(nullptr) {} | |||||
| : compute_graph_(std::move(compute_graph)), priority_assigner_(nullptr) {} | |||||
| Status HybridMemAssigner::AssignMemory(std::unique_ptr<BlockMemAssigner> &block_assigner, size_t &mem_size) { | Status HybridMemAssigner::AssignMemory(std::unique_ptr<BlockMemAssigner> &block_assigner, size_t &mem_size) { | ||||
| vector<int64_t> ranges; | vector<int64_t> ranges; | ||||
| @@ -36,7 +36,10 @@ Status HybridMemAssigner::AssignMemory(std::unique_ptr<BlockMemAssigner> &block_ | |||||
| block_assigner->AssignMemoryWithReuse(ranges); | block_assigner->AssignMemoryWithReuse(ranges); | ||||
| mem_size = block_assigner->GetMemOffset(); | |||||
| // total size | |||||
| for (auto it : block_assigner->GetMemOffsets()) { | |||||
| mem_size += it.second; | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -73,8 +76,7 @@ Status HybridMemAssigner::Assign() { | |||||
| } | } | ||||
| priority_assigner->SetOpMemOffset(false); | priority_assigner->SetOpMemOffset(false); | ||||
| mem_offset_ = priority_assigner->GetMemOffset(); | |||||
| p2p_mem_offset_ = priority_assigner->GetP2PMemOffset(); | |||||
| mem_offsets_ = priority_assigner->GetMemOffsets(); | |||||
| priority_assigner_ = std::move(priority_assigner); | priority_assigner_ = std::move(priority_assigner); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -42,16 +42,14 @@ class HybridMemAssigner : public MemAssigner { | |||||
| Status Assign() override; | Status Assign() override; | ||||
| size_t GetMemOffset() const { return mem_offset_; } | |||||
| size_t GetP2PMemOffset() const { return p2p_mem_offset_; } | |||||
| const std::map<uint64_t, size_t> &GetMemOffsets() const { return mem_offsets_; } | |||||
| BlockMemAssignerPtr GetPriorityAssinger() const { return priority_assigner_; } | BlockMemAssignerPtr GetPriorityAssinger() const { return priority_assigner_; } | ||||
| private: | private: | ||||
| Status AssignMemory(std::unique_ptr<BlockMemAssigner> &block_assigner, size_t &mem_size); | Status AssignMemory(std::unique_ptr<BlockMemAssigner> &block_assigner, size_t &mem_size); | ||||
| size_t mem_offset_; | |||||
| size_t p2p_mem_offset_; | |||||
| std::map<uint64_t, size_t> mem_offsets_; | |||||
| ge::ComputeGraphPtr compute_graph_; | ge::ComputeGraphPtr compute_graph_; | ||||
| @@ -20,7 +20,7 @@ | |||||
| #include "graph/build/memory/graph_mem_assigner.h" | #include "graph/build/memory/graph_mem_assigner.h" | ||||
| namespace ge { | namespace ge { | ||||
| Status MemoryAssigner::AssignMemory(bool is_loop_graph, map<int64_t, size_t> &mem_offset, size_t &zero_copy_mem_size) { | |||||
| Status MemoryAssigner::AssignMemory(bool is_loop_graph, map<uint64_t, size_t> &mem_offset, size_t &zero_copy_mem_size) { | |||||
| GraphMemoryAssigner graph_mem_assigner(compute_graph_); | GraphMemoryAssigner graph_mem_assigner(compute_graph_); | ||||
| if (graph_mem_assigner.AssignMemory() != ge::SUCCESS) { | if (graph_mem_assigner.AssignMemory() != ge::SUCCESS) { | ||||
| @@ -47,6 +47,7 @@ | |||||
| #include "omg/version.h" | #include "omg/version.h" | ||||
| #include "register/op_registry.h" | #include "register/op_registry.h" | ||||
| #include "graph/passes/set_input_output_offset_pass.h" | #include "graph/passes/set_input_output_offset_pass.h" | ||||
| #include "graph/build/memory/block_mem_assigner.h" | |||||
| using std::map; | using std::map; | ||||
| using std::set; | using std::set; | ||||
| @@ -398,9 +399,21 @@ Status ModelBuilder::BuildModelDef(ge::Model &model) { | |||||
| REPORT_INNER_ERROR("E19999", "Set Attr:%s in model failed", ATTR_MODEL_MEMORY_SIZE.c_str()); | REPORT_INNER_ERROR("E19999", "Set Attr:%s in model failed", ATTR_MODEL_MEMORY_SIZE.c_str()); | ||||
| GELOGE(FAILED, "[Set][Attr] %s in model failed", ATTR_MODEL_MEMORY_SIZE.c_str()); | GELOGE(FAILED, "[Set][Attr] %s in model failed", ATTR_MODEL_MEMORY_SIZE.c_str()); | ||||
| return FAILED); | return FAILED); | ||||
| auto mem_type_session_scope = (kSessionScopeMemory | RT_MEMORY_HBM); | |||||
| size_t session_scope_mem_offset = 0; | |||||
| auto it = mem_type_to_mem_offset_.find(mem_type_session_scope); | |||||
| if (it != mem_type_to_mem_offset_.end()) { | |||||
| session_scope_mem_offset = it->second; | |||||
| } | |||||
| if (mem_type_to_mem_offset_.find(RT_MEMORY_P2P_DDR) != mem_type_to_mem_offset_.end()) { | if (mem_type_to_mem_offset_.find(RT_MEMORY_P2P_DDR) != mem_type_to_mem_offset_.end()) { | ||||
| p2p_mem_offset_ = mem_type_to_mem_offset_[RT_MEMORY_P2P_DDR]; | p2p_mem_offset_ = mem_type_to_mem_offset_[RT_MEMORY_P2P_DDR]; | ||||
| } | } | ||||
| GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_SESSION_SCOPE_MEMORY_SIZE, session_scope_mem_offset), | |||||
| REPORT_INNER_ERROR("E19999", "Set Attr:%s in model failed", | |||||
| ATTR_MODEL_SESSION_SCOPE_MEMORY_SIZE.c_str()); | |||||
| GELOGE(FAILED, "SetInt of ATTR_NAME_SESSION_SCOPE_MEMORY_SIZE failed."); | |||||
| return FAILED); | |||||
| GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_P2P_MEMORY_SIZE, p2p_mem_offset_), | GE_CHK_BOOL_EXEC(ge::AttrUtils::SetInt(&model, ATTR_MODEL_P2P_MEMORY_SIZE, p2p_mem_offset_), | ||||
| REPORT_INNER_ERROR("E19999", "Set Attr:%s in model failed", ATTR_MODEL_P2P_MEMORY_SIZE.c_str()); | REPORT_INNER_ERROR("E19999", "Set Attr:%s in model failed", ATTR_MODEL_P2P_MEMORY_SIZE.c_str()); | ||||
| GELOGE(FAILED, "[Set][Attr] %s in model failed", ATTR_MODEL_P2P_MEMORY_SIZE.c_str()); | GELOGE(FAILED, "[Set][Attr] %s in model failed", ATTR_MODEL_P2P_MEMORY_SIZE.c_str()); | ||||
| @@ -434,8 +447,8 @@ Status ModelBuilder::BuildModelDef(ge::Model &model) { | |||||
| REPORT_INNER_ERROR("E19999", "Set Attr:%s in model failed", ATTR_MODEL_OUT_NODES_NAME.c_str()); | REPORT_INNER_ERROR("E19999", "Set Attr:%s in model failed", ATTR_MODEL_OUT_NODES_NAME.c_str()); | ||||
| GELOGE(FAILED, "[Set][Str] %s in model failed.", ATTR_MODEL_OUT_NODES_NAME.c_str()); | GELOGE(FAILED, "[Set][Str] %s in model failed.", ATTR_MODEL_OUT_NODES_NAME.c_str()); | ||||
| return FAILED); | return FAILED); | ||||
| GELOGI("For model, max_mem_offset_: %zu, p2p_mem_size: %zu, zero_copy_mem_size_: %zu", max_mem_offset_, | |||||
| p2p_mem_offset_, zero_copy_mem_size_); | |||||
| GELOGI("For model, max_mem_offset: %zu, p2p_mem_size: %zu, zero_copy_mem_size: %zu, session_scope_mem_size: %zu", | |||||
| max_mem_offset_, p2p_mem_offset_, zero_copy_mem_size_, session_scope_mem_offset); | |||||
| string fp_ceiling_mode; | string fp_ceiling_mode; | ||||
| if (ge::GetContext().GetOption("ge.fpCeilingMode", fp_ceiling_mode) == SUCCESS) { | if (ge::GetContext().GetOption("ge.fpCeilingMode", fp_ceiling_mode) == SUCCESS) { | ||||
| if (!ge::AttrUtils::SetStr(&model, ATTR_FP_CEILING_MODE, fp_ceiling_mode)) { | if (!ge::AttrUtils::SetStr(&model, ATTR_FP_CEILING_MODE, fp_ceiling_mode)) { | ||||
| @@ -93,7 +93,7 @@ class ModelBuilder { | |||||
| uint64_t session_id_; | uint64_t session_id_; | ||||
| map<int64_t, size_t> mem_type_to_mem_offset_; | |||||
| map<uint64_t, size_t> mem_type_to_mem_offset_; | |||||
| size_t weight_offset_; | size_t weight_offset_; | ||||
| @@ -905,6 +905,7 @@ Status StreamAllocator::SplitStreams(vector<set<int64_t>> &split_streams) { | |||||
| added_stream_num_vec[stream_id]++; | added_stream_num_vec[stream_id]++; | ||||
| new_stream_id_vec[stream_id] = last_stream_id; | new_stream_id_vec[stream_id] = last_stream_id; | ||||
| split_streams[stream_id].emplace(last_stream_id); | split_streams[stream_id].emplace(last_stream_id); | ||||
| split_ori_stream_map_[last_stream_id] = stream_id; | |||||
| node_split_stream_map_[cur_node] = last_stream_id; | node_split_stream_map_[cur_node] = last_stream_id; | ||||
| // Add the send/recv event to the first and last nodes of the split stream. | // Add the send/recv event to the first and last nodes of the split stream. | ||||
| @@ -1104,7 +1105,7 @@ Status StreamAllocator::UpdateActiveStreamsForActiveNode(const vector<set<int64_ | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status StreamAllocator::UpdateActiveStreamsForSubgraphs() const { | |||||
| Status StreamAllocator::UpdateActiveStreamsForSubgraphs() { | |||||
| // Update active stream list for active nodes | // Update active stream list for active nodes | ||||
| for (auto &node_stream_pair : node_split_stream_map_) { | for (auto &node_stream_pair : node_split_stream_map_) { | ||||
| auto node = node_stream_pair.first; | auto node = node_stream_pair.first; | ||||
| @@ -1134,6 +1135,7 @@ Status StreamAllocator::UpdateActiveStreamsForSubgraphs() const { | |||||
| if (IsActivated(new_split_stream)) { | if (IsActivated(new_split_stream)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| specific_activated_streams_.emplace(new_split_stream); | |||||
| new_active_streams.emplace(static_cast<uint32_t>(new_split_stream)); | new_active_streams.emplace(static_cast<uint32_t>(new_split_stream)); | ||||
| active_streams.assign(new_active_streams.begin(), new_active_streams.end()); | active_streams.assign(new_active_streams.begin(), new_active_streams.end()); | ||||
| if (!AttrUtils::SetListInt(active_op, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { | if (!AttrUtils::SetListInt(active_op, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { | ||||
| @@ -1148,13 +1150,21 @@ Status StreamAllocator::UpdateActiveStreamsForSubgraphs() const { | |||||
| } | } | ||||
| bool StreamAllocator::IsActivated(int64_t stream_id) const { | bool StreamAllocator::IsActivated(int64_t stream_id) const { | ||||
| const auto &iter = split_ori_stream_map_.find(stream_id); | |||||
| if (iter == split_ori_stream_map_.end()) { | |||||
| REPORT_INNER_ERROR("E19999", "Find original stream_id failed, split_stream_id=%ld", stream_id); | |||||
| GELOGE(INTERNAL_ERROR, "[CheckActivated][Check] Find original stream_id failed, split_stream_id=%ld", stream_id); | |||||
| return false; | |||||
| } | |||||
| int64_t ori_stream_id = iter->second; | |||||
| for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { | for (const auto &node : whole_graph_->GetNodes(whole_graph_->GetGraphUnknownFlag())) { | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| vector<uint32_t> active_streams; | vector<uint32_t> active_streams; | ||||
| if (op_desc == nullptr || !AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { | if (op_desc == nullptr || !AttrUtils::GetListInt(op_desc, ATTR_NAME_ACTIVE_STREAM_LIST, active_streams)) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (std::find(active_streams.begin(), active_streams.end(), stream_id) != active_streams.end()) { | |||||
| if (std::find(active_streams.begin(), active_streams.end(), stream_id) != active_streams.end() || | |||||
| std::find(active_streams.begin(), active_streams.end(), ori_stream_id) != active_streams.end()) { | |||||
| return true; | return true; | ||||
| } | } | ||||
| } | } | ||||
| @@ -66,7 +66,7 @@ class StreamAllocator { | |||||
| Status UpdateActiveStreamsForSwitchNode(NodePtr &switch_node); | Status UpdateActiveStreamsForSwitchNode(NodePtr &switch_node); | ||||
| Status InsertActiveNodesAfterSwitch(NodePtr &switch_nodes, std::vector<NodePtr> &switch_active_nodes); | Status InsertActiveNodesAfterSwitch(NodePtr &switch_nodes, std::vector<NodePtr> &switch_active_nodes); | ||||
| Status UpdateActiveStreamsForActiveNode(const std::vector<std::set<int64_t>> &split_streams, NodePtr &node); | Status UpdateActiveStreamsForActiveNode(const std::vector<std::set<int64_t>> &split_streams, NodePtr &node); | ||||
| Status UpdateActiveStreamsForSubgraphs() const; | |||||
| Status UpdateActiveStreamsForSubgraphs(); | |||||
| bool IsActivated(int64_t stream_id) const; | bool IsActivated(int64_t stream_id) const; | ||||
| Status SetActiveStreamsForLoop(); | Status SetActiveStreamsForLoop(); | ||||
| Status CheckStreamActived() const; | Status CheckStreamActived() const; | ||||
| @@ -114,6 +114,7 @@ class StreamAllocator { | |||||
| std::map<int64_t, std::set<NodePtr>> specific_activated_streams_nodes_map_; | std::map<int64_t, std::set<NodePtr>> specific_activated_streams_nodes_map_; | ||||
| std::map<NodePtr, int64_t> node_split_stream_map_; | std::map<NodePtr, int64_t> node_split_stream_map_; | ||||
| std::map<int64_t, int64_t> split_ori_stream_map_; | |||||
| std::map<ComputeGraphPtr, NodePtr> subgraph_first_active_node_map_; | std::map<ComputeGraphPtr, NodePtr> subgraph_first_active_node_map_; | ||||
| // send events corresponding to the node | // send events corresponding to the node | ||||
| @@ -123,4 +124,4 @@ class StreamAllocator { | |||||
| std::map<NodePtr, std::vector<uint32_t>> node_to_recv_events_; | std::map<NodePtr, std::vector<uint32_t>> node_to_recv_events_; | ||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_BUILD_STREAM_ALLOCATOR_H_ | |||||
| #endif // GE_GRAPH_BUILD_STREAM_ALLOCATOR_H_ | |||||
| @@ -272,20 +272,32 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc) { | |||||
| /// @brief Set Op _force_unknown_shape flag | /// @brief Set Op _force_unknown_shape flag | ||||
| /// @param [in] node | /// @param [in] node | ||||
| /// @param [in] force_unknown, set attribute if true | /// @param [in] force_unknown, set attribute if true | ||||
| /// @param [in] group_index, condition group index of node. | |||||
| /// @return | /// @return | ||||
| /// | /// | ||||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown) { | |||||
| GE_RT_VOID_CHECK_NOTNULL(node); | |||||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index) { | |||||
| if (!force_unknown) { | if (!force_unknown) { | ||||
| return; | return; | ||||
| } | } | ||||
| GELOGD("[%s] mark as force unknown shape node", node->GetName().c_str()); | |||||
| if (!AttrUtils::SetBool(node->GetOpDesc(), ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) { | |||||
| GE_RT_VOID_CHECK_NOTNULL(node); | |||||
| const auto &op_desc = node->GetOpDesc(); | |||||
| GE_RT_VOID_CHECK_NOTNULL(op_desc); | |||||
| // op_desc as AttrHolderAdapter valid, Set attribute always success, just log for check. | |||||
| GELOGD("Mark [%s] as force unknown shape node, group index: %ld", node->GetName().c_str(), group_index); | |||||
| if (!AttrUtils::SetBool(op_desc, ATTR_NAME_FORCE_UNKNOWN_SHAPE, force_unknown)) { | |||||
| REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), | REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), | ||||
| node->GetName().c_str(), node->GetType().c_str()); | node->GetName().c_str(), node->GetType().c_str()); | ||||
| GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), | GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_FORCE_UNKNOWN_SHAPE.c_str(), | ||||
| node->GetName().c_str(), node->GetType().c_str()); | node->GetName().c_str(), node->GetType().c_str()); | ||||
| } | } | ||||
| if (!AttrUtils::SetInt(op_desc, ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||||
| REPORT_INNER_ERROR("E19999", "Set Attr:%s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), | |||||
| node->GetName().c_str(), node->GetType().c_str()); | |||||
| GELOGE(FAILED, "[Set][Attr] %s fail for op:%s(%s)", ATTR_NAME_CONTROL_FLOW_GROUP.c_str(), | |||||
| node->GetName().c_str(), node->GetType().c_str()); | |||||
| } | |||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -129,9 +129,10 @@ bool IsUnknownShapeTensor(const GeTensorDesc &tensor_desc); | |||||
| /// @brief Set Op _force_unknown_shape flag | /// @brief Set Op _force_unknown_shape flag | ||||
| /// @param [in] node | /// @param [in] node | ||||
| /// @param [in] force_unknown, set attribute if true | /// @param [in] force_unknown, set attribute if true | ||||
| /// @param [in] group_index, condition group index of node. | |||||
| /// @return | /// @return | ||||
| /// | /// | ||||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown); | |||||
| void MarkForceUnknownShape(const NodePtr &node, bool force_unknown, int64_t group_index); | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_COMMON_OMG_UTIL_H_ | #endif // GE_GRAPH_COMMON_OMG_UTIL_H_ | ||||
| @@ -33,12 +33,12 @@ Status GraphLoader::UnloadModel(uint32_t model_id) { | |||||
| Status ret = model_manager->Stop(model_id); | Status ret = model_manager->Stop(model_id); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "UnloadModel: Stop failed. model id:%u", model_id); | |||||
| GELOGE(ret, "[Stop][Model] failed. model id:%u", model_id); | |||||
| } | } | ||||
| ret = model_manager->Unload(model_id); | ret = model_manager->Unload(model_id); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "UnloadModel: Unload failed. model id:%u", model_id); | |||||
| GELOGE(ret, "[Unload][Model] failed. model id:%u", model_id); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| GELOGI("UnLoad model success, model id:%u.", model_id); | GELOGI("UnLoad model success, model id:%u.", model_id); | ||||
| @@ -50,14 +50,13 @@ Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge | |||||
| GELOGI("Load model online begin."); | GELOGI("Load model online begin."); | ||||
| rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); | rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", | |||||
| GetContext().DeviceId(), rt_ret); | |||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| if (ge_root_model_ptr == nullptr) { | if (ge_root_model_ptr == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Check param ge_root_model_ptr nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Check param ge_root_model_ptr nullptr, check invalid"); | ||||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph] GE load graph model_ptr is nullptr."); | |||||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph][Check][Param] GE load graph model_ptr is nullptr."); | |||||
| return GE_GRAPH_PARAM_NULLPTR; | return GE_GRAPH_PARAM_NULLPTR; | ||||
| } | } | ||||
| @@ -65,12 +64,12 @@ Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge | |||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| Status ret = model_manager->LoadModelOnline(model_id, ge_root_model_ptr, listener); | Status ret = model_manager->LoadModelOnline(model_id, ge_root_model_ptr, listener); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "LoadModel: Load failed. ret = %u", ret); | |||||
| GELOGE(ret, "[Load][Model] Online failed. ret = %u, model_id:%u", ret, model_id); | |||||
| rt_ret = rtDeviceReset(GetContext().DeviceId()); | rt_ret = rtDeviceReset(GetContext().DeviceId()); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | ||||
| GetContext().DeviceId(), rt_ret); | GetContext().DeviceId(), rt_ret); | ||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| } | } | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -81,31 +80,31 @@ Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge | |||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | ||||
| GetContext().DeviceId(), rt_ret); | GetContext().DeviceId(), rt_ret); | ||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| ret = model_manager->Start(model_id); | ret = model_manager->Start(model_id); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| if (model_manager->Unload(model_id) != SUCCESS) { | if (model_manager->Unload(model_id) != SUCCESS) { | ||||
| GELOGE(ret, "LoadModel: Unload failed while trying to unload after a failed start."); | |||||
| GELOGE(ret, "[Unload][Model] failed while trying to unload after a failed start, model_id:%u.", model_id); | |||||
| } | } | ||||
| rt_ret = rtDeviceReset(GetContext().DeviceId()); | rt_ret = rtDeviceReset(GetContext().DeviceId()); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | ||||
| GetContext().DeviceId(), rt_ret); | GetContext().DeviceId(), rt_ret); | ||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| } | } | ||||
| GELOGE(ret, "LoadModel: Start failed."); | |||||
| GELOGE(ret, "[Start][Model] failed, model_id:%u.", model_id); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| rt_ret = rtDeviceReset(GetContext().DeviceId()); | rt_ret = rtDeviceReset(GetContext().DeviceId()); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | ||||
| GetContext().DeviceId(), rt_ret); | GetContext().DeviceId(), rt_ret); | ||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| GELOGI("Load model online success, model_id:%u.", model_id); | GELOGI("Load model online success, model_id:%u.", model_id); | ||||
| @@ -118,7 +117,7 @@ Status GraphLoader::GetMaxUsedMemory(uint32_t model_id, uint64_t &max_size) { | |||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| Status ret = model_manager->GetMaxUsedMemory(model_id, max_size); | Status ret = model_manager->GetMaxUsedMemory(model_id, max_size); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "GetMaxUsedMemory: GetMaxUsedMemory failed."); | |||||
| GELOGE(ret, "[Call][GetMaxUsedMemory] failed, model_id:%u.", model_id); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -127,21 +126,20 @@ Status GraphLoader::GetMaxUsedMemory(uint32_t model_id, uint64_t &max_size) { | |||||
| Status GraphLoader::LoadDataFromFile(const std::string &path, const std::string &key_path, int32_t priority, | Status GraphLoader::LoadDataFromFile(const std::string &path, const std::string &key_path, int32_t priority, | ||||
| ModelData &model_data) { | ModelData &model_data) { | ||||
| if (!CheckInputPathValid(path)) { | if (!CheckInputPathValid(path)) { | ||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID, "model path is invalid: %s", path.c_str()); | |||||
| GELOGE(ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID, "[Check][Param] model path is invalid:%s", path.c_str()); | |||||
| return ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID; | return ACL_ERROR_GE_EXEC_MODEL_PATH_INVALID; | ||||
| } | } | ||||
| GELOGI("Load model begin, model path is: %s", path.c_str()); | GELOGI("Load model begin, model path is: %s", path.c_str()); | ||||
| if (!key_path.empty() && !CheckInputPathValid(key_path)) { | if (!key_path.empty() && !CheckInputPathValid(key_path)) { | ||||
| REPORT_INNER_ERROR("E19999", "Param key_path:%s empty or invalid", | |||||
| key_path.c_str()); | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "decrypt_key path is invalid: %s", key_path.c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Param key_path:%s empty or invalid", key_path.c_str()); | |||||
| GELOGE(ACL_ERROR_GE_PARAM_INVALID, "[Check][Param] decrypt_key path is invalid:%s", key_path.c_str()); | |||||
| return ACL_ERROR_GE_PARAM_INVALID; | return ACL_ERROR_GE_PARAM_INVALID; | ||||
| } | } | ||||
| Status ret = ModelParserBase::LoadFromFile(path.c_str(), key_path.c_str(), priority, model_data); | Status ret = ModelParserBase::LoadFromFile(path.c_str(), key_path.c_str(), priority, model_data); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "LoadModelFromFile: Load failed. ret = %u", ret); | |||||
| GELOGE(ret, "[Call][LoadFromFile] failed. ret = %u, path:%s, key path:%s", ret, path.c_str(), key_path.c_str()); | |||||
| if (model_data.model_data != nullptr) { | if (model_data.model_data != nullptr) { | ||||
| delete[] static_cast<char *>(model_data.model_data); | delete[] static_cast<char *>(model_data.model_data); | ||||
| model_data.model_data = nullptr; | model_data.model_data = nullptr; | ||||
| @@ -156,18 +154,19 @@ Status GraphLoader::CommandHandle(const Command &command) { | |||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| Status ret = model_manager->HandleCommand(command); | Status ret = model_manager->HandleCommand(command); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "CommandHandle: Command Handle failed."); | |||||
| GELOGE(ret, "[Handle][Command] failed, module_index:%lu.", command.module_index); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| } catch (std::bad_alloc &) { | } catch (std::bad_alloc &) { | ||||
| REPORT_INNER_ERROR("E19999", "Bad memory allocation occur"); | REPORT_INNER_ERROR("E19999", "Bad memory allocation occur"); | ||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Command handle failed, bad memory allocation occur !"); | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Handle][Command] failed, " | |||||
| "bad memory allocation occur, module_index:%lu.", command.module_index); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | return ACL_ERROR_GE_MEMORY_ALLOCATION; | ||||
| } catch (...) { | } catch (...) { | ||||
| REPORT_INNER_ERROR("E19999", "Some exceptions occur"); | REPORT_INNER_ERROR("E19999", "Some exceptions occur"); | ||||
| GELOGE(FAILED, "Command handle failed, some exceptions occur !"); | |||||
| GELOGE(FAILED, "[Handle][Command] failed, some exceptions occur, module_index:%lu.", command.module_index); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -184,7 +183,7 @@ Status GraphLoader::LoadModelFromData(uint32_t &model_id, const ModelData &model | |||||
| Status ret = model_manager->LoadModelOffline( | Status ret = model_manager->LoadModelOffline( | ||||
| model_id, model_data, nullptr, dev_ptr, mem_size, weight_ptr, weight_size); | model_id, model_data, nullptr, dev_ptr, mem_size, weight_ptr, weight_size); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Load model failed, model_id:%u.", model_id); | |||||
| GELOGE(ret, "[Load][Model] failed, model_id:%u.", model_id); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| GELOGI("Load model success, model_id:%u.", model_id); | GELOGI("Load model success, model_id:%u.", model_id); | ||||
| @@ -210,7 +209,7 @@ Status GraphLoader::LoadModelWithQ(uint32_t &model_id, const ModelData &model_da | |||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| Status ret = model_manager->LoadModelWithQ(model_id, model_data, input_queue_ids, output_queue_ids); | Status ret = model_manager->LoadModelWithQ(model_id, model_data, input_queue_ids, output_queue_ids); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Load model with queue failed, model_id:%u.", model_id); | |||||
| GELOGE(ret, "[Load][Model] with queue failed, model_id:%u.", model_id); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -237,7 +236,7 @@ Status GraphLoader::ExecuteModel(uint32_t model_id, rtStream_t stream, bool asyn | |||||
| Status ret = model_manager->ExecuteModel(model_id, stream, async_mode, | Status ret = model_manager->ExecuteModel(model_id, stream, async_mode, | ||||
| input_data, input_desc, output_data, output_desc); | input_data, input_desc, output_data, output_desc); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Execute model failed, model_id:%u.", model_id); | |||||
| GELOGE(ret, "[Execute][Model] failed, model_id:%u.", model_id); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -250,7 +249,7 @@ Status GraphLoader::GetMemoryInfo(int64_t &free) { | |||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", | REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", | ||||
| GetContext().DeviceId(), rt_ret); | GetContext().DeviceId(), rt_ret); | ||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| size_t total_mem = 0; | size_t total_mem = 0; | ||||
| @@ -258,14 +257,14 @@ Status GraphLoader::GetMemoryInfo(int64_t &free) { | |||||
| rt_ret = rtMemGetInfo(&free_mem, &total_mem); | rt_ret = rtMemGetInfo(&free_mem, &total_mem); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtMemGetInfo failed, ret:0x%X", rt_ret); | REPORT_CALL_ERROR("E19999", "Call rtMemGetInfo failed, ret:0x%X", rt_ret); | ||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtMemGetInfo] failed, ret:0x%X", rt_ret); | |||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| rt_ret = rtDeviceReset(GetContext().DeviceId()); | rt_ret = rtDeviceReset(GetContext().DeviceId()); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | ||||
| GetContext().DeviceId(), rt_ret); | GetContext().DeviceId(), rt_ret); | ||||
| GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| // Add small page memory size | // Add small page memory size | ||||
| @@ -280,7 +279,8 @@ Status GraphLoader::DestroyAicpuKernel(uint64_t session_id, uint32_t model_id, u | |||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| Status ret = model_manager->DestroyAicpuKernel(session_id, model_id, sub_model_id); | Status ret = model_manager->DestroyAicpuKernel(session_id, model_id, sub_model_id); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Destroy aicpu kernel failed."); | |||||
| GELOGE(ret, "[Destroy][AicpuKernel] failed, session_id:%lu, model_id:%u, sub_model_id:%u.", | |||||
| session_id, model_id, sub_model_id); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -291,7 +291,7 @@ Status GraphLoader::DestroyAicpuSessionForInfer(uint32_t model_id) { | |||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| Status ret = model_manager->DestroyAicpuSessionForInfer(model_id); | Status ret = model_manager->DestroyAicpuSessionForInfer(model_id); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Destroy aicpu serrion for infer failed."); | |||||
| GELOGE(ret, "[Call][DestroyAicpuSessionForInfer] failed, model_id:%u.", model_id); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -40,7 +40,7 @@ | |||||
| #include "graph/load/model_manager/cpu_queue_schedule.h" | #include "graph/load/model_manager/cpu_queue_schedule.h" | ||||
| #include "graph/load/model_manager/model_manager.h" | #include "graph/load/model_manager/model_manager.h" | ||||
| #include "graph/load/model_manager/tbe_handle_store.h" | #include "graph/load/model_manager/tbe_handle_store.h" | ||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| #include "graph/manager/graph_mem_manager.h" | |||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| #include "graph/manager/trans_var_data_utils.h" | #include "graph/manager/trans_var_data_utils.h" | ||||
| #include "graph/manager/util/debug.h" | #include "graph/manager/util/debug.h" | ||||
| @@ -60,6 +60,8 @@ | |||||
| #include "graph/common/local_context.h" | #include "graph/common/local_context.h" | ||||
| #include "common/formats/utils/formats_trans_utils.h" | #include "common/formats/utils/formats_trans_utils.h" | ||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include "graph/build/memory/block_mem_assigner.h" | |||||
| #include "graph/manager/session_scope_mem_allocator.h" | |||||
| // create std::thread, catch exceptions using try/catch | // create std::thread, catch exceptions using try/catch | ||||
| #define CREATE_STD_THREAD(thread_id, func, args) \ | #define CREATE_STD_THREAD(thread_id, func, args) \ | ||||
| @@ -168,7 +170,6 @@ DavinciModel::DavinciModel(int32_t priority, const std::shared_ptr<ModelListener | |||||
| mem_base_(nullptr), | mem_base_(nullptr), | ||||
| is_inner_mem_base_(false), | is_inner_mem_base_(false), | ||||
| is_inner_weight_base_(false), | is_inner_weight_base_(false), | ||||
| is_inner_p2p_mem_base_(false), | |||||
| data_inputer_(nullptr), | data_inputer_(nullptr), | ||||
| load_begin_time_(0), | load_begin_time_(0), | ||||
| load_end_time_(0), | load_end_time_(0), | ||||
| @@ -236,7 +237,7 @@ DavinciModel::~DavinciModel() { | |||||
| FreeFeatureMapMem(); | FreeFeatureMapMem(); | ||||
| FreeP2PMem(); | |||||
| FreeExMem(); | |||||
| OpDebugUnRegister(); | OpDebugUnRegister(); | ||||
| @@ -389,7 +390,6 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { | |||||
| is_feature_map_mem_has_inited_ = true; | is_feature_map_mem_has_inited_ = true; | ||||
| std::size_t data_size = TotalMemSize(); | std::size_t data_size = TotalMemSize(); | ||||
| std::size_t p2p_data_size = P2PMemInfos().at(RT_MEMORY_P2P_DDR).memory_size; | |||||
| if ((dev_ptr != nullptr) && (mem_size < TotalMemSize())) { | if ((dev_ptr != nullptr) && (mem_size < TotalMemSize())) { | ||||
| REPORT_INNER_ERROR("E19999", "Param dev_ptr is nullptr or mem_size:%zu < ge_model.mem_size:%zu, " | REPORT_INNER_ERROR("E19999", "Param dev_ptr is nullptr or mem_size:%zu < ge_model.mem_size:%zu, " | ||||
| @@ -400,7 +400,6 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { | |||||
| } | } | ||||
| mem_base_ = static_cast<uint8_t *>(dev_ptr); | mem_base_ = static_cast<uint8_t *>(dev_ptr); | ||||
| p2p_mem_base_ = static_cast<uint8_t *>(dev_ptr); | |||||
| is_inner_mem_base_ = false; | is_inner_mem_base_ = false; | ||||
| if (TotalMemSize() && mem_base_ == nullptr) { | if (TotalMemSize() && mem_base_ == nullptr) { | ||||
| @@ -422,24 +421,13 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { | |||||
| is_inner_mem_base_ = true; | is_inner_mem_base_ = true; | ||||
| } | } | ||||
| if (p2p_data_size != 0) { | |||||
| p2p_mem_base_ = MallocP2PMem(p2p_data_size); | |||||
| if (p2p_mem_base_ == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "MallocFeatureMapMem fail, p2p_data_size:%zu, model_id:%u, check invalid", | |||||
| p2p_data_size, model_id_); | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Alloc][Memory] for p2p failed, size:%zu, model_id:%u", | |||||
| p2p_data_size, model_id_); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| GELOGI("InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, | |||||
| p2p_mem_base_, p2p_data_size); | |||||
| is_inner_p2p_mem_base_ = true; | |||||
| if (!runtime_param_.memory_infos.empty()) { | |||||
| GE_CHK_STATUS_RET(MallocExMem(), "MallocExMem failed."); | |||||
| } | } | ||||
| GE_CHK_STATUS_RET(InitVariableMem(), "[Init][VariableMemory] failed, model_id:%u", model_id_); | GE_CHK_STATUS_RET(InitVariableMem(), "[Init][VariableMemory] failed, model_id:%u", model_id_); | ||||
| runtime_param_.mem_base = mem_base_; | runtime_param_.mem_base = mem_base_; | ||||
| runtime_param_.weight_base = weights_mem_base_; | runtime_param_.weight_base = weights_mem_base_; | ||||
| runtime_param_.memory_infos[RT_MEMORY_P2P_DDR].memory_base = p2p_mem_base_; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -465,7 +453,6 @@ Status DavinciModel::InitVariableMem() { | |||||
| void DavinciModel::InitRuntimeParams() { | void DavinciModel::InitRuntimeParams() { | ||||
| int64_t value = 0; | int64_t value = 0; | ||||
| bool ret; | bool ret; | ||||
| MemInfo p2p_mem_info; | |||||
| ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_MEMORY_SIZE, value); | ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_MEMORY_SIZE, value); | ||||
| runtime_param_.mem_size = ret ? (uint64_t)value : 0; | runtime_param_.mem_size = ret ? (uint64_t)value : 0; | ||||
| ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_WEIGHT_SIZE, value); | ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_WEIGHT_SIZE, value); | ||||
| @@ -490,16 +477,18 @@ void DavinciModel::InitRuntimeParams() { | |||||
| runtime_param_.var_size = ret ? (uint64_t)value : 0; | runtime_param_.var_size = ret ? (uint64_t)value : 0; | ||||
| session_id_ = runtime_param_.session_id; | session_id_ = runtime_param_.session_id; | ||||
| ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_P2P_MEMORY_SIZE, value); | ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_P2P_MEMORY_SIZE, value); | ||||
| p2p_mem_info.memory_size = ret ? (uint64_t)value : 0; | |||||
| MemInfo p2p_mem_info; | |||||
| p2p_mem_info.memory_size = static_cast<size_t>(ret ? value : 0); | |||||
| p2p_mem_info.memory_type = RT_MEMORY_P2P_DDR; | |||||
| p2p_mem_info.memory_key = "_p"; | |||||
| runtime_param_.memory_infos[RT_MEMORY_P2P_DDR] = std::move(p2p_mem_info); | runtime_param_.memory_infos[RT_MEMORY_P2P_DDR] = std::move(p2p_mem_info); | ||||
| GELOGI( | |||||
| "InitRuntimeParams(), session_id:%lu, stream_num:%u, event_num:%u, label_num:%u, " | |||||
| "logic_mem_base:0x%lx, logic_weight_base:0x%lx, logic_var_base:0x%lx, " | |||||
| "memory_size:%lu, weight_size:%lu, var_size:%lu", | |||||
| runtime_param_.session_id, runtime_param_.stream_num, runtime_param_.event_num, runtime_param_.label_num, | |||||
| runtime_param_.logic_mem_base, runtime_param_.logic_weight_base, runtime_param_.logic_var_base, | |||||
| runtime_param_.mem_size, runtime_param_.weight_size, runtime_param_.var_size); | |||||
| ret = ge::AttrUtils::GetInt(ge_model_, ATTR_MODEL_SESSION_SCOPE_MEMORY_SIZE, value); | |||||
| MemInfo session_scope_mem_info; | |||||
| session_scope_mem_info.memory_size = static_cast<size_t>(ret ? value : 0); | |||||
| runtime_param_.memory_infos[kSessionScopeMemory | RT_MEMORY_HBM] = std::move(session_scope_mem_info); | |||||
| GELOGI("InitRuntimeParams(), %s.", runtime_param_.ToString().c_str()); | |||||
| } | } | ||||
| void DavinciModel::CheckHasHcomOp(const ComputeGraphPtr &compute_graph) { | void DavinciModel::CheckHasHcomOp(const ComputeGraphPtr &compute_graph) { | ||||
| @@ -4089,14 +4078,15 @@ Status DavinciModel::InitEntryTask() { | |||||
| uint8_t *DavinciModel::MallocFeatureMapMem(size_t data_size) { | uint8_t *DavinciModel::MallocFeatureMapMem(size_t data_size) { | ||||
| uint8_t *mem_base = nullptr; | uint8_t *mem_base = nullptr; | ||||
| const string purpose("feature map,used for op input and output."); | const string purpose("feature map,used for op input and output."); | ||||
| char ge_static_mem_env[MMPA_MAX_PATH] = { 0x00 }; | |||||
| char ge_static_mem_env[MMPA_MAX_PATH] = {0x00}; | |||||
| INT32 res = mmGetEnv(kEnvGeuseStaticMemory, ge_static_mem_env, MMPA_MAX_PATH); | INT32 res = mmGetEnv(kEnvGeuseStaticMemory, ge_static_mem_env, MMPA_MAX_PATH); | ||||
| if (res == EN_OK) { | if (res == EN_OK) { | ||||
| data_size = static_cast<size_t>(VarManager::Instance(session_id_)->GetGraphMemoryMaxSize()); | data_size = static_cast<size_t>(VarManager::Instance(session_id_)->GetGraphMemoryMaxSize()); | ||||
| string memory_key = std::to_string(0) + "_f"; | string memory_key = std::to_string(0) + "_f"; | ||||
| mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, memory_key, data_size, GetDeviceId()); | |||||
| mem_base = | |||||
| MemManager::Instance().MemInstance(RT_MEMORY_HBM).MallocMemory(purpose, memory_key, data_size, GetDeviceId()); | |||||
| } else { | } else { | ||||
| mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, data_size, GetDeviceId()); | |||||
| mem_base = MemManager::Instance().MemInstance(RT_MEMORY_HBM).MallocMemory(purpose, data_size, GetDeviceId()); | |||||
| } | } | ||||
| if (mem_base != nullptr) { | if (mem_base != nullptr) { | ||||
| @@ -4105,83 +4095,119 @@ uint8_t *DavinciModel::MallocFeatureMapMem(size_t data_size) { | |||||
| return mem_base; | return mem_base; | ||||
| } | } | ||||
| uint8_t *DavinciModel::MallocP2PMem(size_t p2p_data_size) { | |||||
| uint8_t *p2p_mem_base = nullptr; | |||||
| const string purpose("p2p memory, used for some op related to hcom"); | |||||
| if (std::getenv(kEnvGeuseStaticMemory) != nullptr) { | |||||
| string p2p_memory_key = std::to_string(0) + "_p"; | |||||
| p2p_mem_base = | |||||
| MemManager::Instance(RT_MEMORY_P2P_DDR)->MallocMemory(purpose, p2p_memory_key, p2p_data_size, GetDeviceId()); | |||||
| } else { | |||||
| p2p_mem_base = MemManager::Instance(RT_MEMORY_P2P_DDR)->MallocMemory(purpose, p2p_data_size, GetDeviceId()); | |||||
| Status DavinciModel::MallocExMem() { | |||||
| char ge_static_mem_env[MMPA_MAX_PATH] = {0x00}; | |||||
| INT32 res_static_memory = mmGetEnv(kEnvGeuseStaticMemory, ge_static_mem_env, MMPA_MAX_PATH); | |||||
| for (auto it : runtime_param_.memory_infos) { | |||||
| auto mem_size = it.second.memory_size; | |||||
| if (mem_size == 0) { | |||||
| continue; | |||||
| } | |||||
| bool sessoion_scope = ((kSessionScopeMemory & it.first) == kSessionScopeMemory); | |||||
| auto mem_type = it.first & kMemoryTypeMask; | |||||
| uint8_t *mem_base = nullptr; | |||||
| const string purpose("p2p memory, used for some op related to hcom or session scope memory"); | |||||
| if (sessoion_scope) { | |||||
| mem_base = MemManager::Instance().SessionScopeMemInstance(mem_type).Malloc(mem_size, runtime_param_.session_id); | |||||
| } else if (res_static_memory == EN_OK) { | |||||
| string memory_key = std::to_string(0) + it.second.memory_key; | |||||
| mem_base = | |||||
| MemManager::Instance().MemInstance(mem_type).MallocMemory(purpose, memory_key, mem_size, GetDeviceId()); | |||||
| } else { | |||||
| mem_base = MemManager::Instance().MemInstance(mem_type).MallocMemory(purpose, mem_size, GetDeviceId()); | |||||
| } | |||||
| if (mem_base == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "MallocExMem fail, type:%ld size:%zu, model_id:%u, check invalid", | |||||
| mem_type, mem_size, model_id_); | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc ex memory failed, type:%ld size: %zu", mem_type, mem_size); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } | |||||
| it.second.memory_base = mem_base; | |||||
| GELOGI("InitFeatureMapAndP2PMem graph_%u MallocMemory type[F] mem_type[%ld] mem_addr[%p] mem_size[%zu]", | |||||
| runtime_param_.graph_id, mem_type, mem_base, mem_size); | |||||
| } | } | ||||
| return p2p_mem_base; | |||||
| return SUCCESS; | |||||
| } | } | ||||
| uint8_t *DavinciModel::MallocWeightsMem(size_t weights_size) { | uint8_t *DavinciModel::MallocWeightsMem(size_t weights_size) { | ||||
| uint8_t *weights_mem_base = nullptr; | uint8_t *weights_mem_base = nullptr; | ||||
| const string purpose("weights memory in inference network."); | const string purpose("weights memory in inference network."); | ||||
| char ge_static_mem_env[MMPA_MAX_PATH] = { 0x00 }; | |||||
| char ge_static_mem_env[MMPA_MAX_PATH] = {0x00}; | |||||
| INT32 res = mmGetEnv(kEnvGeuseStaticMemory, ge_static_mem_env, MMPA_MAX_PATH); | INT32 res = mmGetEnv(kEnvGeuseStaticMemory, ge_static_mem_env, MMPA_MAX_PATH); | ||||
| if (res == EN_OK) { | if (res == EN_OK) { | ||||
| string weight_memory_key = std::to_string(0) + "_w"; | string weight_memory_key = std::to_string(0) + "_w"; | ||||
| weights_mem_base = | |||||
| MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, weight_memory_key, weights_size, GetDeviceId()); | |||||
| weights_mem_base = MemManager::Instance() | |||||
| .MemInstance(RT_MEMORY_HBM) | |||||
| .MallocMemory(purpose, weight_memory_key, weights_size, GetDeviceId()); | |||||
| } else { | } else { | ||||
| weights_mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, weights_size, GetDeviceId()); | |||||
| weights_mem_base = | |||||
| MemManager::Instance().MemInstance(RT_MEMORY_HBM).MallocMemory(purpose, weights_size, GetDeviceId()); | |||||
| } | } | ||||
| return weights_mem_base; | return weights_mem_base; | ||||
| } | } | ||||
| void DavinciModel::FreeFeatureMapMem() { | void DavinciModel::FreeFeatureMapMem() { | ||||
| char ge_static_mem_env[MMPA_MAX_PATH] = { 0x00 }; | |||||
| char ge_static_mem_env[MMPA_MAX_PATH] = {0x00}; | |||||
| INT32 res = mmGetEnv(kEnvGeuseStaticMemory, ge_static_mem_env, MMPA_MAX_PATH); | INT32 res = mmGetEnv(kEnvGeuseStaticMemory, ge_static_mem_env, MMPA_MAX_PATH); | ||||
| if (res == EN_OK && is_inner_mem_base_) { | if (res == EN_OK && is_inner_mem_base_) { | ||||
| string weight_memory_key = std::to_string(0) + "_f"; | string weight_memory_key = std::to_string(0) + "_f"; | ||||
| if (MemManager::Instance(RT_MEMORY_HBM)->GetMemoryAddr(weight_memory_key) != nullptr) { | |||||
| GE_CHK_STATUS(MemManager::Instance(RT_MEMORY_HBM)->FreeMemory(weight_memory_key, GetDeviceId()), | |||||
| "[Free][Memory] failed, model_id:%u", model_id_); | |||||
| if (MemManager::Instance().MemInstance(RT_MEMORY_HBM).GetMemoryAddr(weight_memory_key) != nullptr) { | |||||
| GE_CHK_STATUS(MemManager::Instance().MemInstance(RT_MEMORY_HBM).FreeMemory(weight_memory_key, GetDeviceId()), | |||||
| "failed to free weight memory"); | |||||
| } | } | ||||
| mem_base_ = nullptr; | mem_base_ = nullptr; | ||||
| } else { | } else { | ||||
| GE_IF_BOOL_EXEC(mem_base_ != nullptr && is_inner_mem_base_, | |||||
| GE_CHK_STATUS(MemManager::Instance(RT_MEMORY_HBM)->FreeMemory(mem_base_, GetDeviceId()), | |||||
| "[Free][Memory] failed, model_id:%u", model_id_); | |||||
| mem_base_ = nullptr); | |||||
| GE_IF_BOOL_EXEC( | |||||
| mem_base_ != nullptr && is_inner_mem_base_, | |||||
| GE_CHK_STATUS(MemManager::Instance().MemInstance(RT_MEMORY_HBM).FreeMemory(mem_base_, GetDeviceId()), | |||||
| "failed to free feature_map memory"); | |||||
| mem_base_ = nullptr); | |||||
| } | } | ||||
| } | } | ||||
| void DavinciModel::FreeP2PMem() { | |||||
| if (std::getenv(kEnvGeuseStaticMemory) != nullptr) { | |||||
| std::string p2p_memory_key = std::to_string(0) + "_p"; | |||||
| if (MemManager::Instance(RT_MEMORY_P2P_DDR)->GetMemoryAddr(p2p_memory_key) != nullptr) { | |||||
| GE_CHK_STATUS(MemManager::Instance(RT_MEMORY_P2P_DDR)->FreeMemory(p2p_memory_key, GetDeviceId()), | |||||
| "[Free][Memory] failed, model_id:%u", model_id_); | |||||
| void DavinciModel::FreeExMem() { | |||||
| char ge_static_mem_env[MMPA_MAX_PATH] = {0x00}; | |||||
| INT32 res_static_memory = mmGetEnv(kEnvGeuseStaticMemory, ge_static_mem_env, MMPA_MAX_PATH); | |||||
| for (auto it : runtime_param_.memory_infos) { | |||||
| // free when session destory | |||||
| if ((kSessionScopeMemory & it.first) == kSessionScopeMemory) { | |||||
| continue; | |||||
| } | |||||
| auto mem_type = it.first & kMemoryTypeMask; | |||||
| if (res_static_memory == EN_OK) { | |||||
| std::string memory_key = std::to_string(0) + it.second.memory_key; | |||||
| if (MemManager::Instance().MemInstance(mem_type).GetMemoryAddr(memory_key) != nullptr) { | |||||
| GE_CHK_STATUS(MemManager::Instance().MemInstance(mem_type).FreeMemory(memory_key, GetDeviceId()), | |||||
| "failed to free memory"); | |||||
| } | |||||
| it.second.memory_base = nullptr; | |||||
| } else { | |||||
| GE_IF_BOOL_EXEC( | |||||
| it.second.memory_base != nullptr, | |||||
| GE_CHK_STATUS(MemManager::Instance().MemInstance(mem_type).FreeMemory(it.second.memory_base, GetDeviceId()), | |||||
| "failed to free memory"); | |||||
| it.second.memory_base = nullptr); | |||||
| } | } | ||||
| p2p_mem_base_ = nullptr; | |||||
| } else { | |||||
| GE_IF_BOOL_EXEC(p2p_mem_base_ != nullptr && is_inner_mem_base_, | |||||
| GE_CHK_STATUS(MemManager::Instance(RT_MEMORY_P2P_DDR)->FreeMemory(p2p_mem_base_, GetDeviceId()), | |||||
| "[Free][Memory] failed, model_id:%u", model_id_); | |||||
| p2p_mem_base_ = nullptr); | |||||
| } | } | ||||
| } | } | ||||
| void DavinciModel::FreeWeightsMem() { | void DavinciModel::FreeWeightsMem() { | ||||
| char ge_static_mem_env[MMPA_MAX_PATH] = { 0x00 }; | |||||
| char ge_static_mem_env[MMPA_MAX_PATH] = {0x00}; | |||||
| INT32 res = mmGetEnv(kEnvGeuseStaticMemory, ge_static_mem_env, MMPA_MAX_PATH); | INT32 res = mmGetEnv(kEnvGeuseStaticMemory, ge_static_mem_env, MMPA_MAX_PATH); | ||||
| if (res == EN_OK) { | if (res == EN_OK) { | ||||
| string memory_key = std::to_string(0) + "_w"; | string memory_key = std::to_string(0) + "_w"; | ||||
| if (MemManager::Instance(RT_MEMORY_HBM)->GetMemoryAddr(memory_key) != nullptr) { | |||||
| GE_CHK_STATUS(MemManager::Instance(RT_MEMORY_HBM)->FreeMemory(memory_key, GetDeviceId()), | |||||
| "[Free][Memory] failed, model_id:%u", model_id_); | |||||
| if (MemManager::Instance().MemInstance(RT_MEMORY_HBM).GetMemoryAddr(memory_key) != nullptr) { | |||||
| GE_CHK_STATUS(MemManager::Instance().MemInstance(RT_MEMORY_HBM).FreeMemory(memory_key, GetDeviceId()), | |||||
| "failed to free feature_map memory"); | |||||
| } | } | ||||
| weights_mem_base_ = nullptr; | weights_mem_base_ = nullptr; | ||||
| } else { | } else { | ||||
| GE_IF_BOOL_EXEC(weights_mem_base_ != nullptr && weights_mem_base_ != mem_base_ && is_inner_weight_base_, | |||||
| GE_CHK_STATUS(MemManager::Instance(RT_MEMORY_HBM)->FreeMemory(weights_mem_base_, GetDeviceId()), | |||||
| "[Free][Memory] failed, model_id:%u", model_id_); | |||||
| weights_mem_base_ = nullptr); | |||||
| GE_IF_BOOL_EXEC( | |||||
| weights_mem_base_ != nullptr && weights_mem_base_ != mem_base_ && is_inner_weight_base_, | |||||
| GE_CHK_STATUS(MemManager::Instance().MemInstance(RT_MEMORY_HBM).FreeMemory(weights_mem_base_, GetDeviceId()), | |||||
| "failed to free weight memory"); | |||||
| weights_mem_base_ = nullptr); | |||||
| } | } | ||||
| } | } | ||||
| @@ -248,8 +248,6 @@ class DavinciModel { | |||||
| // get total mem size | // get total mem size | ||||
| size_t TotalMemSize() const { return runtime_param_.mem_size; } | size_t TotalMemSize() const { return runtime_param_.mem_size; } | ||||
| const map<uint32_t, MemInfo> &P2PMemInfos() const { return runtime_param_.memory_infos; } | |||||
| // model name | // model name | ||||
| string Name() const { return name_; } | string Name() const { return name_; } | ||||
| @@ -586,10 +584,8 @@ class DavinciModel { | |||||
| // memory address of model | // memory address of model | ||||
| uintptr_t fixed_mem_base_; // Initial of mem_base_, keep forever. | uintptr_t fixed_mem_base_; // Initial of mem_base_, keep forever. | ||||
| uint8_t *mem_base_; | uint8_t *mem_base_; | ||||
| uint8_t *p2p_mem_base_; | |||||
| bool is_inner_mem_base_; | bool is_inner_mem_base_; | ||||
| bool is_inner_weight_base_; | bool is_inner_weight_base_; | ||||
| bool is_inner_p2p_mem_base_; | |||||
| // input data manager | // input data manager | ||||
| DataInputer *data_inputer_; | DataInputer *data_inputer_; | ||||
| int64_t load_begin_time_; | int64_t load_begin_time_; | ||||
| @@ -668,13 +664,13 @@ class DavinciModel { | |||||
| uint8_t *MallocWeightsMem(size_t weights_size); | uint8_t *MallocWeightsMem(size_t weights_size); | ||||
| uint8_t *MallocP2PMem(size_t p2p_data_size); | |||||
| Status MallocExMem(); | |||||
| void FreeFeatureMapMem(); | void FreeFeatureMapMem(); | ||||
| void FreeWeightsMem(); | void FreeWeightsMem(); | ||||
| void FreeP2PMem(); | |||||
| void FreeExMem(); | |||||
| void ReleaseTask(); | void ReleaseTask(); | ||||
| @@ -310,7 +310,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
| std::lock_guard<std::mutex> lock(exeception_infos_mutex_); | std::lock_guard<std::mutex> lock(exeception_infos_mutex_); | ||||
| auto instance = ModelManager::GetInstance(); | auto instance = ModelManager::GetInstance(); | ||||
| if (instance == nullptr) { | if (instance == nullptr) { | ||||
| GELOGE(FAILED, "Instance is nullptr"); | |||||
| GELOGE(FAILED, "[Get][Instance] failed, as ret is nullptr"); | |||||
| return; | return; | ||||
| } | } | ||||
| instance->AddExceptionInfo(*rt_exception_info); | instance->AddExceptionInfo(*rt_exception_info); | ||||
| @@ -21,14 +21,15 @@ | |||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| #include "graph/types.h" | #include "graph/types.h" | ||||
| #include "graph/build/memory/block_mem_assigner.h" | |||||
| #define VALIDATE_MEM_RANGE(OP, SIZE, OFFSET) \ | #define VALIDATE_MEM_RANGE(OP, SIZE, OFFSET) \ | ||||
| do { \ | do { \ | ||||
| if (SIZE <= static_cast<uint64_t>(OFFSET)) { \ | if (SIZE <= static_cast<uint64_t>(OFFSET)) { \ | ||||
| REPORT_INNER_ERROR("E19999", \ | |||||
| "Node:%s(%s) offset:%ld out of range size:%lu, check invalid", \ | |||||
| REPORT_INNER_ERROR("E19999", "Node:%s(%s) offset:%ld out of range size:%lu, check invalid", \ | |||||
| OP->GetName().c_str(), OP->GetType().c_str(), OFFSET, SIZE); \ | OP->GetName().c_str(), OP->GetType().c_str(), OFFSET, SIZE); \ | ||||
| GELOGE(OUT_OF_MEMORY, "Node: %s, memory out of range[%lu: %ld]", OP->GetName().c_str(), SIZE, OFFSET); \ | |||||
| GELOGE(OUT_OF_MEMORY, "[Check][Param]Node: %s, memory out of range[%lu: %ld]", \ | |||||
| OP->GetName().c_str(), SIZE, OFFSET); \ | |||||
| return {}; \ | return {}; \ | ||||
| } \ | } \ | ||||
| } while (0) | } while (0) | ||||
| @@ -311,8 +312,9 @@ vector<void *> ModelUtils::GetInputDataAddrs(const RuntimeParam &model_param, Co | |||||
| REPORT_INNER_ERROR("E19999", "Attr:%s, memory_type.size:%zu != input_desc.size:%zu, op:%s(%s), check invalid", | REPORT_INNER_ERROR("E19999", "Attr:%s, memory_type.size:%zu != input_desc.size:%zu, op:%s(%s), check invalid", | ||||
| ATTR_NAME_INPUT_MEM_TYPE_LIST.c_str(), v_memory_type.size(), inputs_size, | ATTR_NAME_INPUT_MEM_TYPE_LIST.c_str(), v_memory_type.size(), inputs_size, | ||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| GELOGE(PARAM_INVALID, "Fusion: check input size failed, op: %s, input v_memory_type size: %zu input numbers: %zu", | |||||
| op_desc->GetName().c_str(), v_memory_type.size(), inputs_size); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s, memory_type.size:%zu != input_desc.size:%zu, op:%s(%s)", | |||||
| ATTR_NAME_INPUT_MEM_TYPE_LIST.c_str(), v_memory_type.size(), inputs_size, | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| return v_input_data_addr; | return v_input_data_addr; | ||||
| } | } | ||||
| for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) { | for (size_t i = 0; i < op_desc->GetAllInputsSize(); ++i) { | ||||
| @@ -394,8 +396,7 @@ Status ModelUtils::GetVarAddr(const RuntimeParam &model_param, const ConstOpDesc | |||||
| case RT_MEMORY_RDMA_HBM: | case RT_MEMORY_RDMA_HBM: | ||||
| if (offset < 0) { | if (offset < 0) { | ||||
| REPORT_INNER_ERROR("E19999", "Param offset:%ld < 0, check invalid", offset); | REPORT_INNER_ERROR("E19999", "Param offset:%ld < 0, check invalid", offset); | ||||
| GELOGE(PARAM_INVALID, "rdma var addr is invalid, addr=%p", | |||||
| reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(offset))); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Param offset:%ld cannot be negative", offset); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| var_addr = reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(offset)); | var_addr = reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(offset)); | ||||
| @@ -405,9 +406,9 @@ Status ModelUtils::GetVarAddr(const RuntimeParam &model_param, const ConstOpDesc | |||||
| var_addr = model_param.var_base + offset - model_param.logic_var_base; | var_addr = model_param.var_base + offset - model_param.logic_var_base; | ||||
| break; | break; | ||||
| default: | default: | ||||
| REPORT_INNER_ERROR("E19999", "Get mem_type:%d for offset:%ld is unsupported, check invalid", | |||||
| mem_type, offset); | |||||
| GELOGE(PARAM_INVALID, "unsupported memory type %u", mem_type); | |||||
| REPORT_INNER_ERROR("E19999", "Get mem_type:%d for offset:%ld is unsupported, check invalid", mem_type, offset); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Get mem_type:%d for offset:%ld is unsupported, check invalid", | |||||
| mem_type, offset); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| GE_CHECK_NOTNULL(var_addr); | GE_CHECK_NOTNULL(var_addr); | ||||
| @@ -435,9 +436,9 @@ vector<void *> ModelUtils::GetOutputDataAddrs(const RuntimeParam &model_param, C | |||||
| REPORT_INNER_ERROR("E19999", "Attr:%s, memory_type.size:%zu != output_desc.size:%zu, op:%s(%s), check invalid", | REPORT_INNER_ERROR("E19999", "Attr:%s, memory_type.size:%zu != output_desc.size:%zu, op:%s(%s), check invalid", | ||||
| ATTR_NAME_OUTPUT_MEM_TYPE_LIST.c_str(), v_memory_type.size(), outputs_size, | ATTR_NAME_OUTPUT_MEM_TYPE_LIST.c_str(), v_memory_type.size(), outputs_size, | ||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| GELOGE(PARAM_INVALID, | |||||
| "Fusion: check output size failed, op: %s, output v_memory_type size: %lu output numbers: %zu", | |||||
| op_desc->GetName().c_str(), v_memory_type.size(), outputs_size); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s, memory_type.size:%zu != output_desc.size:%zu, op:%s(%s)", | |||||
| ATTR_NAME_OUTPUT_MEM_TYPE_LIST.c_str(), v_memory_type.size(), outputs_size, | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| return v_output_data_addr; | return v_output_data_addr; | ||||
| } | } | ||||
| for (size_t i = 0; i < outputs_size; ++i) { | for (size_t i = 0; i < outputs_size; ++i) { | ||||
| @@ -520,10 +521,16 @@ vector<void *> ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param | |||||
| bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, TVM_ATTR_NAME_WORKSPACE_TYPE, v_memory_type); | bool has_mem_type_attr = ge::AttrUtils::GetListInt(op_desc, TVM_ATTR_NAME_WORKSPACE_TYPE, v_memory_type); | ||||
| bool has_mem_type_workspace = | bool has_mem_type_workspace = | ||||
| ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_WORKSPACE_TYPE_LIST, workspace_memory_type); | ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_WORKSPACE_TYPE_LIST, workspace_memory_type); | ||||
| vector<int32_t> workspace_no_reuse_scope; | |||||
| bool has_workspace_no_reuse_scope = | |||||
| ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_WORKSPACE_MEMORY_NO_REUSE_SCOPE, workspace_no_reuse_scope); | |||||
| for (size_t i = 0; i < v_workspace_bytes.size(); ++i) { | for (size_t i = 0; i < v_workspace_bytes.size(); ++i) { | ||||
| // Temporary solution, the aicpu workspace of multiple images cannot be shared. | // Temporary solution, the aicpu workspace of multiple images cannot be shared. | ||||
| if (has_workspace_reuse && i < workspace_reuse_flag.size() && !workspace_reuse_flag[i] && | |||||
| !model_param.is_single_op) { | |||||
| bool aicpu_work_space = (has_workspace_reuse && i < workspace_reuse_flag.size() && !workspace_reuse_flag[i] && | |||||
| !model_param.is_single_op); | |||||
| if (aicpu_work_space) { | |||||
| void *mem_addr = model_param.aicpu_mem_mall->Acquire(v_workspace_offset[i], v_workspace_bytes[i]); | void *mem_addr = model_param.aicpu_mem_mall->Acquire(v_workspace_offset[i], v_workspace_bytes[i]); | ||||
| v_workspace_data_addr.push_back(mem_addr); | v_workspace_data_addr.push_back(mem_addr); | ||||
| GELOGI( | GELOGI( | ||||
| @@ -554,7 +561,13 @@ vector<void *> ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param | |||||
| model_param.graph_id, op_desc->GetName().c_str(), i, v_workspace_offset[i], v_workspace_bytes[i]); | model_param.graph_id, op_desc->GetName().c_str(), i, v_workspace_offset[i], v_workspace_bytes[i]); | ||||
| } else { | } else { | ||||
| VALIDATE_MEM_RANGE(op_desc, model_param.mem_size, v_workspace_offset[i]); | VALIDATE_MEM_RANGE(op_desc, model_param.mem_size, v_workspace_offset[i]); | ||||
| uint8_t *mem_addr = model_param.mem_base + v_workspace_offset[i]; | |||||
| uint8_t *mem_addr = nullptr; | |||||
| bool session_scope_memory = (has_workspace_no_reuse_scope) && (i < workspace_no_reuse_scope.size()); | |||||
| if (session_scope_memory) { | |||||
| mem_addr = model_param.memory_infos.at(kSessionScopeMemory | RT_MEMORY_HBM).memory_base + v_workspace_offset[i]; | |||||
| } else { | |||||
| mem_addr = model_param.mem_base + v_workspace_offset[i]; | |||||
| } | |||||
| v_workspace_data_addr.push_back(mem_addr); | v_workspace_data_addr.push_back(mem_addr); | ||||
| GELOGI("[IMAS]GetWorkspaceDataAddrs graph_%u type[F] name[%s] workspace[%zu] offset[%ld] bytes[%ld] memaddr[%p]", | GELOGI("[IMAS]GetWorkspaceDataAddrs graph_%u type[F] name[%s] workspace[%zu] offset[%ld] bytes[%ld] memaddr[%p]", | ||||
| model_param.graph_id, op_desc->GetName().c_str(), i, v_workspace_offset[i], v_workspace_bytes[i], | model_param.graph_id, op_desc->GetName().c_str(), i, v_workspace_offset[i], v_workspace_bytes[i], | ||||
| @@ -587,7 +600,7 @@ Status ModelUtils::GetRtAddress(const RuntimeParam ¶m, uintptr_t logic_addr, | |||||
| } else if (logic_addr != 0) { | } else if (logic_addr != 0) { | ||||
| mem_addr = nullptr; | mem_addr = nullptr; | ||||
| REPORT_INNER_ERROR("E19999", "Check param logic addr:0x%lx abnormal", logic_addr); | REPORT_INNER_ERROR("E19999", "Check param logic addr:0x%lx abnormal", logic_addr); | ||||
| GELOGE(PARAM_INVALID, "The logic addr:0x%lx is abnormal", logic_addr); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] The logic addr:0x%lx is abnormal", logic_addr); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -195,7 +195,8 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret);) | return RT_ERROR_TO_GE_STATUS(rt_ret);) | ||||
| SetIoAddrs(op_desc); | SetIoAddrs(op_desc); | ||||
| InitDumpTask(input_output_addr, op_desc); | |||||
| InitDumpFlag(op_desc); | |||||
| InitDumpArgs(input_output_addr, op_desc); | |||||
| GELOGI("KernelExTaskInfo knonw node Init Success."); | GELOGI("KernelExTaskInfo knonw node Init Success."); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -237,7 +238,8 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
| GELOGE(RT_FAILED, "[Call][RtMemcpy] failed, ret:0x%X, size:%lu", rt_ret, addrs_size); | GELOGE(RT_FAILED, "[Call][RtMemcpy] failed, ret:0x%X, size:%lu", rt_ret, addrs_size); | ||||
| return RT_ERROR_TO_GE_STATUS(rt_ret);) | return RT_ERROR_TO_GE_STATUS(rt_ret);) | ||||
| InitDumpTask(input_output_addr_, op_desc); | |||||
| InitDumpFlag(op_desc); | |||||
| InitDumpArgs(input_output_addr_, op_desc); | |||||
| } | } | ||||
| uint64_t input_output_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(input_output_addr_)); | uint64_t input_output_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(input_output_addr_)); | ||||
| @@ -269,10 +271,16 @@ Status KernelExTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davin | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void KernelExTaskInfo::InitDumpTask(void *addr, const OpDescPtr &op_desc) { | |||||
| if (davinci_model_->OpNeedDump(op_desc->GetName()) || davinci_model_->GetOpDugReg()) { | |||||
| GELOGD("Op %s need dump in kernel ex task info", op_desc->GetName().c_str()); | |||||
| void KernelExTaskInfo::InitDumpFlag(const OpDescPtr &op_desc) { | |||||
| if (davinci_model_->OpNeedDump(op_desc->GetName())) { | |||||
| GELOGD("Op %s need init dump flag in kernel ex task info", op_desc->GetName().c_str()); | |||||
| dump_flag_ = RT_KERNEL_DUMPFLAG; | dump_flag_ = RT_KERNEL_DUMPFLAG; | ||||
| } | |||||
| } | |||||
| void KernelExTaskInfo::InitDumpArgs(void *addr, const OpDescPtr &op_desc) { | |||||
| if (davinci_model_->OpNeedDump(op_desc->GetName())) { | |||||
| GELOGD("Op %s need dump in kernel ex task info", op_desc->GetName().c_str()); | |||||
| dump_args_ = addr; | dump_args_ = addr; | ||||
| } | } | ||||
| if (davinci_model_->GetOpDugReg()) { | if (davinci_model_->GetOpDugReg()) { | ||||
| @@ -61,7 +61,8 @@ class KernelExTaskInfo : public TaskInfo { | |||||
| Status CopyTaskInfo(const domi::KernelExDef &kernel_def, const RuntimeParam &rts_param, const OpDescPtr &op_desc); | Status CopyTaskInfo(const domi::KernelExDef &kernel_def, const RuntimeParam &rts_param, const OpDescPtr &op_desc); | ||||
| void SetIoAddrs(const OpDescPtr &op_desc); | void SetIoAddrs(const OpDescPtr &op_desc); | ||||
| void InitDumpTask(void *addr, const OpDescPtr &op_desc); | |||||
| void InitDumpFlag(const OpDescPtr &op_desc); | |||||
| void InitDumpArgs(void *addr, const OpDescPtr &op_desc); | |||||
| Status InitTaskExtInfo(const std::string &ext_info, const OpDescPtr &op_desc); | Status InitTaskExtInfo(const std::string &ext_info, const OpDescPtr &op_desc); | ||||
| uint32_t task_id_; | uint32_t task_id_; | ||||
| @@ -129,6 +129,7 @@ Status KernelTaskInfo::Init(const domi::TaskDef &task_def, DavinciModel *davinci | |||||
| ctx_.opIndex2[i] = context.origin_op_index(i); | ctx_.opIndex2[i] = context.origin_op_index(i); | ||||
| } | } | ||||
| ctx_.opCount = context.origin_op_index_size(); | ctx_.opCount = context.origin_op_index_size(); | ||||
| InitDumpFlag(); | |||||
| if (kernel_type_ == ccKernelType::TE) { | if (kernel_type_ == ccKernelType::TE) { | ||||
| ctx_.opIndex = context.op_index(); | ctx_.opIndex = context.op_index(); | ||||
| uint16_t *args_offset_tmp = reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data())); | uint16_t *args_offset_tmp = reinterpret_cast<uint16_t *>(const_cast<char *>(context.args_offset().data())); | ||||
| @@ -660,7 +661,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne | |||||
| if (davinci_model_->IsKnownNode()) { | if (davinci_model_->IsKnownNode()) { | ||||
| args_ = l2_buffer_on_ ? davinci_model_->GetCurrentHybridArgsAddr(hybrid_args_offset_) | args_ = l2_buffer_on_ ? davinci_model_->GetCurrentHybridArgsAddr(hybrid_args_offset_) | ||||
| : davinci_model_->GetCurrentArgsAddr(args_offset_); | : davinci_model_->GetCurrentArgsAddr(args_offset_); | ||||
| InitDumpTask(offset); | |||||
| InitDumpArgs(offset); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -726,7 +727,7 @@ Status KernelTaskInfo::InitTVMTask(uint16_t offset, const domi::KernelDef &kerne | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| skt_dump_args_ = static_cast<char *>(args_) + offset; | skt_dump_args_ = static_cast<char *>(args_) + offset; | ||||
| InitDumpTask(offset); | |||||
| InitDumpArgs(offset); | |||||
| vector<void *> virtual_io_addrs; // use virtual address for zero copy key. | vector<void *> virtual_io_addrs; // use virtual address for zero copy key. | ||||
| virtual_io_addrs.insert(virtual_io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); | virtual_io_addrs.insert(virtual_io_addrs.end(), input_data_addrs.begin(), input_data_addrs.end()); | ||||
| @@ -1022,7 +1023,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
| if (davinci_model_->IsKnownNode()) { | if (davinci_model_->IsKnownNode()) { | ||||
| args_ = davinci_model_->GetCurrentHybridArgsAddr(hybrid_args_offset_); | args_ = davinci_model_->GetCurrentHybridArgsAddr(hybrid_args_offset_); | ||||
| InitDumpTask(sizeof(aicpu::AicpuParamHead)); | |||||
| InitDumpArgs(sizeof(aicpu::AicpuParamHead)); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); | const RuntimeParam &rts_param = davinci_model_->GetRuntimeParam(); | ||||
| @@ -1063,7 +1064,7 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), args_size_, rt_ret); | op_desc->GetName().c_str(), op_desc->GetType().c_str(), args_size_, rt_ret); | ||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | return RT_ERROR_TO_GE_STATUS(rt_ret); | ||||
| } | } | ||||
| InitDumpTask(sizeof(aicpu::AicpuParamHead)); | |||||
| InitDumpArgs(sizeof(aicpu::AicpuParamHead)); | |||||
| if (kernel_type_ == ccKernelType::CUST_AI_CPU) { | if (kernel_type_ == ccKernelType::CUST_AI_CPU) { | ||||
| dump_flag_ |= RT_KERNEL_CUSTOM_AICPU; | dump_flag_ |= RT_KERNEL_CUSTOM_AICPU; | ||||
| @@ -1074,14 +1075,20 @@ Status KernelTaskInfo::InitAicpuTask(uint32_t op_index, const domi::KernelDef &k | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void KernelTaskInfo::InitDumpTask(uint32_t offset) { | |||||
| void KernelTaskInfo::InitDumpFlag() { | |||||
| if (davinci_model_->OpNeedDump(op_desc_->GetName())) { | if (davinci_model_->OpNeedDump(op_desc_->GetName())) { | ||||
| GELOGD("Op %s need dump in task info", op_desc_->GetName().c_str()); | |||||
| GELOGD("Op %s init dump flag", op_desc_->GetName().c_str()); | |||||
| if (IsL1FusionOp(op_desc_)) { | if (IsL1FusionOp(op_desc_)) { | ||||
| dump_flag_ = RT_FUSION_KERNEL_DUMPFLAG; | dump_flag_ = RT_FUSION_KERNEL_DUMPFLAG; | ||||
| } else { | } else { | ||||
| dump_flag_ = RT_KERNEL_DUMPFLAG; | dump_flag_ = RT_KERNEL_DUMPFLAG; | ||||
| } | } | ||||
| } | |||||
| } | |||||
| void KernelTaskInfo::InitDumpArgs(uint32_t offset) { | |||||
| if (davinci_model_->OpNeedDump(op_desc_->GetName())) { | |||||
| GELOGD("Op %s need dump in task info", op_desc_->GetName().c_str()); | |||||
| dump_args_ = static_cast<char *>(args_) + offset; | dump_args_ = static_cast<char *>(args_) + offset; | ||||
| } | } | ||||
| if (davinci_model_->GetOpDugReg()) { | if (davinci_model_->GetOpDugReg()) { | ||||
| @@ -128,7 +128,8 @@ class KernelTaskInfo : public TaskInfo { | |||||
| Status SuperKernelDistribute(); | Status SuperKernelDistribute(); | ||||
| bool IsL1FusionOp(const OpDescPtr &op_desc); | bool IsL1FusionOp(const OpDescPtr &op_desc); | ||||
| void SetIoAddrs(const OpDescPtr &op_desc); | void SetIoAddrs(const OpDescPtr &op_desc); | ||||
| void InitDumpTask(uint32_t offset); | |||||
| void InitDumpFlag(); | |||||
| void InitDumpArgs(uint32_t offset); | |||||
| void SetContinuousArgs(uint32_t args_size, DavinciModel *davinci_model); | void SetContinuousArgs(uint32_t args_size, DavinciModel *davinci_model); | ||||
| void SetNoncontinuousArgs(uint32_t args_size, DavinciModel *davinci_model); | void SetNoncontinuousArgs(uint32_t args_size, DavinciModel *davinci_model); | ||||
| Status CopyNoncontinuousArgs(uint16_t offset); | Status CopyNoncontinuousArgs(uint16_t offset); | ||||
| @@ -26,7 +26,7 @@ Status TaskInfo::SetStream(uint32_t stream_id, const std::vector<rtStream_t> &st | |||||
| stream_ = stream_list[stream_id]; | stream_ = stream_list[stream_id]; | ||||
| } else { | } else { | ||||
| REPORT_INNER_ERROR("E19999", "stream_id:%u >= stream_list.size(): %zu, check invalid", | REPORT_INNER_ERROR("E19999", "stream_id:%u >= stream_list.size(): %zu, check invalid", | ||||
| stream_id, stream_list.size()); | |||||
| stream_id, stream_list.size()); | |||||
| GELOGE(FAILED, "[Check][Param] index:%u >= stream_list.size():%zu.", stream_id, stream_list.size()); | GELOGE(FAILED, "[Check][Param] index:%u >= stream_list.size():%zu.", stream_id, stream_list.size()); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -18,6 +18,7 @@ | |||||
| #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_TASK_INFO_H_ | #define GE_GRAPH_LOAD_NEW_MODEL_MANAGER_TASK_INFO_TASK_INFO_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include <sstream> | |||||
| #include "cce/customize.h" | #include "cce/customize.h" | ||||
| #include "framework/common/taskdown_common.h" | #include "framework/common/taskdown_common.h" | ||||
| @@ -28,9 +29,11 @@ | |||||
| namespace ge { | namespace ge { | ||||
| struct MemInfo { | struct MemInfo { | ||||
| uint64_t memory_size = 0; | |||||
| size_t memory_size = 0; | |||||
| uint64_t logic_memory_base = 0; | uint64_t logic_memory_base = 0; | ||||
| uint8_t *memory_base = nullptr; | uint8_t *memory_base = nullptr; | ||||
| uint32_t memory_type = RT_MEMORY_HBM; | |||||
| std::string memory_key = ""; | |||||
| }; | }; | ||||
| struct RuntimeParam { | struct RuntimeParam { | ||||
| @@ -40,6 +43,19 @@ struct RuntimeParam { | |||||
| } | } | ||||
| ~RuntimeParam() = default; | ~RuntimeParam() = default; | ||||
| std::string ToString() { | |||||
| std::stringstream ss; | |||||
| ss << "session_id:" << session_id << ", stream_num:" << stream_num << ", event_num:" << event_num | |||||
| << ", label_num:" << label_num << ", logic_mem_base:" << logic_mem_base | |||||
| << ", logic_weight_base:" << logic_weight_base << ", logic_var_base:" << logic_var_base | |||||
| << ", memory_size:" << mem_size << ", weight_size:" << weight_size << ", var_size:" << var_size | |||||
| << ", ex_memory_info:"; | |||||
| for (auto it : memory_infos) { | |||||
| ss << "[memory_type:" << it.first << ", memory_size:" << it.second.memory_size << "]"; | |||||
| } | |||||
| return ss.str(); | |||||
| } | |||||
| uint64_t mem_size = 0; | uint64_t mem_size = 0; | ||||
| uint64_t logic_mem_base = 0; | uint64_t logic_mem_base = 0; | ||||
| uint8_t *mem_base = nullptr; | uint8_t *mem_base = nullptr; | ||||
| @@ -49,7 +65,7 @@ struct RuntimeParam { | |||||
| uint64_t var_size = 0; | uint64_t var_size = 0; | ||||
| uint64_t logic_var_base = 0; | uint64_t logic_var_base = 0; | ||||
| uint8_t *var_base = nullptr; | uint8_t *var_base = nullptr; | ||||
| std::map<uint32_t, MemInfo> memory_infos; | |||||
| std::map<uint64_t, MemInfo> memory_infos; | |||||
| uint32_t batch_num = 0; | uint32_t batch_num = 0; | ||||
| uint32_t stream_num = 0; | uint32_t stream_num = 0; | ||||
| uint32_t event_num = 0; | uint32_t event_num = 0; | ||||
| @@ -24,7 +24,7 @@ namespace ge { | |||||
| void TbeHandleInfo::used_inc(uint32_t num) { | void TbeHandleInfo::used_inc(uint32_t num) { | ||||
| if (used_ > std::numeric_limits<uint32_t>::max() - num) { | if (used_ > std::numeric_limits<uint32_t>::max() - num) { | ||||
| REPORT_INNER_ERROR("E19999", "Used:%u reach numeric max", used_); | REPORT_INNER_ERROR("E19999", "Used:%u reach numeric max", used_); | ||||
| GELOGE(INTERNAL_ERROR, "Used[%u] reach numeric max.", used_); | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] Used[%u] reach numeric max.", used_); | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -34,7 +34,7 @@ void TbeHandleInfo::used_inc(uint32_t num) { | |||||
| void TbeHandleInfo::used_dec(uint32_t num) { | void TbeHandleInfo::used_dec(uint32_t num) { | ||||
| if (used_ < std::numeric_limits<uint32_t>::min() + num) { | if (used_ < std::numeric_limits<uint32_t>::min() + num) { | ||||
| REPORT_INNER_ERROR("E19999", "Used:%u reach numeric min", used_); | REPORT_INNER_ERROR("E19999", "Used:%u reach numeric min", used_); | ||||
| GELOGE(INTERNAL_ERROR, "Used[%u] reach numeric min.", used_); | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] Used[%u] reach numeric min.", used_); | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -107,9 +107,8 @@ void TBEHandleStore::ReferTBEHandle(const std::string &name) { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | std::lock_guard<std::mutex> lock(mutex_); | ||||
| auto it = kernels_.find(name); | auto it = kernels_.find(name); | ||||
| if (it == kernels_.end()) { | if (it == kernels_.end()) { | ||||
| REPORT_INNER_ERROR("E19999", "Kernel:%s not found in stored check invalid", | |||||
| name.c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "Kernel[%s] not found in stored.", name.c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Kernel:%s not found in stored check invalid", name.c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] Kernel[%s] not found in stored.", name.c_str()); | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -128,9 +127,8 @@ void TBEHandleStore::EraseTBEHandle(const std::map<std::string, uint32_t> &names | |||||
| for (auto &item : names) { | for (auto &item : names) { | ||||
| auto it = kernels_.find(item.first); | auto it = kernels_.find(item.first); | ||||
| if (it == kernels_.end()) { | if (it == kernels_.end()) { | ||||
| REPORT_INNER_ERROR("E19999", "Kernel:%s not found in stored check invalid", | |||||
| item.first.c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "Kernel[%s] not found in stored.", item.first.c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Kernel:%s not found in stored check invalid", item.first.c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] Kernel[%s] not found in stored.", item.first.c_str()); | |||||
| continue; | continue; | ||||
| } | } | ||||
| @@ -142,7 +140,8 @@ void TBEHandleStore::EraseTBEHandle(const std::map<std::string, uint32_t> &names | |||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| REPORT_INNER_ERROR("E19999", "Call rtDevBinaryUnRegister failed for Kernel:%s fail, ret:0x%X", | REPORT_INNER_ERROR("E19999", "Call rtDevBinaryUnRegister failed for Kernel:%s fail, ret:0x%X", | ||||
| item.first.c_str(), rt_ret); | item.first.c_str(), rt_ret); | ||||
| GELOGE(INTERNAL_ERROR, "Kernel[%s] UnRegister handle fail:%u.", item.first.c_str(), rt_ret); | |||||
| GELOGE(INTERNAL_ERROR, "[Call][RtDevBinaryUnRegister] Kernel[%s] UnRegister handle fail:%u.", | |||||
| item.first.c_str(), rt_ret); | |||||
| } | } | ||||
| kernels_.erase(it); | kernels_.erase(it); | ||||
| } | } | ||||
| @@ -43,7 +43,7 @@ class TsMemMall { | |||||
| for (auto it : mem_store_size_) { | for (auto it : mem_store_size_) { | ||||
| rtError_t ret = rtFree(it.second); | rtError_t ret = rtFree(it.second); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Call rtFree failed, ret: 0x%X", ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtFree] failed, ret:0x%X", ret); | |||||
| } | } | ||||
| } | } | ||||
| mem_store_size_.clear(); | mem_store_size_.clear(); | ||||
| @@ -52,7 +52,7 @@ class TsMemMall { | |||||
| void *Acquire(int64_t offset, uint64_t size) { | void *Acquire(int64_t offset, uint64_t size) { | ||||
| if (size == 0) { | if (size == 0) { | ||||
| GELOGE(RT_FAILED, "Acquire mem block failed, size: %lu", size); | |||||
| GELOGE(RT_FAILED, "[Check][Param] Acquire mem block failed, size:%lu", size); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -71,7 +71,7 @@ class TsMemMall { | |||||
| void *addr = nullptr; | void *addr = nullptr; | ||||
| rtError_t rt_ret = rtMalloc(&addr, bytes, mem_type_); | rtError_t rt_ret = rtMalloc(&addr, bytes, mem_type_); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Call rtMalloc failed, ret: 0x%X", rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtMalloc] failed, size:%lu, ret:0x%X", bytes, rt_ret); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -94,7 +94,7 @@ class TsMemMall { | |||||
| mem_store_addr_.erase(it); | mem_store_addr_.erase(it); | ||||
| rtError_t ret = rtFree(addr); | rtError_t ret = rtFree(addr); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| GELOGE(RT_FAILED, "Call rtFree failed, ret: 0x%X", ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtFree] failed, ret:0x%X", ret); | |||||
| } | } | ||||
| } | } | ||||
| @@ -38,8 +38,13 @@ Status ZeroCopyOffset::InitInputDataInfo(int64_t output_size, void *virtual_addr | |||||
| op_name_ = op_desc->GetName(); | op_name_ = op_desc->GetName(); | ||||
| (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_); | (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_); | ||||
| (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_); | (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_); | ||||
| GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), return PARAM_INVALID, | |||||
| "basic_offset_size should be equal to relative_offset_size"); | |||||
| GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), | |||||
| REPORT_INNER_ERROR("E19999", "basic_offset_size:%zu not equal to relative_offset_size:%zu, " | |||||
| "check invalid", zero_copy_basic_offset_.size(), | |||||
| zero_copy_relative_offset_.size()); | |||||
| return PARAM_INVALID, | |||||
| "[Check][Param] basic_offset_size:%zu should be equal to relative_offset_size:%zu", | |||||
| zero_copy_basic_offset_.size(), zero_copy_relative_offset_.size()); | |||||
| GELOGD("[ZCPY] zero_copy_basic_offset size is %zu", zero_copy_basic_offset_.size()); | GELOGD("[ZCPY] zero_copy_basic_offset size is %zu", zero_copy_basic_offset_.size()); | ||||
| int64_t virtual_addr_offset = op_desc->GetOutputOffset().at(kDataIndex); | int64_t virtual_addr_offset = op_desc->GetOutputOffset().at(kDataIndex); | ||||
| @@ -78,7 +83,8 @@ Status ZeroCopyOffset::InitOutputDataInfo(const vector<int64_t> &input_size_list | |||||
| if (TensorUtils::GetTensorSizeInBytes(*tensor_desc, size) != GRAPH_SUCCESS) { | if (TensorUtils::GetTensorSizeInBytes(*tensor_desc, size) != GRAPH_SUCCESS) { | ||||
| REPORT_INNER_ERROR("E19999", "Get input TensorSize in op:%s(%s) failed, input_index:%zu", | REPORT_INNER_ERROR("E19999", "Get input TensorSize in op:%s(%s) failed, input_index:%zu", | ||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), idx); | op_desc->GetName().c_str(), op_desc->GetType().c_str(), idx); | ||||
| GELOGE(FAILED, "GetTensorSizeInBytes failed!"); | |||||
| GELOGE(FAILED, "[Get][InputTensorSize] in op:%s(%s) failed, input_index:%zu", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), idx); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -88,8 +94,13 @@ Status ZeroCopyOffset::InitOutputDataInfo(const vector<int64_t> &input_size_list | |||||
| op_name_ = op_desc->GetName(); | op_name_ = op_desc->GetName(); | ||||
| (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_); | (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_BASIC_OFFSET, zero_copy_basic_offset_); | ||||
| (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_); | (void)ge::AttrUtils::GetListInt(op_desc, ATTR_ZERO_COPY_RELATIVE_OFFSET, zero_copy_relative_offset_); | ||||
| GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), return PARAM_INVALID, | |||||
| "basic_offset_size should be equal to relative_offset_size"); | |||||
| GE_CHK_BOOL_EXEC(zero_copy_basic_offset_.size() == zero_copy_relative_offset_.size(), | |||||
| REPORT_INNER_ERROR("E19999", "basic_offset_size:%zu not equal to relative_offset_size:%zu, " | |||||
| "check invalid", | |||||
| zero_copy_basic_offset_.size(), zero_copy_relative_offset_.size()); | |||||
| return PARAM_INVALID, | |||||
| "[Check][Param] basic_offset_size:%zu should be equal to relative_offset_size:%zu", | |||||
| zero_copy_basic_offset_.size(), zero_copy_relative_offset_.size()); | |||||
| int64_t virtual_addr_offset = op_desc->GetInputOffset().at(idx); | int64_t virtual_addr_offset = op_desc->GetInputOffset().at(idx); | ||||
| IsL2Fusion(zero_copy_basic_offset_, virtual_addr_offset, fusion_flag); | IsL2Fusion(zero_copy_basic_offset_, virtual_addr_offset, fusion_flag); | ||||
| @@ -194,7 +205,8 @@ void ZeroCopyOffset::SetOutsideAddrsValue(ZeroCopyTask &zero_copy_task, void *ou | |||||
| for (uint32_t out_count = 0; out_count < GetAddrCount(); ++out_count) { | for (uint32_t out_count = 0; out_count < GetAddrCount(); ++out_count) { | ||||
| auto args_addrs = outside_addrs_[out_count].find(outside_addr); | auto args_addrs = outside_addrs_[out_count].find(outside_addr); | ||||
| if (args_addrs != outside_addrs_[out_count].end()) { | if (args_addrs != outside_addrs_[out_count].end()) { | ||||
| GE_CHK_STATUS(zero_copy_task.SetTaskArgsOffset(addr_val, offset), "Input args invalid."); | |||||
| GE_CHK_STATUS(zero_copy_task.SetTaskArgsOffset(addr_val, offset), | |||||
| "[Set][TaskArgsOffset] failed, Input args invalid, offset:%zu.", offset); | |||||
| void *args_val = static_cast<uint8_t *>(args) + offset; | void *args_val = static_cast<uint8_t *>(args) + offset; | ||||
| args_addrs->second.push_back(args_val); | args_addrs->second.push_back(args_val); | ||||
| GELOGD("[ZCPY] set copy input: virtual_addr: 0x%lx, task_addr: %p, args: %p, offset: %zu.", addr_val, args_val, | GELOGD("[ZCPY] set copy input: virtual_addr: 0x%lx, task_addr: %p, args: %p, offset: %zu.", addr_val, args_val, | ||||
| @@ -36,9 +36,9 @@ ZeroCopyTask::~ZeroCopyTask() { args_addr_ = nullptr; } | |||||
| */ | */ | ||||
| Status ZeroCopyTask::SetTaskArgsOffset(uintptr_t addr, size_t offset) { | Status ZeroCopyTask::SetTaskArgsOffset(uintptr_t addr, size_t offset) { | ||||
| if (offset + sizeof(uintptr_t) > args_size_) { | if (offset + sizeof(uintptr_t) > args_size_) { | ||||
| REPORT_INNER_ERROR("E19999", "Param offset:%zu + 8 > args_size_:%zu, check invalid", | |||||
| offset, args_size_); | |||||
| GELOGE(FAILED, "[ZCPY] %s set task args failed, args size: %zu, offset: %zu", name_.c_str(), args_size_, offset); | |||||
| REPORT_INNER_ERROR("E19999", "Param offset:%zu + 8 > args_size_:%zu, check invalid", offset, args_size_); | |||||
| GELOGE(FAILED, "[Check][Param] [ZCPY] %s set task args failed, args size:%zu, offset:%zu", | |||||
| name_.c_str(), args_size_, offset); | |||||
| return FAILED; // unexpected error, need fix. | return FAILED; // unexpected error, need fix. | ||||
| } | } | ||||
| @@ -118,9 +118,8 @@ Status ZeroCopyTask::DistributeParam(bool async_mode, rtStream_t stream) { | |||||
| } | } | ||||
| if (rt_err != RT_ERROR_NONE) { | if (rt_err != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtMemcpyAsync or rtMemcpy failed, size:%zu, ret: 0x%X", | |||||
| args_size_, rt_err); | |||||
| GELOGE(RT_FAILED, "[ZCPY] %s distribute task param failed, error=0x%x", name_.c_str(), rt_err); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMemcpyAsync or rtMemcpy failed, size:%zu, ret:0x%X", args_size_, rt_err); | |||||
| GELOGE(RT_FAILED, "[Distribute][TaskParam] for %s failed, error = 0x%x", name_.c_str(), rt_err); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_err); | return RT_ERROR_TO_GE_STATUS(rt_err); | ||||
| } | } | ||||
| @@ -21,7 +21,7 @@ | |||||
| #include <utility> | #include <utility> | ||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| #include "graph/manager/graph_mem_manager.h" | |||||
| namespace ge { | namespace ge { | ||||
| const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, | const size_t bin_ranges[kNumBins] = {kRoundBlockSize * kKByteSize, | ||||
| @@ -112,12 +112,12 @@ Status CachingAllocator::Initialize(uint32_t device_id) { | |||||
| auto bin_ptr = new (std::nothrow) BlockBin(BlockComparator); | auto bin_ptr = new (std::nothrow) BlockBin(BlockComparator); | ||||
| if (bin_ptr == nullptr) { | if (bin_ptr == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "New BlockBin fail, device_id:%u", device_id); | REPORT_CALL_ERROR("E19999", "New BlockBin fail, device_id:%u", device_id); | ||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc BlockBin failed."); | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "[Alloc][BlockBin] failed, device_id:%u", device_id); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | return ACL_ERROR_GE_MEMORY_ALLOCATION; | ||||
| } | } | ||||
| free_block_bins_[i] = bin_ptr; | free_block_bins_[i] = bin_ptr; | ||||
| } | } | ||||
| memory_allocator_ = MemManager::Instance(memory_type_); | |||||
| memory_allocator_ = &MemManager::Instance().MemInstance(memory_type_); | |||||
| if (memory_allocator_ == nullptr) { | if (memory_allocator_ == nullptr) { | ||||
| return ACL_ERROR_GE_INTERNAL_ERROR; | return ACL_ERROR_GE_INTERNAL_ERROR; | ||||
| } | } | ||||
| @@ -137,6 +137,7 @@ uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device | |||||
| uint8_t *ptr = nullptr; | uint8_t *ptr = nullptr; | ||||
| Block *block = FindFreeBlock(size, org_ptr, device_id); | Block *block = FindFreeBlock(size, org_ptr, device_id); | ||||
| if (block == nullptr) { | if (block == nullptr) { | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
| if (ge::SUCCESS == TryExtendCache(size, device_id)) { | if (ge::SUCCESS == TryExtendCache(size, device_id)) { | ||||
| block = FindFreeBlock(size, org_ptr, device_id); | block = FindFreeBlock(size, org_ptr, device_id); | ||||
| if (block != nullptr) { | if (block != nullptr) { | ||||
| @@ -147,9 +148,8 @@ uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device | |||||
| ptr = block->ptr; | ptr = block->ptr; | ||||
| } | } | ||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "FindFreeBlock fail, size:%zu, device_id:%u", | |||||
| size, device_id); | |||||
| GELOGE(FAILED, "Malloc failed device id = %u, size= %zu", device_id, size); | |||||
| REPORT_INNER_ERROR("E19999", "FindFreeBlock fail, size:%zu, device_id:%u", size, device_id); | |||||
| GELOGE(FAILED, "[Check][Param] FindFreeBlock failed device id = %u, size= %zu", device_id, size); | |||||
| } | } | ||||
| return ptr; | return ptr; | ||||
| } | } | ||||
| @@ -157,18 +157,16 @@ uint8_t *CachingAllocator::Malloc(size_t size, uint8_t *org_ptr, uint32_t device | |||||
| Status CachingAllocator::Free(uint8_t *ptr, uint32_t device_id) { | Status CachingAllocator::Free(uint8_t *ptr, uint32_t device_id) { | ||||
| GELOGI("Free device id = %u", device_id); | GELOGI("Free device id = %u", device_id); | ||||
| if (ptr == nullptr) { | if (ptr == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Param ptr is nullptr, device_id:%u, check invalid", | |||||
| device_id); | |||||
| GELOGE(PARAM_INVALID, "Invalid memory pointer"); | |||||
| REPORT_INNER_ERROR("E19999", "Param ptr is nullptr, device_id:%u, check invalid", device_id); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Invalid memory pointer, device_id:%u", device_id); | |||||
| return ge::PARAM_INVALID; | return ge::PARAM_INVALID; | ||||
| } | } | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| auto it = allocated_blocks_.find(ptr); | auto it = allocated_blocks_.find(ptr); | ||||
| if (it == allocated_blocks_.end()) { | if (it == allocated_blocks_.end()) { | ||||
| REPORT_INNER_ERROR("E19999", "Param ptr not allocated before, device_id:%u, check invalid", | |||||
| device_id); | |||||
| GELOGE(PARAM_INVALID, "Invalid memory pointer: %p", ptr); | |||||
| REPORT_INNER_ERROR("E19999", "Param ptr not allocated before, device_id:%u, check invalid", device_id); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Param ptr not allocated before, device_id:%u", device_id); | |||||
| return ge::PARAM_INVALID; | return ge::PARAM_INVALID; | ||||
| } | } | ||||
| Block *block = it->second; | Block *block = it->second; | ||||
| @@ -225,9 +223,8 @@ Block *CachingAllocator::FindFreeBlock(size_t size, uint8_t *org_ptr, uint32_t d | |||||
| Block key(device_id, size, org_ptr); | Block key(device_id, size, org_ptr); | ||||
| BlockBin *bin = GetBlockBin(size); | BlockBin *bin = GetBlockBin(size); | ||||
| if (bin == nullptr) { | if (bin == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "GetBlockBin fail, size:%zu, device_id:%u", | |||||
| size, device_id); | |||||
| GELOGE(ge::FAILED, "Get block bin failed size = %zu", size); | |||||
| REPORT_INNER_ERROR("E19999", "GetBlockBin fail, size:%zu, device_id:%u", size, device_id); | |||||
| GELOGE(ge::FAILED, "[Get][BlockBin] failed, size:%zu, device_id:%u", size, device_id); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| @@ -258,9 +255,8 @@ Block *CachingAllocator::SplitBlock(Block *block, size_t size, BlockBin &bin, ui | |||||
| Block *remaining = block; | Block *remaining = block; | ||||
| Block *new_block = new (std::nothrow) Block(device_id, size, &bin, block->ptr); | Block *new_block = new (std::nothrow) Block(device_id, size, &bin, block->ptr); | ||||
| if (new_block == nullptr) { | if (new_block == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "New Block fail, size:%zu, device_id:%u", | |||||
| size, device_id); | |||||
| GELOGE(ge::FAILED, "Alloc block failed size = %zu", size); | |||||
| REPORT_CALL_ERROR("E19999", "New Block fail, size:%zu, device_id:%u", size, device_id); | |||||
| GELOGE(ge::FAILED, "[Alloc][Block] failed, size:%zu, device_id:%u", size, device_id); | |||||
| return block; | return block; | ||||
| } | } | ||||
| new_block->prev = remaining->prev; | new_block->prev = remaining->prev; | ||||
| @@ -285,7 +281,7 @@ Status CachingAllocator::TryExtendCache(size_t size, uint32_t device_id) { | |||||
| size_t free_cached_memory_size = FreeCachedBlocks(); | size_t free_cached_memory_size = FreeCachedBlocks(); | ||||
| memory_addr = memory_allocator_->MallocMemory(purpose, memory_size, device_id); | memory_addr = memory_allocator_->MallocMemory(purpose, memory_size, device_id); | ||||
| if (memory_addr == nullptr) { | if (memory_addr == nullptr) { | ||||
| GELOGE(ge::FAILED, "TryExtendCache failed, no enough memory for size = %zu, device_id = %u", memory_size, | |||||
| GELOGE(ge::FAILED, "[Malloc][Memory] failed, no enough memory for size = %zu, device_id = %u", memory_size, | |||||
| device_id); | device_id); | ||||
| return ge::FAILED; | return ge::FAILED; | ||||
| } | } | ||||
| @@ -304,16 +300,14 @@ Status CachingAllocator::TryExtendCache(size_t size, uint32_t device_id) { | |||||
| Status CachingAllocator::AddToBlockBin(uint8_t *ptr, size_t size, uint32_t device_id) { | Status CachingAllocator::AddToBlockBin(uint8_t *ptr, size_t size, uint32_t device_id) { | ||||
| BlockBin *bin = GetBlockBin(size); | BlockBin *bin = GetBlockBin(size); | ||||
| if (bin == nullptr) { | if (bin == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "GetBlockBin fail, size:%zu, device_id:%u", | |||||
| size, device_id); | |||||
| GELOGE(ge::FAILED, "Get block bin failed size = %zu", size); | |||||
| REPORT_INNER_ERROR("E19999", "GetBlockBin fail, size:%zu, device_id:%u", size, device_id); | |||||
| GELOGE(ge::FAILED, "[Get][BlockBin] failed, size:%zu, device_id:%u", size, device_id); | |||||
| return ge::FAILED; | return ge::FAILED; | ||||
| } | } | ||||
| Block *block = new (std::nothrow) Block(device_id, size, bin, nullptr); | Block *block = new (std::nothrow) Block(device_id, size, bin, nullptr); | ||||
| if (block == nullptr) { | if (block == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "New Block fail, size:%zu, device_id:%u", | |||||
| size, device_id); | |||||
| GELOGE(ge::FAILED, "Alloc block failed size = %zu", size); | |||||
| REPORT_CALL_ERROR("E19999", "New Block fail, size:%zu, device_id:%u", size, device_id); | |||||
| GELOGE(ge::FAILED, "[Alloc][Block] failed, size:%zu, device_id:%u", size, device_id); | |||||
| return ge::FAILED; | return ge::FAILED; | ||||
| } | } | ||||
| @@ -88,8 +88,8 @@ class CachingAllocator { | |||||
| /// | /// | ||||
| /// @ingroup ge_graph | /// @ingroup ge_graph | ||||
| /// @brief free memory | /// @brief free memory | ||||
| /// @param [in] memory_ptr memory address ptr | |||||
| /// @param [in] device_id device id | /// @param [in] device_id device id | ||||
| /// @param [out] memory_ptr memory address ptr | |||||
| /// @return Status result of function | /// @return Status result of function | ||||
| /// | /// | ||||
| Status Free(uint8_t *memory_addr, uint32_t device_id = 0); | Status Free(uint8_t *memory_addr, uint32_t device_id = 0); | ||||
| @@ -33,7 +33,7 @@ GraphContext::GraphContext(const GraphNodePtr &graph_node) { | |||||
| if (compute_graph_ == nullptr) { | if (compute_graph_ == nullptr) { | ||||
| std::shared_ptr<const ge::Graph> graph = graph_node->GetGraph(); | std::shared_ptr<const ge::Graph> graph = graph_node->GetGraph(); | ||||
| if (graph == nullptr) { | if (graph == nullptr) { | ||||
| GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "compute_graph by graphNode is NULL!"); | |||||
| GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[Get][Graph] failed, compute_graph by graphNode is NULL!"); | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -45,7 +45,7 @@ GraphContext::GraphContext(const GraphNodePtr &graph_node) { | |||||
| Status GraphContext::SetComputeGraph(const GraphNodePtr &graph_node) { | Status GraphContext::SetComputeGraph(const GraphNodePtr &graph_node) { | ||||
| if (graph_node == nullptr) { | if (graph_node == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Param graph_node is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Param graph_node is nullptr, check invalid"); | ||||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "graphNode is NULL!"); | |||||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "[Check][Param] graphNode is NULL!"); | |||||
| return GE_GRAPH_PARAM_NULLPTR; | return GE_GRAPH_PARAM_NULLPTR; | ||||
| } | } | ||||
| @@ -56,7 +56,7 @@ Status GraphContext::SetComputeGraph(const GraphNodePtr &graph_node) { | |||||
| std::shared_ptr<const ge::Graph> graph = graph_node->GetGraph(); | std::shared_ptr<const ge::Graph> graph = graph_node->GetGraph(); | ||||
| if (graph == nullptr) { | if (graph == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Param graph in graph_node is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Param graph in graph_node is nullptr, check invalid"); | ||||
| GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "compute_graph by graphNode is NULL!"); | |||||
| GELOGE(GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL, "[Get][Graph] failed, compute_graph by graphNode is NULL!"); | |||||
| return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; | return GE_GRAPH_OPTIMIZE_COMPUTE_GRAPH_NULL; | ||||
| } | } | ||||
| @@ -73,14 +73,15 @@ Status GraphContext::Finalize() const { return SUCCESS; } | |||||
| Status GraphContext::GetVariableTensor(const std::string &var_data_name, GeTensor &returned_tensor) { | Status GraphContext::GetVariableTensor(const std::string &var_data_name, GeTensor &returned_tensor) { | ||||
| if (var_data_name.empty()) { | if (var_data_name.empty()) { | ||||
| REPORT_INNER_ERROR("E19999", "Param var_data_name is empty, check invalid"); | REPORT_INNER_ERROR("E19999", "Param var_data_name is empty, check invalid"); | ||||
| GELOGE(GE_GRAPH_EMPTY_STRING_NAME, "Variable data name is empty!"); | |||||
| GELOGE(GE_GRAPH_EMPTY_STRING_NAME, "[Check][Param] Variable data name is empty!"); | |||||
| return GE_GRAPH_EMPTY_STRING_NAME; | return GE_GRAPH_EMPTY_STRING_NAME; | ||||
| } | } | ||||
| if (GetVarNodeTensorTable().empty()) { | if (GetVarNodeTensorTable().empty()) { | ||||
| REPORT_INNER_ERROR("E19999", "VarNodeTensorTable is empty, var_data_name:%s, check invalid", | REPORT_INNER_ERROR("E19999", "VarNodeTensorTable is empty, var_data_name:%s, check invalid", | ||||
| var_data_name.c_str()); | var_data_name.c_str()); | ||||
| GELOGE(GE_GRAPH_EMPTY_VARIABLE_TENSOR_TABLE, "VarNodeTensorTable is empty!"); | |||||
| GELOGE(GE_GRAPH_EMPTY_VARIABLE_TENSOR_TABLE, "[Check][Param] VarNodeTensorTable is empty, var_data_name:%s", | |||||
| var_data_name.c_str()); | |||||
| return GE_GRAPH_EMPTY_VARIABLE_TENSOR_TABLE; | return GE_GRAPH_EMPTY_VARIABLE_TENSOR_TABLE; | ||||
| } | } | ||||
| for (auto &var_record : GetVarNodeTensorTable()) { | for (auto &var_record : GetVarNodeTensorTable()) { | ||||
| @@ -88,9 +89,8 @@ Status GraphContext::GetVariableTensor(const std::string &var_data_name, GeTenso | |||||
| returned_tensor.SetTensorDesc(var_record.second.GetTensorDesc()); | returned_tensor.SetTensorDesc(var_record.second.GetTensorDesc()); | ||||
| auto ret = returned_tensor.SetData(var_record.second.GetData()); | auto ret = returned_tensor.SetData(var_record.second.GetData()); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| REPORT_INNER_ERROR("E19999", "SetData to tensor fail, var_data_name:%s", | |||||
| var_data_name.c_str()); | |||||
| GELOGE(ret, "Set Tensor data failed!"); | |||||
| REPORT_INNER_ERROR("E19999", "SetData to tensor fail, var_data_name:%s", var_data_name.c_str()); | |||||
| GELOGE(ret, "[Set][Data] to Tensor failed, var_data_name:%s", var_data_name.c_str()); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -100,7 +100,8 @@ Status GraphContext::GetVariableTensor(const std::string &var_data_name, GeTenso | |||||
| REPORT_INNER_ERROR("E19999", "VarRecord with data_name:%s does not exist, check invalid", | REPORT_INNER_ERROR("E19999", "VarRecord with data_name:%s does not exist, check invalid", | ||||
| var_data_name.c_str()); | var_data_name.c_str()); | ||||
| GELOGE(GE_GRAPH_VARIABLE_DOES_NOT_EXIST, "VarRecord with data_name %s does NOT exist!", var_data_name.c_str()); | |||||
| GELOGE(GE_GRAPH_VARIABLE_DOES_NOT_EXIST, "[Check][Param] VarRecord with data_name %s does NOT exist!", | |||||
| var_data_name.c_str()); | |||||
| return GE_GRAPH_VARIABLE_DOES_NOT_EXIST; | return GE_GRAPH_VARIABLE_DOES_NOT_EXIST; | ||||
| } | } | ||||
| @@ -427,6 +427,8 @@ class GraphManager { | |||||
| void SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id); | void SetSessionGraphId(ComputeGraphPtr compute_graph, uint32_t graph_id); | ||||
| Status ModifyDataIndex(const Graph &graph, const std::map<std::string, std::string> &graph_option); | |||||
| static Status CheckGraphAdded(const GraphId &graph_id, const Graph &graph); | static Status CheckGraphAdded(const GraphId &graph_id, const Graph &graph); | ||||
| std::atomic_bool thread_run_flag_; | std::atomic_bool thread_run_flag_; | ||||
| @@ -46,7 +46,7 @@ GraphNode::GraphNode(GraphId graph_id) | |||||
| sem_(1) { | sem_(1) { | ||||
| graph_run_async_listener_ = MakeShared<RunAsyncListener>(); | graph_run_async_listener_ = MakeShared<RunAsyncListener>(); | ||||
| if (graph_run_async_listener_ == nullptr) { | if (graph_run_async_listener_ == nullptr) { | ||||
| GELOGE(MEMALLOC_FAILED, "Make shared failed"); | |||||
| GELOGE(MEMALLOC_FAILED, "[New][RunAsyncListener] failed"); | |||||
| } | } | ||||
| } | } | ||||
| @@ -82,7 +82,8 @@ SubGraphInfo::~SubGraphInfo() { | |||||
| rt_ret = rtFreeHost(buffer_addr); | rt_ret = rtFreeHost(buffer_addr); | ||||
| buffer_addr = nullptr; | buffer_addr = nullptr; | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| GELOGE(rt_ret, "[GraphManager] subgraph free buffer failed, modelId = %u", model_id_info_.model_id); | |||||
| GELOGE(rt_ret, "[Call][RtFreeHost] subgraph free buffer failed, modelId = %u", | |||||
| model_id_info_.model_id); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -94,8 +95,8 @@ Status SubGraphInfo::FreeInOutBuffer() { | |||||
| rtError_t rt_ret; | rtError_t rt_ret; | ||||
| rt_ret = rtFreeHost(*iter); | rt_ret = rtFreeHost(*iter); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtFreeHost fail"); | |||||
| GELOGE(rt_ret, "[GraphManager] subgraph free buffer failed, modelId = %u", model_id_info_.model_id); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtFreeHost fail, ret:%d", rt_ret); | |||||
| GELOGE(rt_ret, "[Call][RtFreeHost] subgraph free buffer failed, modelId = %u", model_id_info_.model_id); | |||||
| buffer_addr_.erase(buffer_addr_.begin(), iter); | buffer_addr_.erase(buffer_addr_.begin(), iter); | ||||
| return GE_GRAPH_FREE_FAILED; | return GE_GRAPH_FREE_FAILED; | ||||
| } | } | ||||
| @@ -131,7 +132,7 @@ Status GraphModelListener::OnComputeDone(uint32_t model_id, uint32_t task_id, ui | |||||
| uint32_t GraphModelListener::GetResultCode() const { | uint32_t GraphModelListener::GetResultCode() const { | ||||
| if (!is_finished_) { | if (!is_finished_) { | ||||
| REPORT_CALL_ERROR("E19999", "Model not run finish"); | REPORT_CALL_ERROR("E19999", "Model not run finish"); | ||||
| GELOGE(INTERNAL_ERROR, "[GraphManager] model not run finish."); | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] model not run finish."); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| return result_code_; | return result_code_; | ||||
| @@ -170,7 +171,9 @@ bool HasCalcOp(const ComputeGraphPtr &graph) { | |||||
| for (const auto &node : graph->GetAllNodes()) { | for (const auto &node : graph->GetAllNodes()) { | ||||
| OpDescPtr op_desc = node->GetOpDesc(); | OpDescPtr op_desc = node->GetOpDesc(); | ||||
| GE_IF_BOOL_EXEC(op_desc == nullptr, GELOGE(FAILED, "Node GetOpDesc is nullptr"); return false); | |||||
| GE_IF_BOOL_EXEC(op_desc == nullptr, | |||||
| REPORT_INNER_ERROR("E19999", "GetOpDesc failed, Node GetOpDesc is nullptr"); | |||||
| GELOGE(FAILED, "[Get][OpDesc] failed, Node GetOpDesc is nullptr"); return false); | |||||
| if (calc_op_type.find(op_desc->GetType()) != calc_op_type.end()) { | if (calc_op_type.find(op_desc->GetType()) != calc_op_type.end()) { | ||||
| return true; | return true; | ||||
| } | } | ||||
| @@ -17,11 +17,9 @@ | |||||
| #include "graph/manager/graph_mem_allocator.h" | #include "graph/manager/graph_mem_allocator.h" | ||||
| #include <string> | #include <string> | ||||
| #include "graph/manager/graph_caching_allocator.h" | |||||
| #include "graph/manager/rdma_pool_allocator.h" | |||||
| #include "graph/manager/host_mem_allocator.h" | |||||
| namespace ge { | namespace ge { | ||||
| void MemoryAllocator::Initialize(uint32_t device_id) { | |||||
| Status MemoryAllocator::Initialize(uint32_t device_id) { | |||||
| GELOGI("MemoryAllocator::Initialize"); | GELOGI("MemoryAllocator::Initialize"); | ||||
| // when redo Initialize free memory | // when redo Initialize free memory | ||||
| @@ -31,6 +29,7 @@ void MemoryAllocator::Initialize(uint32_t device_id) { | |||||
| } | } | ||||
| } | } | ||||
| memory_base_map_.clear(); | memory_base_map_.clear(); | ||||
| return SUCCESS; | |||||
| } | } | ||||
| void MemoryAllocator::Finalize(uint32_t device_id) { | void MemoryAllocator::Finalize(uint32_t device_id) { | ||||
| @@ -51,9 +50,7 @@ uint8_t *MemoryAllocator::MallocMemory(const string &purpose, size_t memory_size | |||||
| if (rtMalloc(reinterpret_cast<void **>(&memory_addr), memory_size, memory_type_) != RT_ERROR_NONE) { | if (rtMalloc(reinterpret_cast<void **>(&memory_addr), memory_size, memory_type_) != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, purpose:%s, size:%zu, device_id:%u", | REPORT_CALL_ERROR("E19999", "Call rtMalloc fail, purpose:%s, size:%zu, device_id:%u", | ||||
| purpose.c_str(), memory_size, device_id); | purpose.c_str(), memory_size, device_id); | ||||
| GELOGE(ge::INTERNAL_ERROR, | |||||
| "MemoryAllocator::MallocMemory device_id = %u," | |||||
| " size= %lu", | |||||
| GELOGE(ge::INTERNAL_ERROR, "[Malloc][Memory] failed, device_id = %u, size= %lu", | |||||
| device_id, memory_size); | device_id, memory_size); | ||||
| return nullptr; | return nullptr; | ||||
| @@ -69,7 +66,7 @@ Status MemoryAllocator::FreeMemory(uint8_t *memory_addr, uint32_t device_id) con | |||||
| auto rtRet = rtFree(memory_addr); | auto rtRet = rtFree(memory_addr); | ||||
| if (rtRet != RT_ERROR_NONE) { | if (rtRet != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtFree fail, device_id:%u", device_id); | REPORT_CALL_ERROR("E19999", "Call rtFree fail, device_id:%u", device_id); | ||||
| GELOGE(rtRet, "MemoryAllocator::MallocMemory device_id = %u", device_id); | |||||
| GELOGE(rtRet, "[Call][RtFree] failed, device_id = %u", device_id); | |||||
| return RT_ERROR_TO_GE_STATUS(rtRet); | return RT_ERROR_TO_GE_STATUS(rtRet); | ||||
| } | } | ||||
| memory_addr = nullptr; | memory_addr = nullptr; | ||||
| @@ -89,10 +86,8 @@ uint8_t *MemoryAllocator::MallocMemory(const string &purpose, const string &memo | |||||
| if (memory_addr == nullptr) { | if (memory_addr == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "Malloc Memory fail, purpose:%s, memory_key:%s, memory_size:%zu, device_id:%u", | REPORT_CALL_ERROR("E19999", "Malloc Memory fail, purpose:%s, memory_key:%s, memory_size:%zu, device_id:%u", | ||||
| purpose.c_str(), memory_key.c_str(), memory_size, device_id); | purpose.c_str(), memory_key.c_str(), memory_size, device_id); | ||||
| GELOGE(ge::INTERNAL_ERROR, | |||||
| "MemoryAllocator::MallocMemory failed," | |||||
| " memory_key[%s], size = %lu.", | |||||
| memory_key.c_str(), memory_size); | |||||
| GELOGE(ge::INTERNAL_ERROR, "[Malloc][Memory] failed, memory_key[%s], size = %lu, device_id:%u.", | |||||
| memory_key.c_str(), memory_size, device_id); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -127,10 +122,8 @@ Status MemoryAllocator::FreeMemory(const string &memory_key, uint32_t device_id) | |||||
| if (FreeMemory(it->second.memory_addr_, device_id) != ge::SUCCESS) { | if (FreeMemory(it->second.memory_addr_, device_id) != ge::SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "Free Memory fail, memory_key:%s, device_id:%u", | REPORT_CALL_ERROR("E19999", "Free Memory fail, memory_key:%s, device_id:%u", | ||||
| memory_key.c_str(), device_id); | memory_key.c_str(), device_id); | ||||
| GELOGE(ge::INTERNAL_ERROR, | |||||
| "MemoryAllocator::FreeMemory rtFree failed," | |||||
| " memory_key[%s]", | |||||
| memory_key.c_str()); | |||||
| GELOGE(ge::INTERNAL_ERROR, "[Free][Memory] failed, memory_key[%s], device_id:%u", | |||||
| memory_key.c_str(), device_id); | |||||
| return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
| } | } | ||||
| @@ -152,113 +145,4 @@ uint8_t *MemoryAllocator::GetMemoryAddr(const string &memory_key, uint32_t devic | |||||
| return it->second.memory_addr_; | return it->second.memory_addr_; | ||||
| } | } | ||||
| MemManager::MemManager() {} | |||||
| MemManager::~MemManager() { Finalize(); } | |||||
| MemManager &MemManager::Instance() { | |||||
| static MemManager mem_manager; | |||||
| return mem_manager; | |||||
| } | |||||
| MemoryAllocator *MemManager::Instance(rtMemType_t memory_type) { return Instance().GetMemoryAllocator(memory_type); } | |||||
| Status MemManager::Initialize(const std::vector<rtMemType_t> &memory_type) { | |||||
| std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
| MemoryAllocator *memory_allocator = nullptr; | |||||
| for (unsigned int index : memory_type) { | |||||
| auto it = memory_allocator_map_.find(index); | |||||
| if (it == memory_allocator_map_.end()) { | |||||
| memory_allocator = new (std::nothrow) MemoryAllocator(index); | |||||
| if (memory_allocator != nullptr) { | |||||
| memory_allocator_map_[index] = memory_allocator; | |||||
| GELOGI("Create MemoryAllocator memory type[%u] success.", index); | |||||
| } else { | |||||
| REPORT_CALL_ERROR("E19999", "New MemoryAllocator fail, index:%u", index); | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc MemoryAllocator failed."); | |||||
| } | |||||
| } else { | |||||
| memory_allocator = it->second; | |||||
| } | |||||
| if (memory_allocator == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Create MemoryAllocator failed."); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } else { | |||||
| memory_allocator->Initialize(0); | |||||
| } | |||||
| } | |||||
| auto ret = InitAllocator(memory_type, caching_allocator_map_); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Create CachingAllocator failed."); | |||||
| return ret; | |||||
| } | |||||
| ret = InitAllocator(memory_type, rdma_allocator_map_); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Create RdmaAllocator failed."); | |||||
| return ret; | |||||
| } | |||||
| ret = InitAllocator(memory_type, host_allocator_map_); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Create HostMemAllocator failed."); | |||||
| return ret; | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| template <typename T> | |||||
| void FinalizeAllocatorMap(std::map<rtMemType_t, T *> &allocate_map) { | |||||
| for (auto &allocator : allocate_map) { | |||||
| if (allocator.second != nullptr) { | |||||
| allocator.second->Finalize(); | |||||
| delete allocator.second; | |||||
| allocator.second = nullptr; | |||||
| } | |||||
| } | |||||
| allocate_map.clear(); | |||||
| } | |||||
| void MemManager::Finalize() noexcept { | |||||
| GELOGI("Finalize."); | |||||
| std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
| // caching and rdma allocator use memory allocator, so finalize them first | |||||
| FinalizeAllocatorMap(caching_allocator_map_); | |||||
| FinalizeAllocatorMap(rdma_allocator_map_); | |||||
| FinalizeAllocatorMap(host_allocator_map_); | |||||
| FinalizeAllocatorMap(memory_allocator_map_); | |||||
| } | |||||
| MemoryAllocator *MemManager::GetMemoryAllocator(rtMemType_t memory_type) { | |||||
| std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
| MemoryAllocator *memory_allocator = nullptr; | |||||
| auto it = memory_allocator_map_.find(memory_type); | |||||
| if (it != memory_allocator_map_.end()) { | |||||
| memory_allocator = it->second; | |||||
| } | |||||
| // Usually impossible | |||||
| if (memory_allocator == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_INTERNAL_ERROR, "GetMemoryAllocator failed, memory type is %u.", memory_type); | |||||
| static MemoryAllocator default_memory_allocator(RT_MEMORY_RESERVED); | |||||
| return &default_memory_allocator; | |||||
| } | |||||
| return memory_allocator; | |||||
| } | |||||
| CachingAllocator &MemManager::CachingInstance(rtMemType_t memory_type) { | |||||
| return Instance().GetAllocator(memory_type, caching_allocator_map_); | |||||
| } | |||||
| RdmaPoolAllocator &MemManager::RdmaPoolInstance(rtMemType_t memory_type) { | |||||
| return Instance().GetAllocator(memory_type, rdma_allocator_map_); | |||||
| } | |||||
| HostMemAllocator &MemManager::HostMemInstance(rtMemType_t memory_type) { | |||||
| return Instance().GetAllocator(memory_type, host_allocator_map_); | |||||
| } | |||||
| } // namespace ge | } // namespace ge | ||||
| @@ -26,7 +26,6 @@ | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "framework/common/ge_inner_error_codes.h" | #include "framework/common/ge_inner_error_codes.h" | ||||
| #include "graph/manager/host_mem_allocator.h" | |||||
| #include "graph/node.h" | #include "graph/node.h" | ||||
| #include "runtime/mem.h" | #include "runtime/mem.h" | ||||
| @@ -71,9 +70,9 @@ class MemoryAllocator { | |||||
| /// @ingroup ge_graph | /// @ingroup ge_graph | ||||
| /// @brief memory allocator init | /// @brief memory allocator init | ||||
| /// @param [in] options user config params | /// @param [in] options user config params | ||||
| /// @return void | |||||
| /// @return Status of init | |||||
| /// | /// | ||||
| void Initialize(uint32_t device_id = 0); | |||||
| Status Initialize(uint32_t device_id = 0); | |||||
| /// | /// | ||||
| /// @ingroup ge_graph | /// @ingroup ge_graph | ||||
| @@ -136,109 +135,6 @@ class MemoryAllocator { | |||||
| bool mem_malloced_; | bool mem_malloced_; | ||||
| map<string, MemoryInfo> memory_base_map_; | map<string, MemoryInfo> memory_base_map_; | ||||
| }; | }; | ||||
| using MemoryAllocatorPtr = std::shared_ptr<MemoryAllocator>; | |||||
| class CachingAllocator; | |||||
| class RdmaPoolAllocator; | |||||
| class MemManager { | |||||
| public: | |||||
| MemManager(); | |||||
| virtual ~MemManager(); | |||||
| static MemManager &Instance(); | |||||
| static MemoryAllocator *Instance(rtMemType_t memory_type); | |||||
| CachingAllocator &CachingInstance(rtMemType_t memory_type); | |||||
| RdmaPoolAllocator &RdmaPoolInstance(rtMemType_t memory_type); | |||||
| HostMemAllocator &HostMemInstance(rtMemType_t memory_type); | |||||
| MemManager(const MemManager &) = delete; | |||||
| MemManager &operator=(const MemManager &) = delete; | |||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @brief memory allocator manager init | |||||
| /// @param [in] options user config params | |||||
| /// @return Status result of function | |||||
| /// | |||||
| Status Initialize(const std::vector<rtMemType_t> &memory_type); | |||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @brief memory allocator finalize | |||||
| /// @return void | |||||
| /// | |||||
| void Finalize() noexcept; | |||||
| private: | |||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @brief ge memory allocator | |||||
| /// @param [in] memory_type memory type | |||||
| /// @return MemoryAllocator ptr | |||||
| /// | |||||
| MemoryAllocator *GetMemoryAllocator(rtMemType_t memory_type); | |||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @param [in] memory_type memory type | |||||
| /// @param [in] allocate_map memory allocator map | |||||
| /// @return Status result of function | |||||
| /// | |||||
| template <typename T> | |||||
| Status InitAllocator(const std::vector<rtMemType_t> &memory_type, std::map<rtMemType_t, T *> &allocate_map) { | |||||
| T *allocator = nullptr; | |||||
| for (unsigned int index : memory_type) { | |||||
| auto it = allocate_map.find(index); | |||||
| if (it == allocate_map.end()) { | |||||
| allocator = new (std::nothrow) T(index); | |||||
| if (allocator != nullptr) { | |||||
| allocate_map[index] = allocator; | |||||
| GELOGI("Create Allocator memory type[%u] success.", index); | |||||
| } else { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc Allocator failed."); | |||||
| } | |||||
| } else { | |||||
| allocator = it->second; | |||||
| } | |||||
| if (allocator == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Create Allocator failed."); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } else { | |||||
| if (allocator->Initialize() != SUCCESS) { | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @param [in] memory_type memory type | |||||
| /// @param [in] allocate_map memory allocator map | |||||
| /// @return Allocator ptr | |||||
| /// | |||||
| template <typename T> | |||||
| T &GetAllocator(rtMemType_t memory_type, std::map<rtMemType_t, T *> allocate_map) { | |||||
| std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
| T *allocator = nullptr; | |||||
| auto it = allocate_map.find(memory_type); | |||||
| if (it != allocate_map.end()) { | |||||
| allocator = it->second; | |||||
| } | |||||
| // Usually impossible | |||||
| if (allocator == nullptr) { | |||||
| GELOGW("Get allocator failed, memory type is %u.", memory_type); | |||||
| static T default_allocator(RT_MEMORY_RESERVED); | |||||
| return default_allocator; | |||||
| } | |||||
| return *allocator; | |||||
| } | |||||
| std::map<rtMemType_t, MemoryAllocator *> memory_allocator_map_; | |||||
| std::map<rtMemType_t, CachingAllocator *> caching_allocator_map_; | |||||
| std::map<rtMemType_t, RdmaPoolAllocator *> rdma_allocator_map_; | |||||
| std::map<rtMemType_t, HostMemAllocator *> host_allocator_map_; | |||||
| std::recursive_mutex allocator_mutex_; | |||||
| }; | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_MANAGER_GRAPH_MEM_ALLOCATOR_H_ | #endif // GE_GRAPH_MANAGER_GRAPH_MEM_ALLOCATOR_H_ | ||||
| @@ -0,0 +1,116 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/manager/graph_mem_manager.h" | |||||
| #include <string> | |||||
| namespace ge { | |||||
| MemManager::MemManager() {} | |||||
| MemManager::~MemManager() { Finalize(); } | |||||
| MemManager &MemManager::Instance() { | |||||
| static MemManager mem_manager; | |||||
| return mem_manager; | |||||
| } | |||||
| Status MemManager::Initialize(const std::vector<rtMemType_t> &memory_type) { | |||||
| std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
| if (init_) { | |||||
| GELOGW("MemManager has been inited."); | |||||
| return SUCCESS; | |||||
| } | |||||
| auto ret = InitAllocator(memory_type, memory_allocator_map_); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Create MemoryAllocator failed."); | |||||
| return ret; | |||||
| } | |||||
| ret = InitAllocator(memory_type, caching_allocator_map_); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Create CachingAllocator failed."); | |||||
| return ret; | |||||
| } | |||||
| ret = InitAllocator(memory_type, rdma_allocator_map_); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Create RdmaAllocator failed."); | |||||
| return ret; | |||||
| } | |||||
| ret = InitAllocator(memory_type, host_allocator_map_); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Create HostMemAllocator failed."); | |||||
| return ret; | |||||
| } | |||||
| ret = InitAllocator(memory_type, session_scope_allocator_map_); | |||||
| if (ret != SUCCESS) { | |||||
| GELOGE(ret, "Create HostMemAllocator failed."); | |||||
| return ret; | |||||
| } | |||||
| init_ = true; | |||||
| memory_type_ = memory_type; | |||||
| return SUCCESS; | |||||
| } | |||||
| template <typename T> | |||||
| void FinalizeAllocatorMap(std::map<rtMemType_t, T *> &allocate_map) { | |||||
| for (auto &allocator : allocate_map) { | |||||
| if (allocator.second != nullptr) { | |||||
| allocator.second->Finalize(); | |||||
| delete allocator.second; | |||||
| allocator.second = nullptr; | |||||
| } | |||||
| } | |||||
| allocate_map.clear(); | |||||
| } | |||||
| void MemManager::Finalize() noexcept { | |||||
| GELOGI("Finalize."); | |||||
| std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
| // caching and rdma allocator use memory allocator, so finalize them first | |||||
| FinalizeAllocatorMap(session_scope_allocator_map_); | |||||
| FinalizeAllocatorMap(caching_allocator_map_); | |||||
| FinalizeAllocatorMap(rdma_allocator_map_); | |||||
| FinalizeAllocatorMap(host_allocator_map_); | |||||
| FinalizeAllocatorMap(memory_allocator_map_); | |||||
| init_ = false; | |||||
| memory_type_.clear(); | |||||
| } | |||||
| MemoryAllocator &MemManager::MemInstance(rtMemType_t memory_type) { | |||||
| return GetAllocator(memory_type, memory_allocator_map_); | |||||
| } | |||||
| CachingAllocator &MemManager::CachingInstance(rtMemType_t memory_type) { | |||||
| return GetAllocator(memory_type, caching_allocator_map_); | |||||
| } | |||||
| RdmaPoolAllocator &MemManager::RdmaPoolInstance(rtMemType_t memory_type) { | |||||
| return GetAllocator(memory_type, rdma_allocator_map_); | |||||
| } | |||||
| HostMemAllocator &MemManager::HostMemInstance(rtMemType_t memory_type) { | |||||
| return GetAllocator(memory_type, host_allocator_map_); | |||||
| } | |||||
| SessionScopeMemAllocator &MemManager::SessionScopeMemInstance(rtMemType_t memory_type) { | |||||
| return GetAllocator(memory_type, session_scope_allocator_map_); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,141 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_MANAGER_GRAPH_MEM_MANAGER_H_ | |||||
| #define GE_GRAPH_MANAGER_GRAPH_MEM_MANAGER_H_ | |||||
| #include <iostream> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <mutex> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| #include "graph/manager/graph_caching_allocator.h" | |||||
| #include "graph/manager/host_mem_allocator.h" | |||||
| #include "graph/manager/rdma_pool_allocator.h" | |||||
| #include "graph/manager/host_mem_allocator.h" | |||||
| #include "graph/manager/session_scope_mem_allocator.h" | |||||
| #include "graph/node.h" | |||||
| #include "runtime/mem.h" | |||||
| namespace ge { | |||||
| using MemoryAllocatorPtr = std::shared_ptr<MemoryAllocator>; | |||||
| class MemManager { | |||||
| public: | |||||
| MemManager(); | |||||
| virtual ~MemManager(); | |||||
| static MemManager &Instance(); | |||||
| MemoryAllocator &MemInstance(rtMemType_t memory_type); | |||||
| CachingAllocator &CachingInstance(rtMemType_t memory_type); | |||||
| RdmaPoolAllocator &RdmaPoolInstance(rtMemType_t memory_type); | |||||
| HostMemAllocator &HostMemInstance(rtMemType_t memory_type); | |||||
| SessionScopeMemAllocator &SessionScopeMemInstance(rtMemType_t memory_type); | |||||
| MemManager(const MemManager &) = delete; | |||||
| MemManager &operator=(const MemManager &) = delete; | |||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @brief memory allocator manager init | |||||
| /// @param [in] options user config params | |||||
| /// @return Status result of function | |||||
| /// | |||||
| Status Initialize(const std::vector<rtMemType_t> &memory_type); | |||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @brief memory allocator finalize | |||||
| /// @return void | |||||
| /// | |||||
| void Finalize() noexcept; | |||||
| const std::vector<rtMemType_t> &GetAllMemoryType() const { return memory_type_; } | |||||
| private: | |||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @param [in] memory_type memory type | |||||
| /// @param [in] allocate_map memory allocator map | |||||
| /// @return Status result of function | |||||
| /// | |||||
| template <typename T> | |||||
| Status InitAllocator(const std::vector<rtMemType_t> &memory_type, std::map<rtMemType_t, T *> &allocate_map) { | |||||
| T *allocator = nullptr; | |||||
| for (unsigned int index : memory_type) { | |||||
| auto it = allocate_map.find(index); | |||||
| if (it == allocate_map.end()) { | |||||
| allocator = new (std::nothrow) T(index); | |||||
| if (allocator != nullptr) { | |||||
| allocate_map[index] = allocator; | |||||
| GELOGI("Create Allocator memory type[%u] success.", index); | |||||
| } else { | |||||
| REPORT_CALL_ERROR("E19999", "New MemoryAllocator fail, index:%u", index); | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Alloc Allocator failed."); | |||||
| } | |||||
| } else { | |||||
| allocator = it->second; | |||||
| } | |||||
| if (allocator == nullptr) { | |||||
| GELOGE(ACL_ERROR_GE_MEMORY_ALLOCATION, "Create Allocator failed."); | |||||
| return ACL_ERROR_GE_MEMORY_ALLOCATION; | |||||
| } else { | |||||
| if (allocator->Initialize() != SUCCESS) { | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR; | |||||
| } | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @param [in] memory_type memory type | |||||
| /// @param [in] allocate_map memory allocator map | |||||
| /// @return Allocator ptr | |||||
| /// | |||||
| template <typename T> | |||||
| T &GetAllocator(rtMemType_t memory_type, std::map<rtMemType_t, T *> allocate_map) { | |||||
| std::lock_guard<std::recursive_mutex> lock(allocator_mutex_); | |||||
| T *allocator = nullptr; | |||||
| auto it = allocate_map.find(memory_type); | |||||
| if (it != allocate_map.end()) { | |||||
| allocator = it->second; | |||||
| } | |||||
| // Usually impossible | |||||
| if (allocator == nullptr) { | |||||
| GELOGW("Get allocator failed, memory type is %u.", memory_type); | |||||
| static T default_allocator(RT_MEMORY_RESERVED); | |||||
| return default_allocator; | |||||
| } | |||||
| return *allocator; | |||||
| } | |||||
| std::map<rtMemType_t, MemoryAllocator *> memory_allocator_map_; | |||||
| std::map<rtMemType_t, CachingAllocator *> caching_allocator_map_; | |||||
| std::map<rtMemType_t, RdmaPoolAllocator *> rdma_allocator_map_; | |||||
| std::map<rtMemType_t, HostMemAllocator *> host_allocator_map_; | |||||
| std::map<rtMemType_t, SessionScopeMemAllocator *> session_scope_allocator_map_; | |||||
| std::recursive_mutex allocator_mutex_; | |||||
| std::vector<rtMemType_t> memory_type_; | |||||
| bool init_ = false; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_MANAGER_GRAPH_MEM_ALLOCATOR_H_ | |||||
| @@ -17,8 +17,7 @@ | |||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| #include "graph/manager/rdma_pool_allocator.h" | |||||
| #include "graph/manager/graph_mem_manager.h" | |||||
| #include "graph/manager/trans_var_data_utils.h" | #include "graph/manager/trans_var_data_utils.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| @@ -41,7 +40,8 @@ ge::Status VarResource::GetVarAddr(const std::string &var_name, const ge::GeTens | |||||
| if (dev_ptr == nullptr) { | if (dev_ptr == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Param dev_ptr is nullptr, var_name:%s, session_id:%lu, " | REPORT_INNER_ERROR("E19999", "Param dev_ptr is nullptr, var_name:%s, session_id:%lu, " | ||||
| "check invalid", var_name.c_str(), session_id_); | "check invalid", var_name.c_str(), session_id_); | ||||
| GELOGE(FAILED, "[GetVarAddr] dev_ptr is null!"); | |||||
| GELOGE(FAILED, "[Check][Param] Param dev_ptr is nullptr, var_name:%s, session_id:%lu", | |||||
| var_name.c_str(), session_id_); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| std::string var_key = VarKey(var_name, tensor_desc); | std::string var_key = VarKey(var_name, tensor_desc); | ||||
| @@ -52,7 +52,8 @@ ge::Status VarResource::GetVarAddr(const std::string &var_name, const ge::GeTens | |||||
| REPORT_INNER_ERROR("E19999", "var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, " | REPORT_INNER_ERROR("E19999", "var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, " | ||||
| "check invalid", var_key.c_str(), var_name.c_str(), | "check invalid", var_key.c_str(), var_name.c_str(), | ||||
| session_id_); | session_id_); | ||||
| GELOGE(FAILED, "VarResource::GetVarAddr failed, var_key %s", var_key.c_str()); | |||||
| GELOGE(FAILED, "[Check][Param] var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu", | |||||
| var_key.c_str(), var_name.c_str(), session_id_); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -110,7 +111,8 @@ ge::Status VarResource::SaveVarAddr(const std::string &var_name, const ge::GeTen | |||||
| REPORT_INNER_ERROR("E19999", "var_key:%s conflict in var_addr_mgr_map_, var_name:%s, session_id:%lu, " | REPORT_INNER_ERROR("E19999", "var_key:%s conflict in var_addr_mgr_map_, var_name:%s, session_id:%lu, " | ||||
| "check invalid", var_key.c_str(), var_name.c_str(), | "check invalid", var_key.c_str(), var_name.c_str(), | ||||
| session_id_); | session_id_); | ||||
| GELOGE(FAILED, "VarResource::SaveVarAddr, var_key %s save addr conflict", var_key.c_str()); | |||||
| GELOGE(FAILED, "[Check][Param] var_key:%s conflict in var_addr_mgr_map_, var_name:%s, session_id:%lu", | |||||
| var_key.c_str(), var_name.c_str(), session_id_); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -146,14 +148,15 @@ ge::Status VarResource::RenewCurVarDesc(const std::string &var_name, const ge::O | |||||
| if (op_desc == nullptr) { | if (op_desc == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Param op_desc is nullptr, var_name:%s, session_id:%lu, check invalid", | REPORT_INNER_ERROR("E19999", "Param op_desc is nullptr, var_name:%s, session_id:%lu, check invalid", | ||||
| var_name.c_str(), session_id_); | var_name.c_str(), session_id_); | ||||
| GELOGE(FAILED, "[RenewCurVarDesc] renew var desc fail! input opdesc is null!"); | |||||
| GELOGE(FAILED, "[Check][Param] input opdesc is nullptr, var_name:%s, session_id:%lu", | |||||
| var_name.c_str(), session_id_); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| ge::GeTensorDesc curr_desc; | ge::GeTensorDesc curr_desc; | ||||
| ge::Status ret = GetCurVarDesc(var_name, curr_desc); | ge::Status ret = GetCurVarDesc(var_name, curr_desc); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(FAILED, "[RenewCurVarDesc] Get var desc fail!"); | |||||
| GELOGE(FAILED, "[Get][CurVarDesc] fail, var_name:%s, session_id:%lu", var_name.c_str(), session_id_); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| std::string key = VarKey(var_name, curr_desc); | std::string key = VarKey(var_name, curr_desc); | ||||
| @@ -165,7 +168,8 @@ ge::Status VarResource::RenewCurVarDesc(const std::string &var_name, const ge::O | |||||
| REPORT_INNER_ERROR("E19999", "var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, op:%s(%s), " | REPORT_INNER_ERROR("E19999", "var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, op:%s(%s), " | ||||
| "check invalid", key.c_str(), var_name.c_str(), | "check invalid", key.c_str(), var_name.c_str(), | ||||
| session_id_, op_desc->GetName().c_str(), op_desc->GetType().c_str()); | session_id_, op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| GELOGE(FAILED, "[RenewCurVarDesc] can't find ele with key [%s]", key.c_str()); | |||||
| GELOGE(FAILED, "[Check][Param] var_key:%s can't find in var_addr_mgr_map_, var_name:%s, session_id:%lu, op:%s(%s)", | |||||
| key.c_str(), var_name.c_str(), session_id_, op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| auto val = iter->second; | auto val = iter->second; | ||||
| @@ -286,14 +290,15 @@ Status HbmMemResource::AssignVarMem(const std::string &var_name, uint64_t size, | |||||
| if (total_size_ < var_mem_size_) { | if (total_size_ < var_mem_size_) { | ||||
| REPORT_INNER_ERROR("E19999", "VarMemMaxSize:%lu < var_mem_size_:%lu, var_size:%lu, var_name:%s, check invalid" | REPORT_INNER_ERROR("E19999", "VarMemMaxSize:%lu < var_mem_size_:%lu, var_size:%lu, var_name:%s, check invalid" | ||||
| "", total_size_, var_mem_size_, size, var_name.c_str()); | "", total_size_, var_mem_size_, size, var_name.c_str()); | ||||
| GELOGE(PARAM_INVALID, "total_size_: %lu is smaller than var_mem_size_: %lu", total_size_, var_mem_size_); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] total_size_:%lu is smaller than var_mem_size_:%lu, var_name:%s", | |||||
| total_size_, var_mem_size_, var_name.c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| uint64_t free_size = total_size_ - var_mem_size_; | uint64_t free_size = total_size_ - var_mem_size_; | ||||
| if (free_size < (size + kSessionMemAlignSize * kSessionMemAlignUnit)) { | if (free_size < (size + kSessionMemAlignSize * kSessionMemAlignUnit)) { | ||||
| REPORT_INNER_ERROR("E19999", "free_size:%lu not enough, var_align_size:%lu, var_name:%s, check invalid", | REPORT_INNER_ERROR("E19999", "free_size:%lu not enough, var_align_size:%lu, var_name:%s, check invalid", | ||||
| free_size, size, var_name.c_str()); | free_size, size, var_name.c_str()); | ||||
| GELOGE(PARAM_INVALID, "Out of memory : current var size[%lu] exceeds total var size[%lu]", | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Out of memory: current var size[%lu] exceeds total var size[%lu]", | |||||
| size + kSessionMemAlignSize * kSessionMemAlignUnit + var_mem_size_, total_size_); | size + kSessionMemAlignSize * kSessionMemAlignUnit + var_mem_size_, total_size_); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -318,7 +323,7 @@ Status RdmaMemResource::AssignVarMem(const std::string &var_name, uint64_t size, | |||||
| if (buffer == nullptr) { | if (buffer == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "malloc rdma memory fail, var_size:%lu, var_name:%s", | REPORT_CALL_ERROR("E19999", "malloc rdma memory fail, var_size:%lu, var_name:%s", | ||||
| size, var_name.c_str()); | size, var_name.c_str()); | ||||
| GELOGE(MEMALLOC_FAILED, "Failed to malloc rdma memory for node %s, size = %lu", var_name.c_str(), size); | |||||
| GELOGE(MEMALLOC_FAILED, "[Malloc][RdmaMemory] for node %s failed, size = %lu", var_name.c_str(), size); | |||||
| return MEMALLOC_FAILED; | return MEMALLOC_FAILED; | ||||
| } | } | ||||
| address = static_cast<size_t>(reinterpret_cast<uintptr_t>(buffer)); | address = static_cast<size_t>(reinterpret_cast<uintptr_t>(buffer)); | ||||
| @@ -469,7 +474,8 @@ int64_t VarManager::GetVarMemSize(rtMemType_t memory_type) { | |||||
| if (mem_resource == nullptr) { | if (mem_resource == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Find no mem_resource in map, memory_type:%d, session_id:%lu", | REPORT_INNER_ERROR("E19999", "Find no mem_resource in map, memory_type:%d, session_id:%lu", | ||||
| memory_type, session_id_); | memory_type, session_id_); | ||||
| GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid."); | |||||
| GELOGE(ge::INTERNAL_ERROR, "[Check][Param] MemResource is invalid, memory_type:%d, session_id:%lu", | |||||
| memory_type, session_id_); | |||||
| return 0; | return 0; | ||||
| } | } | ||||
| return mem_resource->GetVarMemSize(); | return mem_resource->GetVarMemSize(); | ||||
| @@ -484,7 +490,8 @@ Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) { | |||||
| if (mem_resource == nullptr) { | if (mem_resource == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu", | REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu", | ||||
| memory_type, session_id_); | memory_type, session_id_); | ||||
| GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type); | |||||
| GELOGE(ge::INTERNAL_ERROR, "[Alloc][MemResource] failed, memory_type:%u, session_id:%lu", | |||||
| memory_type, session_id_); | |||||
| return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
| } else { | } else { | ||||
| mem_resource_map_[memory_type] = mem_resource; | mem_resource_map_[memory_type] = mem_resource; | ||||
| @@ -496,7 +503,8 @@ Status VarManager::UpdateVarMemSize(rtMemType_t memory_type, int64_t mem_size) { | |||||
| if (mem_resource == nullptr) { | if (mem_resource == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu", | REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu", | ||||
| memory_type, session_id_); | memory_type, session_id_); | ||||
| GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid."); | |||||
| GELOGE(ge::INTERNAL_ERROR, "[Check][Param] MemResource is invalid, memory_type:%u, session_id:%lu", | |||||
| memory_type, session_id_); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| mem_resource->UpdateVarMemSize(mem_size); | mem_resource->UpdateVarMemSize(mem_size); | ||||
| @@ -516,7 +524,8 @@ ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTen | |||||
| if (result != ge::SUCCESS) { | if (result != ge::SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "Get size from tensor fail, var_name:%s, memory_type:%d, session_id:%lu", | REPORT_CALL_ERROR("E19999", "Get size from tensor fail, var_name:%s, memory_type:%d, session_id:%lu", | ||||
| var_name.c_str(), memory_type, session_id_); | var_name.c_str(), memory_type, session_id_); | ||||
| GELOGE(result, "get size from TensorDesc failed"); | |||||
| GELOGE(result, "[Get][Size] from tensor fail, var_name:%s, memory_type:%u, session_id:%lu", | |||||
| var_name.c_str(), memory_type, session_id_); | |||||
| return result; | return result; | ||||
| } | } | ||||
| @@ -527,7 +536,8 @@ ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTen | |||||
| if (mem_resource == nullptr) { | if (mem_resource == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu", | REPORT_CALL_ERROR("E19999", "memory_type:%d invalid or New MemResource fail, session_id:%lu", | ||||
| memory_type, session_id_); | memory_type, session_id_); | ||||
| GELOGE(ge::INTERNAL_ERROR, "Alloc MemResource failed, memory_type = %u.", memory_type); | |||||
| GELOGE(ge::INTERNAL_ERROR, "[Alloc][MemResource] failed, memory_type:%u, session_id:%lu.", | |||||
| memory_type, session_id_); | |||||
| return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
| } else { | } else { | ||||
| mem_resource_map_[memory_type] = mem_resource; | mem_resource_map_[memory_type] = mem_resource; | ||||
| @@ -539,7 +549,8 @@ ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTen | |||||
| if (mem_resource == nullptr) { | if (mem_resource == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu", | REPORT_INNER_ERROR("E19999", "MemResource is invalid, memory_type:%d, session_id:%lu", | ||||
| memory_type, session_id_); | memory_type, session_id_); | ||||
| GELOGE(ge::INTERNAL_ERROR, "MemResource is invalid, memory_type = %u.", memory_type); | |||||
| GELOGE(ge::INTERNAL_ERROR, "[Check][Param] MemResource is invalid, memory_type:%u, session_id:%lu.", | |||||
| memory_type, session_id_); | |||||
| return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
| } | } | ||||
| @@ -568,14 +579,15 @@ ge::Status VarManager::AssignVarMem(const std::string &var_name, const ge::GeTen | |||||
| if (can_not_reuse_old_memory) { | if (can_not_reuse_old_memory) { | ||||
| result = mem_resource->AssignVarMem(var_name, tensor_desc_size, session_id_, mem_offset); | result = mem_resource->AssignVarMem(var_name, tensor_desc_size, session_id_, mem_offset); | ||||
| if (result != SUCCESS) { | if (result != SUCCESS) { | ||||
| GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed."); | |||||
| GELOGE(ge::INTERNAL_ERROR, "[Assign][VarMem] by offset failed, session_id:%lu.", session_id_); | |||||
| return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
| } | } | ||||
| result = var_resource_->SaveVarAddr( | result = var_resource_->SaveVarAddr( | ||||
| var_name, tensor_desc, reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type); | var_name, tensor_desc, reinterpret_cast<uint8_t *>(static_cast<uintptr_t>(mem_offset)), memory_type); | ||||
| if (result != SUCCESS) { | if (result != SUCCESS) { | ||||
| GELOGE(ge::INTERNAL_ERROR, "AssignVarMem by offset failed."); | |||||
| GELOGE(ge::INTERNAL_ERROR, "[Save][VarAddr] by offset failed, memory type:%u, session_id:%lu.", | |||||
| memory_type, session_id_); | |||||
| return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
| } | } | ||||
| } | } | ||||
| @@ -682,7 +694,8 @@ ge::Status VarManager::RenewCurVarDesc(const std::string &var_name, ge::OpDescPt | |||||
| REPORT_INNER_ERROR("E19999", "VarManager has not been init, op:%s(%s), session_id:%lu, check invalid", | REPORT_INNER_ERROR("E19999", "VarManager has not been init, op:%s(%s), session_id:%lu, check invalid", | ||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), | op_desc->GetName().c_str(), op_desc->GetType().c_str(), | ||||
| session_id_); | session_id_); | ||||
| GELOGE(ge::INTERNAL_ERROR, "VarManager has not been init."); | |||||
| GELOGE(ge::INTERNAL_ERROR, "[Check][Param] VarManager has not been init, op:%s(%s), session_id:%lu", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), session_id_); | |||||
| return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
| } | } | ||||
| return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc)); | return var_resource_->RenewCurVarDesc(var_name, std::move(op_desc)); | ||||
| @@ -728,12 +741,10 @@ ge::Status VarManager::MallocVarMemory(size_t memory_size) { | |||||
| var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize; | var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize; | ||||
| const string purpose("variables and constant op memory in training network."); | const string purpose("variables and constant op memory in training network."); | ||||
| var_mem_base = MemManager::Instance(RT_MEMORY_HBM)->MallocMemory(purpose, memory_key, var_memory_size); | |||||
| var_mem_base = MemManager::Instance().MemInstance(RT_MEMORY_HBM).MallocMemory(purpose, memory_key, var_memory_size); | |||||
| if (var_mem_base == nullptr) { | if (var_mem_base == nullptr) { | ||||
| GELOGE(ge::INTERNAL_ERROR, | |||||
| "VarManager::MallocVarMemory failed " | |||||
| "session_id = %s", | |||||
| memory_key.c_str()); | |||||
| GELOGE(ge::INTERNAL_ERROR, "[Malloc][VarMemory] failed, size:%zu, session_id:%s", | |||||
| var_memory_size, memory_key.c_str()); | |||||
| return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -745,7 +756,7 @@ uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) { | |||||
| return MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).GetRdmaBaseAddr(); | return MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).GetRdmaBaseAddr(); | ||||
| } | } | ||||
| string memory_key = std::to_string(session_id_); | string memory_key = std::to_string(session_id_); | ||||
| return MemManager::Instance(memory_type)->GetMemoryAddr(memory_key); | |||||
| return MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(memory_key); | |||||
| } | } | ||||
| uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) { | uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) { | ||||
| @@ -754,7 +765,7 @@ uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_ty | |||||
| return logic_addr; | return logic_addr; | ||||
| } | } | ||||
| string mem_key = std::to_string(session_id_); | string mem_key = std::to_string(session_id_); | ||||
| uint8_t *mem_base = MemManager::Instance(memory_type)->GetMemoryAddr(mem_key); | |||||
| uint8_t *mem_base = MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(mem_key); | |||||
| if (mem_base == nullptr) { | if (mem_base == nullptr) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -766,7 +777,7 @@ uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_ty | |||||
| ge::Status VarManager::FreeVarMemory() { | ge::Status VarManager::FreeVarMemory() { | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| string memory_key = std::to_string(SessionId()); | string memory_key = std::to_string(SessionId()); | ||||
| return MemManager::Instance(RT_MEMORY_HBM)->FreeMemory(memory_key); | |||||
| return MemManager::Instance().MemInstance(RT_MEMORY_HBM).FreeMemory(memory_key); | |||||
| } | } | ||||
| ge::Status VarManager::SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) { | ge::Status VarManager::SetTransRoad(const std::string &var_name, const VarTransRoad &trans_road) { | ||||
| @@ -813,7 +824,7 @@ Status VarManager::SetMemoryMallocSize(const map<string, string> &options) { | |||||
| string graph_memory_manager_malloc_max_size = it->second; | string graph_memory_manager_malloc_max_size = it->second; | ||||
| ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_); | ge::Status ret = ParseMemoryMallocSize(graph_memory_manager_malloc_max_size, graph_mem_max_size_); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse graph memory manager malloc max size failed."); | |||||
| GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); | |||||
| return ge::GE_GRAPH_OPTIONS_INVALID; | return ge::GE_GRAPH_OPTIONS_INVALID; | ||||
| } | } | ||||
| GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_); | GELOGI("The max size for graph mem is set to %zu", graph_mem_max_size_); | ||||
| @@ -826,7 +837,7 @@ Status VarManager::SetMemoryMallocSize(const map<string, string> &options) { | |||||
| string memory_var_manager_malloc_size = it->second; | string memory_var_manager_malloc_size = it->second; | ||||
| ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_); | ge::Status ret = ParseMemoryMallocSize(memory_var_manager_malloc_size, var_mem_max_size_); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "Parse memory var manager malloc size failed."); | |||||
| GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Call][ParseMemoryMallocSize] failed, session id:%lu.", session_id_); | |||||
| return ge::GE_GRAPH_OPTIONS_INVALID; | return ge::GE_GRAPH_OPTIONS_INVALID; | ||||
| } | } | ||||
| } | } | ||||
| @@ -835,8 +846,8 @@ Status VarManager::SetMemoryMallocSize(const map<string, string> &options) { | |||||
| if (var_mem_logic_base_ > kMaxMemorySize) { | if (var_mem_logic_base_ > kMaxMemorySize) { | ||||
| REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid", | REPORT_INNER_ERROR("E19999", "var_login_base:%zu can not exeed limit:%zu, session_id:%lu, check invalid", | ||||
| var_mem_logic_base_, kMaxMemorySize, session_id_); | var_mem_logic_base_, kMaxMemorySize, session_id_); | ||||
| GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "kMemoryVarLogicBase : %zu can not exceed max memory size : %zu.", | |||||
| var_mem_logic_base_, kMaxMemorySize); | |||||
| GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Check][Param] kMemoryVarLogicBase:%zu can not exceed " | |||||
| "max memory size:%zu, session_id:%lu.", var_mem_logic_base_, kMaxMemorySize, session_id_); | |||||
| return ge::GE_GRAPH_OPTIONS_INVALID; | return ge::GE_GRAPH_OPTIONS_INVALID; | ||||
| } | } | ||||
| @@ -844,8 +855,8 @@ Status VarManager::SetMemoryMallocSize(const map<string, string> &options) { | |||||
| if (use_max_mem_size_ > kMaxMemorySize) { | if (use_max_mem_size_ > kMaxMemorySize) { | ||||
| REPORT_INNER_ERROR("E19999", "all mem_use size:%zu can not exeed limit:%zu, session_id:%lu, check invalid", | REPORT_INNER_ERROR("E19999", "all mem_use size:%zu can not exeed limit:%zu, session_id:%lu, check invalid", | ||||
| use_max_mem_size_, kMaxMemorySize, session_id_); | use_max_mem_size_, kMaxMemorySize, session_id_); | ||||
| GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "kUseMaxMemorySize : %zu can not exceed max memory size : %zu.", | |||||
| use_max_mem_size_, kMaxMemorySize); | |||||
| GELOGE(ge::GE_GRAPH_OPTIONS_INVALID, "[Check][Param] kUseMaxMemorySize:%zu can not exceed " | |||||
| "max memory size:%zu, session_id:%lu.", use_max_mem_size_, kMaxMemorySize, session_id_); | |||||
| return ge::GE_GRAPH_OPTIONS_INVALID; | return ge::GE_GRAPH_OPTIONS_INVALID; | ||||
| } | } | ||||
| GELOGI("Set memory malloc size successfully"); | GELOGI("Set memory malloc size successfully"); | ||||
| @@ -856,7 +867,7 @@ Status VarManager::ParseMemoryMallocSize(string &memory_size, size_t &result) { | |||||
| if (memory_size.empty()) { | if (memory_size.empty()) { | ||||
| REPORT_INNER_ERROR("E19999", "Param memory_size is empty, session_id:%lu, check invalid", | REPORT_INNER_ERROR("E19999", "Param memory_size is empty, session_id:%lu, check invalid", | ||||
| session_id_); | session_id_); | ||||
| GELOGE(GE_GRAPH_OPTIONS_INVALID, "Memory malloc size input is empty."); | |||||
| GELOGE(GE_GRAPH_OPTIONS_INVALID, "[Check][Param] Memory malloc size input is empty, session_id:%lu.", session_id_); | |||||
| return GE_GRAPH_OPTIONS_INVALID; | return GE_GRAPH_OPTIONS_INVALID; | ||||
| } | } | ||||
| // split string by '*' | // split string by '*' | ||||
| @@ -883,7 +894,9 @@ Status VarManager::ParseMemoryMallocSize(string &memory_size, size_t &result) { | |||||
| if (!isdigit(c)) { | if (!isdigit(c)) { | ||||
| REPORT_INNER_ERROR("E19999", "Param memory_size:%s contains non digit, session_id:%lu, check invalid", | REPORT_INNER_ERROR("E19999", "Param memory_size:%s contains non digit, session_id:%lu, check invalid", | ||||
| memory_size.c_str(), session_id_); | memory_size.c_str(), session_id_); | ||||
| GELOGE(GE_GRAPH_OPTIONS_INVALID, "Memory malloc size input contains non digit."); | |||||
| GELOGE(GE_GRAPH_OPTIONS_INVALID, | |||||
| "[Check][Param] Memory malloc size:%s input contains non digit, session_id:%lu.", | |||||
| memory_size.c_str(), session_id_); | |||||
| return GE_GRAPH_OPTIONS_INVALID; | return GE_GRAPH_OPTIONS_INVALID; | ||||
| } | } | ||||
| } | } | ||||
| @@ -892,13 +905,15 @@ Status VarManager::ParseMemoryMallocSize(string &memory_size, size_t &result) { | |||||
| REPORT_INNER_ERROR("E19999", "Param memory_size:%s will overflow after multi all, session_id:%lu, " | REPORT_INNER_ERROR("E19999", "Param memory_size:%s will overflow after multi all, session_id:%lu, " | ||||
| "check invalid", memory_size.c_str(), | "check invalid", memory_size.c_str(), | ||||
| session_id_); | session_id_); | ||||
| GELOGE(FAILED, "Input memory size is out of range."); | |||||
| GELOGE(FAILED, "[Check][Param] Param memory_size:%s will overflow after multi all, session_id:%lu", | |||||
| memory_size.c_str(), session_id_); | |||||
| return FAILED); | return FAILED); | ||||
| if ((num > kMaxMemorySize) || (result * static_cast<size_t>(num) > kMaxMemorySize)) { | if ((num > kMaxMemorySize) || (result * static_cast<size_t>(num) > kMaxMemorySize)) { | ||||
| REPORT_INNER_ERROR("E19999", "Param memory_size:%s after multi will exceed limit:%lu, session_id:%lu, " | REPORT_INNER_ERROR("E19999", "Param memory_size:%s after multi will exceed limit:%lu, session_id:%lu, " | ||||
| "check invalid", memory_size.c_str(), kMaxMemorySize, | "check invalid", memory_size.c_str(), kMaxMemorySize, | ||||
| session_id_); | session_id_); | ||||
| GELOGE(FAILED, "Input memory size can not exceed max memory size : %zu.", kMaxMemorySize); | |||||
| GELOGE(FAILED, "[Check][Param] Input memory size can not exceed max memory size:%zu, session_id:%lu.", | |||||
| kMaxMemorySize, session_id_); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| result *= static_cast<size_t>(num); | result *= static_cast<size_t>(num); | ||||
| @@ -1002,10 +1017,7 @@ VarManager *VarManagerPool::GetVarManager(uint64_t session_id) { | |||||
| VarManager *var_manager = new (std::nothrow) VarManager(session_id); | VarManager *var_manager = new (std::nothrow) VarManager(session_id); | ||||
| if (var_manager == nullptr) { | if (var_manager == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "New VarManager fail, session_id:%lu", session_id); | REPORT_INNER_ERROR("E19999", "New VarManager fail, session_id:%lu", session_id); | ||||
| GELOGE(INTERNAL_ERROR, | |||||
| "VarManager::Instance find session by " | |||||
| "session_id[%lu] failed.", | |||||
| session_id); | |||||
| GELOGE(INTERNAL_ERROR, "[New][VarManager] fail, session_id:%lu", session_id); | |||||
| static VarManager new_var_manager(0); | static VarManager new_var_manager(0); | ||||
| return &new_var_manager; | return &new_var_manager; | ||||
| } | } | ||||
| @@ -21,7 +21,10 @@ | |||||
| namespace ge { | namespace ge { | ||||
| const void *HostMemAllocator::Malloc(const std::shared_ptr<AlignedPtr> &aligned_ptr, size_t size) { | const void *HostMemAllocator::Malloc(const std::shared_ptr<AlignedPtr> &aligned_ptr, size_t size) { | ||||
| if (aligned_ptr == nullptr) { | if (aligned_ptr == nullptr) { | ||||
| GELOGW("Insert a null aligned_ptr"); | |||||
| GELOGW("Insert a null aligned_ptr, size=%zu", size); | |||||
| if (size == 0) { | |||||
| allocated_blocks_[nullptr] = { size, nullptr }; | |||||
| } | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| GELOGD("allocate existed host memory succ, size=%zu", size); | GELOGD("allocate existed host memory succ, size=%zu", size); | ||||
| @@ -34,8 +37,8 @@ uint8_t *HostMemAllocator::Malloc(size_t size) { | |||||
| std::lock_guard<std::mutex> lock(mutex_); | std::lock_guard<std::mutex> lock(mutex_); | ||||
| std::shared_ptr<AlignedPtr> aligned_ptr = MakeShared<AlignedPtr>(size); | std::shared_ptr<AlignedPtr> aligned_ptr = MakeShared<AlignedPtr>(size); | ||||
| if (aligned_ptr == nullptr) { | if (aligned_ptr == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "New AlignedPtr fail"); | |||||
| GELOGE(INTERNAL_ERROR, "make shared_ptr for AlignedPtr failed"); | |||||
| REPORT_INNER_ERROR("E19999", "New AlignedPtr fail, size:%zu", size); | |||||
| GELOGE(INTERNAL_ERROR, "[Call][MakeShared] for AlignedPtr failed, size:%zu", size); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| allocated_blocks_[aligned_ptr->Get()] = { size, aligned_ptr }; | allocated_blocks_[aligned_ptr->Get()] = { size, aligned_ptr }; | ||||
| @@ -46,7 +49,7 @@ uint8_t *HostMemAllocator::Malloc(size_t size) { | |||||
| Status HostMemAllocator::Free(const void *memory_addr) { | Status HostMemAllocator::Free(const void *memory_addr) { | ||||
| if (memory_addr == nullptr) { | if (memory_addr == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Param memory_addr is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Param memory_addr is nullptr, check invalid"); | ||||
| GELOGE(GE_GRAPH_FREE_FAILED, "Invalid memory pointer"); | |||||
| GELOGE(GE_GRAPH_FREE_FAILED, "[Check][Param] Invalid memory pointer"); | |||||
| return GE_GRAPH_FREE_FAILED; | return GE_GRAPH_FREE_FAILED; | ||||
| } | } | ||||
| @@ -54,7 +57,7 @@ Status HostMemAllocator::Free(const void *memory_addr) { | |||||
| auto it = allocated_blocks_.find(memory_addr); | auto it = allocated_blocks_.find(memory_addr); | ||||
| if (it == allocated_blocks_.end()) { | if (it == allocated_blocks_.end()) { | ||||
| REPORT_INNER_ERROR("E19999", "Memory_addr is not alloc before, check invalid"); | REPORT_INNER_ERROR("E19999", "Memory_addr is not alloc before, check invalid"); | ||||
| GELOGE(PARAM_INVALID, "Invalid memory pointer"); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Invalid memory pointer:%p", memory_addr); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| it->second.second.reset(); | it->second.second.reset(); | ||||
| @@ -39,9 +39,8 @@ Status SharedMemAllocator::Allocate(SharedMemInfo &mem_info) { | |||||
| rtMallocHostSharedMemoryOut output_para; | rtMallocHostSharedMemoryOut output_para; | ||||
| rtError_t rt_ret = rtMallocHostSharedMemory(&input_para, &output_para); | rtError_t rt_ret = rtMallocHostSharedMemory(&input_para, &output_para); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtMallocHostSharedMemory fail, ret:0x%X", | |||||
| rt_ret); | |||||
| GELOGE(RT_FAILED, "Call rt api(rtMallocHostSharedMemory) failed, devid:[%u].", device_id); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMallocHostSharedMemory fail, ret:0x%X", rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtMallocHostSharedMemory] failed, devid:[%u].", device_id); | |||||
| return GE_GRAPH_MEMORY_ALLOC_FAILED; | return GE_GRAPH_MEMORY_ALLOC_FAILED; | ||||
| } | } | ||||
| mem_info.fd = output_para.fd; | mem_info.fd = output_para.fd; | ||||
| @@ -60,9 +59,8 @@ Status SharedMemAllocator::DeAllocate(SharedMemInfo &mem_info) { | |||||
| mem_info.host_aligned_ptr->MutableGet(), mem_info.device_address}; | mem_info.host_aligned_ptr->MutableGet(), mem_info.device_address}; | ||||
| rtError_t rt_ret = rtFreeHostSharedMemory(&free_para); | rtError_t rt_ret = rtFreeHostSharedMemory(&free_para); | ||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtFreeHostSharedMemory fail, ret:0x%X", | |||||
| rt_ret); | |||||
| GELOGE(RT_FAILED, "Call rt api(rtFreeHostSharedMemory) failed, ret: 0x%X.", rt_ret); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtFreeHostSharedMemory fail, ret:0x%X", rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtFreeHostSharedMemory] failed, ret:0x%X.", rt_ret); | |||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| @@ -78,7 +76,7 @@ Status HostMemManager::Initialize() { | |||||
| allocator_ = std::unique_ptr<SharedMemAllocator>(new (std::nothrow) SharedMemAllocator()); | allocator_ = std::unique_ptr<SharedMemAllocator>(new (std::nothrow) SharedMemAllocator()); | ||||
| if (allocator_ == nullptr) { | if (allocator_ == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "New SharedMemAllocator fail"); | REPORT_CALL_ERROR("E19999", "New SharedMemAllocator fail"); | ||||
| GELOGE(GE_GRAPH_MALLOC_FAILED, "Shared memory allocator init failed!"); | |||||
| GELOGE(GE_GRAPH_MALLOC_FAILED, "[New][SharedMemAllocator] failed!"); | |||||
| return GE_GRAPH_MALLOC_FAILED; | return GE_GRAPH_MALLOC_FAILED; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -98,9 +96,8 @@ Status HostMemManager::MallocSharedMemory(SharedMemInfo &mem_info) { | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| auto iter = var_memory_base_map_.find(mem_info.op_name); | auto iter = var_memory_base_map_.find(mem_info.op_name); | ||||
| if (iter != var_memory_base_map_.end()) { | if (iter != var_memory_base_map_.end()) { | ||||
| REPORT_INNER_ERROR("E19999", "MemInfo.op_name:%s can't find in var_memory_base_map_", | |||||
| mem_info.op_name.c_str()); | |||||
| GELOGE(FAILED, "Host shared memory for op %s has been malloced", mem_info.op_name.c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "Host shared memory for op %s has been malloced", mem_info.op_name.c_str()); | |||||
| GELOGE(FAILED, "[Check][Param] Host shared memory for op %s has been malloced", mem_info.op_name.c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| mem_info.shm_name = OpNameToShmName(mem_info.op_name); | mem_info.shm_name = OpNameToShmName(mem_info.op_name); | ||||
| @@ -113,9 +110,8 @@ Status HostMemManager::MallocSharedMemory(SharedMemInfo &mem_info) { | |||||
| Status HostMemManager::QueryVarMemInfo(const string &op_name, uint64_t &base_addr, uint64_t &data_size) { | Status HostMemManager::QueryVarMemInfo(const string &op_name, uint64_t &base_addr, uint64_t &data_size) { | ||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| if (var_memory_base_map_.find(op_name) == var_memory_base_map_.end()) { | if (var_memory_base_map_.find(op_name) == var_memory_base_map_.end()) { | ||||
| REPORT_INNER_ERROR("E19999", "MemInfo.op_name:%s can't find in var_memory_base_map_", | |||||
| op_name.c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "Find host base base_addr failed,node name:%s!", op_name.c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "MemInfo.op_name:%s can't find in var_memory_base_map_", op_name.c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] Find host base base_addr failed, node name:%s!", op_name.c_str()); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| base_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(var_memory_base_map_[op_name].device_address)); | base_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(var_memory_base_map_[op_name].device_address)); | ||||
| @@ -19,7 +19,7 @@ | |||||
| #include <memory> | #include <memory> | ||||
| #include "common/ge/plugin_manager.h" | #include "common/ge/plugin_manager.h" | ||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| #include "graph/manager/graph_mem_manager.h" | |||||
| #include "graph/manager/host_mem_manager.h" | #include "graph/manager/host_mem_manager.h" | ||||
| #include "graph/manager/rdma_pool_allocator.h" | #include "graph/manager/rdma_pool_allocator.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| @@ -50,9 +50,8 @@ Status RdmaRemoteRegister(const std::vector<HostVarInfo> &var_info, rtMemType_t | |||||
| path.append(file_name); | path.append(file_name); | ||||
| string canonical_path = RealPath(path.c_str()); | string canonical_path = RealPath(path.c_str()); | ||||
| if (canonical_path.empty()) { | if (canonical_path.empty()) { | ||||
| REPORT_INNER_ERROR("E19999", "canonical_path:%s is empty, check invalid", | |||||
| canonical_path.c_str()); | |||||
| GELOGE(FAILED, "Failed to get realpath of %s", path.c_str()); | |||||
| REPORT_INNER_ERROR("E19999", "canonical_path:%s is empty, check invalid", canonical_path.c_str()); | |||||
| GELOGE(FAILED, "[Call][RealPath] Failed to get realpath of %s", path.c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| GELOGI("FileName:%s, Path:%s.", file_name.c_str(), canonical_path.c_str()); | GELOGI("FileName:%s, Path:%s.", file_name.c_str(), canonical_path.c_str()); | ||||
| @@ -69,15 +68,14 @@ Status RdmaRemoteRegister(const std::vector<HostVarInfo> &var_info, rtMemType_t | |||||
| if (hcom_remote_mem_register == nullptr) { | if (hcom_remote_mem_register == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "Symbol HcomRegRemoteAccessMem can't find in %s, check invalid", | REPORT_CALL_ERROR("E19999", "Symbol HcomRegRemoteAccessMem can't find in %s, check invalid", | ||||
| canonical_path.c_str()); | canonical_path.c_str()); | ||||
| GELOGE(FAILED, "Failed to invoke hcom_remote_mem_register function."); | |||||
| GELOGE(FAILED, "[Check][Param] Symbol HcomRegRemoteAccessMem can't find in %s", canonical_path.c_str()); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| HcclResult hccl_ret = hcom_remote_mem_register(reg_addrs.get(), table_len); | HcclResult hccl_ret = hcom_remote_mem_register(reg_addrs.get(), table_len); | ||||
| if (hccl_ret != HCCL_SUCCESS) { | if (hccl_ret != HCCL_SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "Call hcom_remote_mem_register failed, ret:%d,", | |||||
| hccl_ret); | |||||
| GELOGE(HCCL_E_INTERNAL, "Rdma mem register failed, ret: 0x%X", hccl_ret); | |||||
| REPORT_CALL_ERROR("E19999", "Call hcom_remote_mem_register failed, ret:%d,", hccl_ret); | |||||
| GELOGE(HCCL_E_INTERNAL, "[Call][HcomRemoteMemRegister] Rdma mem register failed, ret:0x%X", hccl_ret); | |||||
| return HCCL_E_INTERNAL; | return HCCL_E_INTERNAL; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -88,14 +86,14 @@ Status MallocSharedMemory(const TensorInfo &tensor_info, uint64_t &dev_addr, uin | |||||
| uint32_t type_size = 0; | uint32_t type_size = 0; | ||||
| bool result = TypeUtils::GetDataTypeLength(tensor_info.data_type, type_size); | bool result = TypeUtils::GetDataTypeLength(tensor_info.data_type, type_size); | ||||
| if (!result) { | if (!result) { | ||||
| GELOGE(GRAPH_FAILED, "GetDataTypeLength failed, data_type=(%s).", | |||||
| GELOGE(GRAPH_FAILED, "[Get][DataTypeLength] failed, data_type=(%s).", | |||||
| TypeUtils::DataTypeToSerialString(tensor_info.data_type).c_str()); | TypeUtils::DataTypeToSerialString(tensor_info.data_type).c_str()); | ||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| memory_size = type_size; | memory_size = type_size; | ||||
| for (auto dim : tensor_info.dims) { | for (auto dim : tensor_info.dims) { | ||||
| if (dim <= 0) { | if (dim <= 0) { | ||||
| GELOGE(GRAPH_FAILED, "Tensor dims should be positive"); | |||||
| GELOGE(GRAPH_FAILED, "[Check][Param] Tensor dims should be positive"); | |||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| memory_size *= dim; | memory_size *= dim; | ||||
| @@ -103,7 +101,7 @@ Status MallocSharedMemory(const TensorInfo &tensor_info, uint64_t &dev_addr, uin | |||||
| SharedMemInfo mem_info(tensor_info.var_name, memory_size); | SharedMemInfo mem_info(tensor_info.var_name, memory_size); | ||||
| Status ret = HostMemManager::Instance().MallocSharedMemory(mem_info); | Status ret = HostMemManager::Instance().MallocSharedMemory(mem_info); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(GRAPH_FAILED, "MallocSharedMemory failed op name [%s]", tensor_info.var_name.c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[Malloc][SharedMemory] failed, op name [%s]", tensor_info.var_name.c_str()); | |||||
| return GRAPH_FAILED; | return GRAPH_FAILED; | ||||
| } | } | ||||
| dev_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(mem_info.device_address)); | dev_addr = reinterpret_cast<uint64_t>(reinterpret_cast<uintptr_t>(mem_info.device_address)); | ||||
| @@ -45,7 +45,7 @@ Status EventManager::Init(size_t event_num) { | |||||
| void EventManager::Release() noexcept { | void EventManager::Release() noexcept { | ||||
| for (size_t i = 0; i < this->event_list_.size(); ++i) { | for (size_t i = 0; i < this->event_list_.size(); ++i) { | ||||
| rtError_t rt_ret = rtEventDestroy(this->event_list_[i]); | rtError_t rt_ret = rtEventDestroy(this->event_list_[i]); | ||||
| RETURN_IF_COND_NOT_MET(rt_ret == RT_ERROR_NONE, "Destroy event failed, idx is %zu, ret is 0x%x.", i, rt_ret); | |||||
| RETURN_IF_COND_NOT_MET(rt_ret == RT_ERROR_NONE, "[Destroy][Event] failed, idx is %zu, ret is 0x%x.", i, rt_ret); | |||||
| } | } | ||||
| this->event_list_.clear(); | this->event_list_.clear(); | ||||
| @@ -20,6 +20,7 @@ | |||||
| #include "framework/common/debug/ge_log.h" | #include "framework/common/debug/ge_log.h" | ||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "runtime/dev.h" | #include "runtime/dev.h" | ||||
| #include "graph/manager/graph_mem_manager.h" | |||||
| namespace { | namespace { | ||||
| const size_t kAlignedSize = 512; | const size_t kAlignedSize = 512; | ||||
| @@ -49,7 +50,7 @@ RdmaPoolAllocator::RdmaPoolAllocator(rtMemType_t memory_type) | |||||
| })) {} | })) {} | ||||
| Status RdmaPoolAllocator::Initialize() { | Status RdmaPoolAllocator::Initialize() { | ||||
| memory_allocator_ = MemManager::Instance(memory_type_); | |||||
| memory_allocator_ = &MemManager::Instance().MemInstance(memory_type_); | |||||
| if (memory_allocator_ == nullptr) { | if (memory_allocator_ == nullptr) { | ||||
| return ACL_ERROR_GE_INTERNAL_ERROR; | return ACL_ERROR_GE_INTERNAL_ERROR; | ||||
| } | } | ||||
| @@ -81,8 +82,8 @@ Status RdmaPoolAllocator::InitMemory(size_t mem_size) { | |||||
| auto device_id = GetContext().DeviceId(); | auto device_id = GetContext().DeviceId(); | ||||
| GELOGD("Init Rdma Memory with size [%zu] for devid:[%u]", mem_size, device_id); | GELOGD("Init Rdma Memory with size [%zu] for devid:[%u]", mem_size, device_id); | ||||
| if (rdma_base_addr_ != nullptr) { | if (rdma_base_addr_ != nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Param rdma_base_addr_ is nullptr, check invalid"); | |||||
| GELOGE(GE_MULTI_INIT, "Rdma pool has been malloced"); | |||||
| REPORT_INNER_ERROR("E19999", "Param rdma_base_addr_ is not nullptr, devid:%u, check invalid", device_id); | |||||
| GELOGE(GE_MULTI_INIT, "[Check][Param] Rdma pool has been malloced, devid:%u", device_id); | |||||
| return GE_MULTI_INIT; | return GE_MULTI_INIT; | ||||
| } | } | ||||
| const std::string purpose = "Memory for rdma pool."; | const std::string purpose = "Memory for rdma pool."; | ||||
| @@ -94,15 +95,15 @@ Status RdmaPoolAllocator::InitMemory(size_t mem_size) { | |||||
| rdma_base_addr_ = memory_allocator_->MallocMemory(purpose, mem_size, device_id); | rdma_base_addr_ = memory_allocator_->MallocMemory(purpose, mem_size, device_id); | ||||
| if (rdma_base_addr_ == nullptr) { | if (rdma_base_addr_ == nullptr) { | ||||
| GELOGE(GE_GRAPH_MALLOC_FAILED, "Rdma pool memory malloc failed"); | |||||
| GELOGE(GE_GRAPH_MALLOC_FAILED, "[Malloc][Memory] failed, size:%zu, device_id:%u", mem_size, device_id); | |||||
| return GE_GRAPH_MALLOC_FAILED; | return GE_GRAPH_MALLOC_FAILED; | ||||
| } | } | ||||
| rdma_mem_size_ = mem_size; | rdma_mem_size_ = mem_size; | ||||
| // Init with a base block. | // Init with a base block. | ||||
| auto *base_block = new (std::nothrow) Block(device_id, mem_size, rdma_base_addr_); | auto *base_block = new (std::nothrow) Block(device_id, mem_size, rdma_base_addr_); | ||||
| if (base_block == nullptr) { | if (base_block == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "New Block failed, device_id:%u", device_id); | |||||
| GELOGE(GE_GRAPH_MALLOC_FAILED, "Block malloc failed"); | |||||
| REPORT_CALL_ERROR("E19999", "New Block failed, size:%zu, device_id:%u", mem_size, device_id); | |||||
| GELOGE(GE_GRAPH_MALLOC_FAILED, "[New][Block] failed, size:%zu, device_id:%u", mem_size, device_id); | |||||
| return GE_GRAPH_MALLOC_FAILED; | return GE_GRAPH_MALLOC_FAILED; | ||||
| } | } | ||||
| block_bin_.insert(base_block); | block_bin_.insert(base_block); | ||||
| @@ -122,7 +123,7 @@ uint8_t *RdmaPoolAllocator::Malloc(size_t size, uint32_t device_id) { | |||||
| if (block->ptr == nullptr) { | if (block->ptr == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Rdmapool memory address is nullptr, device_id:%u, check invalid", | REPORT_INNER_ERROR("E19999", "Rdmapool memory address is nullptr, device_id:%u, check invalid", | ||||
| device_id); | device_id); | ||||
| GELOGE(INTERNAL_ERROR, "Rdmapool memory address is nullptr."); | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] Rdmapool memory address is nullptr, device_id:%u", device_id); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| allocated_blocks_.emplace(block->ptr, block); | allocated_blocks_.emplace(block->ptr, block); | ||||
| @@ -154,9 +155,8 @@ uint8_t *RdmaPoolAllocator::Malloc(size_t size, uint32_t device_id) { | |||||
| Status RdmaPoolAllocator::Free(uint8_t *memory_addr, uint32_t device_id) { | Status RdmaPoolAllocator::Free(uint8_t *memory_addr, uint32_t device_id) { | ||||
| GELOGI("Free rdma memory, device id = %u", device_id); | GELOGI("Free rdma memory, device id = %u", device_id); | ||||
| if (memory_addr == nullptr) { | if (memory_addr == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Param memory_addr is nullptr, device_id:%u, check invalid", | |||||
| device_id); | |||||
| GELOGE(GE_GRAPH_FREE_FAILED, "Invalid memory pointer"); | |||||
| REPORT_INNER_ERROR("E19999", "Param memory_addr is nullptr, device_id:%u, check invalid", device_id); | |||||
| GELOGE(GE_GRAPH_FREE_FAILED, "[Check][Param] Invalid memory pointer, device id:%u", device_id); | |||||
| return GE_GRAPH_FREE_FAILED; | return GE_GRAPH_FREE_FAILED; | ||||
| } | } | ||||
| @@ -165,7 +165,7 @@ Status RdmaPoolAllocator::Free(uint8_t *memory_addr, uint32_t device_id) { | |||||
| if (it == allocated_blocks_.end()) { | if (it == allocated_blocks_.end()) { | ||||
| REPORT_INNER_ERROR("E19999", "Param memory_addr is not allocated before, device_id:%u, " | REPORT_INNER_ERROR("E19999", "Param memory_addr is not allocated before, device_id:%u, " | ||||
| "check invalid", device_id); | "check invalid", device_id); | ||||
| GELOGE(PARAM_INVALID, "Invalid memory pointer"); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Invalid memory pointer, device id:%u", device_id); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -208,7 +208,7 @@ void RdmaPoolAllocator::MergeBlocks(Block *dst, Block *src) { | |||||
| Status RdmaPoolAllocator::GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size) { | Status RdmaPoolAllocator::GetBaseAddr(uint64_t &base_addr, uint64_t &mem_size) { | ||||
| if (rdma_base_addr_ == nullptr) { | if (rdma_base_addr_ == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Param rdma_base_addr_ is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Param rdma_base_addr_ is nullptr, check invalid"); | ||||
| GELOGE(INTERNAL_ERROR, "Rdma base addr is nullptr."); | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] Rdma base addr is nullptr."); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| base_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(rdma_base_addr_)); | base_addr = static_cast<uint64_t>(reinterpret_cast<uintptr_t>(rdma_base_addr_)); | ||||
| @@ -0,0 +1,83 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/manager/session_scope_mem_allocator.h" | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <utility> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/manager/graph_mem_manager.h" | |||||
| namespace ge { | |||||
| SessionScopeMemAllocator::SessionScopeMemAllocator(rtMemType_t memory_type) | |||||
| : memory_type_(memory_type), memory_allocator_(nullptr) {} | |||||
| Status SessionScopeMemAllocator::Initialize(uint32_t device_id) { | |||||
| GELOGI("Device id %u", device_id); | |||||
| // when redo Initialize free old memory | |||||
| FreeAllMemory(); | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
| memory_allocator_ = &MemManager::Instance().MemInstance(memory_type_); | |||||
| if (memory_allocator_ == nullptr) { | |||||
| return ACL_ERROR_GE_INTERNAL_ERROR; | |||||
| } | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| void SessionScopeMemAllocator::Finalize(uint32_t device_id) { | |||||
| GELOGI("Device id %u", device_id); | |||||
| FreeAllMemory(); | |||||
| } | |||||
| uint8_t *SessionScopeMemAllocator::Malloc(size_t size, uint64_t session_id, uint32_t device_id) { | |||||
| GELOGI("Start malloc memory, size:%zu, session id:%lu device id:%u", size, session_id, device_id); | |||||
| const std::string purpose = "Memory for session scope."; | |||||
| auto ptr = memory_allocator_->MallocMemory(purpose, size, device_id); | |||||
| if (ptr == nullptr) { | |||||
| GELOGE(ge::FAILED, "Malloc failed, no enough memory for size:%zu, session_id:%lu device_id:%u", size, | |||||
| session_id, device_id); | |||||
| return nullptr; | |||||
| } | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
| std::shared_ptr<uint8_t> mem_ptr(ptr, [&](uint8_t *p) { (void)memory_allocator_->FreeMemory(p); }); | |||||
| allocated_memory_[session_id].emplace_back(size, mem_ptr); | |||||
| return ptr; | |||||
| } | |||||
| Status SessionScopeMemAllocator::Free(uint64_t session_id, uint32_t device_id) { | |||||
| GELOGI("Free session:%lu memory, device id:%u.", session_id, device_id); | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
| auto it = allocated_memory_.find(session_id); | |||||
| if (it == allocated_memory_.end()) { | |||||
| GELOGW("Invalid session_id"); | |||||
| return ge::PARAM_INVALID; | |||||
| } | |||||
| allocated_memory_.erase(it); | |||||
| return ge::SUCCESS; | |||||
| } | |||||
| void SessionScopeMemAllocator::FreeAllMemory() { | |||||
| GELOGI("Free all memory"); | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | |||||
| for (auto &session_mem : allocated_memory_) { | |||||
| session_mem.second.clear(); | |||||
| } | |||||
| allocated_memory_.clear(); | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,124 @@ | |||||
| /** | |||||
| * Copyright 2019-2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_MANAGER_SESSION_SCOPE_MEM_ALLOCATOR_H_ | |||||
| #define GE_GRAPH_MANAGER_SESSION_SCOPE_MEM_ALLOCATOR_H_ | |||||
| #include <iostream> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <mutex> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include <unordered_map> | |||||
| #include <functional> | |||||
| #include "framework/common/ge_inner_error_codes.h" | |||||
| #include "graph/node.h" | |||||
| #include "graph/manager/block_memory.h" | |||||
| #include "runtime/mem.h" | |||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| namespace ge { | |||||
| class SessionScopeMemoryInfo { | |||||
| public: | |||||
| SessionScopeMemoryInfo(size_t size, const std::shared_ptr<uint8_t> &ptr) : size(size), ptr(ptr) {} | |||||
| SessionScopeMemoryInfo() = delete; | |||||
| virtual ~SessionScopeMemoryInfo() = default; | |||||
| SessionScopeMemoryInfo(const SessionScopeMemoryInfo &other) { | |||||
| if (&other == this) { | |||||
| return; | |||||
| } | |||||
| size = other.size; | |||||
| ptr = other.ptr; | |||||
| }; | |||||
| SessionScopeMemoryInfo &operator=(const SessionScopeMemoryInfo &other) { | |||||
| if (&other == this) { | |||||
| return *this; | |||||
| } | |||||
| size = other.size; | |||||
| ptr = other.ptr; | |||||
| return *this; | |||||
| }; | |||||
| private: | |||||
| size_t size = 0; | |||||
| std::shared_ptr<uint8_t> ptr = nullptr; | |||||
| }; | |||||
| class SessionScopeMemAllocator { | |||||
| public: | |||||
| explicit SessionScopeMemAllocator(rtMemType_t memory_type); | |||||
| SessionScopeMemAllocator(const SessionScopeMemAllocator &) = delete; | |||||
| SessionScopeMemAllocator &operator=(const SessionScopeMemAllocator &) = delete; | |||||
| virtual ~SessionScopeMemAllocator() = default; | |||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @brief caching allocator init | |||||
| /// @param [in] device id | |||||
| /// @return Status of init | |||||
| /// | |||||
| Status Initialize(uint32_t device_id = 0); | |||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @brief memory allocator finalize, release all memory | |||||
| /// @return void | |||||
| /// | |||||
| void Finalize(uint32_t device_id = 0); | |||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @brief malloc memory | |||||
| /// @param [in] size memory size | |||||
| /// @param [in] session_id session id | |||||
| /// @param [in] device id | |||||
| /// @return memory address | |||||
| /// | |||||
| uint8_t *Malloc(size_t size, uint64_t session_id, uint32_t device_id = 0); | |||||
| /// | |||||
| /// @ingroup ge_graph | |||||
| /// @brief free memory | |||||
| /// @param [in] session_id session id | |||||
| /// @param [in] device_id device id | |||||
| /// @return Status result of function | |||||
| /// | |||||
| Status Free(uint64_t session_id, uint32_t device_id = 0); | |||||
| private: | |||||
| void FreeAllMemory(); | |||||
| private: | |||||
| rtMemType_t memory_type_; | |||||
| // device memory allocator | |||||
| MemoryAllocator *memory_allocator_; | |||||
| // lock around all operations | |||||
| mutable std::recursive_mutex mutex_; | |||||
| // allocated blocks by memory pointer | |||||
| std::unordered_map<uint64_t, std::vector<SessionScopeMemoryInfo>> allocated_memory_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_MANAGER_SESSION_SCOPE_MEM_ALLOCATOR_H_ | |||||
| @@ -37,7 +37,8 @@ class RtContextSwitchGuard { | |||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtCtxGetCurrent failed, device_id:%u, ret:0x%X,", | REPORT_CALL_ERROR("E19999", "Call rtCtxGetCurrent failed, device_id:%u, ret:0x%X,", | ||||
| device_id, ret); | device_id, ret); | ||||
| GELOGE(RT_FAILED, "Failed to get current context from rt, error-code %d", ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtCtxGetCurrent] Failed to get current context, device_id:%u, ret:0x%X", | |||||
| device_id, ret); | |||||
| return; | return; | ||||
| } | } | ||||
| @@ -45,15 +46,14 @@ class RtContextSwitchGuard { | |||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtCtxCreate failed, device_id:%u, ret:0x%X,", | REPORT_CALL_ERROR("E19999", "Call rtCtxCreate failed, device_id:%u, ret:0x%X,", | ||||
| device_id, ret); | device_id, ret); | ||||
| GELOGE(RT_FAILED, "Failed to create new context for device %u, error-code %d", device_id, ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtCtxCreate] Failed to create new context for device:%u, ret:%d", device_id, ret); | |||||
| return; | return; | ||||
| } | } | ||||
| ret = rtCtxSetCurrent(current_); | ret = rtCtxSetCurrent(current_); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtCtxSetCurrent failed, device_id:%u, ret:0x%X,", | |||||
| device_id, ret); | |||||
| GELOGE(RT_FAILED, "Failed to switch context to normal, context %p, device %u", current_, device_id); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtCtxSetCurrent failed, device_id:%u, ret:0x%X", device_id, ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtCtxSetCurrent] failed, device_id:%u, ret:0x%X", device_id, ret); | |||||
| return; | return; | ||||
| } | } | ||||
| GELOGD("Create and switch rt context %p type %d for device %u, backup last %p.", current_, mode, device_id, last_); | GELOGD("Create and switch rt context %p type %d for device %u, backup last %p.", current_, mode, device_id, last_); | ||||
| @@ -80,7 +80,7 @@ int64_t CalcVarSizeInBytes(const GeTensorDesc &desc) { | |||||
| if (var_size <= 0) { | if (var_size <= 0) { | ||||
| REPORT_INNER_ERROR("E19999", "Data type:%s in desc, it's size:%ld < 0, check invalid", | REPORT_INNER_ERROR("E19999", "Data type:%s in desc, it's size:%ld < 0, check invalid", | ||||
| TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str(), var_size); | TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str(), var_size); | ||||
| GELOGE(PARAM_INVALID, "Failed to calc var data size from data type %s", | |||||
| GELOGE(PARAM_INVALID, "[Calc][VarDataSize] by data type %s failed.", | |||||
| TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str()); | TypeUtils::DataTypeToSerialString(desc.GetDataType()).c_str()); | ||||
| return -1; | return -1; | ||||
| } | } | ||||
| @@ -99,7 +99,8 @@ Status CopyVarToDevice(const NodePtr &var, const formats::TransResult &trans_res | |||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, op:%s(%s), size:%lu, ret:0x%X,", var->GetName().c_str(), | REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, op:%s(%s), size:%lu, ret:0x%X,", var->GetName().c_str(), | ||||
| var->GetType().c_str(), trans_result.length, ret); | var->GetType().c_str(), trans_result.length, ret); | ||||
| GELOGE(RT_FAILED, "Failed to copy memory to device, size %zu", trans_result.length); | |||||
| GELOGE(RT_FAILED, "[Call][RtMemcpy] failed, op:%s(%s), size:%lu, ret:0x%X,", var->GetName().c_str(), | |||||
| var->GetType().c_str(), trans_result.length, ret); | |||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -111,21 +112,17 @@ Status CopyVarFromDevice(uint64_t session_id, const NodePtr &var, std::unique_pt | |||||
| GE_CHECK_NOTNULL(var); | GE_CHECK_NOTNULL(var); | ||||
| auto ret = VarManager::Instance(session_id)->GetVarAddr(var->GetName(), input_desc, &var_logic); | auto ret = VarManager::Instance(session_id)->GetVarAddr(var->GetName(), input_desc, &var_logic); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, | |||||
| "Failed to copy var %s from device, can not find it" | |||||
| " from var manager %u", | |||||
| var->GetName().c_str(), ret); | |||||
| GELOGE(INTERNAL_ERROR, "[Get][VarAddr] failed, node:%s, session_id:%lu, ret:%d", | |||||
| var->GetName().c_str(), session_id, ret); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM); | uint8_t *var_addr = VarManager::Instance(session_id)->GetVarMemoryAddr(var_logic, RT_MEMORY_HBM); | ||||
| if (var_addr == nullptr) { | if (var_addr == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "Get variable memory addr failed, mem_type:%d, op:%s(%s), session_id:%lu,", | |||||
| REPORT_CALL_ERROR("E19999", "Get variable memory addr failed, mem_type:%d, op:%s(%s), session_id:%lu", | |||||
| RT_MEMORY_HBM, var->GetName().c_str(), var->GetType().c_str(), session_id); | RT_MEMORY_HBM, var->GetName().c_str(), var->GetType().c_str(), session_id); | ||||
| GELOGE(INTERNAL_ERROR, | |||||
| "Failed to copy var %s from device, cant not get " | |||||
| "var addr from logic addr %p", | |||||
| var->GetName().c_str(), var_logic); | |||||
| GELOGE(INTERNAL_ERROR, "[Get][VarMemoryAddr] failed, mem_type:%d, op:%s(%s), session_id:%lu", | |||||
| RT_MEMORY_HBM, var->GetName().c_str(), var->GetType().c_str(), session_id); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| @@ -136,9 +133,10 @@ Status CopyVarFromDevice(uint64_t session_id, const NodePtr &var, std::unique_pt | |||||
| std::unique_ptr<uint8_t[]> var_host(new(std::nothrow) uint8_t[var_size_bytes]); | std::unique_ptr<uint8_t[]> var_host(new(std::nothrow) uint8_t[var_size_bytes]); | ||||
| if (var_host == nullptr) { | if (var_host == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "New host memory failed, size:%ld, op:%s(%s), session_id:%lu,", | |||||
| REPORT_CALL_ERROR("E19999", "New host memory failed, size:%ld, op:%s(%s), session_id:%lu", | |||||
| var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id); | var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id); | ||||
| GELOGE(OUT_OF_MEMORY, "Failed to malloc rt-host memory, size %ld", var_size_bytes); | |||||
| GELOGE(OUT_OF_MEMORY, "[New][Memory] for rt-host failed, size:%ld, op:%s(%s), session_id:%lu", | |||||
| var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id); | |||||
| return OUT_OF_MEMORY; | return OUT_OF_MEMORY; | ||||
| } | } | ||||
| @@ -147,10 +145,8 @@ Status CopyVarFromDevice(uint64_t session_id, const NodePtr &var, std::unique_pt | |||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, size:%ld, op:%s(%s), session_id:%lu, ret:0x%X", | REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, size:%ld, op:%s(%s), session_id:%lu, ret:0x%X", | ||||
| var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id, ret); | var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id, ret); | ||||
| GELOGE(RT_FAILED, | |||||
| "Failed to copy var memory from device, var %s, size %ld," | |||||
| " rt-error-code %u", | |||||
| var->GetName().c_str(), var_size_bytes, ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtMemcpy] failed, size:%ld, op:%s(%s), session_id:%lu, ret:0x%X", | |||||
| var_size_bytes, var->GetName().c_str(), var->GetType().c_str(), session_id, ret); | |||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| @@ -197,9 +193,7 @@ Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats | |||||
| formats::ShapeToString(src_shape).c_str(), | formats::ShapeToString(src_shape).c_str(), | ||||
| formats::ShapeToString(dst_shape).c_str(), | formats::ShapeToString(dst_shape).c_str(), | ||||
| TypeUtils::DataTypeToSerialString(data_type).c_str(), ret); | TypeUtils::DataTypeToSerialString(data_type).c_str(), ret); | ||||
| GELOGE(INTERNAL_ERROR, | |||||
| "Failed to trans format from %s to %s, shape %s to %s, " | |||||
| "data type %s error code %u", | |||||
| GELOGE(INTERNAL_ERROR, "[Trans][Format] from %s to %s, shape %s to %s failed, data type %s error code %u", | |||||
| TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | TypeUtils::FormatToSerialString(src_format).c_str(), TypeUtils::FormatToSerialString(dst_format).c_str(), | ||||
| formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(), | formats::ShapeToString(src_shape).c_str(), formats::ShapeToString(dst_shape).c_str(), | ||||
| TypeUtils::DataTypeToSerialString(data_type).c_str(), ret); | TypeUtils::DataTypeToSerialString(data_type).c_str(), ret); | ||||
| @@ -221,7 +215,7 @@ Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats | |||||
| TypeUtils::DataTypeToSerialString(src_data_type).c_str(), | TypeUtils::DataTypeToSerialString(src_data_type).c_str(), | ||||
| TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), | TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), | ||||
| formats::ShapeToString(input_shape).c_str(), src_data_size, ret); | formats::ShapeToString(input_shape).c_str(), src_data_size, ret); | ||||
| GELOGE(INTERNAL_ERROR, "Failed to trans data type from %s to %s, input shape %s, data size %ld, error code %u", | |||||
| GELOGE(INTERNAL_ERROR, "[Trans][DataType] from %s to %s failed, input shape %s, data size %ld, error code %u", | |||||
| TypeUtils::DataTypeToSerialString(src_data_type).c_str(), | TypeUtils::DataTypeToSerialString(src_data_type).c_str(), | ||||
| TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(), | TypeUtils::DataTypeToSerialString(dst_data_type).c_str(), formats::ShapeToString(input_shape).c_str(), | ||||
| src_data_size, ret); | src_data_size, ret); | ||||
| @@ -230,7 +224,7 @@ Status TransVarOnHost(uint8_t *var_data, const VarTransRoad &trans_road, formats | |||||
| } else { | } else { | ||||
| REPORT_INNER_ERROR("E19999", "Trans var data failed, the trans type %s does not supported, check invalid", | REPORT_INNER_ERROR("E19999", "Trans var data failed, the trans type %s does not supported, check invalid", | ||||
| trans_info.node_type.c_str()); | trans_info.node_type.c_str()); | ||||
| GELOGE(UNSUPPORTED, "Failed to trans var data, the trans type %s does not supported", | |||||
| GELOGE(UNSUPPORTED, "[Trans][VarData] failed, the trans type %s does not supported", | |||||
| trans_info.node_type.c_str()); | trans_info.node_type.c_str()); | ||||
| return UNSUPPORTED; | return UNSUPPORTED; | ||||
| } | } | ||||
| @@ -255,10 +249,8 @@ Status ReAssignVarAddr(uint64_t session_id, | |||||
| uint8_t *var_logic = nullptr; | uint8_t *var_logic = nullptr; | ||||
| Status ret = VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, &var_logic); | Status ret = VarManager::Instance(session_id)->GetVarAddr(var_name, tensor_desc, &var_logic); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, | |||||
| "Failed to get var %s device addr, can not find it" | |||||
| " from var manager %u", | |||||
| var_name.c_str(), ret); | |||||
| GELOGE(INTERNAL_ERROR, "[Get][VarAddr] failed, var name:%s, session_id:%lu, ret:%u", | |||||
| var_name.c_str(), session_id, ret); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| @@ -266,7 +258,8 @@ Status ReAssignVarAddr(uint64_t session_id, | |||||
| if (var_addr == nullptr) { | if (var_addr == nullptr) { | ||||
| REPORT_CALL_ERROR("E19999", "Get variable memory addr failed, mem_type:%d, var_name:%s, session_id:%lu,", | REPORT_CALL_ERROR("E19999", "Get variable memory addr failed, mem_type:%d, var_name:%s, session_id:%lu,", | ||||
| RT_MEMORY_HBM, var_name.c_str(), session_id); | RT_MEMORY_HBM, var_name.c_str(), session_id); | ||||
| GELOGE(INTERNAL_ERROR, "Failed to convert var %s logic addr to real addr", var_name.c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[Get][VarMemoryAddr] failed, mem_type:%d, var_name:%s, session_id:%lu", | |||||
| RT_MEMORY_HBM, var_name.c_str(), session_id); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| *var_device = var_addr; | *var_device = var_addr; | ||||
| @@ -293,9 +286,8 @@ Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t | |||||
| // Sync var data from device | // Sync var data from device | ||||
| std::unique_ptr<uint8_t[]> var_data; | std::unique_ptr<uint8_t[]> var_data; | ||||
| if (trans_road.empty()) { | if (trans_road.empty()) { | ||||
| REPORT_INNER_ERROR("E19999", "Param trans_road is empty, session_id:%lu, check invalid", | |||||
| session_id); | |||||
| GELOGE(INTERNAL_ERROR, "Failed to get trans_road, trans_road is empty."); | |||||
| REPORT_INNER_ERROR("E19999", "Param trans_road is empty, session_id:%lu, check invalid", session_id); | |||||
| GELOGE(INTERNAL_ERROR, "[Check][Param] trans_road is empty, session_id:%lu", session_id); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| const GeTensorDesc &input_desc = trans_road.begin()->input; | const GeTensorDesc &input_desc = trans_road.begin()->input; | ||||
| @@ -307,7 +299,7 @@ Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t | |||||
| formats::TransResult trans_result{}; | formats::TransResult trans_result{}; | ||||
| ret = TransVarOnHost(var_data.get(), trans_road, trans_result); | ret = TransVarOnHost(var_data.get(), trans_road, trans_result); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Failed to trans var data on host, error code %u", ret); | |||||
| GELOGE(ret, "[Call][TransVarOnHost] failed, session_id:%lu, ret:%u", session_id, ret); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -319,14 +311,15 @@ Status TransVarData(const NodePtr &var, const VarTransRoad &trans_road, uint64_t | |||||
| /// TensorDesc needs to be removed. This change is large and needs to be performed step by step. | /// TensorDesc needs to be removed. This change is large and needs to be performed step by step. | ||||
| ret = ReAssignVarAddr(session_id, var->GetName(), trans_road.rbegin()->output, &var_device); | ret = ReAssignVarAddr(session_id, var->GetName(), trans_road.rbegin()->output, &var_device); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Failed to re-assign memory on device, size %zu", trans_result.length); | |||||
| GELOGE(ret, "[Call][ReAssignVarAddr] failed, session id:%lu, op:%s, ret:%u", | |||||
| session_id, var->GetName().c_str(), ret); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| // sync new data to device | // sync new data to device | ||||
| ret = CopyVarToDevice(var, trans_result, var_device); | ret = CopyVarToDevice(var, trans_result, var_device); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "Failed to send var data to device"); | |||||
| GELOGE(ret, "[Call][CopyVarToDevice] failed, var:%s, ret:%u", var->GetName().c_str(), ret); | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| @@ -350,7 +343,10 @@ Status TransTensor(uint8_t *var_data, const NodePtr &var_src, const NodePtr &var | |||||
| TypeUtils::DataTypeToSerialString(src_data_datatype).c_str(), | TypeUtils::DataTypeToSerialString(src_data_datatype).c_str(), | ||||
| TypeUtils::DataTypeToSerialString(dst_data_datatype).c_str(), | TypeUtils::DataTypeToSerialString(dst_data_datatype).c_str(), | ||||
| src_data_shape_size, ret); | src_data_shape_size, ret); | ||||
| GELOGE(INTERNAL_ERROR, "trans var data on host failed"); | |||||
| GELOGE(INTERNAL_ERROR, "[Trans][DataType] from %s to %s failed, data size %ld, ret:%u", | |||||
| TypeUtils::DataTypeToSerialString(src_data_datatype).c_str(), | |||||
| TypeUtils::DataTypeToSerialString(dst_data_datatype).c_str(), | |||||
| src_data_shape_size, ret); | |||||
| return ret; | return ret; | ||||
| }); | }); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -366,9 +362,11 @@ Status CopyTensorFromSrcVarNode(const NodePtr &var_src, | |||||
| /// need copy value from var_fp32 to var_fp16. | /// need copy value from var_fp32 to var_fp16. | ||||
| /// [opdesc of var_src and var_dst are checked before passed in, no need to check if they are nullptr] | /// [opdesc of var_src and var_dst are checked before passed in, no need to check if they are nullptr] | ||||
| GE_IF_BOOL_EXEC(var_src == nullptr || var_dst == nullptr, | GE_IF_BOOL_EXEC(var_src == nullptr || var_dst == nullptr, | ||||
| REPORT_INNER_ERROR("E19999", "Param var_src or var_dst is empty, session_id:%lu, device_id:%u, " | |||||
| REPORT_INNER_ERROR("E19999", "Param var_src or var_dst is nullptr, session_id:%lu, device_id:%u, " | |||||
| "check invalid", session_id, device_id); | "check invalid", session_id, device_id); | ||||
| GELOGE(FAILED, "node var is nullptr"); return FAILED); | |||||
| GELOGE(FAILED, "[Check][Param] Param var_src or var_dst is nullptr, session_id:%lu, device_id:%u", | |||||
| session_id, device_id); | |||||
| return FAILED); | |||||
| // src_node output_desc (fp32) | // src_node output_desc (fp32) | ||||
| GeTensorDesc output_desc = var_src->GetOpDesc()->GetOutputDesc(0); | GeTensorDesc output_desc = var_src->GetOpDesc()->GetOutputDesc(0); | ||||
| auto src_data_type = output_desc.GetDataType(); | auto src_data_type = output_desc.GetDataType(); | ||||
| @@ -390,31 +388,45 @@ Status CopyTensorFromSrcVarNode(const NodePtr &var_src, | |||||
| RtContextSwitchGuard switch_context(RT_CTX_NORMAL_MODE, device_id); | RtContextSwitchGuard switch_context(RT_CTX_NORMAL_MODE, device_id); | ||||
| // copy from src_node | // copy from src_node | ||||
| auto ret = CopyVarFromDevice(session_id, var_src, var_src_data, output_desc); | auto ret = CopyVarFromDevice(session_id, var_src, var_src_data, output_desc); | ||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "Copy Var From Device failed"); return ret); | |||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, | |||||
| GELOGE(FAILED, "[Call][CopyVarFromDevice] failed, session id:%lu, var_src:%s", | |||||
| session_id, var_src->GetName().c_str()); | |||||
| return ret); | |||||
| // trans dtype | // trans dtype | ||||
| formats::TransResult trans_result{}; | formats::TransResult trans_result{}; | ||||
| ret = TransTensor(var_src_data.get(), var_src, var_dst, trans_result); | ret = TransTensor(var_src_data.get(), var_src, var_dst, trans_result); | ||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "trans var data on host failed"); return ret); | |||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, | |||||
| GELOGE(INTERNAL_ERROR, "[Trans][Tensor] failed, var_src:%s, var_dst:%s", | |||||
| var_src->GetName().c_str(), var_dst->GetName().c_str()); | |||||
| return ret); | |||||
| // reset src value. | // reset src value. | ||||
| void *var_device = nullptr; | void *var_device = nullptr; | ||||
| ret = ReAssignVarAddr(session_id, var_dst->GetName(), dst_tensor_desc, &var_device); | ret = ReAssignVarAddr(session_id, var_dst->GetName(), dst_tensor_desc, &var_device); | ||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(INTERNAL_ERROR, "assign mem failed"); return ret); | |||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, | |||||
| GELOGE(INTERNAL_ERROR, "[Call][ReAssignVarAddr] failed, session id:%lu, var_dst:%s", | |||||
| session_id, var_dst->GetName().c_str()); | |||||
| return ret); | |||||
| // copy to device | // copy to device | ||||
| ret = CopyVarToDevice(var_dst, trans_result, var_device); | ret = CopyVarToDevice(var_dst, trans_result, var_device); | ||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(ret, "Failed to send var data to device"); return ret); | |||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, | |||||
| GELOGE(ret, "[Call][CopyVarToDevice] failed, var_dst:%s, ret:%u", | |||||
| var_dst->GetName().c_str(), ret); | |||||
| return ret); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace | } // namespace | ||||
| Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | ||||
| uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) { | uint8_t *dst_addr, int64_t dst_addr_size, uint64_t session_id) { | ||||
| GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "dst addr is null. "); | |||||
| GE_CHK_BOOL_RET_STATUS(dst_addr != nullptr, FAILED, "[Check][Param] dst addr is nullptr."); | |||||
| uint8_t *src_host_addr = nullptr; | uint8_t *src_host_addr = nullptr; | ||||
| int64_t src_addr_size = 0; | int64_t src_addr_size = 0; | ||||
| GE_MAKE_GUARD_RTMEM(src_host_addr); | GE_MAKE_GUARD_RTMEM(src_host_addr); | ||||
| GE_CHK_STATUS_RET(SyncTensorToHost(var_name, src_tensor_desc, &src_host_addr, src_addr_size, session_id)); | GE_CHK_STATUS_RET(SyncTensorToHost(var_name, src_tensor_desc, &src_host_addr, src_addr_size, session_id)); | ||||
| GELOGI("src_addr_size: %ld, dst_addr_size: %ld", src_addr_size, dst_addr_size); | GELOGI("src_addr_size: %ld, dst_addr_size: %ld", src_addr_size, dst_addr_size); | ||||
| GE_CHK_BOOL_RET_STATUS(src_addr_size == dst_addr_size, FAILED, "var data size is not equal broadcast "); | |||||
| GE_CHK_BOOL_RET_STATUS(src_addr_size == dst_addr_size, FAILED, | |||||
| "[Check][Param] src_addr_size:%ld not equal to dst_addr_size:%ld", | |||||
| src_addr_size, dst_addr_size); | |||||
| GE_CHK_RT_RET(rtMemcpy(dst_addr, dst_addr_size, src_host_addr, src_addr_size, RT_MEMCPY_HOST_TO_DEVICE)); | GE_CHK_RT_RET(rtMemcpy(dst_addr, dst_addr_size, src_host_addr, src_addr_size, RT_MEMCPY_HOST_TO_DEVICE)); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -422,7 +434,7 @@ Status TransVarDataUtils::SyncVarData2BroadCast(const string &var_name, const ge | |||||
| Status TransVarDataUtils::SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, | Status TransVarDataUtils::SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_addr_size, const string &var_name, | ||||
| const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { | const ge::GeTensorDesc &dst_tensor_desc, uint64_t session_id) { | ||||
| GE_CHK_BOOL_RET_STATUS(src_addr != nullptr, FAILED, "src addr is null. "); | |||||
| GE_CHK_BOOL_RET_STATUS(src_addr != nullptr, FAILED, "[Check][Param] src addr is nullptr. "); | |||||
| uint8_t *host_addr = nullptr; | uint8_t *host_addr = nullptr; | ||||
| GE_MAKE_GUARD_RTMEM(host_addr); | GE_MAKE_GUARD_RTMEM(host_addr); | ||||
| GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(&host_addr), src_addr_size)); | GE_CHK_RT_RET(rtMallocHost(reinterpret_cast<void **>(&host_addr), src_addr_size)); | ||||
| @@ -436,7 +448,7 @@ Status TransVarDataUtils::SyncBroadCastData2Var(uint8_t *src_addr, int64_t src_a | |||||
| Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | Status TransVarDataUtils::SyncTensorToHost(const string &var_name, const ge::GeTensorDesc &src_tensor_desc, | ||||
| uint8_t **host_addr, int64_t &src_tensor_size, uint64_t session_id) { | uint8_t **host_addr, int64_t &src_tensor_size, uint64_t session_id) { | ||||
| GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(src_tensor_desc, src_tensor_size), "get size from TensorDesc failed"); | |||||
| GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(src_tensor_desc, src_tensor_size), "[Get][Size] from TensorDesc failed"); | |||||
| uint8_t *src_addr = nullptr; | uint8_t *src_addr = nullptr; | ||||
| GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr)); | GE_CHK_STATUS_RET(VarManager::Instance(session_id)->GetVarAddr(var_name, src_tensor_desc, &src_addr)); | ||||
| @@ -493,7 +505,8 @@ Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes, | |||||
| if (rt_ret != RT_ERROR_NONE) { | if (rt_ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtCtxSetCurrent failed, session_id:%lu, graph_id:%u, ret:0x%X,", | REPORT_CALL_ERROR("E19999", "Call rtCtxSetCurrent failed, session_id:%lu, graph_id:%u, ret:0x%X,", | ||||
| session_id, graph_id, rt_ret); | session_id, graph_id, rt_ret); | ||||
| GELOGE(RT_FAILED, "Failed to set context, error_code is: 0x%X.", rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtCtxSetCurrent] failed, session_id:%lu, graph_id:%u, ret:0x%X,", | |||||
| session_id, graph_id, rt_ret); | |||||
| return RT_ERROR_TO_GE_STATUS(rt_ret); | return RT_ERROR_TO_GE_STATUS(rt_ret); | ||||
| } | } | ||||
| uint32_t allocated_graph_id = 0; | uint32_t allocated_graph_id = 0; | ||||
| @@ -501,8 +514,8 @@ Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes, | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "Get allocated GraphId failed, session_id:%lu, graph_id:%u, ret:0x%X,", | REPORT_CALL_ERROR("E19999", "Get allocated GraphId failed, session_id:%lu, graph_id:%u, ret:0x%X,", | ||||
| session_id, graph_id, ret); | session_id, graph_id, ret); | ||||
| GELOGE(INTERNAL_ERROR, "var has not been allocated, node:%s, graph_id:%u.", node->GetName().c_str(), | |||||
| graph_id); | |||||
| GELOGE(INTERNAL_ERROR, "[Get][AllocatedGraphId] failed, node:%s, graph_id:%u.", | |||||
| node->GetName().c_str(), graph_id); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| uint32_t changed_graph_id = 0; | uint32_t changed_graph_id = 0; | ||||
| @@ -518,7 +531,8 @@ Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes, | |||||
| } | } | ||||
| ret = TransVarData(node, *trans_road, session_id); | ret = TransVarData(node, *trans_road, session_id); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(INTERNAL_ERROR, "TransVarData failed, node:%s, graph_id:%u.", node->GetName().c_str(), graph_id); | |||||
| GELOGE(INTERNAL_ERROR, "[Trans][VarData] failed, node:%s, graph_id:%u, session_id:%lu.", | |||||
| node->GetName().c_str(), graph_id, session_id); | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| VarManager::Instance(session_id)->RemoveChangedGraphId(node->GetName()); | VarManager::Instance(session_id)->RemoveChangedGraphId(node->GetName()); | ||||
| @@ -527,7 +541,7 @@ Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes, | |||||
| }, | }, | ||||
| node, session_id, context, graph_id, ErrorManager::GetInstance().GetErrorManagerContext()); | node, session_id, context, graph_id, ErrorManager::GetInstance().GetErrorManagerContext()); | ||||
| if (!f.valid()) { | if (!f.valid()) { | ||||
| GELOGE(FAILED, "Future is invalid"); | |||||
| GELOGE(FAILED, "[Check][Param] Future is invalid, session id:%lu, graph id:%u", session_id, graph_id); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| vector_future.push_back(std::move(f)); | vector_future.push_back(std::move(f)); | ||||
| @@ -537,7 +551,7 @@ Status TransVarDataUtils::TransAllVarData(const vector<NodePtr> &variable_nodes, | |||||
| for (size_t i = 0; i < vector_future.size(); ++i) { | for (size_t i = 0; i < vector_future.size(); ++i) { | ||||
| ret_status = vector_future[i].get(); | ret_status = vector_future[i].get(); | ||||
| if (ret_status != SUCCESS) { | if (ret_status != SUCCESS) { | ||||
| GELOGE(ret_status, "TransAllVarData:: trans %zu vardata failed", i); | |||||
| GELOGE(ret_status, "[Check][Param] trans %zu vardata failed", i); | |||||
| return ret_status; | return ret_status; | ||||
| } | } | ||||
| } | } | ||||
| @@ -550,7 +564,8 @@ Status TransVarDataUtils::CopyVarData(const ComputeGraphPtr &compute_graph, uint | |||||
| if (compute_graph == nullptr) { | if (compute_graph == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Param compute_graph is nullptr, session_id:%lu, device_id:%u, check invalid", | REPORT_INNER_ERROR("E19999", "Param compute_graph is nullptr, session_id:%lu, device_id:%u, check invalid", | ||||
| session_id, device_id); | session_id, device_id); | ||||
| GELOGE(FAILED, "compute_graph is nullptr"); | |||||
| GELOGE(FAILED, "[Check][Param] compute_graph is nullptr, session_id:%lu, device_id:%u", | |||||
| session_id, device_id); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -568,7 +583,10 @@ Status TransVarDataUtils::CopyVarData(const ComputeGraphPtr &compute_graph, uint | |||||
| GELOGI("current_var_node__: [%s] copy_from_var_node__: [%s].", node->GetName().c_str(), | GELOGI("current_var_node__: [%s] copy_from_var_node__: [%s].", node->GetName().c_str(), | ||||
| src_node->GetName().c_str()); | src_node->GetName().c_str()); | ||||
| auto ret = CopyTensorFromSrcVarNode(src_node, node, session_id, device_id); | auto ret = CopyTensorFromSrcVarNode(src_node, node, session_id, device_id); | ||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, GELOGE(FAILED, "copy tensor failed!"); return FAILED); | |||||
| GE_IF_BOOL_EXEC(ret != SUCCESS, | |||||
| GELOGE(FAILED, "[Copy][Tensor] failed, src_node:%s, node:%s, session_id:%lu, device_id:%u", | |||||
| src_node->GetName().c_str(), node->GetName().c_str(), session_id, device_id); | |||||
| return FAILED); | |||||
| // only copy once | // only copy once | ||||
| (void) ge::AttrUtils::SetBool(node->GetOpDesc(), "_copy_value", true); // no need to check value | (void) ge::AttrUtils::SetBool(node->GetOpDesc(), "_copy_value", true); // no need to check value | ||||
| } | } | ||||
| @@ -63,17 +63,15 @@ Status Debug::DumpDevMem(const char *file, const void *addr, int64_t size) { | |||||
| uint8_t *host_addr = nullptr; | uint8_t *host_addr = nullptr; | ||||
| rtError_t ret = rtMallocHost(reinterpret_cast<void **>(&host_addr), size); | rtError_t ret = rtMallocHost(reinterpret_cast<void **>(&host_addr), size); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtMallocHost failed, size:%zu, ret: 0x%X", | |||||
| size, ret); | |||||
| GELOGE(FAILED, "Call rt api rtMallocHost failed, ret: 0x%X", ret); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMallocHost failed, size:%zu, ret:0x%X", size, ret); | |||||
| GELOGE(FAILED, "[Call][RtMallocHost] failed, size:%zu, ret:0x%X", size, ret); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| GE_MAKE_GUARD_RTMEM(host_addr); | GE_MAKE_GUARD_RTMEM(host_addr); | ||||
| ret = rtMemcpy(host_addr, size, addr, size, RT_MEMCPY_DEVICE_TO_HOST); | ret = rtMemcpy(host_addr, size, addr, size, RT_MEMCPY_DEVICE_TO_HOST); | ||||
| if (ret != RT_ERROR_NONE) { | if (ret != RT_ERROR_NONE) { | ||||
| REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, size:%zu, ret: 0x%X", | |||||
| size, ret); | |||||
| GELOGE(FAILED, "Call rt api rtMemcpy failed, ret: 0x%X", ret); | |||||
| REPORT_CALL_ERROR("E19999", "Call rtMemcpy failed, size:%zu, ret:0x%X", size, ret); | |||||
| GELOGE(FAILED, "[Call][RtMemcpy] failed, size:%zu, ret:0x%X", size, ret); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -28,7 +28,8 @@ Status HcomOmeUtil::GetHcclDataType(const ge::ConstOpDescPtr &op_desc, | |||||
| std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | std::vector<GETaskKernelHcclInfo> &kernel_hccl_infos) { | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| if (CheckKernelHcclInfo(op_desc, kernel_hccl_infos) != SUCCESS) { | if (CheckKernelHcclInfo(op_desc, kernel_hccl_infos) != SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "HcomOmeUtil:: the number of GETaskKernelHcclInfo is invalid."); | |||||
| GELOGE(PARAM_INVALID, "[Check][KernelHcclInfo] failed, op:%s(%s).", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| GELOGI("GetHcclDataType start, node[%s], opType[%s].", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | GELOGI("GetHcclDataType start, node[%s], opType[%s].", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| @@ -40,10 +41,10 @@ Status HcomOmeUtil::GetHcclDataType(const ge::ConstOpDescPtr &op_desc, | |||||
| if (op_desc->GetType() == HCOMRECEIVE) { | if (op_desc->GetType() == HCOMRECEIVE) { | ||||
| bool ret = ge::AttrUtils::GetDataType(op_desc, HCOM_ATTR_DATA_TYPE, src_data_type); | bool ret = ge::AttrUtils::GetDataType(op_desc, HCOM_ATTR_DATA_TYPE, src_data_type); | ||||
| if (ret == false) { | if (ret == false) { | ||||
| REPORT_INNER_ERROR("E19999", "Get Attr:%s in op:%s(%s) fail", | |||||
| HCOM_ATTR_DATA_TYPE.c_str(), | |||||
| REPORT_INNER_ERROR("E19999", "Get Attr:%s in op:%s(%s) fail", HCOM_ATTR_DATA_TYPE.c_str(), | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| GELOGE(PARAM_INVALID, "op:HcomReceive, op desc no attr: dtype."); | |||||
| GELOGE(PARAM_INVALID, "[Get][Attr] %s in op:%s(%s) fail", HCOM_ATTR_DATA_TYPE.c_str(), | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| } else { | } else { | ||||
| @@ -55,13 +56,11 @@ Status HcomOmeUtil::GetHcclDataType(const ge::ConstOpDescPtr &op_desc, | |||||
| auto iter = kConstOpHcclDataType.find(static_cast<int64_t>(src_data_type)); | auto iter = kConstOpHcclDataType.find(static_cast<int64_t>(src_data_type)); | ||||
| if (iter == kConstOpHcclDataType.end()) { | if (iter == kConstOpHcclDataType.end()) { | ||||
| REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value data_type:%s, not support in kConstOpHcclDataType now, " | REPORT_INNER_ERROR("E19999", "Attr:%s in op:%s(%s), value data_type:%s, not support in kConstOpHcclDataType now, " | ||||
| "check invalid", HCOM_ATTR_DATA_TYPE.c_str(), | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||||
| ge::TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
| GELOGE(PARAM_INVALID, | |||||
| "HcomOmeUtil:: Node: %s Optype: %s HcomDataType cann't support! Current Davinci Data Type : %s", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||||
| ge::TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
| "check invalid", HCOM_ATTR_DATA_TYPE.c_str(), op_desc->GetName().c_str(), | |||||
| op_desc->GetType().c_str(), ge::TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s in op:%s(%s), value data_type:%s, " | |||||
| "not support in kConstOpHcclDataType now", HCOM_ATTR_DATA_TYPE.c_str(), op_desc->GetName().c_str(), | |||||
| op_desc->GetType().c_str(), ge::TypeUtils::DataTypeToSerialString(src_data_type).c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -73,7 +72,7 @@ Status HcomOmeUtil::GetHcclDataType(const ge::ConstOpDescPtr &op_desc, | |||||
| Status HcomOmeUtil::GetHcclTypeSize(HcclDataType data_type, int32_t &size) { | Status HcomOmeUtil::GetHcclTypeSize(HcclDataType data_type, int32_t &size) { | ||||
| auto iter = kConstOpHcclDataTypeSize.find(data_type); | auto iter = kConstOpHcclDataTypeSize.find(data_type); | ||||
| GE_CHK_BOOL_EXEC(iter != kConstOpHcclDataTypeSize.end(), return PARAM_INVALID, | GE_CHK_BOOL_EXEC(iter != kConstOpHcclDataTypeSize.end(), return PARAM_INVALID, | ||||
| "HcomOmeUtil::HcomDataTypeSize , No DataTypeSize!"); | |||||
| "[Check][Param] param data_type:%d not find", data_type); | |||||
| size = iter->second; | size = iter->second; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -83,21 +82,22 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, HcclDataType | |||||
| int &count) { | int &count) { | ||||
| GE_CHECK_NOTNULL(op_desc); | GE_CHECK_NOTNULL(op_desc); | ||||
| if (!IsHCOMOp(op_desc->GetType())) { | if (!IsHCOMOp(op_desc->GetType())) { | ||||
| REPORT_INNER_ERROR("E19999", "Op:%s(%s) is not hcom op, check invalid", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| GELOGE(PARAM_INVALID, "HcomOmeUtil:: operator is not Hcom operator."); | |||||
| REPORT_INNER_ERROR("E19999", "Op:%s(%s) is not hcom op, check invalid", op_desc->GetName().c_str(), | |||||
| op_desc->GetType().c_str()); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Op:%s(%s) is not hcom op", op_desc->GetName().c_str(), | |||||
| op_desc->GetType().c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| int64_t total_size = 0; | int64_t total_size = 0; | ||||
| int64_t align_size = 512; | int64_t align_size = 512; | ||||
| int32_t size = 0; | int32_t size = 0; | ||||
| GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclTypeSize(data_type, size), "GetHcomCount: GetHcclTypeSize fail!"); | |||||
| GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclTypeSize(data_type, size), "[Get][HcclTypeSize] fail, datatype:%d", data_type); | |||||
| if (op_desc->GetType() == HCOMRECEIVE) { | if (op_desc->GetType() == HCOMRECEIVE) { | ||||
| for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { | for (size_t i = 0; i < op_desc->GetOutputsSize(); ++i) { | ||||
| int64_t output_size = 0; | int64_t output_size = 0; | ||||
| GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(i)); | GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(i)); | ||||
| GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetOutputDescPtr(i), output_size), | GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetOutputDescPtr(i), output_size), | ||||
| "Get size from TensorDesc failed, op: %s, output index: %zu.", op_desc->GetName().c_str(), i); | |||||
| "[Get][Size] from TensorDesc failed, op:%s, output index:%zu.", op_desc->GetName().c_str(), i); | |||||
| output_size = (output_size + align_size - 1) / align_size * align_size; | output_size = (output_size + align_size - 1) / align_size * align_size; | ||||
| total_size += output_size; | total_size += output_size; | ||||
| } | } | ||||
| @@ -107,42 +107,48 @@ Status HcomOmeUtil::GetHcomCount(const ge::ConstOpDescPtr &op_desc, HcclDataType | |||||
| int64_t block_size = 0; | int64_t block_size = 0; | ||||
| GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(i)); | GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(i)); | ||||
| GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetInputDescPtr(i), input_size), | GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetInputDescPtr(i), input_size), | ||||
| "get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); | |||||
| "[Get][Size] from TensorDesc failed, op:%s, input index:%zu", op_desc->GetName().c_str(), i); | |||||
| // dynamic shape hccl op get size from output tensor desc | // dynamic shape hccl op get size from output tensor desc | ||||
| if (op_desc->HasAttr(ATTR_NAME_IS_UNKNOWN_SHAPE)) { | if (op_desc->HasAttr(ATTR_NAME_IS_UNKNOWN_SHAPE)) { | ||||
| GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(i)); | GE_CHECK_NOTNULL(op_desc->GetOutputDescPtr(i)); | ||||
| GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetOutputDescPtr(i), input_size), | GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetOutputDescPtr(i), input_size), | ||||
| "get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); | |||||
| "[Get][Size] from TensorDesc failed, op:%s, input index:%zu", op_desc->GetName().c_str(), i); | |||||
| } | } | ||||
| GE_IF_BOOL_EXEC( | GE_IF_BOOL_EXEC( | ||||
| op_desc->GetType() == HCOMREDUCESCATTER, int32_t rank_size = 0; | op_desc->GetType() == HCOMREDUCESCATTER, int32_t rank_size = 0; | ||||
| GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(op_desc, HCOM_ATTR_RANK_SIZE, rank_size), PARAM_INVALID, | GE_CHK_BOOL_RET_STATUS(ge::AttrUtils::GetInt(op_desc, HCOM_ATTR_RANK_SIZE, rank_size), PARAM_INVALID, | ||||
| "get HCOM_ATTR_RANK_SIZE failed"); | |||||
| GE_CHK_BOOL_RET_STATUS(rank_size != 0, PARAM_INVALID, "rank size is zero"); | |||||
| int64_t shape_size = op_desc->GetInputDescPtr(i)->GetShape().GetShapeSize(); GE_CHK_STATUS_RET( | |||||
| ge::CheckInt64Uint32MulOverflow(shape_size, size), "Product of shape size and size beyond INT64_MAX"); | |||||
| "[Get][Attr] %s in op:%s(%s) failed", HCOM_ATTR_RANK_SIZE.c_str(), | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| GE_CHK_BOOL_RET_STATUS(rank_size != 0, PARAM_INVALID, "[Check][Param] rank size is zero"); | |||||
| int64_t shape_size = op_desc->GetInputDescPtr(i)->GetShape().GetShapeSize(); | |||||
| GE_CHK_STATUS_RET(ge::CheckInt64Uint32MulOverflow(shape_size, size), | |||||
| "[Check][Param] Product of shape size:%ld and size:%d beyond INT64_MAX, op:%s(%s)", | |||||
| shape_size, size, op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| block_size = (shape_size * size) / rank_size; | block_size = (shape_size * size) / rank_size; | ||||
| GE_CHK_STATUS_RET(ge::CheckInt64AddOverflow(total_size, block_size), "Total size is beyond the INT64_MAX"); | |||||
| GE_CHK_STATUS_RET(ge::CheckInt64AddOverflow(total_size, block_size), | |||||
| "[Check][Param] Total size:%ld is beyond the INT64_MAX, op:%s(%s)", | |||||
| total_size, op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| total_size = total_size + block_size; continue;); | total_size = total_size + block_size; continue;); | ||||
| int64_t shape_size = op_desc->GetInputDescPtr(i)->GetShape().GetShapeSize(); | int64_t shape_size = op_desc->GetInputDescPtr(i)->GetShape().GetShapeSize(); | ||||
| GELOGD("hcom util node %s inputsize %ld, shapesize %ld, datasize %d.", | GELOGD("hcom util node %s inputsize %ld, shapesize %ld, datasize %d.", | ||||
| op_desc->GetName().c_str(), input_size, shape_size, size); | op_desc->GetName().c_str(), input_size, shape_size, size); | ||||
| GE_CHK_STATUS_RET(ge::CheckInt64Int32MulOverflow(shape_size, size), | GE_CHK_STATUS_RET(ge::CheckInt64Int32MulOverflow(shape_size, size), | ||||
| "Product of shape size and size beyond INT64_MAX"); | |||||
| "[Check][Param] Product of shape size:%ld and size:%d beyond INT64_MAX", shape_size, size); | |||||
| GE_IF_BOOL_EXEC(is_allgather, block_size = shape_size * size;); | GE_IF_BOOL_EXEC(is_allgather, block_size = shape_size * size;); | ||||
| GE_IF_BOOL_EXEC(!is_allgather, block_size = (input_size + align_size - 1) / align_size * align_size;); | GE_IF_BOOL_EXEC(!is_allgather, block_size = (input_size + align_size - 1) / align_size * align_size;); | ||||
| GE_CHK_STATUS_RET(ge::CheckInt64AddOverflow(total_size, block_size), "Total size is beyond the INT64_MAX"); | |||||
| GE_CHK_STATUS_RET(ge::CheckInt64AddOverflow(total_size, block_size), | |||||
| "[Check][Param] Total size:%ld is beyond the INT64_MAX", total_size); | |||||
| total_size = total_size + block_size; | total_size = total_size + block_size; | ||||
| } | } | ||||
| } | } | ||||
| GE_CHK_BOOL_RET_STATUS(size != 0, PARAM_INVALID, "Size is zero"); | |||||
| GE_CHK_BOOL_RET_STATUS(size != 0, PARAM_INVALID, "[Check][Param] Size is zero"); | |||||
| count = static_cast<int>(total_size / size); | count = static_cast<int>(total_size / size); | ||||
| GE_CHK_BOOL_EXEC(total_size % size == 0, return PARAM_INVALID, "total_size:%ld is not divisiable by size:%d.", | |||||
| total_size, size); | |||||
| GE_CHK_BOOL_EXEC(total_size % size == 0, return PARAM_INVALID, | |||||
| "[Check][Param] total_size:%ld is not divisiable by size:%d.", total_size, size); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -153,32 +159,34 @@ Status HcomOmeUtil::GetHorovodCount(const ge::ConstOpDescPtr &op_desc, | |||||
| if (!IsHorovodOp(op_desc->GetType())) { | if (!IsHorovodOp(op_desc->GetType())) { | ||||
| REPORT_INNER_ERROR("E19999", "Op:%s(%s) is not horovod op, check invalid", | REPORT_INNER_ERROR("E19999", "Op:%s(%s) is not horovod op, check invalid", | ||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| GELOGE(PARAM_INVALID, "HcomOmeUtil:: operator is not Horovod operator."); | |||||
| GELOGE(PARAM_INVALID, "[Call][IsHorovodOp] failed, Op:%s(%s) is not horovod op", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| int64_t align_size = 512; | int64_t align_size = 512; | ||||
| int32_t size = 0; | int32_t size = 0; | ||||
| for (size_t i = 0; i < op_desc->GetInputsSize(); i++) { | for (size_t i = 0; i < op_desc->GetInputsSize(); i++) { | ||||
| GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclTypeSize(static_cast<HcclDataType>(kernel_hccl_infos[i].dataType), size), | GE_CHK_STATUS_RET(HcomOmeUtil::GetHcclTypeSize(static_cast<HcclDataType>(kernel_hccl_infos[i].dataType), size), | ||||
| "GetHorovodCount: GetHcclTypeSize fail!"); | |||||
| "[Call][GetHcclTypeSize] fail, op:%s(%s)", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| int64_t input_size = 0; | int64_t input_size = 0; | ||||
| int64_t block_size = 0; | int64_t block_size = 0; | ||||
| GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(i)); | GE_CHECK_NOTNULL(op_desc->GetInputDescPtr(i)); | ||||
| GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetInputDescPtr(i), input_size), | GE_CHK_STATUS_RET(ge::TensorUtils::GetSize(*op_desc->GetInputDescPtr(i), input_size), | ||||
| "get size from TensorDesc failed, op : %s, input index : %zu", op_desc->GetName().c_str(), i); | |||||
| "[Get][Size] from TensorDesc failed, op:%s, input index:%zu", op_desc->GetName().c_str(), i); | |||||
| int64_t shape_size = op_desc->GetInputDescPtr(i)->GetShape().GetShapeSize(); | int64_t shape_size = op_desc->GetInputDescPtr(i)->GetShape().GetShapeSize(); | ||||
| GE_CHK_STATUS_RET(ge::CheckInt64Int32MulOverflow(shape_size, size), | GE_CHK_STATUS_RET(ge::CheckInt64Int32MulOverflow(shape_size, size), | ||||
| "Product of shape size and size beyond INT64_MAX"); | |||||
| "[Check][Param] Product of shape size:%ld and size:%d beyond INT64_MAX", shape_size, size); | |||||
| if (kernel_hccl_infos[0].hccl_type == HVDCALLBACKALLGATHER) { | if (kernel_hccl_infos[0].hccl_type == HVDCALLBACKALLGATHER) { | ||||
| block_size = shape_size * size; | block_size = shape_size * size; | ||||
| } else { | } else { | ||||
| block_size = (input_size + align_size - 1) / align_size * align_size; | block_size = (input_size + align_size - 1) / align_size * align_size; | ||||
| } | } | ||||
| GE_CHK_BOOL_RET_STATUS(size != 0, PARAM_INVALID, "Size is zero"); | |||||
| GE_CHK_BOOL_EXEC(block_size % size == 0, return PARAM_INVALID, "block_size:%ld is not divisiable by size:%d.", | |||||
| block_size, size); | |||||
| GE_CHK_BOOL_RET_STATUS(size != 0, PARAM_INVALID, "[Check][Param] Size is zero"); | |||||
| GE_CHK_BOOL_EXEC(block_size % size == 0, return PARAM_INVALID, | |||||
| "[Check][Param] block_size:%ld is not divisiable by size:%d.", block_size, size); | |||||
| kernel_hccl_infos[i].count = static_cast<int>(block_size / size); | kernel_hccl_infos[i].count = static_cast<int>(block_size / size); | ||||
| } | } | ||||
| @@ -191,7 +199,8 @@ Status HcomOmeUtil::GetHcclCount(const ge::ConstOpDescPtr &op_desc, | |||||
| Status ret; | Status ret; | ||||
| ret = CheckKernelHcclInfo(op_desc, kernel_hccl_infos); | ret = CheckKernelHcclInfo(op_desc, kernel_hccl_infos); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "HcomOmeUtil:: the number of GETaskKernelHcclInfo is invalid."); | |||||
| GELOGE(PARAM_INVALID, "[Check][KernelHcclInfo] failed, the number of GETaskKernelHcclInfo is invalid, op:%s(%s).", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| GELOGI("GetHcclCount start, node[%s], opType[%s].", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | GELOGI("GetHcclCount start, node[%s], opType[%s].", op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| @@ -200,7 +209,7 @@ Status HcomOmeUtil::GetHcclCount(const ge::ConstOpDescPtr &op_desc, | |||||
| ret = GetHcomCount(op_desc, static_cast<HcclDataType>(kernel_hccl_infos[0].dataType), | ret = GetHcomCount(op_desc, static_cast<HcclDataType>(kernel_hccl_infos[0].dataType), | ||||
| kernel_hccl_infos[0].hccl_type == HCOMALLGATHER, count); | kernel_hccl_infos[0].hccl_type == HCOMALLGATHER, count); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "HcomOmeUtil:: Node: %s Optype: %s get the Hcom operator hccl count fail.", | |||||
| GELOGE(ret, "[Call][GetHcomCount] Node:%s Optype:%s get the Hcom operator hccl count fail.", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -210,7 +219,7 @@ Status HcomOmeUtil::GetHcclCount(const ge::ConstOpDescPtr &op_desc, | |||||
| if (IsHorovodOp(op_desc->GetType())) { | if (IsHorovodOp(op_desc->GetType())) { | ||||
| ret = GetHorovodCount(op_desc, kernel_hccl_infos); | ret = GetHorovodCount(op_desc, kernel_hccl_infos); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "HcomOmeUtil:: Node: %s Optype: %s get the Horovod hccl operator count fail.", | |||||
| GELOGE(PARAM_INVALID, "[Call][GetHorovodCount] Node:%s Optype:%s get the Horovod hccl operator count fail.", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -225,11 +234,10 @@ Status HcomOmeUtil::GetHcclOperationType(const ge::ConstOpDescPtr &op_desc, Hccl | |||||
| if (IsHCOMOp(op_desc->GetType())) { | if (IsHCOMOp(op_desc->GetType())) { | ||||
| std::string hcom_op_type; | std::string hcom_op_type; | ||||
| GE_CHK_BOOL_EXEC(ge::AttrUtils::GetStr(op_desc, HCOM_ATTR_REDUCE_TYPE, hcom_op_type), | GE_CHK_BOOL_EXEC(ge::AttrUtils::GetStr(op_desc, HCOM_ATTR_REDUCE_TYPE, hcom_op_type), | ||||
| REPORT_INNER_ERROR("E19999", "Get Attr:%s in op:%s(%s) fail", | |||||
| HCOM_ATTR_REDUCE_TYPE.c_str(), | |||||
| REPORT_INNER_ERROR("E19999", "Get Attr:%s in op:%s(%s) fail", HCOM_ATTR_REDUCE_TYPE.c_str(), | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| return PARAM_INVALID, | return PARAM_INVALID, | ||||
| "HcomOmeUtil:: Node: %s Optype: %s Get HCOM_ATTR_REDUCE_TYPE fail, not support!", | |||||
| "[Get][Attr] %s in op:%s(%s) fail", HCOM_ATTR_REDUCE_TYPE.c_str(), | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| if (hcom_op_type == "min") { | if (hcom_op_type == "min") { | ||||
| @@ -244,7 +252,9 @@ Status HcomOmeUtil::GetHcclOperationType(const ge::ConstOpDescPtr &op_desc, Hccl | |||||
| REPORT_INNER_ERROR("E19999", "Attr:%s in Op:%s(%s), hcom_op_type value:%s is not support now, " | REPORT_INNER_ERROR("E19999", "Attr:%s in Op:%s(%s), hcom_op_type value:%s is not support now, " | ||||
| "check invalid", HCOM_ATTR_REDUCE_TYPE.c_str(), | "check invalid", HCOM_ATTR_REDUCE_TYPE.c_str(), | ||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), hcom_op_type.c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str(), hcom_op_type.c_str()); | ||||
| GELOGE(PARAM_INVALID, "HcomOmeUtil::Get HCOM_ATTR_REDUCE_TYPE fail, [%s] not support!", hcom_op_type.c_str()); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s in Op:%s(%s), hcom_op_type value:%s is not support now", | |||||
| HCOM_ATTR_REDUCE_TYPE.c_str(), op_desc->GetName().c_str(), | |||||
| op_desc->GetType().c_str(), hcom_op_type.c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| } | } | ||||
| @@ -256,7 +266,7 @@ Status HcomOmeUtil::GetHcclOperationType(const ge::ConstOpDescPtr &op_desc, Hccl | |||||
| ATTR_HOROVOD_ATTR_REDUCE_TYPE.c_str(), | ATTR_HOROVOD_ATTR_REDUCE_TYPE.c_str(), | ||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| return PARAM_INVALID, | return PARAM_INVALID, | ||||
| "HcomOmeUtil:: Node: %s Optype: %s Get ATTR_HOROVOD_ATTR_REDUCE_TYPE fail, not support!", | |||||
| "[Get][Attr] %s in op:%s(%s) fail", ATTR_HOROVOD_ATTR_REDUCE_TYPE.c_str(), | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| auto iter = kHorovodRedOpToHcclRedOp.find(static_cast<HorovodReduceOp>(horovod_op_type)); | auto iter = kHorovodRedOpToHcclRedOp.find(static_cast<HorovodReduceOp>(horovod_op_type)); | ||||
| @@ -264,8 +274,8 @@ Status HcomOmeUtil::GetHcclOperationType(const ge::ConstOpDescPtr &op_desc, Hccl | |||||
| REPORT_INNER_ERROR("E19999", "Attr:%s in Op:%s(%s), horovod_op_type value:%ld is not support now, " | REPORT_INNER_ERROR("E19999", "Attr:%s in Op:%s(%s), horovod_op_type value:%ld is not support now, " | ||||
| "check invalid", ATTR_HOROVOD_ATTR_REDUCE_TYPE.c_str(), | "check invalid", ATTR_HOROVOD_ATTR_REDUCE_TYPE.c_str(), | ||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), horovod_op_type); | op_desc->GetName().c_str(), op_desc->GetType().c_str(), horovod_op_type); | ||||
| GELOGE(PARAM_INVALID, "HcomOmeUtil:: Node: %s Optype: %s HcomOpType cann't support! Current HcomOpType : %ld", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), horovod_op_type); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Attr:%s in Op:%s(%s), horovod_op_type value:%ld is not support now", | |||||
| ATTR_HOROVOD_ATTR_REDUCE_TYPE.c_str(), op_desc->GetName().c_str(), op_desc->GetType().c_str(), horovod_op_type); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| op_type = iter->second; | op_type = iter->second; | ||||
| @@ -281,7 +291,7 @@ Status HcomOmeUtil::GetHcclRootId(const ge::ConstOpDescPtr &op_desc, int64_t &ro | |||||
| HCOM_ATTR_ROOT_RANK.c_str(), | HCOM_ATTR_ROOT_RANK.c_str(), | ||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| return PARAM_INVALID, | return PARAM_INVALID, | ||||
| "HcomOmeUtil::Node %s Optype: %s Get HCOM_ATTR_ROOT_INDEX fail, not support!", | |||||
| "[Get][Attr] %s in op:%s(%s) fail", HCOM_ATTR_ROOT_RANK.c_str(), | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -296,7 +306,7 @@ Status HcomOmeUtil::GetAllRootId(const ge::ConstOpDescPtr &op_desc, | |||||
| int64_t root_id = 0; | int64_t root_id = 0; | ||||
| Status dmrt = GetHcclRootId(op_desc, root_id); | Status dmrt = GetHcclRootId(op_desc, root_id); | ||||
| if (dmrt != SUCCESS) { | if (dmrt != SUCCESS) { | ||||
| GELOGE(FAILED, "davinci_model: GetHcomRootId fail! domi error: %u", dmrt); | |||||
| GELOGE(FAILED, "[Get][HcclRootId] fail! domi error: %u", dmrt); | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| @@ -324,7 +334,8 @@ Status HcomOmeUtil::CheckKernelHcclInfo(const ge::ConstOpDescPtr &op_desc, | |||||
| REPORT_INNER_ERROR("E19999", "Op:%s(%s) is not hcom op or param kernel_hccl_infos.size:%zu != 1, " | REPORT_INNER_ERROR("E19999", "Op:%s(%s) is not hcom op or param kernel_hccl_infos.size:%zu != 1, " | ||||
| "check invalid", | "check invalid", | ||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), kernel_hccl_infos.size()); | op_desc->GetName().c_str(), op_desc->GetType().c_str(), kernel_hccl_infos.size()); | ||||
| GELOGE(PARAM_INVALID, "HcomOmeUtil:: in Hcom scenario, the number of GETaskKernelHcclInfo is invalid."); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Op:%s(%s) is not hcom op or param kernel_hccl_infos.size:%zu != 1", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str(), kernel_hccl_infos.size()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -337,7 +348,9 @@ Status HcomOmeUtil::CheckKernelHcclInfo(const ge::ConstOpDescPtr &op_desc, | |||||
| "in op:%s(%s), check invalid", | "in op:%s(%s), check invalid", | ||||
| kernel_hccl_infos.size(), op_desc->GetInputsSize(), | kernel_hccl_infos.size(), op_desc->GetInputsSize(), | ||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| GELOGE(PARAM_INVALID, "HcomOmeUtil:: in Horovod scenario, the number of GETaskKernelHcclInfo is invalid."); | |||||
| GELOGE(PARAM_INVALID, "Param kernel_hccl_infos.size:%zu is empty or not equal to " | |||||
| "input_desc size:%zu in op:%s(%s)", kernel_hccl_infos.size(), op_desc->GetInputsSize(), | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | |||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| } | } | ||||
| @@ -360,7 +373,7 @@ Status HcomOmeUtil::GetHorovodInputs(const ge::ConstOpDescPtr &op_desc, | |||||
| } | } | ||||
| if (CheckKernelHcclInfo(op_desc, kernel_hccl_infos) != SUCCESS) { | if (CheckKernelHcclInfo(op_desc, kernel_hccl_infos) != SUCCESS) { | ||||
| GELOGE(PARAM_INVALID, "HcomOmeUtil:: Node: %s Optype: %s the number of GETaskKernelHcclInfo is invalid.", | |||||
| GELOGE(PARAM_INVALID, "[Check][KernelHcclInfo] Node:%s Optype:%s the number of GETaskKernelHcclInfo is invalid.", | |||||
| op_desc->GetName().c_str(), op_desc->GetType().c_str()); | op_desc->GetName().c_str(), op_desc->GetType().c_str()); | ||||
| return PARAM_INVALID; | return PARAM_INVALID; | ||||
| } | } | ||||
| @@ -54,7 +54,7 @@ void VarAccelerateCtrl::SetVarChanged(const std::string &var_name) { | |||||
| void VarAccelerateCtrl::AddGraph(uint32_t graph_id, const ComputeGraphPtr &compute_graph) { | void VarAccelerateCtrl::AddGraph(uint32_t graph_id, const ComputeGraphPtr &compute_graph) { | ||||
| std::lock_guard<std::mutex> lock(mutex_); | std::lock_guard<std::mutex> lock(mutex_); | ||||
| if (compute_graph == nullptr) { | if (compute_graph == nullptr) { | ||||
| GELOGE(PARAM_INVALID, "Failed to add graph %u, the compute graph is null", graph_id); | |||||
| GELOGE(PARAM_INVALID, "[Check][Param] Failed to add graph %u, the compute graph is null", graph_id); | |||||
| return; | return; | ||||
| } | } | ||||
| auto &var_names = graph_ids_to_var_names_[graph_id]; | auto &var_names = graph_ids_to_var_names_[graph_id]; | ||||
| @@ -46,11 +46,6 @@ | |||||
| #define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) | #define REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| const std::set<std::string> kControlFlowOps{ | |||||
| STREAMACTIVE, STREAMSWITCH, STREAMMERGE, ENTER, REFENTER, LOOPCOND, NEXTITERATION, REFNEXTITERATION, EXIT, REFEXIT | |||||
| }; | |||||
| } | |||||
| using Cluster = DynamicShapePartitioner::Cluster; | using Cluster = DynamicShapePartitioner::Cluster; | ||||
| using ClusterPtr = std::shared_ptr<Cluster>; | using ClusterPtr = std::shared_ptr<Cluster>; | ||||
| @@ -279,9 +274,17 @@ Status DynamicShapePartitioner::InitClusters() { | |||||
| auto cluster = MakeShared<Cluster>(rank++, type, node, this); | auto cluster = MakeShared<Cluster>(rank++, type, node, this); | ||||
| REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster."); | REQUIRE_NOT_NULL(cluster, "Failed new memory for cluster."); | ||||
| node_2_cluster_[node] = cluster; | node_2_cluster_[node] = cluster; | ||||
| if (cluster->IsUnknownShape() && !cluster->IsControlFlow()) { | |||||
| if (cluster->IsUnknownShape()) { | |||||
| ordered_cluster_.push_back(cluster); | ordered_cluster_.push_back(cluster); | ||||
| } | } | ||||
| int64_t group_index = -1; | |||||
| if (AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index)) { | |||||
| GELOGD("[%s] is rts control flow Op, group index: %ld", node->GetName().c_str(), group_index); | |||||
| auto &control_cluster = control_clusters_[group_index]; | |||||
| control_cluster.emplace_back(cluster); | |||||
| } | |||||
| // Already sorted topologically, so access to the parent cluster is safe | // Already sorted topologically, so access to the parent cluster is safe | ||||
| for (const auto &parent : node->GetInAllNodes()) { | for (const auto &parent : node->GetInAllNodes()) { | ||||
| cluster->AddInput(node_2_cluster_[parent]); | cluster->AddInput(node_2_cluster_[parent]); | ||||
| @@ -350,14 +353,38 @@ static std::string ToString(const std::vector<ClusterPtr> &clusters) { | |||||
| } | } | ||||
| } | } | ||||
| void DynamicShapePartitioner::MergeClustersControlFlow() { | |||||
| for (const auto &item : control_clusters_) { | |||||
| const auto &control_cluster = item.second; | |||||
| auto rit = control_cluster.rbegin(); | |||||
| if (rit == control_cluster.rend()) { | |||||
| GELOGW("Invalid empty control flow cluster."); | |||||
| continue; | |||||
| } | |||||
| const auto &cluster = *rit; | |||||
| for (++rit; rit != control_cluster.rend(); ++rit) { | |||||
| const auto &cluster_from = *rit; | |||||
| auto merged_clusters = cluster->MergeAllPathFrom(cluster_from); | |||||
| GELOGD("Merge all path cluster from %lu to %lu %s.", cluster_from->Id(), cluster->Id(), | |||||
| ToString(merged_clusters).c_str()); | |||||
| for (const auto &merged_cluster : merged_clusters) { | |||||
| for (const auto &node : merged_cluster->Nodes()) { | |||||
| node_2_cluster_[node] = cluster; | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| } | |||||
| void DynamicShapePartitioner::MergeClustersUnknownShape() { | void DynamicShapePartitioner::MergeClustersUnknownShape() { | ||||
| // Merge unknown shape clusters | // Merge unknown shape clusters | ||||
| for (const auto &cluster : ordered_cluster_) { | for (const auto &cluster : ordered_cluster_) { | ||||
| if (cluster->IsIndependent() || cluster->IsControlFlow()) { | |||||
| if (cluster->IsIndependent()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| for (const auto &in_cluster : cluster->Inputs()) { | for (const auto &in_cluster : cluster->Inputs()) { | ||||
| if (!in_cluster->IsUnknownShape() || in_cluster->IsControlFlow()) { | |||||
| if (!in_cluster->IsUnknownShape()) { | |||||
| continue; | continue; | ||||
| } | } | ||||
| auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | auto merged_clusters = cluster->MergeAllPathFrom(in_cluster); | ||||
| @@ -419,6 +446,7 @@ void DynamicShapePartitioner::MergeClustersInputData() { | |||||
| } | } | ||||
| Status DynamicShapePartitioner::MergeClusters() { | Status DynamicShapePartitioner::MergeClusters() { | ||||
| MergeClustersControlFlow(); | |||||
| MergeClustersUnknownShape(); | MergeClustersUnknownShape(); | ||||
| REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters."); | REQUIRE_SUCCESS(TopologicalSortClusters(), "Failed topological sort clusters after merge unknown shape clusters."); | ||||
| MergeClustersKnownShape(); | MergeClustersKnownShape(); | ||||
| @@ -608,13 +636,6 @@ bool Cluster::IsRefVariable() const { | |||||
| return false; | return false; | ||||
| } | } | ||||
| bool Cluster::IsControlFlow() const { | |||||
| const auto &op_desc = nodes_[0]->GetOpDesc(); | |||||
| bool is_ctrl_flow = kControlFlowOps.count(op_desc->GetType()) > 0 && op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||||
| GELOGD("[%s] %s rts control flow Op ", op_desc->GetName().c_str(), is_ctrl_flow ? "Is" : "Not"); | |||||
| return is_ctrl_flow; | |||||
| } | |||||
| void Cluster::AddInput(ClusterPtr in) { | void Cluster::AddInput(ClusterPtr in) { | ||||
| if (std::find(in_clusters_.begin(), in_clusters_.end(), in) != in_clusters_.end()) return; | if (std::find(in_clusters_.begin(), in_clusters_.end(), in) != in_clusters_.end()) return; | ||||
| in_clusters_.insert(in_clusters_.end(), in); | in_clusters_.insert(in_clusters_.end(), in); | ||||
| @@ -694,10 +715,7 @@ std::vector<ClusterPtr> Cluster::MergeAllPathFrom(ClusterPtr other) { | |||||
| if (other->IsIndependent()) { | if (other->IsIndependent()) { | ||||
| return path_clusters; | return path_clusters; | ||||
| } | } | ||||
| if (std::find(other->out_clusters_.begin(), other->out_clusters_.end(), shared_from_this()) == | |||||
| other->out_clusters_.end()) { | |||||
| return path_clusters; | |||||
| } | |||||
| path_clusters.push_back(other); | path_clusters.push_back(other); | ||||
| forward_reached_queue.push(other); | forward_reached_queue.push(other); | ||||
| backward_reached_queue.push(shared_from_this()); | backward_reached_queue.push(shared_from_this()); | ||||
| @@ -761,7 +779,7 @@ InControlAnchorPtr Cluster::GetFrameInControlAnchor() { return partition_node_-> | |||||
| OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); }; | OutControlAnchorPtr Cluster::GetFrameOutControlAnchor() { return partition_node_->GetOutControlAnchor(); }; | ||||
| Status Cluster::BuildFrame() { | Status Cluster::BuildFrame() { | ||||
| if ((IsUnknownShape() || IsKnownShape() || IsInputNode()) && !IsControlFlow()) { | |||||
| if (IsUnknownShape() || IsKnownShape() || IsInputNode()) { | |||||
| return BuildPartitionFrame(); | return BuildPartitionFrame(); | ||||
| } else { | } else { | ||||
| auto node = nodes_.front(); | auto node = nodes_.front(); | ||||
| @@ -896,7 +914,7 @@ Status Cluster::CombinePartitionFrame() { | |||||
| } | } | ||||
| Status Cluster::BuildPartitionSubgraph() { | Status Cluster::BuildPartitionSubgraph() { | ||||
| if (IsData() || IsNetOutput() || IsIndependent() || IsControlFlow()) { | |||||
| if (IsData() || IsNetOutput() || IsIndependent()) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| int64_t parent_node_index = 0; | int64_t parent_node_index = 0; | ||||
| @@ -47,7 +47,6 @@ class DynamicShapePartitioner { | |||||
| bool IsUnknownShape() const; | bool IsUnknownShape() const; | ||||
| bool IsIndependent() const; | bool IsIndependent() const; | ||||
| bool IsNetOutput() const; | bool IsNetOutput() const; | ||||
| bool IsControlFlow() const; | |||||
| std::vector<std::shared_ptr<Cluster>> Inputs() const; | std::vector<std::shared_ptr<Cluster>> Inputs() const; | ||||
| std::vector<std::shared_ptr<Cluster>> Outputs() const; | std::vector<std::shared_ptr<Cluster>> Outputs() const; | ||||
| bool IsInputNode() const; | bool IsInputNode() const; | ||||
| @@ -126,13 +125,15 @@ class DynamicShapePartitioner { | |||||
| // and there's only one path between the two clusters , merge the two clusters | // and there's only one path between the two clusters , merge the two clusters | ||||
| // 3) Iterate through the INPUT_DATA clusters, merge all INPUT_DATA | // 3) Iterate through the INPUT_DATA clusters, merge all INPUT_DATA | ||||
| Status MergeClusters(); | Status MergeClusters(); | ||||
| // Merge clusters step0 | |||||
| void MergeClustersControlFlow(); | |||||
| // Merge clusters step1 | // Merge clusters step1 | ||||
| void MergeClustersUnknownShape(); | void MergeClustersUnknownShape(); | ||||
| // Merge clusters step2 | // Merge clusters step2 | ||||
| void MergeClustersKnownShape(); | void MergeClustersKnownShape(); | ||||
| // Merge clusters step3 | // Merge clusters step3 | ||||
| void MergeClustersInputData(); | void MergeClustersInputData(); | ||||
| // Topological sort clusters after merge unknow shape clusters. | |||||
| // Topological sort clusters after merge unknown shape clusters. | |||||
| Status TopologicalSortClusters(); | Status TopologicalSortClusters(); | ||||
| // Deduplicate merged clusters | // Deduplicate merged clusters | ||||
| void PruneUniqueClusters(); | void PruneUniqueClusters(); | ||||
| @@ -140,7 +141,7 @@ class DynamicShapePartitioner { | |||||
| Status BuildPartitionFrame(); | Status BuildPartitionFrame(); | ||||
| // Establish connection between corresponding partitioned of clusters | // Establish connection between corresponding partitioned of clusters | ||||
| Status CombinePartitionFrame(); | Status CombinePartitionFrame(); | ||||
| // Convert the nodes in cluster into a complete ComputeGraoh | |||||
| // Convert the nodes in cluster into a complete ComputeGraph | |||||
| Status BuildPartitionSubgraph(); | Status BuildPartitionSubgraph(); | ||||
| // Clear resource and break circular dependency | // Clear resource and break circular dependency | ||||
| void ClearResource(); | void ClearResource(); | ||||
| @@ -155,6 +156,8 @@ class DynamicShapePartitioner { | |||||
| Status CtrlEdgeTransfer(); | Status CtrlEdgeTransfer(); | ||||
| ge::ComputeGraphPtr root_graph_; // The original graph to partition | ge::ComputeGraphPtr root_graph_; // The original graph to partition | ||||
| std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | std::unordered_map<NodePtr, std::shared_ptr<Cluster>> node_2_cluster_; // Record nodes and the cluster it belongs to | ||||
| // V1 control flow cluster, need merge to one Graph. | |||||
| std::unordered_map<int64_t, std::vector<std::shared_ptr<Cluster>>> control_clusters_; | |||||
| // topological sorted clusters, this field will change with the splitting. | // topological sorted clusters, this field will change with the splitting. | ||||
| // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters | // When partitioning UNKNOWN_SHAPE cluster, it is a collection of all topological sorted UNKNOWN_SHAPE clusters | ||||
| // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters | // When partitioning KNOWN_SHAPE cluster, it is a collection of all topological sorted KNOWN_SHAPE clusters | ||||
| @@ -36,6 +36,8 @@ struct DuringPassNodeSets { | |||||
| std::unordered_set<NodePtr> nodes_re_pass; | std::unordered_set<NodePtr> nodes_re_pass; | ||||
| std::unordered_set<NodePtr> nodes_re_pass_immediately; | std::unordered_set<NodePtr> nodes_re_pass_immediately; | ||||
| std::unordered_set<NodePtr> nodes_last; | std::unordered_set<NodePtr> nodes_last; | ||||
| std::unordered_set<NodePtr> nodes_suspend; | |||||
| std::unordered_set<NodePtr> nodes_resume; | |||||
| }; | }; | ||||
| void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, | void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &input_edge_nodes, | ||||
| @@ -55,8 +57,15 @@ void GetAllNodesNoInputEdge(const ComputeGraphPtr &graph, std::deque<NodePtr> &i | |||||
| } | } | ||||
| } | } | ||||
| bool IsAllInNodesAlive(const Node::Vistor<NodePtr> &nodes, const std::unordered_set<NodePtr> &nodes_suspend) { | |||||
| return !std::any_of(nodes.begin(), nodes.end(), [&](const NodePtr &n) { return nodes_suspend.count(n) > 0; }); | |||||
| } | |||||
| void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, | void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &nodes_to_pass, | ||||
| std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_last) { | |||||
| DuringPassNodeSets &during_pass_node_set) { | |||||
| auto &nodes_seen = during_pass_node_set.nodes_seen; | |||||
| const auto &nodes_last = during_pass_node_set.nodes_last; | |||||
| const auto &nodes_suspend = during_pass_node_set.nodes_suspend; | |||||
| for (auto &node : nodes) { | for (auto &node : nodes) { | ||||
| if (node == nullptr) { | if (node == nullptr) { | ||||
| continue; | continue; | ||||
| @@ -64,16 +73,57 @@ void AddNextIterNodes(const Node::Vistor<NodePtr> &nodes, std::deque<NodePtr> &n | |||||
| if (nodes_last.count(node) != 0) { | if (nodes_last.count(node) != 0) { | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (nodes_suspend.count(node) > 0) { | |||||
| GELOGD("The node %s has suspend by pass, skip it.", node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| bool all_in_nodes_alive = IsAllInNodesAlive(node->GetInAllNodes(), nodes_suspend); | |||||
| bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); | bool all_in_nodes_seen = node->IsAllInNodesSeen(nodes_seen); | ||||
| if (all_in_nodes_seen && nodes_seen.insert(node.get()).second) { | |||||
| if (all_in_nodes_seen && all_in_nodes_alive && nodes_seen.insert(node.get()).second) { | |||||
| nodes_to_pass.push_back(node); | nodes_to_pass.push_back(node); | ||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| void AddRepassNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) { | |||||
| for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { | |||||
| GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); | |||||
| nodes.push_front(node); | |||||
| } | |||||
| during_pass_node_set.nodes_re_pass_immediately.clear(); | |||||
| } | |||||
| void AddResumeNodes(DuringPassNodeSets &during_pass_node_set, std::deque<NodePtr> &nodes) { | |||||
| for (auto &node : during_pass_node_set.nodes_resume) { | |||||
| const auto &it = during_pass_node_set.nodes_suspend.find(node); | |||||
| if (it != during_pass_node_set.nodes_suspend.end()) { | |||||
| during_pass_node_set.nodes_suspend.erase(node); | |||||
| GELOGD("The node %s resumed by pass.", node->GetName().c_str()); | |||||
| nodes.push_back(node); | |||||
| } else { | |||||
| GELOGW("The node %s not suspend, drop from resumed", node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| during_pass_node_set.nodes_resume.clear(); | |||||
| } | |||||
| void PushToSuspendNodes(DuringPassNodeSets &during_pass_node_set, const std::string &pass_name, | |||||
| const std::unordered_set<NodePtr> &nodes_suspend, | |||||
| const std::unordered_set<NodePtr> &nodes_resume) { | |||||
| for (const auto &node : nodes_suspend) { | |||||
| GELOGD("The iteration suspend of node %s has been set by pass %s", node->GetName().c_str(), pass_name.c_str()); | |||||
| during_pass_node_set.nodes_suspend.emplace(node); | |||||
| } | |||||
| for (const auto &node : nodes_resume) { | |||||
| GELOGD("The iteration suspend of node %s has been resumed by pass %s", node->GetName().c_str(), pass_name.c_str()); | |||||
| during_pass_node_set.nodes_resume.emplace(node); | |||||
| } | |||||
| } | |||||
| void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, | void PushToRePassIfSeen(NodePtr &node, const std::pair<std::string, BaseNodePass *> &name_to_pass, | ||||
| std::unordered_set<Node *> &nodes_seen, std::unordered_set<NodePtr> &nodes_to_re_pass, | |||||
| std::unordered_set<Node *> &nodes_seen, const std::unordered_set<NodePtr> &nodes_to_re_pass, | |||||
| std::unordered_set<NodePtr> &nodes_re_pass) { | std::unordered_set<NodePtr> &nodes_re_pass) { | ||||
| for (const auto &node_to_re_pass : nodes_to_re_pass) { | for (const auto &node_to_re_pass : nodes_to_re_pass) { | ||||
| if (node_to_re_pass == nullptr) { | if (node_to_re_pass == nullptr) { | ||||
| @@ -113,15 +163,18 @@ Status RunPasses(NodePtr &node, const NamesToPass &names_to_passes, DuringPassNo | |||||
| return result; | return result; | ||||
| } | } | ||||
| auto nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); | |||||
| const auto &nodes_to_re_pass = name_to_pass.second->GetNodesNeedRePass(); | |||||
| PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, | PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass, | ||||
| during_pass_node_set.nodes_re_pass); | during_pass_node_set.nodes_re_pass); | ||||
| auto nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); | |||||
| const auto &nodes_to_re_pass_immediately = name_to_pass.second->GetNodesNeedRePassImmediately(); | |||||
| PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, | PushToRePassIfSeen(node, name_to_pass, during_pass_node_set.nodes_seen, nodes_to_re_pass_immediately, | ||||
| during_pass_node_set.nodes_re_pass_immediately); | during_pass_node_set.nodes_re_pass_immediately); | ||||
| auto nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); | |||||
| PushToSuspendNodes(during_pass_node_set, name_to_pass.first, | |||||
| name_to_pass.second->GetNodesSuspend(), name_to_pass.second->GetNodesResume()); | |||||
| const auto &nodes_deleted_by_pass = name_to_pass.second->GetNodesDeleted(); | |||||
| during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); | during_pass_node_set.nodes_deleted.insert(nodes_deleted_by_pass.begin(), nodes_deleted_by_pass.end()); | ||||
| if (nodes_deleted_by_pass.count(node) > 0) { | if (nodes_deleted_by_pass.count(node) > 0) { | ||||
| GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), | GELOGD("The node %s was deleted by pass %s, stop the remain passes", node->GetName().c_str(), | ||||
| @@ -221,8 +274,13 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||||
| GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); | GELOGD("The node %s was deleted before, skip it.", node->GetName().c_str()); | ||||
| continue; | continue; | ||||
| } | } | ||||
| if (during_pass_node_set.nodes_suspend.count(node) > 0) { | |||||
| GELOGD("The node %s has been added to suspend-iteration nodes list, the iteration of it will be suspend.", | |||||
| node->GetName().c_str()); | |||||
| continue; | |||||
| } | |||||
| AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set.nodes_seen, during_pass_node_set.nodes_last); | |||||
| AddNextIterNodes(node->GetOutNodes(), nodes, during_pass_node_set); | |||||
| auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | auto ret = RunPasses(node, names_to_passes, during_pass_node_set); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| @@ -253,11 +311,9 @@ Status GEPass::RunPassesOneGraph(const NamesToPass &names_to_passes) { | |||||
| // should be called each time at the begin of the iteration | // should be called each time at the begin of the iteration | ||||
| ClearOption(names_to_passes); | ClearOption(names_to_passes); | ||||
| } | } | ||||
| for (const auto &node : during_pass_node_set.nodes_re_pass_immediately) { | |||||
| GELOGD("The node %s will be re-pass immediately.", node->GetName().c_str()); | |||||
| nodes.push_front(node); | |||||
| } | |||||
| during_pass_node_set.nodes_re_pass_immediately.clear(); | |||||
| AddRepassNodes(during_pass_node_set, nodes); | |||||
| AddResumeNodes(during_pass_node_set, nodes); | |||||
| } | } | ||||
| for (auto &node : during_pass_node_set.nodes_last) { | for (auto &node : during_pass_node_set.nodes_last) { | ||||
| @@ -51,11 +51,15 @@ class BaseNodePass { | |||||
| virtual ~BaseNodePass() = default; | virtual ~BaseNodePass() = default; | ||||
| std::unordered_set<NodePtr> GetNodesNeedRePass() { return nodes_need_re_pass_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesNeedRePass() { return nodes_need_re_pass_; } | |||||
| std::unordered_set<NodePtr> GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesNeedRePassImmediately() { return nodes_need_re_pass_immediately_; } | |||||
| std::unordered_set<NodePtr> GetNodesDeleted() { return nodes_deleted_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesDeleted() { return nodes_deleted_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesSuspend() { return nodes_suspend_; } | |||||
| const std::unordered_set<NodePtr> &GetNodesResume() { return nodes_resume_; } | |||||
| void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } | void SetOption(NodePassOption option, const std::string &value) { options_[option] = value; } | ||||
| @@ -65,6 +69,8 @@ class BaseNodePass { | |||||
| nodes_need_re_pass_.clear(); | nodes_need_re_pass_.clear(); | ||||
| nodes_deleted_.clear(); | nodes_deleted_.clear(); | ||||
| nodes_need_re_pass_immediately_.clear(); | nodes_need_re_pass_immediately_.clear(); | ||||
| nodes_suspend_.clear(); | |||||
| nodes_resume_.clear(); | |||||
| } | } | ||||
| protected: | protected: | ||||
| @@ -80,7 +86,7 @@ class BaseNodePass { | |||||
| /// optimized by other passes, call this function. | /// optimized by other passes, call this function. | ||||
| /// @param node | /// @param node | ||||
| /// | /// | ||||
| void AddRePassNode(NodePtr &node) { nodes_need_re_pass_.insert(node); } | |||||
| void AddRePassNode(const NodePtr &node) { nodes_need_re_pass_.insert(node); } | |||||
| /// | /// | ||||
| /// Add a node to be optimized immediately again. If you add a new node to the graph, or | /// Add a node to be optimized immediately again. If you add a new node to the graph, or | ||||
| @@ -88,13 +94,13 @@ class BaseNodePass { | |||||
| /// optimized by other passes, call this function. | /// optimized by other passes, call this function. | ||||
| /// @param node | /// @param node | ||||
| /// | /// | ||||
| void AddImmediateRePassNode(NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } | |||||
| void AddImmediateRePassNode(const NodePtr &node) { nodes_need_re_pass_immediately_.insert(node); } | |||||
| /// | /// | ||||
| /// Add a node and it's input/output data nodes to be optimized again. | /// Add a node and it's input/output data nodes to be optimized again. | ||||
| /// @param node | /// @param node | ||||
| /// | /// | ||||
| void AddRePassNodesWithInOut(NodePtr &node) { | |||||
| void AddRePassNodesWithInOut(const NodePtr &node) { | |||||
| AddRePassNode(node); | AddRePassNode(node); | ||||
| auto out_nodes = node->GetOutNodes(); | auto out_nodes = node->GetOutNodes(); | ||||
| for (auto &out_node : out_nodes) { | for (auto &out_node : out_nodes) { | ||||
| @@ -116,12 +122,34 @@ class BaseNodePass { | |||||
| /// | /// | ||||
| void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } | void AddNodeDeleted(const NodePtr &node) { nodes_deleted_.insert(node); } | ||||
| /// | |||||
| /// If you suspend a node from the graph, especially following node. The remain | |||||
| /// iterate passes will stop process on the suspend node(if it can be | |||||
| /// reached by edge connections) till the last one. Obviously it is a waste of | |||||
| /// time. You can add the suspend nodes by calling this function, to stop the | |||||
| /// next iterations. | |||||
| /// @param node | |||||
| /// | |||||
| void AddNodeSuspend(const NodePtr &node) { nodes_suspend_.insert(node); } | |||||
| /// | |||||
| /// If you resume a node from the graph, especially following node. The remain | |||||
| /// iterate passes will continue process on the resume node(if it can be | |||||
| /// reached by edge connections) till the last one. | |||||
| /// You can add the resume nodes by calling this function, to resume the | |||||
| /// next iterations. | |||||
| /// @param node | |||||
| /// | |||||
| void AddNodeResume(const NodePtr &node) { nodes_resume_.insert(node); } | |||||
| bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } | bool OptionExists(NodePassOption option) { return options_.count(option) > 0; } | ||||
| private: | private: | ||||
| std::unordered_set<NodePtr> nodes_need_re_pass_; | std::unordered_set<NodePtr> nodes_need_re_pass_; | ||||
| std::unordered_set<NodePtr> nodes_need_re_pass_immediately_; | std::unordered_set<NodePtr> nodes_need_re_pass_immediately_; | ||||
| std::unordered_set<NodePtr> nodes_deleted_; | std::unordered_set<NodePtr> nodes_deleted_; | ||||
| std::unordered_set<NodePtr> nodes_suspend_; | |||||
| std::unordered_set<NodePtr> nodes_resume_; | |||||
| std::map<NodePassOption, std::string> options_; | std::map<NodePassOption, std::string> options_; | ||||
| }; | }; | ||||
| @@ -21,6 +21,8 @@ | |||||
| #include "framework/common/util.h" | #include "framework/common/util.h" | ||||
| #include "graph/shape_refiner.h" | #include "graph/shape_refiner.h" | ||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/common/omg_util.h" | |||||
| #include "graph/debug/ge_attr_define.h" | #include "graph/debug/ge_attr_define.h" | ||||
| #include "utils/tensor_utils.h" | #include "utils/tensor_utils.h" | ||||
| #include "utils/type_utils.h" | #include "utils/type_utils.h" | ||||
| @@ -117,7 +119,9 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||||
| const auto RePassNode = [&](const std::set<std::string> &re_pass_types) { | const auto RePassNode = [&](const std::set<std::string> &re_pass_types) { | ||||
| for (auto &n : node->GetOutDataNodes()) { | for (auto &n : node->GetOutDataNodes()) { | ||||
| GE_CHECK_NOTNULL(n); | GE_CHECK_NOTNULL(n); | ||||
| if (re_pass_types.count(n->GetType()) > 0) { | |||||
| std::string node_type; | |||||
| GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "Get original node type failed."); | |||||
| if (re_pass_types.count(node_type) > 0) { | |||||
| AddImmediateRePassNode(n); | AddImmediateRePassNode(n); | ||||
| (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false); | (void)AttrUtils::SetBool(n->GetOpDesc(), ATTR_NAME_NEED_INFER_AGAIN, false); | ||||
| GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str()); | GELOGD("Node %s need repass immediately after %s.", n->GetName().c_str(), node->GetName().c_str()); | ||||
| @@ -126,17 +130,44 @@ Status InferShapePass::RePassLoopNode(const NodePtr &node) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| }; | }; | ||||
| if (node->GetType() == NEXTITERATION || node->GetType() == REFNEXTITERATION) { | |||||
| return RePassNode({MERGE, REFMERGE}); // Re-Pass Merge | |||||
| const auto ExProcNode = [&](const std::set<std::string> &proc_types, | |||||
| const std::function<void(InferShapePass *, NodePtr)> &proc_func, | |||||
| const std::string &info) { | |||||
| for (auto &n : node->GetOutDataNodes()) { | |||||
| GE_CHECK_NOTNULL(n); | |||||
| std::string node_type; | |||||
| GE_CHK_STATUS_RET(GetOriginalType(n, node_type), "Get original node type failed."); | |||||
| if (proc_types.count(node_type) > 0) { | |||||
| proc_func(this, n); | |||||
| GELOGD("Node %s %s after %s.", n->GetName().c_str(), info.c_str(), node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| }; | |||||
| std::string node_type; | |||||
| GE_CHK_STATUS_RET(GetOriginalType(node, node_type), "Get original node type failed."); | |||||
| if (kNextIterationOpTypes.count(node_type) > 0) { | |||||
| return RePassNode(kMergeOpTypes); // Re-Pass Merge | |||||
| } | } | ||||
| if (node->GetType() == MERGE || node->GetType() == REFMERGE) { | |||||
| if (kMergeOpTypes.count(node_type) > 0) { | |||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | ||||
| node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | ||||
| return RePassNode(kSwitchOpTypes); // Re-Pass Switch | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| if (kSwitchOpTypes.count(node_type) > 0) { | |||||
| if (node->GetOpDesc()->HasAttr(ATTR_NAME_NEED_INFER_AGAIN)) { | |||||
| node->GetOpDesc()->DelAttr(ATTR_NAME_NEED_INFER_AGAIN); | |||||
| return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeResume, "need resume"); // Resume Exit | |||||
| } else { | |||||
| return ExProcNode(kExitOpTypes, &InferShapePass::AddNodeSuspend, "need suspend"); // Suspend Exit | |||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -1,37 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/passes/isolated_op_remove_pass.h" | |||||
| #include "common/debug/log.h" | |||||
| #include "common/types.h" | |||||
| #include "common/util.h" | |||||
| namespace ge { | |||||
| Status IsolatedOpRemovePass::Run(ge::ComputeGraphPtr graph) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| for (NodePtr &node_ptr : graph->GetDirectNode()) { | |||||
| GE_IF_BOOL_EXEC(node_ptr->GetOpDesc() == nullptr, continue); | |||||
| if (node_ptr->GetInDataNodes().size() == 0 && node_ptr->GetOutAllNodes().size() == 0 && | |||||
| !(node_ptr->GetOpDesc()->HasAttr(TO_BE_OUTPUT))) { | |||||
| GE_RETURN_WITH_LOG_IF_ERROR(graph->RemoveNode(node_ptr), "remove graph node [%s] fail", | |||||
| node_ptr->GetOpDesc()->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,28 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ | |||||
| #define GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ | |||||
| #include "inc/graph_pass.h" | |||||
| namespace ge { | |||||
| class IsolatedOpRemovePass : public GraphPass { | |||||
| public: | |||||
| Status Run(ge::ComputeGraphPtr graph); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_PASSES_ISOLATED_OP_REMOVE_PASS_H_ | |||||
| @@ -18,20 +18,25 @@ | |||||
| #include <queue> | #include <queue> | ||||
| #include "graph/utils/node_utils.h" | |||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | namespace { | ||||
| const std::set<std::string> kMergeOpTypes{ MERGE, REFMERGE }; | |||||
| inline bool IsMergeInLoop(const NodePtr &node) { | |||||
| const static std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | |||||
| const std::set<std::string> kSwitchOpTypes{ SWITCH, REFSWITCH }; | |||||
| std::string node_type; | |||||
| (void)GetOriginalType(node, node_type); | |||||
| return kLoopMergeInputs.count(node_type) > 0; | |||||
| } | |||||
| const std::set<std::string> kLoopMergeInputs{ ENTER, REFENTER, NEXTITERATION, REFNEXTITERATION }; | |||||
| inline bool IsSwitchInLoop(const NodePtr &node) { | |||||
| const static std::set<std::string> kLoopSwitchInputs{ MERGE, REFMERGE, LOOPCOND }; | |||||
| inline bool IsMergeInLoop(const NodePtr &node) { | |||||
| std::string node_type; | std::string node_type; | ||||
| (void)GetOriginalType(node, node_type); | (void)GetOriginalType(node, node_type); | ||||
| return kLoopMergeInputs.count(node_type) > 0; | |||||
| return kLoopSwitchInputs.count(node_type) > 0; | |||||
| } | } | ||||
| } | } | ||||
| @@ -103,7 +108,13 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||||
| if (dst_span > 0) { | if (dst_span > 0) { | ||||
| search_queue.push({in_node, dst_span - 1}); | search_queue.push({in_node, dst_span - 1}); | ||||
| } else { | } else { | ||||
| switch_group.emplace_back(in_node); | |||||
| const auto &all_in_nodes = in_node->GetInDataNodes(); | |||||
| if (std::any_of(all_in_nodes.begin(), all_in_nodes.end(), IsSwitchInLoop)) { | |||||
| GELOGW("Travel node: %s, %s node: %s, Skip LoopCond switch", dst_node->GetName().c_str(), node_type.c_str(), | |||||
| in_node->GetName().c_str()); | |||||
| } else { | |||||
| switch_group.emplace_back(in_node); | |||||
| } | |||||
| } | } | ||||
| } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | } else if (kMergeOpTypes.count(node_type) > 0) { // Merge input node. | ||||
| search_queue.push({in_node, dst_span + 1}); | search_queue.push({in_node, dst_span + 1}); | ||||
| @@ -121,19 +132,37 @@ void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const NodePtr &node, std: | |||||
| /// | /// | ||||
| void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { | void MarkForceUnknownForCondPass::MarkUnknownForSwitch(const std::map<NodePtr, std::vector<NodePtr>> &switch_groups) { | ||||
| std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | ||||
| return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||||
| return n->GetOpDesc()->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP); | |||||
| }; | }; | ||||
| for (const auto &group : switch_groups) { | |||||
| const auto &node = group.first; | |||||
| const auto &switch_group = group.second; | |||||
| const auto &op_desc = node->GetOpDesc(); | |||||
| if (IsUnknownShapeTensor(op_desc->GetOutputDesc(0)) || op_desc->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE) || | |||||
| std::any_of(switch_group.begin(), switch_group.end(), callback)) { | |||||
| GELOGI("Mark [%s] as force unknown shape", node->GetName().c_str()); | |||||
| MarkForceUnknownShape(node, true); | |||||
| for (const auto &n : switch_group) { | |||||
| MarkForceUnknownShape(n, true); | |||||
| for (auto it1 = switch_groups.begin(); it1 != switch_groups.end(); ++it1) { | |||||
| const auto &op_node1 = it1->first; | |||||
| const auto &op_desc1 = op_node1->GetOpDesc(); | |||||
| if (op_desc1->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||||
| continue; | |||||
| } | |||||
| if (IsUnknownShapeTensor(op_desc1->GetOutputDesc(0))) { | |||||
| int64_t group_index = op_desc1->GetId(); | |||||
| GELOGI("Mark %s as unknown shape control flow, group index: %ld", op_desc1->GetName().c_str(), group_index); | |||||
| MarkForceUnknownShape(op_node1, true, group_index); | |||||
| for (const auto &n : it1->second) { | |||||
| MarkForceUnknownShape(n, true, group_index); | |||||
| } | |||||
| for (auto it2 = switch_groups.begin(); it2 != switch_groups.end(); ++it2) { | |||||
| const auto &op_node2 = it2->first; | |||||
| const auto &op_desc2 = op_node2->GetOpDesc(); | |||||
| if (op_desc2->HasAttr(ATTR_NAME_CONTROL_FLOW_GROUP)) { | |||||
| continue; | |||||
| } | |||||
| if (std::any_of(it2->second.begin(), it2->second.end(), callback)) { | |||||
| MarkForceUnknownShape(op_node2, true, group_index); | |||||
| for (const auto &n : it2->second) { | |||||
| MarkForceUnknownShape(n, true, group_index); | |||||
| } | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -25,15 +25,15 @@ | |||||
| namespace ge { | namespace ge { | ||||
| Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { | Status MemcpyAddrAsyncPass::Run(ComputeGraphPtr graph) { | ||||
| GE_CHECK_NOTNULL(graph); | GE_CHECK_NOTNULL(graph); | ||||
| for (const auto &node : graph->GetAllNodes()) { | |||||
| if (node->GetType() == STREAMSWITCH) { | |||||
| auto sub_graph = node->GetOwnerComputeGraph(); | |||||
| if (sub_graph != nullptr && !sub_graph->GetGraphUnknownFlag()) { | |||||
| GE_CHK_STATUS_RET(AddMemcpyAsyncNode(node), "Add memcpyasync node failed in known subgraph."); | |||||
| if (graph->GetGraphUnknownFlag()) { | |||||
| for (const auto &node : graph->GetAllNodes()) { | |||||
| if (node->GetType() == STREAMSWITCH) { | |||||
| auto sub_graph = node->GetOwnerComputeGraph(); | |||||
| if (sub_graph != nullptr && !sub_graph->GetGraphUnknownFlag()) { | |||||
| GE_CHK_STATUS_RET(AddMemcpyAsyncNode(node), "Add memcpyasync node failed in known subgraph."); | |||||
| } | |||||
| } | } | ||||
| } | } | ||||
| } | |||||
| if (graph->GetGraphUnknownFlag()) { | |||||
| GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str()); | GELOGD("Graph[%s] is unknown graph, skip.", graph->GetName().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -84,8 +84,9 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, | GE_CHK_BOOL_EXEC(node != nullptr, | ||||
| REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | ||||
| return FAILED, "Param of pre node is null."); | return FAILED, "Param of pre node is null."); | ||||
| bool force_unknown = node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||||
| MarkForceUnknownShape(node, force_unknown); | |||||
| int64_t group_index = -1; | |||||
| bool force_unknown = AttrUtils::GetInt(node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
| MarkForceUnknownShape(node, force_unknown, group_index); | |||||
| for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | for (const InDataAnchorPtr &in_data_anchor : node->GetAllInDataAnchors()) { | ||||
| OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | OutDataAnchorPtr peer_out_anchor = in_data_anchor->GetPeerOutAnchor(); | ||||
| GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(peer_out_anchor == nullptr, continue); | ||||
| @@ -102,7 +103,7 @@ Status MergeToStreamMergePass::AddActiveNodes(const ComputeGraphPtr &graph, cons | |||||
| GELOGE(FAILED, "SetActiveLabelList for node %s failed.", active_node->GetName().c_str()); | GELOGE(FAILED, "SetActiveLabelList for node %s failed.", active_node->GetName().c_str()); | ||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| MarkForceUnknownShape(active_node, force_unknown); | |||||
| MarkForceUnknownShape(active_node, force_unknown, group_index); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -18,6 +18,7 @@ | |||||
| #include "common/ge/ge_util.h" | #include "common/ge/ge_util.h" | ||||
| #include "graph/common/omg_util.h" | #include "graph/common/omg_util.h" | ||||
| #include "graph/utils/node_utils.h" | |||||
| using std::string; | using std::string; | ||||
| @@ -203,6 +204,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
| for (const auto &loop_cond_iter : loop_group_map_) { | for (const auto &loop_cond_iter : loop_group_map_) { | ||||
| const LoopCondGroup &loop_group = *loop_cond_iter.second; | const LoopCondGroup &loop_group = *loop_cond_iter.second; | ||||
| const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); | const std::string &cond_name = loop_cond_iter.second->loop_cond->GetName(); | ||||
| const int64_t group_index = loop_group.loop_cond->GetOpDesc()->GetId(); | |||||
| GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); | GELOGI("Handle while group, LoopCond node: %s.", cond_name.c_str()); | ||||
| // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge | // Create Active node, Enter->Active->Merge, NextIteration->Active->Merge | ||||
| @@ -223,7 +225,7 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
| enter_active->GetName().c_str()); | enter_active->GetName().c_str()); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape); | |||||
| MarkForceUnknownShape(enter_node, loop_group.is_unknown_shape, group_index); | |||||
| } | } | ||||
| for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { | for (const auto &pair : loop_cond_iter.second->merge_next_pairs) { | ||||
| @@ -253,8 +255,8 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| MarkForceUnknownShape(next_node, loop_group.is_unknown_shape); | |||||
| MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape); | |||||
| MarkForceUnknownShape(next_node, loop_group.is_unknown_shape, group_index); | |||||
| MarkForceUnknownShape(merge_node, loop_group.is_unknown_shape, group_index); | |||||
| } | } | ||||
| if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | if ((SetActiveLabelList(enter_active, {cond_name}) != SUCCESS) || | ||||
| @@ -263,10 +265,10 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| } | } | ||||
| MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape); | |||||
| MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape); | |||||
| MarkForceUnknownShape(next_active, loop_group.is_unknown_shape); | |||||
| HandleSwitchExitNodes(loop_group); | |||||
| MarkForceUnknownShape(loop_group.loop_cond, loop_group.is_unknown_shape, group_index); | |||||
| MarkForceUnknownShape(enter_active, loop_group.is_unknown_shape, group_index); | |||||
| MarkForceUnknownShape(next_active, loop_group.is_unknown_shape, group_index); | |||||
| HandleSwitchExitNodes(loop_group, group_index); | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -275,20 +277,21 @@ Status NextIterationPass::HandleWhileGroup(ComputeGraphPtr &graph) { | |||||
| /// | /// | ||||
| /// @brief Mark force unknown for Exit node | /// @brief Mark force unknown for Exit node | ||||
| /// @param [in] group of LoopCond | /// @param [in] group of LoopCond | ||||
| /// @param [in] index of LoopCond Node | |||||
| /// @return void | /// @return void | ||||
| /// | /// | ||||
| void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group) { | |||||
| void NextIterationPass::HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index) { | |||||
| if (!loop_group.is_unknown_shape) { | if (!loop_group.is_unknown_shape) { | ||||
| return; | return; | ||||
| } | } | ||||
| for (const auto &switch_node : loop_group.switch_nodes) { | for (const auto &switch_node : loop_group.switch_nodes) { | ||||
| MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape); | |||||
| MarkForceUnknownShape(switch_node, loop_group.is_unknown_shape, group_index); | |||||
| for (const auto &node : switch_node->GetOutDataNodes()) { | for (const auto &node : switch_node->GetOutDataNodes()) { | ||||
| std::string node_type; | std::string node_type; | ||||
| (void)GetOriginalType(node, node_type); | (void)GetOriginalType(node, node_type); | ||||
| if (node_type == EXIT || node_type == REFEXIT) { | |||||
| MarkForceUnknownShape(node, loop_group.is_unknown_shape); | |||||
| if (kExitOpTypes.count(node_type) > 0) { | |||||
| MarkForceUnknownShape(node, loop_group.is_unknown_shape, group_index); | |||||
| } | } | ||||
| } | } | ||||
| } | } | ||||
| @@ -96,9 +96,10 @@ class NextIterationPass : public GraphPass { | |||||
| /// | /// | ||||
| /// @brief Mark force unknown for Exit node | /// @brief Mark force unknown for Exit node | ||||
| /// @param [in] group of LoopCond | /// @param [in] group of LoopCond | ||||
| /// @param [in] index of LoopCond Node | |||||
| /// @return void | /// @return void | ||||
| /// | /// | ||||
| void HandleSwitchExitNodes(const LoopCondGroup &loop_group); | |||||
| void HandleSwitchExitNodes(const LoopCondGroup &loop_group, int64_t group_index); | |||||
| // map<frame_name, LoopCondGroup> | // map<frame_name, LoopCondGroup> | ||||
| std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_; | std::unordered_map<std::string, LoopCondGroupPtr> loop_group_map_; | ||||
| @@ -1,47 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "remove_nodes_pass.h" | |||||
| #include "debug/ge_log.h" | |||||
| #include "inc/framework/common/util.h" | |||||
| #include "inc/graph/utils/node_utils.h" | |||||
| namespace ge { | |||||
| Status RemoveNodesPass::Run(NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto node_type = NodeUtils::GetNodeType(*node); | |||||
| auto type_iter = remove_node_types_to_arg_.find(node_type); | |||||
| if (type_iter != remove_node_types_to_arg_.end()) { | |||||
| GELOGI("Remove node %s by type %s", node->GetName().c_str(), node_type.c_str()); | |||||
| return IsolateAndDeleteNode(node, type_iter->second); | |||||
| } | |||||
| for (const auto &attr_name_to_arg : remove_node_attr_names_to_arg_) { | |||||
| if (AttrUtils::HasAttr(node->GetOpDesc(), attr_name_to_arg.first)) { | |||||
| GELOGI("Remove node %s by attr name %s", node->GetName().c_str(), attr_name_to_arg.first.c_str()); | |||||
| return IsolateAndDeleteNode(node, attr_name_to_arg.second); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| RemoveNodesPass &RemoveNodesPass::AddNodeType(const string &node_type, std::initializer_list<int> arg) { | |||||
| remove_node_types_to_arg_[node_type] = std::move(arg); | |||||
| return *this; | |||||
| } | |||||
| RemoveNodesPass &RemoveNodesPass::AddAttrName(const string &attr_name, std::initializer_list<int> arg) { | |||||
| remove_node_attr_names_to_arg_[attr_name] = std::move(arg); | |||||
| return *this; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,32 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_REMOVE_NODES_PASS_H_ | |||||
| #define GE_REMOVE_NODES_PASS_H_ | |||||
| #include "graph/passes/base_pass.h" | |||||
| namespace ge { | |||||
| class RemoveNodesPass : public BaseNodePass { | |||||
| public: | |||||
| Status Run(NodePtr &node) override; | |||||
| RemoveNodesPass &AddNodeType(const std::string &node_type, std::initializer_list<int> arg = {0}); | |||||
| RemoveNodesPass &AddAttrName(const std::string &attr_name, std::initializer_list<int> arg = {0}); | |||||
| private: | |||||
| std::map<std::string, std::initializer_list<int>> remove_node_types_to_arg_; | |||||
| std::map<std::string, std::initializer_list<int>> remove_node_attr_names_to_arg_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif //GE_REMOVE_NODES_PASS_H_ | |||||
| @@ -464,8 +464,8 @@ Status SubgraphPass::InsertMemcpyNode(const ComputeGraphPtr &graph, const OutDat | |||||
| GE_CHECK_NOTNULL(out_anchor); | GE_CHECK_NOTNULL(out_anchor); | ||||
| NodePtr in_node = out_anchor->GetOwnerNode(); | NodePtr in_node = out_anchor->GetOwnerNode(); | ||||
| OpDescBuilder op_desc_builder(name, IDENTITY); | OpDescBuilder op_desc_builder(name, IDENTITY); | ||||
| OpDescPtr op_desc = op_desc_builder.AddInput("x", in_node->GetOpDesc()->GetOutputDesc(0)) | |||||
| .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(0)) | |||||
| OpDescPtr op_desc = op_desc_builder.AddInput("x", in_node->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx())) | |||||
| .AddOutput("y", in_node->GetOpDesc()->GetOutputDesc(out_anchor->GetIdx())) | |||||
| .Build(); | .Build(); | ||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | (void)AttrUtils::SetBool(op_desc, ATTR_NO_NEED_CONSTANT_FOLDING, false); | ||||
| (void)AttrUtils::SetBool(op_desc, ATTR_NAME_CANNOT_BE_DELETED, true); | (void)AttrUtils::SetBool(op_desc, ATTR_NAME_CANNOT_BE_DELETED, true); | ||||
| @@ -369,7 +369,9 @@ NodePtr SwitchToStreamSwitchPass::CreateStreamSwitchNode(const ComputeGraphPtr & | |||||
| GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::AddEdge(peer_cond_anchor, stream_switch->GetInDataAnchor(0)), | ||||
| "StreamSwitch node add cond edge failed."); | "StreamSwitch node add cond edge failed."); | ||||
| MarkForceUnknownShape(stream_switch, switch_node->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE)); | |||||
| int64_t group_index = -1; | |||||
| bool force_unknown = AttrUtils::GetInt(switch_node->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
| MarkForceUnknownShape(stream_switch, force_unknown, group_index); | |||||
| return stream_switch; | return stream_switch; | ||||
| } | } | ||||
| @@ -488,11 +490,12 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||||
| return FAILED; | return FAILED; | ||||
| } | } | ||||
| std::function<bool(const NodePtr &)> callback = [](const NodePtr &n) { | |||||
| return n->GetOpDesc()->HasAttr(ATTR_NAME_FORCE_UNKNOWN_SHAPE); | |||||
| int64_t group_index = -1; | |||||
| std::function<bool(const NodePtr &)> callback = [&group_index](const NodePtr &n) { | |||||
| return AttrUtils::GetInt(n->GetOpDesc(), ATTR_NAME_CONTROL_FLOW_GROUP, group_index); | |||||
| }; | }; | ||||
| bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | bool is_unknown_shape = std::any_of(same_cond_switch.begin(), same_cond_switch.end(), callback); | ||||
| MarkForceUnknownShape(active_node, is_unknown_shape); | |||||
| MarkForceUnknownShape(active_node, is_unknown_shape, group_index); | |||||
| const std::string &cond_group = cond_node->GetName(); | const std::string &cond_group = cond_node->GetName(); | ||||
| for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | for (uint32_t i = 0; i < SWITCH_OUTPUT_NUM; ++i) { | ||||
| @@ -522,7 +525,7 @@ Status SwitchToStreamSwitchPass::CombineSwitchNode(const ComputeGraphPtr &graph) | |||||
| GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::AddEdge(cast_node->GetOutDataAnchor(0), stream_switch->GetInDataAnchor(0)), | ||||
| "Cast add data edge failed."); | "Cast add data edge failed."); | ||||
| MarkForceUnknownShape(stream_switch, is_unknown_shape); | |||||
| MarkForceUnknownShape(stream_switch, is_unknown_shape, group_index); | |||||
| for (const NodePtr &node : switch_list) { | for (const NodePtr &node : switch_list) { | ||||
| GE_IF_BOOL_EXEC(node != stream_switch, { | GE_IF_BOOL_EXEC(node != stream_switch, { | ||||
| GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | GE_CHK_STATUS(GraphUtils::RemoveEdge(peer_cond_anchor, node->GetInDataAnchor(0)), | ||||
| @@ -1,134 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/passes/unused_op_remove_pass.h" | |||||
| #include <queue> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "common/debug/log.h" | |||||
| #include "common/op/ge_op_utils.h" | |||||
| #include "common/types.h" | |||||
| #include "common/util.h" | |||||
| #include "graph/utils/attr_utils.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "inc/pass_manager.h" | |||||
| #include "graph/passes/isolated_op_remove_pass.h" | |||||
| using domi::SUCCESS; | |||||
| namespace ge { | |||||
| const std::set<std::string> kRemoveOpSet = {DROPOUT, PERMUTE, UNUSEDCONST, ASSERT}; | |||||
| const std::set<std::string> kOtherRemoveOpSet = {DROPOUT}; | |||||
| Status UnusedOpRemovePass::Run(ComputeGraphPtr graph) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| std::set<std::string> remove_op_set; | |||||
| vector<NodePtr> nodes_to_be_deleted; | |||||
| if (fmktype_ == TENSORFLOW) { | |||||
| remove_op_set = kRemoveOpSet; | |||||
| } else { | |||||
| remove_op_set = kOtherRemoveOpSet; | |||||
| } | |||||
| for (auto &node : graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node->GetOpDesc()); | |||||
| std::string op_type_str = node->GetOpDesc()->GetType(); | |||||
| if (remove_op_set.count(op_type_str)) { | |||||
| if (IsExceptions(node)) { | |||||
| continue; | |||||
| } | |||||
| for (auto &out_anchor : node->GetAllOutDataAnchors()) { | |||||
| for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
| NodePtr dst_node = in_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(dst_node->GetOpDesc()); | |||||
| int dst_index = in_anchor->GetIdx(); | |||||
| std::vector<bool> list_bool; | |||||
| GE_CHECK_NOTNULL(dst_node->GetOpDesc()); | |||||
| list_bool = dst_node->GetOpDesc()->GetIsInputConst(); | |||||
| GE_IF_BOOL_EXEC(list_bool.size() == 0, continue); | |||||
| list_bool.erase(list_bool.begin() + dst_index); | |||||
| dst_node->GetOpDesc()->SetIsInputConst(list_bool); | |||||
| } | |||||
| } | |||||
| if (op_type_str == ASSERT) { | |||||
| GE_CHK_STATUS_RET(CollectParentNode(graph, node, nodes_to_be_deleted), "remove node failed"); | |||||
| } else { | |||||
| GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node failed"); | |||||
| } | |||||
| } | |||||
| } | |||||
| for (auto &node : nodes_to_be_deleted) { | |||||
| for (InDataAnchorPtr &inAnchor : node->GetAllInDataAnchors()) { | |||||
| inAnchor->UnlinkAll(); | |||||
| } | |||||
| for (OutDataAnchorPtr &outAnchorPtr : node->GetAllOutDataAnchors()) { | |||||
| outAnchorPtr->UnlinkAll(); | |||||
| } | |||||
| if (node->GetOutControlAnchor() != nullptr) { | |||||
| node->GetOutControlAnchor()->UnlinkAll(); | |||||
| } | |||||
| GE_CHK_STATUS_RET(graph->RemoveNode(node), "remove node:%s failed", node->GetName().c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status UnusedOpRemovePass::CollectParentNode(const ComputeGraphPtr &graph, const NodePtr &node, | |||||
| vector<NodePtr> &node_vec) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| GE_CHECK_NOTNULL(node); | |||||
| node_vec.push_back(node); | |||||
| std::queue<NodePtr> node_queue; | |||||
| for (auto &src_node : node->GetInDataNodes()) { | |||||
| if (src_node->GetOutDataNodesSize() == 1) { | |||||
| node_queue.push(src_node); | |||||
| } | |||||
| } | |||||
| while (!node_queue.empty()) { | |||||
| NodePtr temp = node_queue.front(); | |||||
| node_queue.pop(); | |||||
| for (auto &src_node : temp->GetInDataNodes()) { | |||||
| if (src_node->GetOutDataNodesSize() == 1) { | |||||
| node_queue.push(src_node); | |||||
| } | |||||
| } | |||||
| node_vec.push_back(temp); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| bool UnusedOpRemovePass::IsExceptions(const NodePtr &node) { | |||||
| GE_CHK_BOOL_EXEC(node != nullptr, return false, "node is nullptr"); | |||||
| auto op_def = node->GetOpDesc(); | |||||
| GE_CHK_BOOL_EXEC(op_def != nullptr, return false, "opdesc is nullptr"); | |||||
| // permute optimised in permute_pass.cpp | |||||
| if (op_def->GetType() == PERMUTE) { | |||||
| GE_IF_BOOL_EXEC( | |||||
| (node->GetInDataNodes().size() != 0 && | |||||
| (node->GetInDataNodes().at(0) != nullptr && node->GetInDataNodes().at(0)->GetOpDesc() != nullptr && | |||||
| node->GetInDataNodes().at(0)->GetOpDesc()->GetType() == ATTENTIONDECODER)), | |||||
| return false); | |||||
| return true; | |||||
| } | |||||
| return false; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,41 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ | |||||
| #define GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ | |||||
| #include <string> | |||||
| #include <vector> | |||||
| #include "framework/common/ge_types.h" | |||||
| #include "inc/graph_pass.h" | |||||
| namespace ge { | |||||
| class UnusedOpRemovePass : public GraphPass { | |||||
| public: | |||||
| explicit UnusedOpRemovePass(FrameworkType type) : fmktype_(type) {} | |||||
| ~UnusedOpRemovePass() {} | |||||
| Status Run(ge::ComputeGraphPtr graph) override; | |||||
| bool IsExceptions(const ge::NodePtr &node); | |||||
| private: | |||||
| Status CollectParentNode(const ge::ComputeGraphPtr &graph, const ge::NodePtr &node, | |||||
| std::vector<ge::NodePtr> &node_vec); | |||||
| std::vector<std::string> v_remove_ops; | |||||
| FrameworkType fmktype_; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_PASSES_UNUSED_OP_REMOVE_PASS_H_ | |||||
| @@ -1,119 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "graph/passes/variable_format_pass.h" | |||||
| #include <map> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| namespace ge { | |||||
| Status VariableFormatPass::Run(ge::ComputeGraphPtr graph) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| for (auto &node : graph->GetDirectNode()) { | |||||
| GE_IF_BOOL_EXEC(node->GetOpDesc() == nullptr, continue); | |||||
| GE_IF_BOOL_EXEC(node->GetOpDesc()->GetType() != VARIABLE, continue); | |||||
| ge::NodePtr use_node = nullptr; | |||||
| if (GetApplyMomentumOpByVariableInput(node, use_node)) { | |||||
| GE_CHK_STATUS_RET(UpdateVariableOutFormat(node, use_node), "update variable out format failed"); | |||||
| GE_CHK_STATUS_RET(UpdateApplyMomentumInputFormat(use_node), "update apply momentum input format failed"); | |||||
| } | |||||
| } | |||||
| return domi::SUCCESS; | |||||
| } | |||||
| bool VariableFormatPass::GetApplyMomentumOpByVariableInput(const ge::NodePtr &var_node, ge::NodePtr &use_node) { | |||||
| GE_IF_BOOL_EXEC(var_node == nullptr, return false); | |||||
| std::map<std::string, std::set<int>> confirm_ops = {{"ApplyMomentum", {1}}}; | |||||
| for (auto &out_anchor : var_node->GetAllOutDataAnchors()) { | |||||
| for (auto &in_anchor : out_anchor->GetPeerInDataAnchors()) { | |||||
| GE_IF_BOOL_EXEC(ConfirmUseOpAndIndexByAnchor(in_anchor, confirm_ops, use_node), return true); | |||||
| } | |||||
| } | |||||
| return false; | |||||
| } | |||||
| bool VariableFormatPass::ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, | |||||
| const map<string, std::set<int>> &confirm_ops, | |||||
| ge::NodePtr &use_node) { | |||||
| GE_IF_BOOL_EXEC(in_anchor == nullptr, return false); | |||||
| ge::NodePtr dst_node = in_anchor->GetOwnerNode(); | |||||
| ge::OpDescPtr dst_op_desc = dst_node->GetOpDesc(); | |||||
| GE_IF_BOOL_EXEC(dst_op_desc == nullptr, return false); | |||||
| const string &dst_type = dst_op_desc->GetType(); | |||||
| int input_index = in_anchor->GetIdx(); | |||||
| GELOGD("ConfirmUseOpAndIndex, var name %s, dst_type = %s, input index %d", dst_node->GetName().c_str(), | |||||
| dst_type.c_str(), input_index); | |||||
| GE_IF_BOOL_EXEC(confirm_ops.count(dst_type) > 0, | |||||
| GE_IF_BOOL_EXEC(confirm_ops.at(dst_type).count(input_index) > 0, use_node = dst_node; return true);); | |||||
| return false; | |||||
| } | |||||
| Status VariableFormatPass::UpdateVariableOutFormat(const ge::NodePtr &var_node, ge::NodePtr &use_node) { | |||||
| GE_CHECK_NOTNULL(var_node); | |||||
| GE_CHECK_NOTNULL(use_node); | |||||
| ge::OpDescPtr op_desc_ptr = use_node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc_ptr); | |||||
| GE_CHECK_NOTNULL(use_node->GetInDataAnchor(0)); | |||||
| GE_CHECK_NOTNULL(use_node->GetInDataAnchor(0)->GetPeerOutAnchor()); | |||||
| NodePtr in_node = use_node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||||
| if (in_node != nullptr) { | |||||
| string in_op_type = in_node->GetType(); | |||||
| if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr) && | |||||
| (in_node->GetOpDesc()->MutableOutputDesc(0) != nullptr)) { | |||||
| ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat(); | |||||
| ge::OpDescPtr cur_op_desc_ptr = var_node->GetOpDesc(); | |||||
| if (cur_op_desc_ptr != nullptr) { | |||||
| cur_op_desc_ptr->MutableOutputDesc(0)->SetFormat(format); | |||||
| cur_op_desc_ptr->MutableOutputDesc(0)->SetOriginFormat(format); | |||||
| } | |||||
| } | |||||
| } | |||||
| return domi::SUCCESS; | |||||
| } | |||||
| Status VariableFormatPass::UpdateApplyMomentumInputFormat(const ge::NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| ge::OpDescPtr op_desc_ptr = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc_ptr); | |||||
| GE_CHECK_NOTNULL(node->GetInDataAnchor(0)); | |||||
| GE_CHECK_NOTNULL(node->GetInDataAnchor(0)->GetPeerOutAnchor()); | |||||
| GE_CHECK_NOTNULL(op_desc_ptr->MutableInputDesc(0)); | |||||
| GE_CHECK_NOTNULL(op_desc_ptr->MutableInputDesc(1)); | |||||
| GE_CHECK_NOTNULL(op_desc_ptr->MutableOutputDesc(0)); | |||||
| NodePtr in_node = node->GetInDataAnchor(0)->GetPeerOutAnchor()->GetOwnerNode(); | |||||
| if (in_node != nullptr) { | |||||
| string in_op_type = in_node->GetType(); | |||||
| if ((in_op_type == VARIABLE) && (in_node->GetOpDesc() != nullptr)) { | |||||
| ge::Format format = in_node->GetOpDesc()->MutableOutputDesc(0)->GetFormat(); | |||||
| op_desc_ptr->MutableInputDesc(0)->SetFormat(format); | |||||
| op_desc_ptr->MutableInputDesc(0)->SetOriginFormat(format); | |||||
| op_desc_ptr->MutableInputDesc(1)->SetFormat(format); | |||||
| op_desc_ptr->MutableInputDesc(1)->SetOriginFormat(format); | |||||
| op_desc_ptr->MutableOutputDesc(0)->SetFormat(format); | |||||
| op_desc_ptr->MutableOutputDesc(0)->SetOriginFormat(format); | |||||
| } | |||||
| } | |||||
| return domi::SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -1,44 +0,0 @@ | |||||
| /** | |||||
| * Copyright 2020 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ | |||||
| #define GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ | |||||
| #include <map> | |||||
| #include <set> | |||||
| #include <string> | |||||
| #include "graph/types.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "inc/graph_pass.h" | |||||
| namespace ge { | |||||
| class VariableFormatPass : public GraphPass { | |||||
| public: | |||||
| Status Run(ge::ComputeGraphPtr graph) override; | |||||
| private: | |||||
| bool GetApplyMomentumOpByVariableInput(const ge::NodePtr &var_node, ge::NodePtr &use_node); | |||||
| bool ConfirmUseOpAndIndexByAnchor(const ge::InDataAnchorPtr &in_anchor, | |||||
| const map<string, std::set<int> > &confirm_ops, ge::NodePtr &use_node); | |||||
| Status UpdateApplyMomentumInputFormat(const ge::NodePtr &node); | |||||
| Status UpdateVariableOutFormat(const ge::NodePtr &var_node, ge::NodePtr &use_node); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // GE_GRAPH_PASSES_VARIABLE_FORMAT_PASS_H_ | |||||
| @@ -74,6 +74,7 @@ | |||||
| #include "graph/passes/unused_const_pass.h" | #include "graph/passes/unused_const_pass.h" | ||||
| #include "graph/passes/var_is_initialized_op_pass.h" | #include "graph/passes/var_is_initialized_op_pass.h" | ||||
| #include "graph/passes/variable_prepare_op_pass.h" | #include "graph/passes/variable_prepare_op_pass.h" | ||||
| #include "graph/passes/mark_force_unknown_for_cond_pass.h" | |||||
| #include "graph/preprocess/insert_op/util_insert_aipp_op.h" | #include "graph/preprocess/insert_op/util_insert_aipp_op.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
| @@ -1675,6 +1676,7 @@ Status GraphPrepare::PrepareDynShape(const GraphNodePtr &graph_node, const std:: | |||||
| PP_RUN_AND_DUMP("InsertAipp", TryDoAipp); | PP_RUN_AND_DUMP("InsertAipp", TryDoAipp); | ||||
| PP_RUN_AND_DUMP("ProcessBeforeInfershape", ProcessBeforeInfershape); | PP_RUN_AND_DUMP("ProcessBeforeInfershape", ProcessBeforeInfershape); | ||||
| PP_RUN_AND_DUMP("InferFormatAndShape", FormatAndShapeProcess); | PP_RUN_AND_DUMP("InferFormatAndShape", FormatAndShapeProcess); | ||||
| PP_RUN_AND_DUMP("CtrlFlowPreProcess", CtrlFlowPreProcess); | |||||
| PP_RUN_AND_DUMP("GetDynamicOutputShape", multibatch::GetDynamicOutputShape, compute_graph_); | PP_RUN_AND_DUMP("GetDynamicOutputShape", multibatch::GetDynamicOutputShape, compute_graph_); | ||||
| PP_RUN_AND_DUMP("ProcessAippStage2", InsertNewOpUtil::Instance().UpdateDataNodeByAipp, compute_graph_); | PP_RUN_AND_DUMP("ProcessAippStage2", InsertNewOpUtil::Instance().UpdateDataNodeByAipp, compute_graph_); | ||||
| PP_RUN("SaveOriginalGraphToOmModel", SaveOriginalGraphToOmModel); | PP_RUN("SaveOriginalGraphToOmModel", SaveOriginalGraphToOmModel); | ||||
| @@ -1683,6 +1685,17 @@ Status GraphPrepare::PrepareDynShape(const GraphNodePtr &graph_node, const std:: | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphPrepare::CtrlFlowPreProcess() { | |||||
| PassManager graph_pass; | |||||
| // After InferShape Mark v1 control flow for unknown shape. | |||||
| auto mark_force_unknown_pass = new (std::nothrow) MarkForceUnknownForCondPass; | |||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("PreRun::MarkForceUnknownForCondPass", mark_force_unknown_pass)); | |||||
| GE_CHK_STATUS_RET(graph_pass.Run(compute_graph_)); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphPrepare::RecordAIPPInfo(ge::ComputeGraphPtr &compute_graph) { | Status GraphPrepare::RecordAIPPInfo(ge::ComputeGraphPtr &compute_graph) { | ||||
| PP_RUN("RecordAIPPInfo", InsertNewOpUtil::Instance().RecordAIPPInfoToData, compute_graph_); | PP_RUN("RecordAIPPInfo", InsertNewOpUtil::Instance().RecordAIPPInfoToData, compute_graph_); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -79,6 +79,7 @@ class GraphPrepare { | |||||
| Status ProcessNetOutput(); | Status ProcessNetOutput(); | ||||
| Status ProcessBeforeInfershape(); | Status ProcessBeforeInfershape(); | ||||
| Status UpdateInputOutputByOptions(); | Status UpdateInputOutputByOptions(); | ||||
| Status CtrlFlowPreProcess(); | |||||
| bool IsTansDataOpData(const ge::NodePtr &var_node); | bool IsTansDataOpData(const ge::NodePtr &var_node); | ||||
| @@ -335,9 +335,9 @@ Status DeleteIdentityInsertByAdapter(ComputeGraphPtr &graph) { | |||||
| GE_IF_BOOL_EXEC(peer_in_anchor == nullptr, continue); | GE_IF_BOOL_EXEC(peer_in_anchor == nullptr, continue); | ||||
| auto dst_node = peer_in_anchor->GetOwnerNode(); | auto dst_node = peer_in_anchor->GetOwnerNode(); | ||||
| GE_IF_BOOL_EXEC(dst_node == nullptr, continue); | GE_IF_BOOL_EXEC(dst_node == nullptr, continue); | ||||
| if (dst_node->GetType() == IDENTITY) { | |||||
| if (dst_node->GetType() == IDENTITY && dst_node->GetAllOutDataAnchors().empty()) { | |||||
| GELOGI("Need to remove %s.", dst_node->GetName().c_str()); | GELOGI("Need to remove %s.", dst_node->GetName().c_str()); | ||||
| if (ge::GraphUtils::RemoveNodeWithoutRelink(graph, dst_node) != GRAPH_SUCCESS) { | |||||
| if (GraphUtils::RemoveNodeWithoutRelink(graph, dst_node) != GRAPH_SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) from graph:%s failed", | REPORT_CALL_ERROR("E19999", "Remove node:%s(%s) from graph:%s failed", | ||||
| dst_node->GetName().c_str(), dst_node->GetType().c_str(), graph->GetName().c_str()); | dst_node->GetName().c_str(), dst_node->GetType().c_str(), graph->GetName().c_str()); | ||||
| GELOGE(FAILED, "Remove Identity node %s failed.", dst_node->GetName().c_str()); | GELOGE(FAILED, "Remove Identity node %s failed.", dst_node->GetName().c_str()); | ||||
| @@ -17,10 +17,7 @@ | |||||
| #include "npu_memory_allocator.h" | #include "npu_memory_allocator.h" | ||||
| #include <mutex> | #include <mutex> | ||||
| #include "framework/common/debug/log.h" | #include "framework/common/debug/log.h" | ||||
| #include "graph/manager/graph_caching_allocator.h" | |||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| #include "graph/manager/rdma_pool_allocator.h" | |||||
| #include "graph/manager/host_mem_allocator.h" | |||||
| #include "graph/manager/graph_mem_manager.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| @@ -40,6 +40,12 @@ class TensorBuffer { | |||||
| TensorBuffer &operator = (const TensorBuffer &) = delete; | TensorBuffer &operator = (const TensorBuffer &) = delete; | ||||
| ~TensorBuffer(); | ~TensorBuffer(); | ||||
| void* Release() { | |||||
| auto ret = buffer_; | |||||
| buffer_ = nullptr; | |||||
| return ret; | |||||
| } | |||||
| void *GetData() { | void *GetData() { | ||||
| return buffer_; | return buffer_; | ||||
| } | } | ||||
| @@ -48,6 +54,10 @@ class TensorBuffer { | |||||
| return size_; | return size_; | ||||
| } | } | ||||
| MemStorageType GetMemType() const { | |||||
| return mem_type_; | |||||
| } | |||||
| private: | private: | ||||
| TensorBuffer(NpuMemoryAllocator *allocator, void *buffer, size_t size, MemStorageType mem_type = HBM); | TensorBuffer(NpuMemoryAllocator *allocator, void *buffer, size_t size, MemStorageType mem_type = HBM); | ||||
| @@ -69,6 +79,10 @@ class TensorValue { | |||||
| void Destroy(); | void Destroy(); | ||||
| void *Release() { | |||||
| return buffer_->Release(); | |||||
| } | |||||
| bool IsEmpty() { | bool IsEmpty() { | ||||
| return ref_buffer_ == nullptr && buffer_ == nullptr; | return ref_buffer_ == nullptr && buffer_ == nullptr; | ||||
| } | } | ||||
| @@ -80,6 +94,10 @@ class TensorValue { | |||||
| void SetName(const std::string &name) { | void SetName(const std::string &name) { | ||||
| name_ = name; | name_ = name; | ||||
| } | } | ||||
| MemStorageType GetMemType() const { | |||||
| return buffer_->GetMemType(); | |||||
| } | |||||
| void *MutableData(); | void *MutableData(); | ||||
| @@ -62,6 +62,7 @@ struct GraphExecutionContext { | |||||
| const HybridModel *model = nullptr; | const HybridModel *model = nullptr; | ||||
| const GEThreadLocalContext *ge_context = nullptr; | const GEThreadLocalContext *ge_context = nullptr; | ||||
| rtStream_t stream = nullptr; | rtStream_t stream = nullptr; | ||||
| rtStream_t hccl_stream = nullptr; | |||||
| rtContext_t rt_context = nullptr; | rtContext_t rt_context = nullptr; | ||||
| rtContext_t rt_gen_context = nullptr; | rtContext_t rt_gen_context = nullptr; | ||||
| std::unique_ptr<CallbackManager> callback_manager = nullptr; | std::unique_ptr<CallbackManager> callback_manager = nullptr; | ||||
| @@ -19,6 +19,13 @@ | |||||
| #include "graph/utils/tensor_utils.h" | #include "graph/utils/tensor_utils.h" | ||||
| #include "graph/utils/type_utils.h" | #include "graph/utils/type_utils.h" | ||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "graph/types.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "graph/manager/graph_caching_allocator.h" | |||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| #include "graph/manager/rdma_pool_allocator.h" | |||||
| #include "graph/manager/host_mem_allocator.h" | |||||
| #include "graph/manager/graph_mem_manager.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace hybrid { | namespace hybrid { | ||||
| @@ -440,22 +447,31 @@ Status HybridModelAsyncExecutor::CopyOutputs(HybridModelExecutor::ExecuteArgs &a | |||||
| GeShape ge_shape(tensor_desc->GetShape().GetDims()); | GeShape ge_shape(tensor_desc->GetShape().GetDims()); | ||||
| GeTensorDesc ge_tensor_desc; | GeTensorDesc ge_tensor_desc; | ||||
| ge_tensor_desc.SetShape(ge_shape); | ge_tensor_desc.SetShape(ge_shape); | ||||
| GeTensor ge_tensor(ge_tensor_desc); | |||||
| if (output_size > 0) { | if (output_size > 0) { | ||||
| auto aligned_ptr = MakeShared<AlignedPtr>(output_size, kAlignment); | |||||
| GE_CHECK_NOTNULL(aligned_ptr); | |||||
| auto data_buf = aligned_ptr->MutableGet(); | |||||
| GE_CHECK_NOTNULL(data_buf); | |||||
| GE_CHK_RT_RET(rtMemcpy(data_buf, output_size, output_tensor.GetData(), output_size, RT_MEMCPY_DEVICE_TO_HOST)); | |||||
| ge_tensor.SetData(aligned_ptr, output_size); | |||||
| output_data->blobs.emplace_back(data_buf, static_cast<uint32_t>(output_size), false); | |||||
| if (execute_mode != kLazyRecompile) { | |||||
| auto aligned_ptr = MakeShared<AlignedPtr>(output_size, kAlignment); | |||||
| GE_CHECK_NOTNULL(aligned_ptr); | |||||
| auto data_buf = aligned_ptr->MutableGet(); | |||||
| GE_CHECK_NOTNULL(data_buf); | |||||
| GE_CHK_RT_RET(rtMemcpy(data_buf, output_size, output_tensor.GetData(), output_size, RT_MEMCPY_DEVICE_TO_HOST)); | |||||
| GeTensor ge_tensor(ge_tensor_desc); | |||||
| ge_tensor.SetData(aligned_ptr, output_size); | |||||
| output_data->blobs.emplace_back(data_buf, static_cast<uint32_t>(output_size), false); | |||||
| auto tensor = TensorAdapter::AsTensor(ge_tensor); | |||||
| outputs.emplace_back(std::move(tensor)); | |||||
| } else { | |||||
| BuildDeviceTensor(output_tensor, ge_tensor_desc, output_size, outputs); | |||||
| output_data->blobs.emplace_back(output_tensor.Release(), static_cast<uint32_t>(output_size), false, | |||||
| static_cast<uint32_t>(kPlacementDevice)); | |||||
| } | |||||
| } else { | } else { | ||||
| GELOGW("Output[%zu] is empty. shape = [%s]", i, tensor_desc->GetShape().ToString().c_str()); | |||||
| GELOGW("Output [%zu] is empty. shape = [%s]", i, tensor_desc->GetShape().ToString().c_str()); | |||||
| GeTensor ge_tensor(ge_tensor_desc); | |||||
| ge_tensor.SetData(nullptr, 0U); | ge_tensor.SetData(nullptr, 0U); | ||||
| output_data->blobs.emplace_back(nullptr, 0U, false); | output_data->blobs.emplace_back(nullptr, 0U, false); | ||||
| auto tensor = TensorAdapter::AsTensor(ge_tensor); | |||||
| outputs.emplace_back(std::move(tensor)); | |||||
| } | } | ||||
| auto tensor = TensorAdapter::AsTensor(ge_tensor); | |||||
| outputs.emplace_back(std::move(tensor)); | |||||
| GELOGD("Output[%zu] added, type = %s, shape = [%s], size = %ld", i, | GELOGD("Output[%zu] added, type = %s, shape = [%s], size = %ld", i, | ||||
| TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()).c_str(), | TypeUtils::DataTypeToSerialString(tensor_desc->GetDataType()).c_str(), | ||||
| tensor_desc->GetShape().ToString().c_str(), output_size); | tensor_desc->GetShape().ToString().c_str(), output_size); | ||||
| @@ -464,6 +480,29 @@ Status HybridModelAsyncExecutor::CopyOutputs(HybridModelExecutor::ExecuteArgs &a | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| void HybridModelAsyncExecutor::BuildDeviceTensor(TensorValue &output_tensor, GeTensorDesc &ge_tensor_desc, | |||||
| int64_t output_size, std::vector<ge::Tensor> &outputs) { | |||||
| GELOGD("Start to build device tensor"); | |||||
| auto mem_type = output_tensor.GetMemType(); | |||||
| GELOGD("Mem type is %d", static_cast<uint32_t>(mem_type)); | |||||
| auto deleter = [=](uint8_t *device_data) { | |||||
| if (device_data != nullptr) { | |||||
| if (mem_type == RDMA_HBM) { | |||||
| MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).Free(device_data, device_id_); | |||||
| } else if (mem_type == HOST_DDR) { | |||||
| MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Free(device_data); | |||||
| } else { | |||||
| MemManager::Instance().CachingInstance(RT_MEMORY_HBM).Free(device_data, device_id_); | |||||
| } | |||||
| } | |||||
| }; | |||||
| ge_tensor_desc.SetPlacement(kPlacementDevice); | |||||
| GeTensor ge_tensor(ge_tensor_desc); | |||||
| auto tensor = TensorAdapter::AsTensor(ge_tensor); | |||||
| tensor.SetData(reinterpret_cast<uint8_t *>(output_tensor.Release()), static_cast<size_t>(output_size), deleter); | |||||
| outputs.emplace_back(std::move(tensor)); | |||||
| } | |||||
| Status HybridModelAsyncExecutor::Execute(const std::vector<DataBuffer> &inputs, | Status HybridModelAsyncExecutor::Execute(const std::vector<DataBuffer> &inputs, | ||||
| const std::vector<GeTensorDesc> &input_desc, | const std::vector<GeTensorDesc> &input_desc, | ||||
| std::vector<DataBuffer> &outputs, | std::vector<DataBuffer> &outputs, | ||||
| @@ -75,9 +75,9 @@ class HybridModelAsyncExecutor { | |||||
| HybridModelExecutor::ExecuteArgs &args, | HybridModelExecutor::ExecuteArgs &args, | ||||
| OutputData *output_data); | OutputData *output_data); | ||||
| Status CopyOutputs(HybridModelExecutor::ExecuteArgs &args, | |||||
| OutputData *output_data, | |||||
| std::vector<ge::Tensor> &outputs); | |||||
| Status CopyOutputs(HybridModelExecutor::ExecuteArgs &args, OutputData *output_data, std::vector<ge::Tensor> &outputs); | |||||
| void BuildDeviceTensor(TensorValue &output_tensor, GeTensorDesc &ge_tensor_desc, int64_t output_size, | |||||
| std::vector<ge::Tensor> &outputs); | |||||
| Status OnComputeDone(uint32_t data_index, uint32_t result_code, std::vector<ge::Tensor> &outputs); | Status OnComputeDone(uint32_t data_index, uint32_t result_code, std::vector<ge::Tensor> &outputs); | ||||
| @@ -50,7 +50,7 @@ Status HybridModelExecutor::Execute(HybridModelExecutor::ExecuteArgs &args) { | |||||
| auto root_graph_item = model_->GetRootGraphItem(); | auto root_graph_item = model_->GetRootGraphItem(); | ||||
| GE_CHECK_NOTNULL(root_graph_item); | GE_CHECK_NOTNULL(root_graph_item); | ||||
| if (root_graph_item->IsDynamic()) { | |||||
| if (root_graph_item->IsDynamic() && !model_->IsSingleOp()) { | |||||
| GE_CHK_STATUS_RET(CheckInputShapeByShapeRange(root_graph_item, args), | GE_CHK_STATUS_RET(CheckInputShapeByShapeRange(root_graph_item, args), | ||||
| "[%s] check input node shape by shape range failed.", | "[%s] check input node shape by shape range failed.", | ||||
| root_graph_item->GetName().c_str()); | root_graph_item->GetName().c_str()); | ||||
| @@ -18,14 +18,26 @@ const char *const kEnvProfilingLevel = "HYBRID_PROFILING_LEVEL"; | |||||
| StageExecutor::StageExecutor(int id, HybridModel *model, PipeExecutionConfig *config) | StageExecutor::StageExecutor(int id, HybridModel *model, PipeExecutionConfig *config) | ||||
| : id_(id), model_(model), pipe_config_(config) {} | : id_(id), model_(model), pipe_config_(config) {} | ||||
| StageExecutor::~StageExecutor() { GELOGD("~StageExecutor(), id = %d", id_); } | |||||
| StageExecutor::~StageExecutor() { | |||||
| GELOGD("~StageExecutor(), id = %d", id_); | |||||
| if (stream_ != nullptr) { | |||||
| GE_CHK_RT(rtStreamDestroy(stream_)); | |||||
| stream_ = nullptr; | |||||
| } | |||||
| if (hccl_stream_ != nullptr) { | |||||
| GE_CHK_RT(rtStreamDestroy(hccl_stream_)); | |||||
| hccl_stream_ = nullptr; | |||||
| } | |||||
| } | |||||
| Status StageExecutor::Init() { | Status StageExecutor::Init() { | ||||
| GELOGD("[Executor: %d] Start to init StateExecutor", id_); | GELOGD("[Executor: %d] Start to init StateExecutor", id_); | ||||
| context_.rt_context = pipe_config_->rt_context; | context_.rt_context = pipe_config_->rt_context; | ||||
| GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | ||||
| GE_CHK_RT_RET(rtStreamCreate(&stream_, RT_STREAM_PRIORITY_DEFAULT)); | GE_CHK_RT_RET(rtStreamCreate(&stream_, RT_STREAM_PRIORITY_DEFAULT)); | ||||
| GE_CHK_RT_RET(rtStreamCreate(&hccl_stream_, RT_STREAM_PRIORITY_DEFAULT)); | |||||
| context_.stream = stream_; | context_.stream = stream_; | ||||
| context_.hccl_stream = hccl_stream_; | |||||
| root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_)); | root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_)); | ||||
| GE_CHECK_NOTNULL(root_graph_executor_); | GE_CHECK_NOTNULL(root_graph_executor_); | ||||
| @@ -78,11 +90,11 @@ Status StageExecutor::Start(const std::vector<TensorValue> &inputs, const std::v | |||||
| if (task_info.event != nullptr) { | if (task_info.event != nullptr) { | ||||
| GELOGD("[%d] Add StreamWaitEvent", id_); | GELOGD("[%d] Add StreamWaitEvent", id_); | ||||
| GE_CHK_RT_RET(rtStreamWaitEvent(stream_, task_info.event)); | GE_CHK_RT_RET(rtStreamWaitEvent(stream_, task_info.event)); | ||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %ld] [Stage = %d] End", task_info.iteration - 1, | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %ld] [Stage = %d] EventWait End", task_info.iteration, | |||||
| task_info.stage); | task_info.stage); | ||||
| } | } | ||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %lld] [Stage = %d] Start", task_info.iteration, | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %ld] [Stage = %d] Start", task_info.iteration, | |||||
| task_info.stage); | task_info.stage); | ||||
| if (task_info.stage == 0) { | if (task_info.stage == 0) { | ||||
| @@ -102,6 +114,10 @@ Status StageExecutor::Start(const std::vector<TensorValue> &inputs, const std::v | |||||
| StageTask next_task; | StageTask next_task; | ||||
| next_task.stage = task_info.stage; | next_task.stage = task_info.stage; | ||||
| next_task.iteration = task_info.iteration + 1; | next_task.iteration = task_info.iteration + 1; | ||||
| if ((task_info.iteration + 1) % iteration_count > 0) { | |||||
| GE_CHK_RT_RET(rtEventCreate(&next_task.event)); | |||||
| GE_CHK_RT_RET(rtEventRecord(next_task.event, context_.hccl_stream)); | |||||
| } | |||||
| auto sync_result = Synchronize(); | auto sync_result = Synchronize(); | ||||
| if (sync_result != SUCCESS) { | if (sync_result != SUCCESS) { | ||||
| @@ -110,15 +126,22 @@ Status StageExecutor::Start(const std::vector<TensorValue> &inputs, const std::v | |||||
| id_, sync_result, task_info.iteration); | id_, sync_result, task_info.iteration); | ||||
| REPORT_CALL_ERROR("E19999", "[Executor: %d] Failed to sync result:%d. iteration = %ld", | REPORT_CALL_ERROR("E19999", "[Executor: %d] Failed to sync result:%d. iteration = %ld", | ||||
| id_, sync_result, task_info.iteration); | id_, sync_result, task_info.iteration); | ||||
| context_.profiler->Dump(std::cout); | |||||
| if (context_.profiler != nullptr) { | |||||
| context_.profiler->Dump(std::cout); | |||||
| } | |||||
| context_.callback_manager->Destroy(); | context_.callback_manager->Destroy(); | ||||
| RuntimeInferenceContext::DestroyContext(std::to_string(context_.context_id)); | RuntimeInferenceContext::DestroyContext(std::to_string(context_.context_id)); | ||||
| return sync_result; | return sync_result; | ||||
| } | } | ||||
| if (task_info.event != nullptr) { | |||||
| GE_CHK_RT_RET(rtEventDestroy(task_info.event)); | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %ld] [Stage = %d] EventDestroy End", task_info.iteration, | |||||
| task_info.stage); | |||||
| } | |||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %ld] [Stage = %d] End", task_info.iteration, task_info.stage); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %ld] [Stage = %d] End", task_info.iteration, task_info.stage); | ||||
| // if not end stage | |||||
| // if end stage | |||||
| if (task_info.stage >= pipe_config_->num_stages - 1) { | if (task_info.stage >= pipe_config_->num_stages - 1) { | ||||
| RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %ld] Schedule End", task_info.iteration); | RECORD_MODEL_EXECUTION_EVENT(&context_, "[iteration = %ld] Schedule End", task_info.iteration); | ||||
| GELOGD("[Executor: %d] End of iteration [%ld]", id_, task_info.iteration); | GELOGD("[Executor: %d] End of iteration [%ld]", id_, task_info.iteration); | ||||
| @@ -163,6 +186,7 @@ Status StageExecutor::InitExecutionContext() { | |||||
| context_.callback_manager = std::unique_ptr<CallbackManager>(new (std::nothrow) CallbackManager()); | context_.callback_manager = std::unique_ptr<CallbackManager>(new (std::nothrow) CallbackManager()); | ||||
| GE_CHECK_NOTNULL(context_.callback_manager); | GE_CHECK_NOTNULL(context_.callback_manager); | ||||
| context_.dump_properties = DumpManager::GetInstance().GetDumpProperties(context_.session_id); | context_.dump_properties = DumpManager::GetInstance().GetDumpProperties(context_.session_id); | ||||
| context_.is_eos_ = false; | |||||
| if (IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) { | if (IsLogEnable(GE_MODULE_NAME, DLOG_DEBUG)) { | ||||
| context_.trace_enabled = true; | context_.trace_enabled = true; | ||||
| } | } | ||||
| @@ -63,6 +63,7 @@ class StageExecutor { | |||||
| StageExecutor *next_executor_ = nullptr; | StageExecutor *next_executor_ = nullptr; | ||||
| rtStream_t stream_ = nullptr; | rtStream_t stream_ = nullptr; | ||||
| rtStream_t hccl_stream_ = nullptr; | |||||
| }; | }; | ||||
| class HybridModelPipelineExecutor { | class HybridModelPipelineExecutor { | ||||
| @@ -121,5 +121,10 @@ void NodeDoneManager::Reset(const NodePtr &node) { | |||||
| GELOGD("[%s] Node reset.", node->GetName().c_str()); | GELOGD("[%s] Node reset.", node->GetName().c_str()); | ||||
| } | } | ||||
| } | } | ||||
| void NodeDoneManager::Reset() { | |||||
| subjects_.clear(); | |||||
| destroyed_ = false; | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -35,6 +35,8 @@ class NodeDoneManager { | |||||
| void Destroy(); | void Destroy(); | ||||
| void Reset(); | |||||
| private: | private: | ||||
| class Cond { | class Cond { | ||||
| public: | public: | ||||
| @@ -104,11 +104,47 @@ void ShapeInferenceState::UpdateInputShapeFuture(int idx, ShapeFuture &&future) | |||||
| } | } | ||||
| } | } | ||||
| Status ShapeInferenceState::UpdateInputForMerge(const GraphExecutionContext &context) { | |||||
| int merge_index = -1; | |||||
| const auto &guard = node_item.MutexGuard("UpdateInputForMerge"); | |||||
| if (!AttrUtils::GetInt(node_item.op_desc, ATTR_NAME_MERGE_INPUT_INDEX, merge_index)) { | |||||
| GELOGE(FAILED, "[%s] Get attr %s failed", node_item.NodeName().c_str(), ATTR_NAME_MERGE_INPUT_INDEX.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| if (merge_index < 0 || static_cast<size_t>(merge_index) >= input_tensor_desc.size()) { | |||||
| GELOGE(FAILED, "[%s] merge index: %d invalid, should in range[0, %zu)", | |||||
| node_item.NodeName().c_str(), merge_index, input_tensor_desc.size()); | |||||
| return FAILED; | |||||
| } | |||||
| auto dst_tensor_desc = node_item.MutableInputDesc(merge_index); | |||||
| GE_CHECK_NOTNULL(dst_tensor_desc); | |||||
| int64_t tensor_size = -1; | |||||
| auto &tensor_desc = input_tensor_desc[merge_index]; | |||||
| (void)TensorUtils::GetSize(tensor_desc, tensor_size); | |||||
| dst_tensor_desc->SetShape(tensor_desc.MutableShape()); | |||||
| dst_tensor_desc->SetOriginShape(tensor_desc.GetOriginShape()); | |||||
| (void)TensorUtils::SetSize(*dst_tensor_desc, tensor_size); | |||||
| (void)guard; | |||||
| GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], tensor size = %ld", | |||||
| node_item.NodeName().c_str(), merge_index, dst_tensor_desc->GetShape().ToString().c_str(), | |||||
| dst_tensor_desc->GetOriginShape().ToString().c_str(), tensor_size); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &context) { | Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &context) { | ||||
| if (!node_item.is_dynamic) { | if (!node_item.is_dynamic) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| std::unique_lock<std::mutex> lk(mu_); | std::unique_lock<std::mutex> lk(mu_); | ||||
| if (node_item.IsMergeOp()) { | |||||
| return UpdateInputForMerge(context); | |||||
| } | |||||
| if (num_pending_shapes_ > 0) { | if (num_pending_shapes_ > 0) { | ||||
| GELOGD("[%s] Await pending shape or shape future start.", node_item.NodeName().c_str()); | GELOGD("[%s] Await pending shape or shape future start.", node_item.NodeName().c_str()); | ||||
| int try_count = 0; | int try_count = 0; | ||||
| @@ -169,7 +205,7 @@ Status ShapeInferenceState::AwaitShapesReady(const GraphExecutionContext &contex | |||||
| int64_t tensor_size = -1; | int64_t tensor_size = -1; | ||||
| (void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); | (void) TensorUtils::GetSize(*src_tensor_desc, tensor_size); | ||||
| GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], index = %zu", | |||||
| GELOGD("[%s] Update input shape [%u] with shape: [%s] and ori_shape: [%s], tensor size = %ld", | |||||
| node_item.NodeName().c_str(), | node_item.NodeName().c_str(), | ||||
| idx, | idx, | ||||
| src_tensor_desc->GetShape().ToString().c_str(), | src_tensor_desc->GetShape().ToString().c_str(), | ||||
| @@ -283,11 +319,8 @@ void NodeState::ResetContext(int group) { | |||||
| } | } | ||||
| switch_index_ = -1; | switch_index_ = -1; | ||||
| const auto &guard = node_item_->MutexGuard("ResetContext"); | |||||
| shape_inference_state_.InitShapeState(); | |||||
| subgraph_context_->ResetContext(node_item_->node); | subgraph_context_->ResetContext(node_item_->node); | ||||
| GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_); | GELOGD("Node[%s] in while loop, current loop: %lu, merge index: %d", GetName().c_str(), loop_count_, merge_index_); | ||||
| (void)guard; | |||||
| } | } | ||||
| void NodeState::ResetSchedule() { | void NodeState::ResetSchedule() { | ||||
| @@ -67,6 +67,8 @@ struct ShapeInferenceState { | |||||
| const NodeItem &node_item; | const NodeItem &node_item; | ||||
| private: | private: | ||||
| Status UpdateInputForMerge(const GraphExecutionContext &context); | |||||
| friend struct NodeState; | friend struct NodeState; | ||||
| std::vector<std::pair<int, ShapeFuture>> shape_futures; | std::vector<std::pair<int, ShapeFuture>> shape_futures; | ||||
| // do not directly update op_desc, in case race condition across pipelines | // do not directly update op_desc, in case race condition across pipelines | ||||
| @@ -15,8 +15,6 @@ | |||||
| */ | */ | ||||
| #include "subgraph_context.h" | #include "subgraph_context.h" | ||||
| #include "common/debug/log.h" | |||||
| #include "hybrid/executor/hybrid_model_executor.h" | #include "hybrid/executor/hybrid_model_executor.h" | ||||
| namespace ge { | namespace ge { | ||||
| @@ -25,6 +23,13 @@ SubgraphContext::SubgraphContext(const GraphItem *graph_item, const GraphExecuti | |||||
| : graph_item_(graph_item), execution_context_(execution_context) { | : graph_item_(graph_item), execution_context_(execution_context) { | ||||
| } | } | ||||
| SubgraphContext::~SubgraphContext() { | |||||
| if (mmRWLockDestroy(&rw_lock_) != EN_OK) { | |||||
| REPORT_CALL_ERROR("E19999", "Destroy rw_lock failed"); | |||||
| GELOGE(INTERNAL_ERROR, "[RWLock][Destroy] Destroy rw_lock failed"); | |||||
| } | |||||
| } | |||||
| Status SubgraphContext::Init() { | Status SubgraphContext::Init() { | ||||
| GE_CHECK_NOTNULL(graph_item_); | GE_CHECK_NOTNULL(graph_item_); | ||||
| GELOGD("[%s] Start to init subgraph context. total inputs = %d, total outputs = %d", | GELOGD("[%s] Start to init subgraph context. total inputs = %d, total outputs = %d", | ||||
| @@ -33,7 +38,11 @@ Status SubgraphContext::Init() { | |||||
| graph_item_->TotalOutputs()); | graph_item_->TotalOutputs()); | ||||
| all_inputs_.resize(static_cast<unsigned long>(graph_item_->TotalInputs())); | all_inputs_.resize(static_cast<unsigned long>(graph_item_->TotalInputs())); | ||||
| all_outputs_.resize(static_cast<unsigned long>(graph_item_->TotalOutputs())); | all_outputs_.resize(static_cast<unsigned long>(graph_item_->TotalOutputs())); | ||||
| if (mmRWLockInit(&rw_lock_) != EN_OK) { | |||||
| REPORT_CALL_ERROR("E19999", "Init rw_lock failed"); | |||||
| GELOGE(INTERNAL_ERROR, "[RWLock][Init] Init rw_lock failed"); | |||||
| return INTERNAL_ERROR; | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -42,13 +51,48 @@ void SubgraphContext::ResetContext(const NodePtr &node) { | |||||
| } | } | ||||
| NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | NodeStatePtr SubgraphContext::GetOrCreateNodeState(const NodeItem *node_item) { | ||||
| std::lock_guard<std::mutex> lk(mu_); | |||||
| GELOGD("[%s] lock for read", node_item->NodeName().c_str()); | |||||
| if (mmRWLockRDLock(&rw_lock_) != EN_OK) { | |||||
| REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for read failed", node_item->NodeName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for read failed", node_item->NodeName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| const auto &iter = node_states_.find(node_item); | |||||
| if (iter != node_states_.end()) { | |||||
| auto state = iter->second; | |||||
| GELOGD("[%s] unlock for read", node_item->NodeName().c_str()); | |||||
| if (mmRDLockUnLock(&rw_lock_) != EN_OK) { | |||||
| REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for read failed", node_item->NodeName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[RWLock][Unlock][Node:%s] Unlock for read failed", node_item->NodeName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| return state; | |||||
| } | |||||
| GELOGD("[%s] unlock for read", node_item->NodeName().c_str()); | |||||
| if (mmRDLockUnLock(&rw_lock_) != EN_OK) { | |||||
| REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for read failed", node_item->NodeName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[RWLock][Unlock][Node:%s] Unlock for read failed", node_item->NodeName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| GELOGD("[%s] lock for write", node_item->NodeName().c_str()); | |||||
| if (mmRWLockWRLock(&rw_lock_) != EN_OK) { | |||||
| REPORT_CALL_ERROR("E19999", "[Node:%s] Lock for write failed", node_item->NodeName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[RWLock][Lock][Node:%s] Lock for write failed", node_item->NodeName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| auto &node_state = node_states_[node_item]; | auto &node_state = node_states_[node_item]; | ||||
| if (node_state == nullptr) { | if (node_state == nullptr) { | ||||
| const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | const auto &guard = node_item->MutexGuard("GetOrCreateNodeState"); | ||||
| node_state.reset(new(std::nothrow)NodeState(*node_item, this)); | |||||
| node_state = std::move(std::unique_ptr<NodeState>(new(std::nothrow)NodeState(*node_item, this))); | |||||
| (void)guard; | (void)guard; | ||||
| } | } | ||||
| GELOGD("[%s] unlock for write", node_item->NodeName().c_str()); | |||||
| if (mmWRLockUnLock(&rw_lock_) != EN_OK) { | |||||
| REPORT_CALL_ERROR("E19999", "[Node:%s] Unlock for write failed", node_item->NodeName().c_str()); | |||||
| GELOGE(INTERNAL_ERROR, "[RWLock][Unlock][Node:%s] Unlock for write failed", node_item->NodeName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| return node_state; | return node_state; | ||||
| } | } | ||||
| @@ -144,5 +188,13 @@ void SubgraphContext::OnError(Status error) { | |||||
| void SubgraphContext::NodeDone(const NodePtr &node) { | void SubgraphContext::NodeDone(const NodePtr &node) { | ||||
| node_done_manager_.NodeDone(node); | node_done_manager_.NodeDone(node); | ||||
| } | } | ||||
| void SubgraphContext::Reset() { | |||||
| node_done_manager_.Reset(); | |||||
| if (mmRWLockWRLock(&rw_lock_) == EN_OK) { | |||||
| node_states_.clear(); | |||||
| (void)mmWRLockUnLock(&rw_lock_); | |||||
| } | |||||
| } | |||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -18,7 +18,7 @@ | |||||
| #define GE_HYBRID_EXECUTOR_ITERATION_CONTEXT_H_ | #define GE_HYBRID_EXECUTOR_ITERATION_CONTEXT_H_ | ||||
| #include <vector> | #include <vector> | ||||
| #include "mmpa/mmpa_api.h" | |||||
| #include "hybrid/common/tensor_value.h" | #include "hybrid/common/tensor_value.h" | ||||
| #include "hybrid/executor/hybrid_execution_context.h" | #include "hybrid/executor/hybrid_execution_context.h" | ||||
| #include "hybrid/executor/node_state.h" | #include "hybrid/executor/node_state.h" | ||||
| @@ -31,10 +31,11 @@ namespace hybrid { | |||||
| class SubgraphContext { | class SubgraphContext { | ||||
| public: | public: | ||||
| explicit SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context); | explicit SubgraphContext(const GraphItem *graph_item, const GraphExecutionContext *execution_context); | ||||
| ~SubgraphContext() = default; | |||||
| ~SubgraphContext(); | |||||
| Status Init(); | Status Init(); | ||||
| void ResetContext(const NodePtr &node); | void ResetContext(const NodePtr &node); | ||||
| void Reset(); | |||||
| NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item); | NodeStatePtr GetOrCreateNodeState(const NodeItem *node_item); | ||||
| void OnError(Status error); | void OnError(Status error); | ||||
| @@ -52,7 +53,7 @@ class SubgraphContext { | |||||
| friend class TaskContext; | friend class TaskContext; | ||||
| const GraphItem *graph_item_; | const GraphItem *graph_item_; | ||||
| const GraphExecutionContext *execution_context_; | const GraphExecutionContext *execution_context_; | ||||
| std::mutex mu_; | |||||
| mmRWLock_t rw_lock_; | |||||
| std::vector<TensorValue> all_inputs_; | std::vector<TensorValue> all_inputs_; | ||||
| std::vector<TensorValue> all_outputs_; | std::vector<TensorValue> all_outputs_; | ||||
| NodeDoneManager node_done_manager_; | NodeDoneManager node_done_manager_; | ||||
| @@ -704,7 +704,21 @@ Status SubgraphExecutor::PartialExecuteAsync(int task_group) { | |||||
| Status SubgraphExecutor::InitForPartialExecution(const vector<TensorValue> &inputs, | Status SubgraphExecutor::InitForPartialExecution(const vector<TensorValue> &inputs, | ||||
| const vector<ConstGeTensorDescPtr> &input_desc) { | const vector<ConstGeTensorDescPtr> &input_desc) { | ||||
| return Init(inputs, input_desc); | |||||
| if (subgraph_context_ == nullptr) { | |||||
| return Init(inputs, input_desc); | |||||
| } | |||||
| subgraph_context_->Reset(); | |||||
| if (graph_item_->IsDynamic()) { | |||||
| GE_CHK_STATUS_RET(InitInputsForUnknownShape(inputs, input_desc), | |||||
| "[%s] Failed to set inputs.", | |||||
| graph_item_->GetName().c_str()); | |||||
| } else { | |||||
| GE_CHK_STATUS_RET(InitInputsForKnownShape(inputs), | |||||
| "[Invoke][InitInputsForKnownShape][%s] Failed to init subgraph executor for known shape subgraph", | |||||
| graph_item_->GetName().c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | } | ||||
| } // namespace hybrid | } // namespace hybrid | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -300,7 +300,7 @@ Status NodeDoneCallback::OnNodeDone() { | |||||
| GE_CHK_STATUS_RET(SaveDumpOpInfo(), "[Save][DumpOpInfo] Failed to dump op info."); | GE_CHK_STATUS_RET(SaveDumpOpInfo(), "[Save][DumpOpInfo] Failed to dump op info."); | ||||
| } | } | ||||
| if (ProfilingManager::Instance().ProfilingModelExecuteOn()) { | |||||
| if (ProfilingManager::Instance().ProfilingModelLoadOn()) { | |||||
| GE_CHK_STATUS_RET(ProfilingReport(), "[Report][Profiling] of node[%s] failed.", node_item.NodeName().c_str()); | GE_CHK_STATUS_RET(ProfilingReport(), "[Report][Profiling] of node[%s] failed.", node_item.NodeName().c_str()); | ||||
| } | } | ||||
| @@ -26,8 +26,7 @@ | |||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| #include "graph/manager/host_mem_manager.h" | #include "graph/manager/host_mem_manager.h" | ||||
| #include "graph/manager/trans_var_data_utils.h" | #include "graph/manager/trans_var_data_utils.h" | ||||
| #include "graph/manager/graph_mem_allocator.h" | |||||
| #include "graph/manager/host_mem_allocator.h" | |||||
| #include "graph/manager/graph_mem_manager.h" | |||||
| #include "graph/utils/graph_utils.h" | #include "graph/utils/graph_utils.h" | ||||
| #include "hybrid/common/npu_memory_allocator.h" | #include "hybrid/common/npu_memory_allocator.h" | ||||
| #include "hybrid/node_executor/node_executor.h" | #include "hybrid/node_executor/node_executor.h" | ||||
| @@ -260,6 +259,10 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| if (node->GetType() == MEMCPYASYNC) { // Convert MemcpyAsync to Identity. | |||||
| node->GetOpDesc()->SetType(IDENTITY); | |||||
| } | |||||
| std::unique_ptr<NodeItem> new_node; | std::unique_ptr<NodeItem> new_node; | ||||
| GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); | GE_CHK_STATUS_RET(NodeItem::Create(node, new_node), "[Invoke][Create] failed, model_name_:[%s]", GetGraphName()); | ||||
| GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); | GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance().GetExecutor(*node, &new_node->node_executor)); | ||||
| @@ -1002,14 +1005,18 @@ Status HybridModelBuilder::InitConstantOps() { | |||||
| // Tensors return by api GetWeights share data with proto, whose addr is not confirmed to be aligned | // Tensors return by api GetWeights share data with proto, whose addr is not confirmed to be aligned | ||||
| GeTensor aligned_tensor = ge_tensor->Clone(); | GeTensor aligned_tensor = ge_tensor->Clone(); | ||||
| GELOGD("Init tensor with host constant %s size = %zu", var_name.c_str(), aligned_tensor.MutableData().GetSize()); | GELOGD("Init tensor with host constant %s size = %zu", var_name.c_str(), aligned_tensor.MutableData().GetSize()); | ||||
| if (MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(aligned_tensor.GetAlignedPtr(), | |||||
| aligned_tensor.GetData().size()) == nullptr) { | |||||
| GELOGE(MEMALLOC_FAILED, "[Malloc][HostMemory] for an existed GeTensor failed, model_name_:%s.", | |||||
| GetGraphName()); | |||||
| return MEMALLOC_FAILED; | |||||
| if (aligned_tensor.GetData().size() > 0) { | |||||
| if (MemManager::Instance().HostMemInstance(RT_MEMORY_HBM).Malloc(aligned_tensor.GetAlignedPtr(), | |||||
| aligned_tensor.GetData().size()) == nullptr) { | |||||
| GELOGE(MEMALLOC_FAILED, "[Malloc][HostMemory] for an existed GeTensor failed, model_name_:%s.", | |||||
| GetGraphName()); | |||||
| return MEMALLOC_FAILED; | |||||
| } | |||||
| var_tensor.reset(new(std::nothrow)TensorValue(aligned_tensor.MutableData().data(), | |||||
| aligned_tensor.GetData().size())); | |||||
| } else { | |||||
| var_tensor.reset(new(std::nothrow)TensorValue(nullptr, 0)); | |||||
| } | } | ||||
| var_tensor.reset(new(std::nothrow)TensorValue(aligned_tensor.MutableData().data(), | |||||
| aligned_tensor.GetData().size())); | |||||
| } else { | } else { | ||||
| GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); | GE_CHK_STATUS_RET_NOLOG(VarNodeToTensor(var_node, var_tensor)); | ||||
| GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); | GELOGD("Init const op tensor. name = %s, size = %ld", var_name.c_str(), var_tensor->GetSize()); | ||||