| @@ -168,6 +168,11 @@ set(EXECUTOR_SRC_LIST | |||||
| "graph/manager/util/debug.cc" | "graph/manager/util/debug.cc" | ||||
| #"graph/manager/util/hcom_util.cc" # Just for runner. | #"graph/manager/util/hcom_util.cc" # Just for runner. | ||||
| "graph/passes/pass_utils.cc" | "graph/passes/pass_utils.cc" | ||||
| "graph/passes/mds_pass.cc" | |||||
| "graph/passes/mds_kernels/mds_utils.cc" | |||||
| "graph/passes/mds_kernels/variable_mds_kernel.cc" | |||||
| "graph/passes/mds_kernels/conv2d_mds_kernel.cc" | |||||
| "graph/passes/mds_kernels/base_mds_kernel.cc" | |||||
| "host_kernels/add_kernel.cc" | "host_kernels/add_kernel.cc" | ||||
| "host_kernels/broadcast_args_kernel.cc" | "host_kernels/broadcast_args_kernel.cc" | ||||
| "host_kernels/broadcast_gradient_args_kernel.cc" | "host_kernels/broadcast_gradient_args_kernel.cc" | ||||
| @@ -76,5 +76,13 @@ Status GeLocalOpsKernelInfoStore::DestroySession(const map<string, string> &sess | |||||
| // Do nothing | // Do nothing | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GeLocalOpsKernelInfoStore::SetCutSupportedInfo(const NodePtr &node) { | |||||
| // TODO: | |||||
| // 1. Whether the variable type is identified as a trainable variable | |||||
| // 2, whether to turn on smdp1 and 3 | |||||
| // To meet the above two points, set the current variable | |||||
| // node to be tangent in the variable segmentation information | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge_local | } // namespace ge_local | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -86,6 +86,8 @@ class GE_FUNC_VISIBILITY GeLocalOpsKernelInfoStore : public OpsKernelInfoStore { | |||||
| */ | */ | ||||
| Status DestroySession(const std::map<std::string, std::string> &session_options) override; | Status DestroySession(const std::map<std::string, std::string> &session_options) override; | ||||
| Status SetCutSupportedInfo(const ge::NodePtr &node) override; | |||||
| // Copy prohibited | // Copy prohibited | ||||
| GeLocalOpsKernelInfoStore(const GeLocalOpsKernelInfoStore &ops_kernel_store) = delete; | GeLocalOpsKernelInfoStore(const GeLocalOpsKernelInfoStore &ops_kernel_store) = delete; | ||||
| @@ -22,8 +22,18 @@ | |||||
| #include "graph/load/model_manager/model_manager.h" | #include "graph/load/model_manager/model_manager.h" | ||||
| #include "graph/load/model_manager/davinci_model.h" | #include "graph/load/model_manager/davinci_model.h" | ||||
| #include "common/profiling/profiling_manager.h" | #include "common/profiling/profiling_manager.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "common/thread_pool.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| //deploy info | |||||
| const char *const kAttrDeviceType = "_device_type"; | |||||
| const char *const kAttrDeviceId = "_device_id"; | |||||
| const char *const kAttrGraphName = "_graph_name"; | |||||
| const char *const kAttrGraphInputs = "_graph_inputs"; | |||||
| const char *const kAttrNeedReturnResult = "_need_return_result"; | |||||
| } | |||||
| using Uint32Pair = pair<uint32_t, uint32_t>; | using Uint32Pair = pair<uint32_t, uint32_t>; | ||||
| const uint32_t kInvalidModelId = UINT32_MAX; | const uint32_t kInvalidModelId = UINT32_MAX; | ||||
| GraphExecutor::GraphExecutor() | GraphExecutor::GraphExecutor() | ||||
| @@ -386,7 +396,14 @@ Status GraphExecutor::ExecuteGraphAsync(GraphId graph_id, const GeRootModelPtr & | |||||
| } | } | ||||
| last_graph_id_ = graph_id; | last_graph_id_ = graph_id; | ||||
| GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED); | GE_CHECK_NOTNULL_EXEC(ge_root_model, return FAILED); | ||||
| Status ret = AsyncExecuteModel(ge_root_model, input_tensor, callback); | |||||
| vector<GeAttrValue::NAMED_ATTRS> deployInfo; | |||||
| ModelIdInfo model_id_info; | |||||
| Status ret; | |||||
| if (ge::AttrUtils::GetListNamedAttrs(ge_root_model->GetRootGraph(), ATTR_NAME_DEPLOY_INFO, deployInfo)) { | |||||
| ret = AsyncMultiExecuteModel(ge_root_model, input_tensor, callback); | |||||
| } else { | |||||
| ret = AsyncExecuteModel(ge_root_model, GetExecuteModelId(ge_root_model), input_tensor, callback); | |||||
| } | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[AsyncExecute][Model] Error! graph id:%u", graph_id); | GELOGE(GE_GRAPH_SYNC_MODEL_FAILED, "[AsyncExecute][Model] Error! graph id:%u", graph_id); | ||||
| return GE_GRAPH_SYNC_MODEL_FAILED; | return GE_GRAPH_SYNC_MODEL_FAILED; | ||||
| @@ -522,10 +539,67 @@ Status GraphExecutor::SetCallback(uint32_t model_id, const GeRootModelPtr &ge_ro | |||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphExecutor::AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<ge::Tensor> &inputs, | |||||
| const RunAsyncCallback &callback) { | |||||
| // get deploy number of model instance | |||||
| auto root_graph = ge_root_model->GetRootGraph(); | |||||
| vector<GeAttrValue::NAMED_ATTRS> deploy_info; | |||||
| if (!ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deploy_info) || deploy_info.empty()) { | |||||
| GELOGE(FAILED, "[AsyncMultiExecuteModel] graph %s has invalid deploy attr %s", root_graph->GetName().c_str(), | |||||
| ATTR_NAME_DEPLOY_INFO.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| auto thread_instances_size = deploy_info.size(); | |||||
| auto model_ids = ge_root_model->GetAllModelId(); | |||||
| if (model_ids.size() != thread_instances_size) { | |||||
| GELOGE(FAILED, | |||||
| "[AsyncMultiExecuteModel] something wrong, attr deploy numbers %zu should be equal to loaded models %zu", | |||||
| thread_instances_size, model_ids.size()); | |||||
| return FAILED; | |||||
| } | |||||
| ThreadPool executor(thread_instances_size); | |||||
| std::vector<std::future<Status>> vector_future; | |||||
| for (size_t i = 0; i < thread_instances_size; ++i) { | |||||
| auto thread_instance = deploy_info[i]; | |||||
| std::vector<GeTensorPtr> graph_inputs; | |||||
| if (ge::AttrUtils::MutableListTensor(thread_instance, kAttrGraphInputs, graph_inputs)) { | |||||
| std::vector<ge::Tensor> graph_input_updated(inputs.begin(), inputs.end()); | |||||
| for (const auto &ge_tensor_ptr : graph_inputs) { | |||||
| graph_input_updated.push_back(TensorAdapter::AsTensor(*ge_tensor_ptr)); | |||||
| } | |||||
| GraphExecutor graph_executor; | |||||
| ExecuteModelFunc execute_model_func(&GraphExecutor::AsyncExecuteModel); | |||||
| std::future<Status> f; | |||||
| bool need_return_result = false; | |||||
| if ((ge::AttrUtils::GetBool(thread_instance, kAttrNeedReturnResult, need_return_result) && need_return_result)) { | |||||
| f = executor.commit(execute_model_func, &graph_executor, ge_root_model, model_ids[i], graph_input_updated, | |||||
| callback); | |||||
| } else { | |||||
| RunAsyncCallback callback_stub; | |||||
| f = executor.commit(execute_model_func, &graph_executor, ge_root_model, model_ids[i], graph_input_updated, | |||||
| callback_stub); | |||||
| } | |||||
| if (!f.valid()) { | |||||
| GELOGE(FAILED, "[Call][Commit] failed, Future is invalid"); | |||||
| return FAILED; | |||||
| } | |||||
| vector_future.emplace_back(std::move(f)); | |||||
| } | |||||
| } | |||||
| for (size_t i = 0; i < vector_future.size(); ++i) { | |||||
| Status ret_status = vector_future[i].get(); | |||||
| if (ret_status != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", " Execute multi model %zu failed", i); | |||||
| GELOGE(ret_status, "[AsyncMultiExecuteModel] Execute multi model failed", i); | |||||
| return ret_status; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<ge::Tensor> &inputs, | |||||
| const RunAsyncCallback &callback) { | |||||
| uint32_t model_id = GetExecuteModelId(ge_root_model); | |||||
| Status GraphExecutor::AsyncExecuteModel(const GeRootModelPtr &ge_root_model, uint32_t model_id, | |||||
| const std::vector<ge::Tensor> &inputs, const RunAsyncCallback &callback) { | |||||
| if (model_id == kInvalidModelId) { | if (model_id == kInvalidModelId) { | ||||
| GELOGE(INTERNAL_ERROR, "No valid model id."); | GELOGE(INTERNAL_ERROR, "No valid model id."); | ||||
| return INTERNAL_ERROR; | return INTERNAL_ERROR; | ||||
| @@ -136,8 +136,10 @@ class GraphExecutor { | |||||
| Status SyncExecuteModel(uint32_t model_id, const std::vector<GeTensor> &input_tensor, | Status SyncExecuteModel(uint32_t model_id, const std::vector<GeTensor> &input_tensor, | ||||
| std::vector<GeTensor> &output_tensor); | std::vector<GeTensor> &output_tensor); | ||||
| Status AsyncExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<ge::Tensor> &input_tensor, | |||||
| Status AsyncExecuteModel(const GeRootModelPtr &ge_root_model, uint32_t model_id, const std::vector<ge::Tensor> &input_tensor, | |||||
| const RunAsyncCallback &callback); | const RunAsyncCallback &callback); | ||||
| Status AsyncMultiExecuteModel(const GeRootModelPtr &ge_root_model, const std::vector<ge::Tensor> &input_tensor, | |||||
| const RunAsyncCallback &callback); | |||||
| void InitModelIdInfo(std::vector<uint32_t> &out_model_id_info, std::vector<SubGraphInfoPtr> &sub_graph_vec, | void InitModelIdInfo(std::vector<uint32_t> &out_model_id_info, std::vector<SubGraphInfoPtr> &sub_graph_vec, | ||||
| uint32_t output_size); | uint32_t output_size); | ||||
| @@ -170,6 +172,11 @@ class GraphExecutor { | |||||
| std::vector<void *> buffer_addr_; | std::vector<void *> buffer_addr_; | ||||
| std::vector<uint64_t> buffer_size_; | std::vector<uint64_t> buffer_size_; | ||||
| }; | }; | ||||
| using ExecuteModelFunc = std::function<Status(GraphExecutor *, | |||||
| const GeRootModelPtr &ge_root_model, | |||||
| uint32_t model_id, | |||||
| const std::vector<ge::Tensor> &inputs, | |||||
| const RunAsyncCallback &callback)>; | |||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_EXECUTE_GRAPH_EXECUTE_H_ | #endif // GE_GRAPH_EXECUTE_GRAPH_EXECUTE_H_ | ||||
| @@ -325,34 +325,45 @@ Status ModelExecutor::RunGraphWithStream(const GraphNodePtr &graph_node, GraphId | |||||
| Status ModelExecutor::ModelLoadSync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { | Status ModelExecutor::ModelLoadSync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { | ||||
| ge_root_model->SetIsSpecificStream(graph_node->IsSpecificStream()); | ge_root_model->SetIsSpecificStream(graph_node->IsSpecificStream()); | ||||
| return ModelLoad(ge_root_model, graph_node, graph_run_listener_); | |||||
| return ModelLoad(ge_root_model, graph_node, false); | |||||
| } | } | ||||
| Status ModelExecutor::ModelLoadAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { | Status ModelExecutor::ModelLoadAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node) { | ||||
| auto listener = MakeShared<RunAsyncListener>(); | |||||
| GE_CHECK_NOTNULL(listener); | |||||
| return ModelLoad(ge_root_model, graph_node, listener); | |||||
| return ModelLoad(ge_root_model, graph_node, true); | |||||
| } | } | ||||
| Status ModelExecutor::ModelLoad(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node, | Status ModelExecutor::ModelLoad(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node, | ||||
| const std::shared_ptr<ModelListener> &listener) { | |||||
| bool is_async) { | |||||
| ge_root_model->SetTrainFlag(train_graph_flag_); | ge_root_model->SetTrainFlag(train_graph_flag_); | ||||
| bool is_unknown_shape = false; | bool is_unknown_shape = false; | ||||
| GE_CHK_STATUS_RET(ge_root_model->CheckIsUnknownShape(is_unknown_shape)); | GE_CHK_STATUS_RET(ge_root_model->CheckIsUnknownShape(is_unknown_shape)); | ||||
| auto root_graph = ge_root_model->GetRootGraph(); | |||||
| if (!is_unknown_shape) { | if (!is_unknown_shape) { | ||||
| if (getenv(kEnvGeuseStaticMemory) != nullptr) { | if (getenv(kEnvGeuseStaticMemory) != nullptr) { | ||||
| GELOGI("[LoadGraph] GE_USE_STATIC_MEMORY is seted."); | GELOGI("[LoadGraph] GE_USE_STATIC_MEMORY is seted."); | ||||
| } else { | } else { | ||||
| auto root_graph = ge_root_model->GetRootGraph(); | |||||
| GE_CHECK_NOTNULL(root_graph); | GE_CHECK_NOTNULL(root_graph); | ||||
| auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); | auto name_to_model = ge_root_model->GetSubgraphInstanceNameToModel(); | ||||
| GeModelPtr ge_model = name_to_model[root_graph->GetName()]; | GeModelPtr ge_model = name_to_model[root_graph->GetName()]; | ||||
| GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)); | GE_CHK_STATUS_RET(CheckAndReleaseMemory(ge_model, graph_node)); | ||||
| } | } | ||||
| } | } | ||||
| std::shared_ptr<ModelListener> listener = | |||||
| is_async ? std::dynamic_pointer_cast<ModelListener>(MakeShared<RunAsyncListener>()) : std::dynamic_pointer_cast< | |||||
| ModelListener>(graph_run_listener_); | |||||
| GE_TIMESTAMP_START(LoadModelOnline); | GE_TIMESTAMP_START(LoadModelOnline); | ||||
| uint32_t model_id = INVALID_MODEL_ID; | uint32_t model_id = INVALID_MODEL_ID; | ||||
| Status ret = GraphLoader::LoadModelOnline(model_id, ge_root_model, listener); | |||||
| vector<GeAttrValue::NAMED_ATTRS> deployInfo; | |||||
| Status ret; | |||||
| if (ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deployInfo)) { | |||||
| ret = GraphLoader::LoadMultiModelOnline(ge_root_model, is_async); | |||||
| } else { | |||||
| ret = GraphLoader::LoadModelOnline(model_id, | |||||
| ge_root_model, | |||||
| listener, | |||||
| GetContext().DeviceId(), | |||||
| kInvalidDieId); | |||||
| } | |||||
| GE_TIMESTAMP_EVENT_END(LoadModelOnline, "GraphLoader::LoadModelOnline"); | GE_TIMESTAMP_EVENT_END(LoadModelOnline, "GraphLoader::LoadModelOnline"); | ||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "[Load][ModelOnline] Failed, model_id:%u", model_id); | GELOGE(ret, "[Load][ModelOnline] Failed, model_id:%u", model_id); | ||||
| @@ -360,7 +371,6 @@ Status ModelExecutor::ModelLoad(const GeRootModelPtr &ge_root_model, const Graph | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| graph_node->SetLoadFlag(true); | graph_node->SetLoadFlag(true); | ||||
| ge_root_model->SetModelId(model_id); | |||||
| graph_node->SetGeRootModel(ge_root_model); | graph_node->SetGeRootModel(ge_root_model); | ||||
| AddGraphNode(graph_node->GetGraphId(), graph_node); | AddGraphNode(graph_node->GetGraphId(), graph_node); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -98,8 +98,7 @@ class ModelExecutor : public Executor { | |||||
| Status ModelLoadSync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); | Status ModelLoadSync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); | ||||
| Status ModelLoadAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); | Status ModelLoadAsync(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node); | ||||
| Status ModelLoad(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node, | |||||
| const std::shared_ptr<ModelListener> &listener); | |||||
| Status ModelLoad(const GeRootModelPtr &ge_root_model, const GraphNodePtr &graph_node, bool is_async); | |||||
| Status UnloadModel(const GeRootModelPtr &ge_root_model, uint32_t graph_id); | Status UnloadModel(const GeRootModelPtr &ge_root_model, uint32_t graph_id); | ||||
| @@ -18,14 +18,24 @@ | |||||
| #include <string> | #include <string> | ||||
| #include <vector> | #include <vector> | ||||
| #include <thread> | |||||
| #include "framework/common/helper/model_helper.h" | #include "framework/common/helper/model_helper.h" | ||||
| #include "common/model_parser/model_parser.h" | #include "common/model_parser/model_parser.h" | ||||
| #include "graph/ge_context.h" | #include "graph/ge_context.h" | ||||
| #include "graph/load/model_manager/model_manager.h" | #include "graph/load/model_manager/model_manager.h" | ||||
| #include "graph/manager/graph_var_manager.h" | #include "graph/manager/graph_var_manager.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "common/thread_pool.h" | |||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| //deploy info | |||||
| const char *const kAttrDeviceType = "_device_type"; | |||||
| const char *const kAttrDeviceId = "_device_id"; | |||||
| const char *const kAttrGraphName = "_graph_name"; | |||||
| const char *const kAttrGraphInputs = "_graph_inputs"; | |||||
| } | |||||
| Status GraphLoader::UnloadModel(uint32_t model_id) { | Status GraphLoader::UnloadModel(uint32_t model_id) { | ||||
| auto model_manager = ModelManager::GetInstance(); | auto model_manager = ModelManager::GetInstance(); | ||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| @@ -45,43 +55,81 @@ Status GraphLoader::UnloadModel(uint32_t model_id) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeRootModel> &ge_root_model_ptr, | |||||
| const std::shared_ptr<ModelListener> &listener) { | |||||
| GELOGI("Load model online begin."); | |||||
| rtError_t rt_ret = rtSetDevice(GetContext().DeviceId()); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| Status GraphLoader::SetDevice(uint32_t device_id, int64_t die_id) { | |||||
| if (device_id != kInvalidDeviceId && die_id != kInvalidDieId) { | |||||
| rtError_t rt_ret = rtSetDevice(device_id, kMultiMode); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", device_id, rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][rtSetDevice] failed, device_id:%u, ret:0x%X", device_id, rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| rt_ret = rtSetDieId(die_id); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtSetDieId failed, device_id:%u, ret:0x%X", die_id, rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtSetDevice] rtSetDieId, device_id:%u, ret:0x%X", die_id, rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| } else if (device_id != kInvalidDeviceId && die_id == kInvalidDieId) { | |||||
| rtError_t rt_ret = rtSetDevice(device_id); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", device_id, rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", device_id, rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| } else { | |||||
| REPORT_CALL_ERROR("E19999", "Call SetDevice failed, device_id:%u, die_id:%ld", device_id, die_id); | |||||
| GELOGE(RT_FAILED, "[Call][SetDevice] failed, device_id:%u, die_id:%ld", device_id, die_id); | |||||
| return RT_FAILED; | return RT_FAILED; | ||||
| } | } | ||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphLoader::ResetDevice(uint32_t device_id, int64_t die_id) { | |||||
| if (die_id != kInvalidDieId) { | |||||
| rtError_t rt_ret = rtDieReset(die_id); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", die_id, rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", die_id, rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| } else { | |||||
| rtError_t rt_ret = rtDeviceReset(device_id); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtSetDevice failed, device_id:%u, ret:0x%X", device_id, rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtSetDevice] failed, device_id:%u, ret:0x%X", device_id, rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphLoader::LoadModelOnline(uint32_t &model_id, | |||||
| const std::shared_ptr<ge::GeRootModel> &ge_root_model_ptr, | |||||
| const std::shared_ptr<ModelListener> &listener, | |||||
| uint32_t device_id, | |||||
| int64_t die_id) { | |||||
| GELOGI("Load model online begin."); | |||||
| if (ge_root_model_ptr == nullptr) { | if (ge_root_model_ptr == nullptr) { | ||||
| REPORT_INNER_ERROR("E19999", "Check param ge_root_model_ptr nullptr, check invalid"); | REPORT_INNER_ERROR("E19999", "Check param ge_root_model_ptr nullptr, check invalid"); | ||||
| GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph][Check][Param] GE load graph model_ptr is nullptr."); | GELOGE(GE_GRAPH_PARAM_NULLPTR, "[LoadGraph][Check][Param] GE load graph model_ptr is nullptr."); | ||||
| return GE_GRAPH_PARAM_NULLPTR; | return GE_GRAPH_PARAM_NULLPTR; | ||||
| } | } | ||||
| if (SetDevice(device_id, die_id) != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", "Call SetDevice failed, device_id:%u", device_id); | |||||
| GELOGE(RT_FAILED, "[Call][SetDevice] failed, device_id:%u", device_id); | |||||
| return RT_FAILED; | |||||
| } | |||||
| GE_MAKE_GUARD(reset_device, [&] { GE_CHK_RT(ResetDevice(device_id, die_id)); }); | |||||
| auto model_manager = ModelManager::GetInstance(); | auto model_manager = ModelManager::GetInstance(); | ||||
| GE_CHECK_NOTNULL(model_manager); | GE_CHECK_NOTNULL(model_manager); | ||||
| Status ret = model_manager->LoadModelOnline(model_id, ge_root_model_ptr, listener); | |||||
| Status ret = model_manager->LoadModelOnline(model_id, ge_root_model_ptr, listener,device_id, die_id); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| GELOGE(ret, "[Load][Model] Online failed. ret = %u, model_id:%u", ret, model_id); | GELOGE(ret, "[Load][Model] Online failed. ret = %u, model_id:%u", ret, model_id); | ||||
| rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | |||||
| GetContext().DeviceId(), rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| } | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| ge_root_model_ptr->SetModelId(model_id); | |||||
| if (ge_root_model_ptr->IsSpecificStream()) { | if (ge_root_model_ptr->IsSpecificStream()) { | ||||
| GELOGI("No need to start a new thread to run model in specific scene."); | GELOGI("No need to start a new thread to run model in specific scene."); | ||||
| rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | |||||
| GetContext().DeviceId(), rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| ret = model_manager->Start(model_id); | ret = model_manager->Start(model_id); | ||||
| @@ -89,25 +137,77 @@ Status GraphLoader::LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge | |||||
| if (model_manager->Unload(model_id) != SUCCESS) { | if (model_manager->Unload(model_id) != SUCCESS) { | ||||
| GELOGE(ret, "[Unload][Model] failed while trying to unload after a failed start, model_id:%u.", model_id); | GELOGE(ret, "[Unload][Model] failed while trying to unload after a failed start, model_id:%u.", model_id); | ||||
| } | } | ||||
| rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | |||||
| GetContext().DeviceId(), rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| } | |||||
| GELOGE(ret, "[Start][Model] failed, model_id:%u.", model_id); | GELOGE(ret, "[Start][Model] failed, model_id:%u.", model_id); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| rt_ret = rtDeviceReset(GetContext().DeviceId()); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| REPORT_CALL_ERROR("E19999", "Call rtDeviceReset failed, device_id:%u, ret:0x%X", | |||||
| GetContext().DeviceId(), rt_ret); | |||||
| GELOGE(RT_FAILED, "[Call][RtDeviceReset] failed, device_id:%u, ret:0x%X", GetContext().DeviceId(), rt_ret); | |||||
| return RT_FAILED; | |||||
| } | |||||
| GELOGI("Load model online success, model_id:%u.", model_id); | GELOGI("Load model online success, model_id:%u.", model_id); | ||||
| return SUCCESS; | |||||
| } | |||||
| Status GraphLoader::LoadMultiModelOnline(const std::shared_ptr<ge::GeRootModel> &ge_root_model, bool is_async) { | |||||
| // get deploy number of model instance | |||||
| auto root_graph = ge_root_model->GetRootGraph(); | |||||
| vector<GeAttrValue::NAMED_ATTRS> deploy_info; | |||||
| if (!ge::AttrUtils::GetListNamedAttrs(root_graph, ATTR_NAME_DEPLOY_INFO, deploy_info) || deploy_info.empty()) { | |||||
| GELOGE(FAILED, "[LoadMultiModelOnline] Load multi model failed, graph %s has invalid deploy attr %s", | |||||
| root_graph->GetName().c_str(), ATTR_NAME_DEPLOY_INFO.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| auto thread_instances_size = deploy_info.size(); | |||||
| auto device_id_fission_from = GetContext().DeviceId(); | |||||
| GELOGI("Graph %s need to load model %zu times, and fission from device %u.", root_graph->GetName().c_str(), | |||||
| thread_instances_size, device_id_fission_from); | |||||
| ThreadPool executor(thread_instances_size); | |||||
| std::vector<std::future<Status>> vector_future; | |||||
| GE_TIMESTAMP_START(LoadModelOnline); | |||||
| for (size_t i = 0; i < thread_instances_size; ++i) { | |||||
| auto thread_instance = deploy_info[i]; | |||||
| std::string device_type; | |||||
| ModelIdInfo model_id_info; | |||||
| std::shared_ptr<ModelListener> listener; | |||||
| if (is_async) { | |||||
| listener = MakeShared<RunAsyncListener>(); | |||||
| GE_CHECK_NOTNULL(listener); | |||||
| } else { | |||||
| // TODO: GraphModelListener for sync | |||||
| } | |||||
| int64_t device_id_fissioned = kInvalidDieId; | |||||
| if (!ge::AttrUtils::GetInt(thread_instance, kAttrDeviceId, device_id_fissioned) || | |||||
| device_id_fissioned == kInvalidDieId) { | |||||
| REPORT_CALL_ERROR("E19999", "graph %s has invalid deploy attr %s", root_graph->GetName().c_str(), | |||||
| ATTR_NAME_DEPLOY_INFO.c_str()); | |||||
| GELOGE(GRAPH_FAILED, "[LoadMultiModelOnline] graph %s has invalid deploy attr %s", root_graph->GetName().c_str(), | |||||
| ATTR_NAME_DEPLOY_INFO.c_str()); | |||||
| return GRAPH_FAILED; | |||||
| }; | |||||
| if (ge::AttrUtils::GetStr(thread_instance, kAttrDeviceType, device_type) && device_type == kMultiMode) { | |||||
| std::future<Status> f = executor.commit(GraphLoader::LoadModelOnline, model_id_info.model_id, ge_root_model, | |||||
| listener, device_id_fission_from, device_id_fissioned); | |||||
| if (!f.valid()) { | |||||
| GELOGE(FAILED, "[Call][Commit] failed, Future is invalid"); | |||||
| return FAILED; | |||||
| } | |||||
| vector_future.emplace_back(std::move(f)); | |||||
| } else { | |||||
| std::future<Status> f = executor.commit(GraphLoader::LoadModelOnline, model_id_info.model_id, ge_root_model, | |||||
| listener, device_id_fissioned, kInvalidDieId); | |||||
| if (!f.valid()) { | |||||
| GELOGE(FAILED, "[Call][Commit] failed, Future is invalid"); | |||||
| return FAILED; | |||||
| } | |||||
| vector_future.emplace_back(std::move(f)); | |||||
| } | |||||
| } | |||||
| GE_TIMESTAMP_EVENT_END(LoadModelOnline, "GraphLoader::LoadModelOnline"); | |||||
| for (size_t i = 0; i < vector_future.size(); ++i) { | |||||
| Status ret_status = vector_future[i].get(); | |||||
| if (ret_status != SUCCESS) { | |||||
| REPORT_CALL_ERROR("E19999", " Load multi model %zu failed", i); | |||||
| GELOGE(ret_status, "[LoadMultiModelOnline] Load multi model failed", i); | |||||
| return ret_status; | |||||
| } | |||||
| } | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -30,6 +30,13 @@ | |||||
| #include "runtime/mem.h" | #include "runtime/mem.h" | ||||
| namespace ge { | namespace ge { | ||||
| namespace { | |||||
| const int64_t kInvalidDieId = -1; | |||||
| const uint32_t kInvalidDeviceId = UINT32_MAX; | |||||
| const char* kMultiMode ="MultiMode"; | |||||
| const char* kSingleMode ="SingleMode"; | |||||
| } | |||||
| class GraphLoader { | class GraphLoader { | ||||
| public: | public: | ||||
| GraphLoader() = default; | GraphLoader() = default; | ||||
| @@ -64,9 +71,12 @@ class GraphLoader { | |||||
| static Status DestroyAicpuKernel(uint64_t session_id, uint32_t model_id, uint32_t sub_model_id); | static Status DestroyAicpuKernel(uint64_t session_id, uint32_t model_id, uint32_t sub_model_id); | ||||
| static Status DestroyAicpuSessionForInfer(uint32_t model_id); | static Status DestroyAicpuSessionForInfer(uint32_t model_id); | ||||
| static Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeRootModel> &ge_root_model, | static Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeRootModel> &ge_root_model, | ||||
| const std::shared_ptr<ModelListener> &listener); | |||||
| const std::shared_ptr<ModelListener> &listener, uint32_t device_id, | |||||
| int64_t die_id = kInvalidDieId); | |||||
| static Status SetDevice(uint32_t device_id, int64_t die_id); | |||||
| static Status ResetDevice(uint32_t device_id, int64_t die_id); | |||||
| static Status LoadMultiModelOnline(const std::shared_ptr<ge::GeRootModel> &ge_root_model_ptr, bool is_async); | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| #endif // GE_GRAPH_LOAD_GRAPH_LOADER_H_ | #endif // GE_GRAPH_LOAD_GRAPH_LOADER_H_ | ||||
| @@ -444,16 +444,16 @@ Status DavinciModel::InitFeatureMapAndP2PMem(void *dev_ptr, size_t mem_size) { | |||||
| Status DavinciModel::InitVariableMem() { | Status DavinciModel::InitVariableMem() { | ||||
| // malloc variable memory base | // malloc variable memory base | ||||
| var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM); | |||||
| var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM, GetDeviceId()); | |||||
| if (TotalVarMemSize() && (var_mem_base_ == nullptr)) { | if (TotalVarMemSize() && (var_mem_base_ == nullptr)) { | ||||
| Status ret = VarManager::Instance(session_id_)->MallocVarMemory(TotalVarMemSize()); | |||||
| Status ret = VarManager::Instance(session_id_)->MallocVarMemory(TotalVarMemSize(), GetDeviceId()); | |||||
| if (ret != SUCCESS) { | if (ret != SUCCESS) { | ||||
| REPORT_CALL_ERROR("E19999", "MallocVarMemory fail, var_size:%zu, model_id:%u, check invalid", | REPORT_CALL_ERROR("E19999", "MallocVarMemory fail, var_size:%zu, model_id:%u, check invalid", | ||||
| TotalVarMemSize(), model_id_); | TotalVarMemSize(), model_id_); | ||||
| GELOGE(ret, "[Malloc][VarMemory] failed, var_size:%zu, model_id:%u", TotalVarMemSize(), model_id_); | GELOGE(ret, "[Malloc][VarMemory] failed, var_size:%zu, model_id:%u", TotalVarMemSize(), model_id_); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM); | |||||
| var_mem_base_ = VarManager::Instance(session_id_)->GetVarMemoryBase(RT_MEMORY_HBM, GetDeviceId()); | |||||
| GEEVENT("[IMAS]InitVariableMem graph_%u MallocMemory type[V] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, | GEEVENT("[IMAS]InitVariableMem graph_%u MallocMemory type[V] memaddr[%p] mem_size[%zu]", runtime_param_.graph_id, | ||||
| var_mem_base_, TotalVarMemSize()); | var_mem_base_, TotalVarMemSize()); | ||||
| } | } | ||||
| @@ -2819,18 +2819,16 @@ void *DavinciModel::Run(DavinciModel *model) { | |||||
| bool seq_end_flag = false; | bool seq_end_flag = false; | ||||
| uint32_t model_id = model->Id(); | uint32_t model_id = model->Id(); | ||||
| uint32_t device_id = model->GetDeviceId(); | uint32_t device_id = model->GetDeviceId(); | ||||
| int64_t die_id = model->GetDieId(); | |||||
| ErrorManager::GetInstance().SetErrorContext(model->GetErrorContext()); | ErrorManager::GetInstance().SetErrorContext(model->GetErrorContext()); | ||||
| GELOGI("Model Run thread start, model_id:%u.", model_id); | GELOGI("Model Run thread start, model_id:%u.", model_id); | ||||
| rtError_t rt_ret = rtSetDevice(static_cast<int32_t>(device_id)); | |||||
| if (rt_ret != RT_ERROR_NONE) { | |||||
| GELOGE(FAILED, "[Run][Rtsetdevice] failed, model_id:%u, device_id:%u.", model_id, device_id); | |||||
| if (GraphLoader::SetDevice(device_id, die_id) != SUCCESS) { | |||||
| GELOGE(FAILED, "[Run][Setdevice] failed, model_id:%u, device_id:%u die_id%ld.", model_id, device_id, die_id); | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| // DeviceReset before thread run finished! | // DeviceReset before thread run finished! | ||||
| GE_MAKE_GUARD(not_used_var, [&] { GE_CHK_RT(rtDeviceReset(device_id)); }); | |||||
| GE_MAKE_GUARD(reset_device, [&] { GE_CHK_RT(GraphLoader::ResetDevice(device_id, model->GetDieId())); }); | |||||
| ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); | ErrorManager::GetInstance().SetStage(error_message::kModelExecute, error_message::kModelExecute); | ||||
| while (model->RunFlag()) { | while (model->RunFlag()) { | ||||
| // Model hasn't truly started runing before received data | // Model hasn't truly started runing before received data | ||||
| @@ -2886,7 +2884,7 @@ void *DavinciModel::Run(DavinciModel *model) { | |||||
| GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_START)); | GE_IF_BOOL_EXEC(ProfilingManager::Instance().ProfilingModelExecuteOn(), model->SetProfileTime(MODEL_INFER_START)); | ||||
| GE_TIMESTAMP_START(rtModelExecute); | GE_TIMESTAMP_START(rtModelExecute); | ||||
| GELOGI("rtModelExecute start."); | GELOGI("rtModelExecute start."); | ||||
| rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0); | |||||
| auto rt_ret = rtModelExecute(model->rt_model_handle_, model->rt_model_stream_, 0); | |||||
| GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false; | GE_IF_BOOL_EXEC(rt_ret != RT_ERROR_NONE, rslt_flg = false; | ||||
| (void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput()); | (void)model->ReturnResult(current_data.index, false, false, data_wrapper->GetOutput()); | ||||
| continue); | continue); | ||||
| @@ -59,7 +59,9 @@ namespace ge { | |||||
| // op debug need 2048 bits buffer | // op debug need 2048 bits buffer | ||||
| const size_t kOpDebugMemorySize = 2048UL; | const size_t kOpDebugMemorySize = 2048UL; | ||||
| const size_t kDebugP2pSize = 8UL; | const size_t kDebugP2pSize = 8UL; | ||||
| const size_t kDebugP2pSize = 8UL; | |||||
| const int64_t kInvalidDieId = -1; | |||||
| typedef enum tagModelProcStage { | typedef enum tagModelProcStage { | ||||
| MODEL_LOAD_START = 1, | MODEL_LOAD_START = 1, | ||||
| MODEL_LOAD_END, | MODEL_LOAD_END, | ||||
| @@ -441,13 +443,17 @@ class DavinciModel { | |||||
| /// @return void | /// @return void | ||||
| /// | /// | ||||
| void SetDeviceId(uint32_t device_id) { device_id_ = device_id; } | void SetDeviceId(uint32_t device_id) { device_id_ = device_id; } | ||||
| void SetDieId(int64_t die_id) { die_id_ = die_id; } | |||||
| /// | /// | ||||
| /// @ingroup ge | /// @ingroup ge | ||||
| /// @brief Get device Id | /// @brief Get device Id | ||||
| /// @return device id | /// @return device id | ||||
| /// | /// | ||||
| uint32_t GetDeviceId() const { return device_id_; } | |||||
| uint32_t GetDeviceId() const { | |||||
| return die_id_ == kInvalidDieId ? device_id_ : die_id_; | |||||
| } | |||||
| int64_t GetDieId() const { return die_id_; } | |||||
| bool NeedDestroyAicpuKernel() const { return need_destroy_aicpu_kernel_; } | bool NeedDestroyAicpuKernel() const { return need_destroy_aicpu_kernel_; } | ||||
| @@ -1010,6 +1016,7 @@ class DavinciModel { | |||||
| struct error_message::Context error_context_; | struct error_message::Context error_context_; | ||||
| uint32_t device_id_; | uint32_t device_id_; | ||||
| int64_t die_id_ = kInvalidDieId; | |||||
| mutex flowctrl_op_index_internal_map_mutex_; | mutex flowctrl_op_index_internal_map_mutex_; | ||||
| map<uint32_t, uint32_t> flowctrl_op_index_internal_map_; | map<uint32_t, uint32_t> flowctrl_op_index_internal_map_; | ||||
| @@ -324,7 +324,7 @@ bool ModelManager::IsNeedHybridLoad(ge::GeRootModel &ge_root_model) { | |||||
| /// @return Status run result | /// @return Status run result | ||||
| /// | /// | ||||
| Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::GeRootModel> &ge_root_model, | Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::GeRootModel> &ge_root_model, | ||||
| std::shared_ptr<ModelListener> listener) { | |||||
| std::shared_ptr<ModelListener> listener, uint32_t &device_id, int64_t die_id) { | |||||
| GE_CHK_BOOL_RET_STATUS(listener.get() != nullptr, PARAM_INVALID, "[Check][Param] Param incorrect, listener is null"); | GE_CHK_BOOL_RET_STATUS(listener.get() != nullptr, PARAM_INVALID, "[Check][Param] Param incorrect, listener is null"); | ||||
| if (model_id == INVALID_MODEL_ID) { | if (model_id == INVALID_MODEL_ID) { | ||||
| GenModelId(&model_id); | GenModelId(&model_id); | ||||
| @@ -342,7 +342,8 @@ Status ModelManager::LoadModelOnline(uint32_t &model_id, const shared_ptr<ge::Ge | |||||
| davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano + | davinci_model->SetProfileTime(MODEL_LOAD_START, (timespec.tv_sec * kTimeSpecNano + | ||||
| timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond | timespec.tv_nsec)); // 1000 ^ 3 converts second to nanosecond | ||||
| davinci_model->SetId(model_id); | davinci_model->SetId(model_id); | ||||
| davinci_model->SetDeviceId(GetContext().DeviceId()); | |||||
| davinci_model->SetDeviceId(device_id); | |||||
| davinci_model->SetDieId(die_id); | |||||
| auto root_graph = ge_root_model->GetRootGraph(); | auto root_graph = ge_root_model->GetRootGraph(); | ||||
| GE_CHECK_NOTNULL(root_graph); | GE_CHECK_NOTNULL(root_graph); | ||||
| @@ -71,7 +71,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY ModelManager { | |||||
| /// @author @ | /// @author @ | ||||
| /// | /// | ||||
| ge::Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeRootModel> &ge_root_model, | ge::Status LoadModelOnline(uint32_t &model_id, const std::shared_ptr<ge::GeRootModel> &ge_root_model, | ||||
| std::shared_ptr<ModelListener> listener); | |||||
| std::shared_ptr<ModelListener> listener,uint32_t &device_id, int64_t die_id); | |||||
| ge::Status DoLoadHybridModelOnline(uint32_t model_id, const string &model_name, | ge::Status DoLoadHybridModelOnline(uint32_t model_id, const string &model_name, | ||||
| const shared_ptr<ge::GeRootModel> &ge_root_model, | const shared_ptr<ge::GeRootModel> &ge_root_model, | ||||
| @@ -98,6 +98,7 @@ | |||||
| #include "graph/passes/hccl_continuous_memcpy_pass.h" | #include "graph/passes/hccl_continuous_memcpy_pass.h" | ||||
| #include "graph/passes/parallel_group_pass.h" | #include "graph/passes/parallel_group_pass.h" | ||||
| #include "graph/passes/buffer_pool_memory_pass.h" | #include "graph/passes/buffer_pool_memory_pass.h" | ||||
| #include "graph/passes/mds_pass.h" | |||||
| #include "graph/build/label_allocator.h" | #include "graph/build/label_allocator.h" | ||||
| #include "graph/utils/tensor_adapter.h" | #include "graph/utils/tensor_adapter.h" | ||||
| #include "inc/pass_manager.h" | #include "inc/pass_manager.h" | ||||
| @@ -110,6 +111,7 @@ | |||||
| #include "external/graph/types.h" | #include "external/graph/types.h" | ||||
| #include "common/util/error_manager/error_manager.h" | #include "common/util/error_manager/error_manager.h" | ||||
| #include "common/profiling/profiling_manager.h" | #include "common/profiling/profiling_manager.h" | ||||
| #include "graph/debug/ge_attr_define.h" | |||||
| namespace { | namespace { | ||||
| const char *const kSummary = "Summary"; | const char *const kSummary = "Summary"; | ||||
| @@ -1087,7 +1089,6 @@ Status GraphManager::LoadGraph(const GeRootModelPtr &ge_root_model, const GraphN | |||||
| if (!options_.run_graph_flag) { | if (!options_.run_graph_flag) { | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| ErrorManager::GetInstance().SetStage(error_message::kModelLoad, error_message::kModelLoad); | ErrorManager::GetInstance().SetStage(error_message::kModelLoad, error_message::kModelLoad); | ||||
| GE_CHECK_NOTNULL(executor_); | GE_CHECK_NOTNULL(executor_); | ||||
| return executor_->LoadGraph(ge_root_model, graph_node); | return executor_->LoadGraph(ge_root_model, graph_node); | ||||
| @@ -2816,9 +2817,40 @@ const map<std::string, std::string> *GraphManager::GetGraphOptions(uint32_t grap | |||||
| } | } | ||||
| void GraphManager::SetOptionsRunGraphFlag(bool run_graph_flag) { options_.run_graph_flag = run_graph_flag; } | void GraphManager::SetOptionsRunGraphFlag(bool run_graph_flag) { options_.run_graph_flag = run_graph_flag; } | ||||
| Status GraphManager::SetNodeCutInfo(ComputeGraphPtr &compute_graph) { | |||||
| auto instance_ptr = ge::GELib::GetInstance(); | |||||
| if (instance_ptr == nullptr || !instance_ptr->InitFlag()) { | |||||
| REPORT_INNER_ERROR("E19999", "GeLib is not init before, check invalid"); | |||||
| GELOGE(GE_CLI_GE_NOT_INITIALIZED, "[Check][Param] GE is not initialized"); | |||||
| return FAILED; | |||||
| } | |||||
| for (const auto &node : compute_graph->GetDirectNode()) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto kernel_lib_name = node->GetOpDesc()->GetOpKernelLibName(); | |||||
| OpsKernelInfoStorePtr kernel_info = instance_ptr->OpsKernelManagerObj().GetOpsKernelInfoStore(kernel_lib_name); | |||||
| if (kernel_info == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "Find ops kernel by name:%s failed", | |||||
| kernel_lib_name.c_str()); | |||||
| GELOGE(FAILED, "[Get][OpsKernelInfoStore] by name:%s failed", kernel_lib_name.c_str()); | |||||
| return FAILED; | |||||
| } | |||||
| GE_CHK_STATUS_RET(kernel_info->SetCutSupportedInfo(node)); | |||||
| } | |||||
| } | |||||
| Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, | Status GraphManager::OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, | ||||
| uint64_t session_id) { | uint64_t session_id) { | ||||
| GE_TIMESTAMP_START(MDS); | |||||
| // Set the cut support information based on the engine of the node | |||||
| EnginePlacer engine_placer; | |||||
| engine_placer.SetComputeGraph(compute_graph); | |||||
| GE_CHK_STATUS_RET(engine_placer.Run()); | |||||
| GE_CHK_STATUS_RET(SetNodeCutInfo(compute_graph)); | |||||
| // mds pass | |||||
| PassManager graph_pass; | |||||
| GE_CHK_STATUS_RET(graph_pass.AddPass("OptimizeSubgraph::MDS", new (std::nothrow) ModelDeploySchedulerPass)) | |||||
| GE_CHK_STATUS_RET(graph_pass.Run(compute_graph)); | |||||
| GE_TIMESTAMP_EVENT_END(MDS, "OptimizeSubgraph::MDS"); | |||||
| // graph partition | // graph partition | ||||
| // Stage partition, only for root graph | // Stage partition, only for root graph | ||||
| GE_TIMESTAMP_START(StagePartition); | GE_TIMESTAMP_START(StagePartition); | ||||
| @@ -242,7 +242,7 @@ class GraphManager { | |||||
| uint64_t session_id = INVALID_SESSION_ID); | uint64_t session_id = INVALID_SESSION_ID); | ||||
| Status OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, uint64_t session_id); | Status OptimizeSubgraph(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, uint64_t session_id); | ||||
| Status SetNodeCutInfo (ComputeGraphPtr &compute_graph); | |||||
| Status Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, | Status Build(const GraphNodePtr &graph_node, ComputeGraphPtr &compute_graph, | ||||
| GeRootModelPtr &ge_root_model, uint64_t session_id); | GeRootModelPtr &ge_root_model, uint64_t session_id); | ||||
| @@ -23,12 +23,16 @@ Status MemoryAllocator::Initialize(uint32_t device_id) { | |||||
| GELOGI("MemoryAllocator::Initialize"); | GELOGI("MemoryAllocator::Initialize"); | ||||
| // when redo Initialize free memory | // when redo Initialize free memory | ||||
| for (auto &it : memory_base_map_) { | |||||
| if (FreeMemory(it.second.memory_addr_, device_id) != ge::SUCCESS) { | |||||
| GELOGW("Initialize: FreeMemory failed"); | |||||
| for (auto &it_map : deviceid_2_memory_bases_map_) { | |||||
| for (auto &it : it_map.second) { | |||||
| if (FreeMemory(it.second.memory_addr_, device_id) != ge::SUCCESS) { | |||||
| GELOGW("Initialize: FreeMemory failed"); | |||||
| } | |||||
| } | } | ||||
| it_map.second.clear(); | |||||
| } | } | ||||
| memory_base_map_.clear(); | |||||
| deviceid_2_memory_bases_map_.clear(); | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -36,12 +40,16 @@ void MemoryAllocator::Finalize(uint32_t device_id) { | |||||
| GELOGI("MemoryAllocator::Finalize"); | GELOGI("MemoryAllocator::Finalize"); | ||||
| // free memory | // free memory | ||||
| for (auto &it : memory_base_map_) { | |||||
| if (FreeMemory(it.second.memory_addr_, device_id) != ge::SUCCESS) { | |||||
| GELOGW("Finalize: FreeMemory failed"); | |||||
| for (auto &it_map : deviceid_2_memory_bases_map_) { | |||||
| for (auto &it : it_map.second) { | |||||
| if (FreeMemory(it.second.memory_addr_, device_id) != ge::SUCCESS) { | |||||
| GELOGW("Finalize: FreeMemory failed"); | |||||
| } | |||||
| } | } | ||||
| it_map.second.clear(); | |||||
| } | } | ||||
| memory_base_map_.clear(); | |||||
| deviceid_2_memory_bases_map_.clear(); | |||||
| } | } | ||||
| uint8_t *MemoryAllocator::MallocMemory(const string &purpose, size_t memory_size, uint32_t device_id) const { | uint8_t *MemoryAllocator::MallocMemory(const string &purpose, size_t memory_size, uint32_t device_id) const { | ||||
| @@ -75,12 +83,16 @@ Status MemoryAllocator::FreeMemory(uint8_t *memory_addr, uint32_t device_id) con | |||||
| uint8_t *MemoryAllocator::MallocMemory(const string &purpose, const string &memory_key, size_t memory_size, | uint8_t *MemoryAllocator::MallocMemory(const string &purpose, const string &memory_key, size_t memory_size, | ||||
| uint32_t device_id) { | uint32_t device_id) { | ||||
| auto it = memory_base_map_.find(memory_key); | |||||
| if (it != memory_base_map_.end()) { | |||||
| it->second.memory_used_num_++; | |||||
| return it->second.memory_addr_; | |||||
| map<string, MemoryInfo> memory_base_map; | |||||
| auto it_map = deviceid_2_memory_bases_map_.find(device_id); | |||||
| if (it_map != deviceid_2_memory_bases_map_.end()) { | |||||
| memory_base_map = it_map->second; | |||||
| auto it = it_map->second.find(memory_key); | |||||
| if (it != it_map->second.end()) { | |||||
| it->second.memory_used_num_++; | |||||
| return it->second.memory_addr_; | |||||
| } | |||||
| } | } | ||||
| uint8_t *memory_addr = MallocMemory(purpose, memory_size, device_id); | uint8_t *memory_addr = MallocMemory(purpose, memory_size, device_id); | ||||
| if (memory_addr == nullptr) { | if (memory_addr == nullptr) { | ||||
| @@ -91,16 +103,27 @@ uint8_t *MemoryAllocator::MallocMemory(const string &purpose, const string &memo | |||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| MemoryInfo memory_info(memory_addr, memory_size); | |||||
| MemoryInfo memory_info(memory_addr, memory_size, device_id); | |||||
| memory_info.memory_used_num_++; | memory_info.memory_used_num_++; | ||||
| memory_base_map_[memory_key] = memory_info; | |||||
| memory_base_map[memory_key] = memory_info; | |||||
| deviceid_2_memory_bases_map_[device_id] = memory_base_map; | |||||
| mem_malloced_ = true; | mem_malloced_ = true; | ||||
| return memory_addr; | return memory_addr; | ||||
| } | } | ||||
| Status MemoryAllocator::FreeMemory(const string &memory_key, uint32_t device_id) { | Status MemoryAllocator::FreeMemory(const string &memory_key, uint32_t device_id) { | ||||
| auto it = memory_base_map_.find(memory_key); | |||||
| if (it == memory_base_map_.end()) { | |||||
| auto it_map = deviceid_2_memory_bases_map_.find(device_id); | |||||
| if (it_map == deviceid_2_memory_bases_map_.end()){ | |||||
| if (mem_malloced_) { | |||||
| GELOGW( | |||||
| "MemoryAllocator::FreeMemory failed," | |||||
| " memory_key[%s] was not exist, device_id = %u.", | |||||
| memory_key.c_str(), device_id); | |||||
| } | |||||
| return ge::INTERNAL_ERROR; | |||||
| } | |||||
| auto it = it_map->second.find(memory_key); | |||||
| if (it == it_map->second.end()) { | |||||
| if (mem_malloced_) { | if (mem_malloced_) { | ||||
| GELOGW( | GELOGW( | ||||
| "MemoryAllocator::FreeMemory failed," | "MemoryAllocator::FreeMemory failed," | ||||
| @@ -109,7 +132,6 @@ Status MemoryAllocator::FreeMemory(const string &memory_key, uint32_t device_id) | |||||
| } | } | ||||
| return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
| } | } | ||||
| if (it->second.memory_used_num_ > 1) { | if (it->second.memory_used_num_ > 1) { | ||||
| GELOGW("MemoryAllocator::FreeMemory memory_key[%s] should not be released, reference count %d", memory_key.c_str(), | GELOGW("MemoryAllocator::FreeMemory memory_key[%s] should not be released, reference count %d", memory_key.c_str(), | ||||
| it->second.memory_used_num_); | it->second.memory_used_num_); | ||||
| @@ -129,20 +151,28 @@ Status MemoryAllocator::FreeMemory(const string &memory_key, uint32_t device_id) | |||||
| GELOGI("MemoryAllocator::FreeMemory device_id = %u", device_id); | GELOGI("MemoryAllocator::FreeMemory device_id = %u", device_id); | ||||
| memory_base_map_.erase(it); | |||||
| it_map->second.erase(it); | |||||
| return ge::SUCCESS; | return ge::SUCCESS; | ||||
| } | } | ||||
| uint8_t *MemoryAllocator::GetMemoryAddr(const string &memory_key, uint32_t device_id) { | uint8_t *MemoryAllocator::GetMemoryAddr(const string &memory_key, uint32_t device_id) { | ||||
| auto it = memory_base_map_.find(memory_key); | |||||
| if (it == memory_base_map_.end()) { | |||||
| auto it_map = deviceid_2_memory_bases_map_.find(device_id); | |||||
| if (it_map == deviceid_2_memory_bases_map_.end()) { | |||||
| GELOGW( | |||||
| "MemoryAllocator::GetMemoryAddr failed," | |||||
| " memory_key[%s] was not exist, device_id = %u.", | |||||
| memory_key.c_str(), device_id); | |||||
| return nullptr; | |||||
| } | |||||
| auto it = it_map->second.find(memory_key); | |||||
| if (it == it_map->second.end()) { | |||||
| GELOGW( | GELOGW( | ||||
| "MemoryAllocator::GetMemoryAddr failed," | "MemoryAllocator::GetMemoryAddr failed," | ||||
| " memory_key[%s] was not exist, device_id = %u.", | " memory_key[%s] was not exist, device_id = %u.", | ||||
| memory_key.c_str(), device_id); | memory_key.c_str(), device_id); | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| return it->second.memory_addr_; | return it->second.memory_addr_; | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -32,10 +32,13 @@ | |||||
| namespace ge { | namespace ge { | ||||
| class MemoryInfo { | class MemoryInfo { | ||||
| public: | public: | ||||
| MemoryInfo() : memory_addr_(nullptr), memory_size_(0), memory_used_num_(0) {} | |||||
| MemoryInfo() : memory_addr_(nullptr), memory_size_(0), memory_used_num_(0), device_id_(0) {} | |||||
| MemoryInfo(uint8_t *memory_addr, size_t memory_size) | MemoryInfo(uint8_t *memory_addr, size_t memory_size) | ||||
| : memory_addr_(memory_addr), memory_size_(memory_size), memory_used_num_(0) {} | |||||
| : memory_addr_(memory_addr), memory_size_(memory_size), memory_used_num_(0), device_id_(0) {} | |||||
| MemoryInfo(uint8_t *memory_addr, size_t memory_size, uint32_t device_id) | |||||
| : memory_addr_(memory_addr), memory_size_(memory_size), device_id_(device_id), memory_used_num_(0) {} | |||||
| MemoryInfo &operator=(const MemoryInfo &op) { | MemoryInfo &operator=(const MemoryInfo &op) { | ||||
| if (&op == this) { | if (&op == this) { | ||||
| @@ -44,7 +47,7 @@ class MemoryInfo { | |||||
| this->memory_addr_ = op.memory_addr_; | this->memory_addr_ = op.memory_addr_; | ||||
| this->memory_size_ = op.memory_size_; | this->memory_size_ = op.memory_size_; | ||||
| this->memory_used_num_ = op.memory_used_num_; | |||||
| this->device_id_ = op.device_id_; | |||||
| return *this; | return *this; | ||||
| } | } | ||||
| @@ -52,12 +55,14 @@ class MemoryInfo { | |||||
| this->memory_addr_ = op.memory_addr_; | this->memory_addr_ = op.memory_addr_; | ||||
| this->memory_size_ = op.memory_size_; | this->memory_size_ = op.memory_size_; | ||||
| this->memory_used_num_ = op.memory_used_num_; | this->memory_used_num_ = op.memory_used_num_; | ||||
| this->device_id_ = op.device_id_; | |||||
| } | } | ||||
| virtual ~MemoryInfo() = default; | virtual ~MemoryInfo() = default; | ||||
| uint8_t *memory_addr_; | uint8_t *memory_addr_; | ||||
| uint64_t memory_size_; | uint64_t memory_size_; | ||||
| int32_t memory_used_num_; | int32_t memory_used_num_; | ||||
| uint32_t device_id_; | |||||
| }; | }; | ||||
| class MemoryAllocator { | class MemoryAllocator { | ||||
| @@ -133,7 +138,7 @@ class MemoryAllocator { | |||||
| private: | private: | ||||
| rtMemType_t memory_type_; | rtMemType_t memory_type_; | ||||
| bool mem_malloced_; | bool mem_malloced_; | ||||
| map<string, MemoryInfo> memory_base_map_; | |||||
| map<uint32_t, map<string, MemoryInfo>> deviceid_2_memory_bases_map_; | |||||
| }; | }; | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -348,7 +348,7 @@ ge::Status VarManager::Init(const uint32_t &version, const uint64_t &session_id, | |||||
| device_id_ = device_id; | device_id_ = device_id; | ||||
| session_id_ = session_id; | session_id_ = session_id; | ||||
| job_id_ = job_id; | job_id_ = job_id; | ||||
| var_resource_ = std::unique_ptr<VarResource>(new (std::nothrow) VarResource(session_id_)); | |||||
| var_resource_ = std::unique_ptr<VarResource>(new(std::nothrow) VarResource(session_id_)); | |||||
| if (var_resource_ == nullptr) { | if (var_resource_ == nullptr) { | ||||
| GELOGW("VarManager init failed session id = %lu.", session_id); | GELOGW("VarManager init failed session id = %lu.", session_id); | ||||
| return ge::INTERNAL_ERROR; | return ge::INTERNAL_ERROR; | ||||
| @@ -637,7 +637,7 @@ rtMemType_t VarManager::GetVarMemType(const int64_t &offset) { | |||||
| return var_resource_->GetVarMemType(offset); | return var_resource_->GetVarMemType(offset); | ||||
| } | } | ||||
| ge::Status VarManager::MallocVarMemory(size_t memory_size) { | |||||
| ge::Status VarManager::MallocVarMemory(size_t memory_size, uint32_t device_id) { | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| uint8_t *var_mem_base = nullptr; | uint8_t *var_mem_base = nullptr; | ||||
| string memory_key = std::to_string(session_id_); | string memory_key = std::to_string(session_id_); | ||||
| @@ -649,7 +649,7 @@ ge::Status VarManager::MallocVarMemory(size_t memory_size) { | |||||
| var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize; | var_memory_size = (var_memory_size + kSessionMemAlignSize - 1) / kSessionMemAlignSize * kSessionMemAlignSize; | ||||
| const string purpose("variables and constant op memory in training network."); | const string purpose("variables and constant op memory in training network."); | ||||
| var_mem_base = MemManager::Instance().MemInstance(RT_MEMORY_HBM).MallocMemory(purpose, memory_key, var_memory_size); | |||||
| var_mem_base = MemManager::Instance().MemInstance(RT_MEMORY_HBM).MallocMemory(purpose, memory_key, var_memory_size, device_id); | |||||
| if (var_mem_base == nullptr) { | if (var_mem_base == nullptr) { | ||||
| GELOGE(ge::INTERNAL_ERROR, "[Malloc][VarMemory] failed, size:%zu, session_id:%s", | GELOGE(ge::INTERNAL_ERROR, "[Malloc][VarMemory] failed, size:%zu, session_id:%s", | ||||
| var_memory_size, memory_key.c_str()); | var_memory_size, memory_key.c_str()); | ||||
| @@ -658,22 +658,22 @@ ge::Status VarManager::MallocVarMemory(size_t memory_size) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type) { | |||||
| uint8_t *VarManager::GetVarMemoryBase(rtMemType_t memory_type, uint32_t device_id) { | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| if (memory_type == RT_MEMORY_RDMA_HBM) { | if (memory_type == RT_MEMORY_RDMA_HBM) { | ||||
| return MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).GetRdmaBaseAddr(); | return MemManager::Instance().RdmaPoolInstance(RT_MEMORY_HBM).GetRdmaBaseAddr(); | ||||
| } | } | ||||
| string memory_key = std::to_string(session_id_); | string memory_key = std::to_string(session_id_); | ||||
| return MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(memory_key); | |||||
| return MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(memory_key, device_id); | |||||
| } | } | ||||
| uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type) { | |||||
| uint8_t *VarManager::GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type, uint32_t device_id) { | |||||
| std::lock_guard<std::recursive_mutex> lock(mutex_); | std::lock_guard<std::recursive_mutex> lock(mutex_); | ||||
| if (memory_type == RT_MEMORY_RDMA_HBM) { | if (memory_type == RT_MEMORY_RDMA_HBM) { | ||||
| return logic_addr; | return logic_addr; | ||||
| } | } | ||||
| string mem_key = std::to_string(session_id_); | string mem_key = std::to_string(session_id_); | ||||
| uint8_t *mem_base = MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(mem_key); | |||||
| uint8_t *mem_base = MemManager::Instance().MemInstance(memory_type).GetMemoryAddr(mem_key, device_id); | |||||
| if (mem_base == nullptr) { | if (mem_base == nullptr) { | ||||
| return nullptr; | return nullptr; | ||||
| } | } | ||||
| @@ -231,7 +231,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
| ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); | ge::Status RenewCurVarDesc(const std::string &var_name, ge::OpDescPtr op_desc); | ||||
| ge::Status MallocVarMemory(size_t memory_size = kMemoryVarManagerMallocSize); | |||||
| ge::Status MallocVarMemory(size_t memory_size = kMemoryVarManagerMallocSize, uint32_t device_id = 0); | |||||
| ge::Status FreeVarMemory(); | ge::Status FreeVarMemory(); | ||||
| @@ -277,9 +277,9 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
| rtMemType_t GetVarMemType(const int64_t &offset); | rtMemType_t GetVarMemType(const int64_t &offset); | ||||
| uint8_t *GetVarMemoryBase(rtMemType_t memory_type); | |||||
| uint8_t *GetVarMemoryBase(rtMemType_t memory_type, uint32_t device_id = 0); | |||||
| uint8_t *GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type); | |||||
| uint8_t *GetVarMemoryAddr(uint8_t *logic_addr, rtMemType_t memory_type, uint32_t device_id = 0); | |||||
| Status GetAllVariables(std::map<std::string, GeTensorDesc> &all_variables); | Status GetAllVariables(std::map<std::string, GeTensorDesc> &all_variables); | ||||
| @@ -293,6 +293,7 @@ class FMK_FUNC_HOST_VISIBILITY FMK_FUNC_DEV_VISIBILITY VarManager { | |||||
| size_t var_mem_logic_base_; | size_t var_mem_logic_base_; | ||||
| size_t use_max_mem_size_; | size_t use_max_mem_size_; | ||||
| std::unique_ptr<ge::VarResource> var_resource_; | std::unique_ptr<ge::VarResource> var_resource_; | ||||
| // map<uint32_t , std::shared_ptr<ge::VarResource>> var_resource_map_; | |||||
| map<rtMemType_t, MemResource *> mem_resource_map_; | map<rtMemType_t, MemResource *> mem_resource_map_; | ||||
| mutable std::recursive_mutex mutex_; | mutable std::recursive_mutex mutex_; | ||||
| @@ -0,0 +1,142 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "./base_mds_kernel.h" | |||||
| namespace ge { | |||||
| namespace mds_cut_pass { | |||||
| shared_ptr<DeploySchedulerKernel> GetKernelByType(const NodePtr &node) { | |||||
| if (node == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "Param node is nullptr, check invalid"); | |||||
| GELOGE(FAILED, "[Check][Param] parameter node is nullptr."); | |||||
| return nullptr; | |||||
| } | |||||
| KernelFactory &factory = KernelFactory::Instance(); | |||||
| string type = node->GetType(); | |||||
| if (type == FRAMEWORKOP) { | |||||
| if (!ge::AttrUtils::GetStr(node->GetOpDesc(), ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE, type)) { | |||||
| REPORT_CALL_ERROR("E19999", "Get Attr:%s from op:%s(%s) failed", ATTR_NAME_FRAMEWORK_ORIGINAL_TYPE.c_str(), | |||||
| node->GetName().c_str(), node->GetType().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| } | |||||
| return factory.Create(type); | |||||
| } | |||||
| } // namespace mds_cut_pass | |||||
| shared_ptr<DeploySchedulerKernel> DeploySchedulerKernel::Instance() { | |||||
| static const std::shared_ptr<DeploySchedulerKernel> instance_ptr = | |||||
| shared_ptr<DeploySchedulerKernel>(new (std::nothrow) DeploySchedulerKernel()); | |||||
| return instance_ptr; | |||||
| } | |||||
| Status DeploySchedulerKernel::CutN(const ge::NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| for (auto &in_anchor : node->GetAllInDataAnchors()) { | |||||
| GE_CHECK_NOTNULL(in_anchor); | |||||
| auto src_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (src_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto tensor_desc = op_desc->MutableInputDesc(in_anchor->GetIdx()); | |||||
| auto src_node = src_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(src_node); | |||||
| auto src_op_desc = src_node->GetOpDesc(); | |||||
| auto src_tensor_desc = src_op_desc->MutableOutputDesc(src_anchor->GetIdx()); | |||||
| GE_CHECK_NOTNULL(src_tensor_desc); | |||||
| // peer out shape is cutted already | |||||
| if (MdsUtils::IsDistributedDeploySupported(src_tensor_desc, kCutN)) { | |||||
| if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutN)) { | |||||
| tensor_desc->SetShape(src_tensor_desc->GetShape()); | |||||
| } else { | |||||
| MDS_REQUIRE_SUCCESS( | |||||
| MdsUtils::DataGather(src_anchor, in_anchor), "[CutN] failed to gather between node[%s][%d] to node[%s][%d]", | |||||
| src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); | |||||
| } | |||||
| } else { | |||||
| if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutN)) { | |||||
| MDS_REQUIRE_SUCCESS(MdsUtils::DataSlice(src_anchor, in_anchor, input_node_), | |||||
| "[CutN] failed to slice between node[%s][%d] to node[%s][%d]", | |||||
| src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), | |||||
| in_anchor->GetIdx()); | |||||
| } else { | |||||
| tensor_desc->SetShape(src_tensor_desc->GetShape()); | |||||
| } | |||||
| } | |||||
| // insert hcomallreduce for cutn | |||||
| bool is_grad_compute_node = false; | |||||
| if (ge::AttrUtils::GetBool(src_node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) && | |||||
| is_grad_compute_node) { | |||||
| MDS_REQUIRE_SUCCESS( | |||||
| MdsUtils::DataReduce(src_anchor, in_anchor), "[CutN] failed to reduce between node[%s][%d] to node[%s][%d]", | |||||
| src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); | |||||
| } | |||||
| } | |||||
| // call infer shape, update output shape | |||||
| MDS_REQUIRE_SUCCESS(node->InferShapeAndType(), "[CutN] %s call infershape failed", node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status DeploySchedulerKernel::CutH(const ge::NodePtr &node) { | |||||
| GE_CHECK_NOTNULL(node); | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| for (auto &in_anchor : node->GetAllInDataAnchors()) { | |||||
| GE_CHECK_NOTNULL(in_anchor); | |||||
| auto src_anchor = in_anchor->GetPeerOutAnchor(); | |||||
| if (src_anchor == nullptr) { | |||||
| continue; | |||||
| } | |||||
| auto tensor_desc = op_desc->MutableInputDesc(in_anchor->GetIdx()); | |||||
| auto src_node = src_anchor->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(src_node); | |||||
| auto src_op_desc = src_node->GetOpDesc(); | |||||
| auto src_tensor_desc = src_op_desc->MutableOutputDesc(src_anchor->GetIdx()); | |||||
| GE_CHECK_NOTNULL(src_tensor_desc); | |||||
| // peer out shape is cutted already | |||||
| if (MdsUtils::IsDistributedDeploySupported(src_tensor_desc, kCutH)) { | |||||
| if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutH)) { | |||||
| MDS_REQUIRE_SUCCESS(HaloExchangeProcess(node, in_anchor->GetIdx()), | |||||
| "[CutH] failed to do overlap between node[%s][%d] to node[%s][%d]", | |||||
| src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), | |||||
| in_anchor->GetIdx()); | |||||
| } else { | |||||
| MDS_REQUIRE_SUCCESS( | |||||
| MdsUtils::DataGather(src_anchor, in_anchor), "[CutH] failed to gather between node[%s][%d] to node[%s][%d]", | |||||
| src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), in_anchor->GetIdx()); | |||||
| } | |||||
| } else { | |||||
| if (MdsUtils::IsDistributedDeploySupported(tensor_desc, kCutH)) { | |||||
| MDS_REQUIRE_SUCCESS(MdsUtils::DataSlice(src_anchor, in_anchor, input_node_), | |||||
| "[CutH] failed to slice between node[%s][%d] to node[%s][%d]", | |||||
| src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), | |||||
| in_anchor->GetIdx()); | |||||
| } else { | |||||
| MDS_REQUIRE_SUCCESS(HaloExchangeProcess(node, in_anchor->GetIdx(), true), | |||||
| "[CutH] failed to do overlap between node[%s][%d] to node[%s][%d]", | |||||
| src_op_desc->GetName().c_str(), src_anchor->GetIdx(), op_desc->GetName().c_str(), | |||||
| in_anchor->GetIdx()); | |||||
| } | |||||
| } | |||||
| } | |||||
| // call infer shape, update output shape | |||||
| MDS_REQUIRE_SUCCESS(node->InferShapeAndType(), "[CutH] call infer shape failed", node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,76 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_ | |||||
| #define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_ | |||||
| #include <vector> | |||||
| #include "common/op/ge_op_utils.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/graph.h" | |||||
| #include "graph/op_desc.h" | |||||
| #include "graph/debug/ge_op_types.h" | |||||
| #include "framework/common/types.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/shape_refiner.h" | |||||
| #include "../pass_utils.h" | |||||
| #include "./mds_utils.h" | |||||
| #include "./mds_kernel_factory.h" | |||||
| namespace ge { | |||||
| class DeploySchedulerKernel { | |||||
| public: | |||||
| static shared_ptr<DeploySchedulerKernel> Instance(); | |||||
| /// CutN imply | |||||
| /// @param [in] node_ptr | |||||
| virtual Status CutN(const ge::NodePtr &node_ptr); | |||||
| /// CutH imply | |||||
| /// @param [in] node_ptr | |||||
| virtual Status CutH(const ge::NodePtr &node_ptr); | |||||
| /// DynamicCutN imply | |||||
| /// @param [in] node_ptr | |||||
| virtual Status DynamicCutN(const ge::NodePtr &node_ptr); | |||||
| /// DynamicCutH imply | |||||
| /// @param [in] node_ptr | |||||
| virtual Status DynamicCutH(const ge::NodePtr &node_ptr); | |||||
| // halo exchange process | |||||
| Status HaloExchangeProcess(NodePtr node, int64_t index, bool local_slice = false); | |||||
| NodePtr GetInputNode() { | |||||
| return input_node_; | |||||
| } | |||||
| DeploySchedulerKernel &operator=(const DeploySchedulerKernel &kernel) = delete; | |||||
| DeploySchedulerKernel(const DeploySchedulerKernel &kernel) = delete; | |||||
| protected: | |||||
| DeploySchedulerKernel() = default; | |||||
| virtual ~DeploySchedulerKernel() = default; | |||||
| private: | |||||
| NodePtr input_node_ = nullptr; | |||||
| }; | |||||
| namespace mds_cut_pass { | |||||
| shared_ptr<DeploySchedulerKernel> GetKernelByType(const NodePtr &node); | |||||
| } | |||||
| } // namespace ge | |||||
| #endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_BASE_MDS_KERNEL_H_ | |||||
| @@ -0,0 +1,30 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "conv2d_mds_kernel.h" | |||||
| #include "mds_kernel_factory.h" | |||||
| namespace ge { | |||||
| Status Conv2dDeploySchedulerKernel::CutN(const ge::NodePtr node_ptr) { | |||||
| return DeploySchedulerKernel::CutN(node_ptr); | |||||
| } | |||||
| Status Conv2dDeploySchedulerKernel::CutH(const ge::NodePtr node_ptr) { | |||||
| return DeploySchedulerKernel::CutH(node_ptr); | |||||
| } | |||||
| REGISTER_MDS_KERNEL(CONV2D, Conv2dDeploySchedulerKernel); | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,29 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_CONV2D_MDS_KERNEL_H_ | |||||
| #define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_CONV2D_MDS_KERNEL_H_ | |||||
| #include "base_mds_kernel.h" | |||||
| namespace ge { | |||||
| class Conv2dDeploySchedulerKernel : public DeploySchedulerKernel { | |||||
| public: | |||||
| Status CutN(const ge::NodePtr& node_ptr) override; | |||||
| Status CutH(const ge::NodePtr& node_ptr) override; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_CONV2D_MDS_KERNEL_H_ | |||||
| @@ -0,0 +1,102 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_KERNEL_FACTORY_H_ | |||||
| #define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_KERNEL_FACTORY_H_ | |||||
| #include <functional> | |||||
| #include <map> | |||||
| #include <memory> | |||||
| #include <string> | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "framework/common/debug/ge_log.h" | |||||
| #include "graph/graph.h" | |||||
| using std::string; | |||||
| namespace ge { | |||||
| class DeploySchedulerKernel; | |||||
| /// | |||||
| /// @brief DeploySchedulerKernel create factory | |||||
| /// | |||||
| class KernelFactory { | |||||
| public: | |||||
| // KernelCreator(function), type definition | |||||
| using KERNEL_CREATOR_FUN = std::function<std::shared_ptr<DeploySchedulerKernel>(void)>; | |||||
| /// | |||||
| /// Get singleton instance | |||||
| /// | |||||
| static KernelFactory &Instance() { | |||||
| static KernelFactory instance; | |||||
| return instance; | |||||
| } | |||||
| /// | |||||
| /// create DeploySchedulerKernel | |||||
| /// @param [in] op_type operation type | |||||
| /// | |||||
| std::shared_ptr<DeploySchedulerKernel> Create(const std::string &op_type) { | |||||
| std::map<std::string, KERNEL_CREATOR_FUN>::iterator iter = creator_map_.find(op_type); | |||||
| if (iter != creator_map_.end()) { | |||||
| return iter->second(); | |||||
| } | |||||
| return nullptr; | |||||
| } | |||||
| // DeploySchedulerKernel registration function to register different types of DeploySchedulerKernel to the factory | |||||
| class Register { | |||||
| public: | |||||
| /// | |||||
| /// @brief Constructor | |||||
| /// @param [in] type operation type | |||||
| /// @param [in| fun DeploySchedulerKernel function of the operation | |||||
| /// | |||||
| Register(const string &type, const KERNEL_CREATOR_FUN &fun) { | |||||
| KernelFactory::Instance().RegisterCreator(type, fun); | |||||
| } | |||||
| ~Register() = default; | |||||
| }; | |||||
| protected: | |||||
| KernelFactory() = default; | |||||
| ~KernelFactory() = default; | |||||
| // register creator, this function will call in the constructor | |||||
| void RegisterCreator(const string &type, const KERNEL_CREATOR_FUN &fun) { | |||||
| std::map<std::string, KERNEL_CREATOR_FUN>::iterator iter = creator_map_.find(type); | |||||
| if (iter != creator_map_.end()) { | |||||
| GELOGW("KernelFactory::RegisterCreator: %s creator already exist", type.c_str()); | |||||
| return; | |||||
| } | |||||
| creator_map_[type] = fun; | |||||
| } | |||||
| private: | |||||
| std::map<std::string, KERNEL_CREATOR_FUN> creator_map_{}; | |||||
| }; | |||||
| #define REGISTER_MDS_KERNEL(type, clazz) \ | |||||
| std::shared_ptr<DeploySchedulerKernel> Creator_##type##_Kernel() { \ | |||||
| std::shared_ptr<clazz> ptr = nullptr; \ | |||||
| ptr = MakeShared<clazz>(); \ | |||||
| return ptr; \ | |||||
| } \ | |||||
| KernelFactory::Register g_##type##_Kernel_Creator(type, Creator_##type##_Kernel) | |||||
| } // namespace ge | |||||
| #endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_KERNEL_FACTORY_H_ | |||||
| @@ -0,0 +1,476 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "./mds_utils.h" | |||||
| namespace ge { | |||||
| namespace { | |||||
| // for count | |||||
| thread_local int64_t data_slice_count = 0; | |||||
| thread_local int64_t data_gather_count = 0; | |||||
| thread_local int64_t data_reduce_count = 0; | |||||
| const std::string kPrefix = "mds"; | |||||
| } // namespace | |||||
| int64_t MdsUtils::GetNLocation(Format fmt) { | |||||
| int64_t loc = kNInvalidLocation; | |||||
| switch (fmt) { | |||||
| case FORMAT_NCHW: | |||||
| case FORMAT_NHWC: | |||||
| loc = kNLocation0; | |||||
| break; | |||||
| case FORMAT_CHWN: | |||||
| case FORMAT_HWCN: | |||||
| loc = kNLocation3; | |||||
| break; | |||||
| default: | |||||
| GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str()); | |||||
| } | |||||
| return loc; | |||||
| } | |||||
| int64_t MdsUtils::GetHLocation(Format fmt) { | |||||
| int64_t loc = kHInvalidLocation; | |||||
| switch (fmt) { | |||||
| case FORMAT_HWCN: | |||||
| loc = kHLocation0; | |||||
| break; | |||||
| case FORMAT_NHWC: | |||||
| case FORMAT_CHWN: | |||||
| loc = kHLocation1; | |||||
| break; | |||||
| case FORMAT_NCHW: | |||||
| loc = kHLocation2; | |||||
| default: | |||||
| GELOGE(FAILED, "[MDS]unsupported format:%d %s", fmt, TypeUtils::FormatToSerialString(fmt).c_str()); | |||||
| } | |||||
| return loc; | |||||
| } | |||||
| int64_t MdsUtils::GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type) { | |||||
| Format fmt = ge_tensor_desc->GetFormat(); | |||||
| switch (type) { | |||||
| case kCutN: | |||||
| return GetNLocation(fmt); | |||||
| case kCutH: | |||||
| return GetHLocation(fmt); | |||||
| default:; | |||||
| } | |||||
| GELOGE(FAILED, "[MDS]invalid CutType:%d", type); | |||||
| return kInvalidIndex; | |||||
| } | |||||
| bool MdsUtils::IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_desc, CutType type) { | |||||
| if (ge_tensor_desc == nullptr) { | |||||
| REPORT_INNER_ERROR("E19999", "invalid input param: tensor is null!"); | |||||
| GELOGE(FAILED, "[MDS]invalid input param: tensor is null!"); | |||||
| return false; | |||||
| } | |||||
| if (type != kCutN && type != kCutH) { | |||||
| REPORT_INNER_ERROR("E19999", "invalid CutType:%d", type); | |||||
| GELOGE(FAILED, "[MDS]invalid CutType:%d", type); | |||||
| return false; | |||||
| } | |||||
| int64_t cut_index = GetIndexByFormat(ge_tensor_desc, type); | |||||
| if (cut_index == kInvalidIndex) { | |||||
| REPORT_INNER_ERROR("E19999", "invalid index param:%ld", cut_index); | |||||
| GELOGE(FAILED, "[MDS]", "invalid index param:%ld", cut_index); | |||||
| return false; | |||||
| } | |||||
| auto dims = ge_tensor_desc->GetShape().GetDims(); | |||||
| if (cut_index < 0 || cut_index >= dims.size()) { | |||||
| REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type, | |||||
| dims.size()); | |||||
| GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of dims size %zu", cut_index, type, | |||||
| dims.size()); | |||||
| return false; | |||||
| } | |||||
| if (dims[cut_index] % kDeployNumber != 0) { | |||||
| GELOGW("[MDS] cut_index %ld for CutType %d with dim %ld can not deploy", cut_index, type, dims[cut_index]); | |||||
| return false; | |||||
| } | |||||
| vector<int64_t> cut_support_info; | |||||
| if (!(AttrUtils::GetListInt(*ge_tensor_desc, ATTR_NAME_CUT_INFO, cut_support_info))) { | |||||
| REPORT_INNER_ERROR("E19999", "call GetlistInt failed"); | |||||
| GELOGE(FAILED, "[MDS]", "call GetlistInt failed"); | |||||
| return false; | |||||
| } | |||||
| if (cut_index < 0 || cut_index >= cut_support_info.size()) { | |||||
| REPORT_INNER_ERROR("E19999", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index, | |||||
| type, cut_support_info.size()); | |||||
| GELOGE(FAILED, "[MDS]", "cut_index %ld for CutType %d is out of range of cut_support_info size %zu", cut_index, | |||||
| type, cut_support_info.size()); | |||||
| return false; | |||||
| } | |||||
| if (cut_support_info[cut_index] < kNotSupport || cut_support_info[cut_index] > kAnyCutSupported) { | |||||
| REPORT_INNER_ERROR("E19999", "invalid cut info value:%ld", cut_support_info[cut_index]); | |||||
| GELOGE(FAILED, "[MDS]", "invalid cut info value:%ld", cut_support_info[cut_index]); | |||||
| return false; | |||||
| } | |||||
| return cut_support_info[cut_index] & kSplitCutSupported; | |||||
| } | |||||
| Status MdsUtils::DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type, int64_t deploy_number) { | |||||
| GE_CHECK_NOTNULL(ge_tensor_desc); | |||||
| auto index = MdsUtils::GetIndexByFormat(ge_tensor_desc, type); | |||||
| auto dims = ge_tensor_desc->GetShape().GetDims(); | |||||
| REQUIRE(index < dims.size(), "[DistributedDeploy] failed, index %ld should less than %zu", index, dims.size()); | |||||
| auto dim_after_deploy = dims[index] / deploy_number; | |||||
| MDS_REQUIRE_SUCCESS(ge_tensor_desc->MutableShape().SetDim(index, dim_after_deploy), | |||||
| "[DistributedDeploy] update shape failed"); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MdsUtils::SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor, const std::string &group_name) { | |||||
| GE_CHECK_NOTNULL(hcom_op); | |||||
| REQUIRE(fission_factor > kDefaultFissionFactor, "fission_factor %ld need be bigger than %ld", fission_factor, | |||||
| kDefaultFissionFactor); | |||||
| REQUIRE(ge::AttrUtils::SetInt(hcom_op, ATTR_NAME_FISSION_FACTOR, fission_factor), | |||||
| "Failed to set attr fission_factor %ld for op:%s(%s)", fission_factor, hcom_op->GetName().c_str(), | |||||
| hcom_op->GetType().c_str()); | |||||
| if (!group_name.empty()) { | |||||
| REQUIRE(ge::AttrUtils::SetStr(hcom_op, HCOM_ATTR_GROUP, group_name), "Failed to set attr group %s for op:%s(%s)", | |||||
| group_name.c_str(), hcom_op->GetName().c_str(), hcom_op->GetType().c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| bool MdsUtils::IsMDSNeeded() { | |||||
| std::string device_type; | |||||
| if (ge::GetContext().GetOption(ge::OPTION_DEVICE_TYPE, device_type) && device_type == kDefaultDeviceType) { | |||||
| GELOGI("[MDS]device type is %s, skip mds", device_type.c_str()); | |||||
| return false; | |||||
| } | |||||
| // TODO: Parse the configuration file of the system to get the sys_config_exe_unit | |||||
| std::string sys_config_exe_unit = "DIE"; | |||||
| return device_type != sys_config_exe_unit; | |||||
| } | |||||
| Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodePtr &input_node) { | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| GELOGD("[MDS]%s SetDeployInfo start", compute_graph->GetName().c_str()); | |||||
| // build deploy info | |||||
| vector<GeAttrValue::NAMED_ATTRS> deploy_info; | |||||
| GE_CHECK_NOTNULL(input_node); | |||||
| for (int64_t j = 0; j < kDeployNumber; j++) { | |||||
| int64_t device_id = j; | |||||
| GeAttrValue::LIST_TENSOR graph_inputs; | |||||
| GeTensorPtr graph_input = MakeShared<GeTensor>(input_node->GetOpDesc()->GetOutputDesc(0)); | |||||
| vector<uint8_t> data{static_cast<uint8_t>(device_id)}; | |||||
| graph_input->SetData(data); | |||||
| // For now, only one graph_input | |||||
| graph_inputs.push_back(graph_input); | |||||
| GeAttrValue::NAMED_ATTRS thread_instance; | |||||
| thread_instance.SetName(std::to_string(device_id)); | |||||
| (void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom<GeAttrValue::INT>(device_id)); | |||||
| // TODO:Change to enumeration from RTS header file | |||||
| (void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom<GeAttrValue::STR>("MultiMode")); | |||||
| (void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom<GeAttrValue::STR>(compute_graph->GetName())); | |||||
| (void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom<GeAttrValue::LIST_TENSOR>(graph_inputs)); | |||||
| deploy_info.emplace_back(thread_instance); | |||||
| GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id); | |||||
| } | |||||
| // set deploy info | |||||
| REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info), | |||||
| "Set attr failed for graph %s", compute_graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| CutType MdsUtils::TryGetGraphCutType(const ComputeGraphPtr &compute_graph) { | |||||
| bool is_unknown_graph = false; | |||||
| if (GraphUtils::IsUnknownShapeGraph(compute_graph)) { | |||||
| GELOGI("Graph %s is unknown shape graph", compute_graph->GetName().c_str()); | |||||
| is_unknown_graph = true; | |||||
| } | |||||
| CutType selected_cut_type = kNoCut; | |||||
| for (const auto &data : compute_graph->GetInputNodes()) { | |||||
| GELOGI("Get graph input %s %s", data->GetName().c_str(), data->GetType().c_str()); | |||||
| auto data_n_index = MdsUtils::GetIndexByFormat(data->GetOpDesc()->MutableOutputDesc(0), kCutN); | |||||
| auto data_n_dim = data->GetOpDesc()->GetOutputDesc(0).GetShape().GetDim(data_n_index); | |||||
| auto data_h_index = MdsUtils::GetIndexByFormat(data->GetOpDesc()->MutableOutputDesc(0), kCutH); | |||||
| auto data_h_dim = data->GetOpDesc()->GetOutputDesc(0).GetShape().GetDim(data_h_index); | |||||
| if (data_n_dim == -1 && data_h_dim == -1) { | |||||
| selected_cut_type = kDynamicCutAll; | |||||
| break; | |||||
| } | |||||
| if (data_n_dim % kDeployNumber == 0) { | |||||
| is_unknown_graph ? selected_cut_type = kDynamicCutN : selected_cut_type = kCutN; | |||||
| break; | |||||
| } | |||||
| if (data_h_dim % kDeployNumber == 0) { | |||||
| is_unknown_graph ? selected_cut_type = kDynamicCutH : selected_cut_type = kCutH; | |||||
| } | |||||
| } | |||||
| return selected_cut_type; | |||||
| } | |||||
| Status MdsUtils::SetDeployInfo(const ComputeGraphPtr &compute_graph, | |||||
| const std::multimap<DeviceId, GraphInputs> &deploys, const std::string &device_type) { | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| GELOGD("[MDS]%s SetDeployInfo start", compute_graph->GetName().c_str()); | |||||
| // build deploy info | |||||
| vector<GeAttrValue::NAMED_ATTRS> deploy_info; | |||||
| for (const auto &pair : deploys) { | |||||
| int64_t device_id = pair.first; | |||||
| GeAttrValue::NAMED_ATTRS thread_instance; | |||||
| thread_instance.SetName(std::to_string(device_id)); | |||||
| (void)thread_instance.SetAttr(kAttrNeedReturnResult, | |||||
| GeAttrValue::CreateFrom<GeAttrValue::BOOL>(deploy_info.empty() ? true : false)); | |||||
| (void)thread_instance.SetAttr(kAttrDeviceId, GeAttrValue::CreateFrom<GeAttrValue::INT>(device_id)); | |||||
| (void)thread_instance.SetAttr(kAttrDeviceType, GeAttrValue::CreateFrom<GeAttrValue::STR>(device_type)); | |||||
| (void)thread_instance.SetAttr(kAttrGraphName, GeAttrValue::CreateFrom<GeAttrValue::STR>(compute_graph->GetName())); | |||||
| (void)thread_instance.SetAttr(kAttrGraphInputs, GeAttrValue::CreateFrom<GeAttrValue::LIST_TENSOR>(pair.second)); | |||||
| deploy_info.emplace_back(thread_instance); | |||||
| GELOGD("[MDS]%s SetDeployInfo on device id: %d", compute_graph->GetName().c_str(), device_id); | |||||
| } | |||||
| // set deploy info | |||||
| REQUIRE(ge::AttrUtils::SetListNamedAttrs(*compute_graph, ATTR_NAME_DEPLOY_INFO, deploy_info), | |||||
| "Set attr failed for graph %s", compute_graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MdsUtils::DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst) { | |||||
| auto src_node = src->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(src_node); | |||||
| auto dst_node = dst->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(dst_node); | |||||
| auto src_graph = src_node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(src_graph); | |||||
| std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_gather_count)); | |||||
| auto hcom_allgather_node = | |||||
| AddDynamicInputOutputNode(src_graph, HCOMALLGATHER, HCOMALLGATHER + node_name_suffix, 1, 1); | |||||
| GE_CHECK_NOTNULL(hcom_allgather_node); | |||||
| MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, hcom_allgather_node), | |||||
| "[DataGather] failed between %s and %s", src_node->GetName().c_str(), | |||||
| dst_node->GetName().c_str()); | |||||
| MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(hcom_allgather_node->GetOpDesc(), kDeployNumber, kDefaultGroup), | |||||
| "[DataGather]set attr for node for %s(%s) failed", hcom_allgather_node->GetName().c_str(), | |||||
| hcom_allgather_node->GetType().c_str()); | |||||
| REQUIRE(ge::AttrUtils::SetInt(hcom_allgather_node->GetOpDesc(), HCOM_ATTR_RANK_SIZE, kDefaultRankSize), | |||||
| "Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(), | |||||
| hcom_allgather_node->GetName().c_str(), hcom_allgather_node->GetType().c_str()); | |||||
| MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(hcom_allgather_node, false), | |||||
| "[DataGather] %s call infershape failed", hcom_allgather_node->GetName().c_str()); | |||||
| data_gather_count++; | |||||
| return SUCCESS; | |||||
| } | |||||
| // gradients->ApplyMomentum | |||||
| // we want to reduce gradients on different device(die), so graph topo changed to | |||||
| // gradients->hcomallreducemean->ApplyMomentum; Because 'mean' is not currently supported by hcomallreduce, | |||||
| // topo will end up like gradients->hcomallreducesum->div->ApplyMomentum | |||||
| Status MdsUtils::DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst) { | |||||
| auto src_node = src->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(src_node); | |||||
| auto dst_node = dst->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(dst_node); | |||||
| auto src_graph = src_node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(src_graph); | |||||
| NodePtr all_reduce_node = nullptr; | |||||
| if (NeedInsertHcomAllReduce(src_node, all_reduce_node)) { | |||||
| MDS_REQUIRE_SUCCESS(ConstructReduceNode(src_graph, src, dst, all_reduce_node), | |||||
| "[DataReduce] construct allreduce node for %s failed", all_reduce_node->GetName().c_str()); | |||||
| GE_CHECK_NOTNULL(all_reduce_node); | |||||
| } else { | |||||
| GE_CHECK_NOTNULL(all_reduce_node); | |||||
| MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(all_reduce_node->GetOpDesc(), kDeployNumber), | |||||
| "[DataReduce][Modify] set attr for allreduce node for %s failed", | |||||
| all_reduce_node->GetName().c_str()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| // tensor t with shape like [n,c,h,w], we want get [0:2/n, c, h, w] and [2/n : n, c, h, w] on different | |||||
| // device; To achieve this goal, we use slice nodes. | |||||
| // slice(t, [i * n/2, 0, 0, 0], [n/2, c, h, w]) i=0,1 | |||||
| // slice three input like : t->slice; data(0,1)->mul(n/2)->pack[i*n/2,0,0,0]->slice; const(n,c,h,w)->slice | |||||
| Status MdsUtils::DataSlice(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, NodePtr &input_node) { | |||||
| auto src_node = src->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(src_node); | |||||
| auto dst_node = dst->GetOwnerNode(); | |||||
| GE_CHECK_NOTNULL(dst_node); | |||||
| auto src_graph = src_node->GetOwnerComputeGraph(); | |||||
| GE_CHECK_NOTNULL(src_graph); | |||||
| if (input_node == nullptr) { | |||||
| std::string input_node_name = std::string(DATA) + "_" + kPrefix + "_" + std::to_string(0); | |||||
| input_node = AddSingleInputOutputNode(src_graph, input_node_name, DATA); | |||||
| AddInputNode(input_node); | |||||
| } | |||||
| GeTensorDesc tensor = src_node->GetOpDesc()->GetOutputDesc(src->GetIdx()); | |||||
| NodePtr slice_node = nullptr; | |||||
| MDS_REQUIRE_SUCCESS(ConstructSliceNode(src_graph, tensor, input_node.get(), slice_node), | |||||
| "[DataSlice] construct slice node for %s failed", src_node->GetName().c_str()); | |||||
| GE_CHECK_NOTNULL(slice_node); | |||||
| MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, slice_node), "[DataSlice] failed between %s and %s", | |||||
| src_node->GetName().c_str(), dst_node->GetName().c_str()); | |||||
| MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(slice_node, false), "[DataSlice] %s call infer shape failed", | |||||
| slice_node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status MdsUtils::ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *input_node, | |||||
| NodePtr &slice_node) { | |||||
| vector<int64_t> slice_sizes = tensor.GetShape().GetDims(); | |||||
| // TODO: Express with graph structure | |||||
| slice_sizes[0] /= kDeployNumber; | |||||
| vector<GeTensorPtr> ge_tensors; | |||||
| GeTensorDesc ge_tensor_desc; | |||||
| ge_tensor_desc.SetDataType(DT_INT64); | |||||
| MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_sizes, ge_tensors), | |||||
| "[ConstructTensorDescWithData] failed"); | |||||
| GeTensorPtr slice_size_tensor = ge_tensors[0]; | |||||
| auto const_node_slice_size = AddConstNodeToGraph(slice_size_tensor, src_graph); | |||||
| vector<int64_t> slice_offset_other_dim{0}; | |||||
| ge_tensors.clear(); | |||||
| MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_offset_other_dim, ge_tensors, true), | |||||
| "[ConstructTensorDescWithData] failed"); | |||||
| GeTensorPtr slice_offset_tensor = ge_tensors[0]; | |||||
| auto const_node_slice_offset = AddConstNodeToGraph(slice_offset_tensor, src_graph); | |||||
| vector<int64_t> slice_offset_first_dim{slice_sizes[0]}; | |||||
| ge_tensors.clear(); | |||||
| MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_offset_first_dim, ge_tensors, true), | |||||
| "[ConstructTensorDescWithData] failed"); | |||||
| GeTensorPtr slice_offset_first_dim_tensor = ge_tensors[0]; | |||||
| auto const_node_slice_offset_first_dim = AddConstNodeToGraph(slice_offset_first_dim_tensor, src_graph); | |||||
| std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_slice_count)); | |||||
| NodePtr mul_node = AddDynamicInputOutputNode(src_graph, MUL, MUL + node_name_suffix, 2, 1); | |||||
| GE_CHECK_NOTNULL(input_node); | |||||
| MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(input_node->GetOutDataAnchor(0), mul_node->GetInDataAnchor(0)), | |||||
| "[ConstructSliceNode] add edge failed"); | |||||
| MDS_REQUIRE_SUCCESS( | |||||
| GraphUtils::AddEdge(const_node_slice_offset_first_dim->GetOutDataAnchor(0), mul_node->GetInDataAnchor(1)), | |||||
| "[ConstructSliceNode] add edge failed"); | |||||
| MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(mul_node, false), "[DataSlice] %s call infer shape failed", | |||||
| mul_node->GetName().c_str()); | |||||
| NodePtr pack_node = AddDynamicInputOutputNode(src_graph, PACK, PACK + node_name_suffix, slice_sizes.size(), 1); | |||||
| bool is_first_input = true; | |||||
| for (const auto &in_anchor : pack_node->GetAllInDataAnchors()) { | |||||
| if (is_first_input) { | |||||
| MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(mul_node->GetOutDataAnchor(0), in_anchor), | |||||
| "[ConstructSliceNode] add edge failed"); | |||||
| is_first_input = false; | |||||
| } else { | |||||
| MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_slice_offset->GetOutDataAnchor(0), in_anchor), | |||||
| "[ConstructSliceNode] add edge failed"); | |||||
| } | |||||
| } | |||||
| MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(pack_node, false), "[DataSlice] %s call infer shape failed", | |||||
| pack_node->GetName().c_str()); | |||||
| slice_node = AddDynamicInputOutputNode(src_graph, SLICE, SLICE + node_name_suffix, 3, 1); | |||||
| MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(pack_node->GetOutDataAnchor(0), slice_node->GetInDataAnchor(1)), | |||||
| "[ConstructSliceNode] add edge failed"); | |||||
| MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_slice_size->GetOutDataAnchor(0), slice_node->GetInDataAnchor(2)), | |||||
| "[ConstructSliceNode] add edge failed"); | |||||
| ++data_slice_count; | |||||
| return SUCCESS; | |||||
| } | |||||
| NodePtr MdsUtils::AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type, | |||||
| const GeTensorDesc &tensor) { | |||||
| GELOGI("Begin to create op: %s", name.c_str()); | |||||
| OpDescBuilder op_desc_builder(name, type); | |||||
| OpDescPtr op_desc = op_desc_builder.AddInput("x", tensor).AddOutput("y", tensor).Build(); | |||||
| if (op_desc == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "Create op_desc:%s(%s) failed", name.c_str(), type.c_str()); | |||||
| GELOGE(FAILED, "[Create][OpDesc] failed, name:%s(%s).", name.c_str(), type.c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| NodePtr node = graph->AddNode(op_desc); | |||||
| if (node == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(), | |||||
| op_desc->GetType().c_str(), graph->GetName().c_str()); | |||||
| GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||||
| graph->GetName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| return node; | |||||
| } | |||||
| NodePtr MdsUtils::AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const std::string &type, | |||||
| const std::string &node_name, size_t input_num, size_t output_num) { | |||||
| GELOGI("Begin to create op: %s", node_name.c_str()); | |||||
| OpDescBuilder op_desc_builder(node_name, type); | |||||
| OpDescPtr op_desc = op_desc_builder.AddDynamicInput("x", input_num).AddDynamicOutput("y", output_num).Build(); | |||||
| if (op_desc == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "Create op_desc:%s(%s) failed", node_name.c_str(), type.c_str()); | |||||
| GELOGE(FAILED, "[Create][OpDesc] failed, name:%s(%s).", node_name.c_str(), type.c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| NodePtr node = graph->AddNode(op_desc); | |||||
| if (node == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "Add node:%s(%s) to graph:%s failed", op_desc->GetName().c_str(), | |||||
| op_desc->GetType().c_str(), graph->GetName().c_str()); | |||||
| GELOGE(FAILED, "[Add][Node] %s(%s) to graph:%s failed", op_desc->GetName().c_str(), op_desc->GetType().c_str(), | |||||
| graph->GetName().c_str()); | |||||
| return nullptr; | |||||
| } | |||||
| return node; | |||||
| } | |||||
| NodePtr MdsUtils::AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr &graph) { | |||||
| auto const_desc = OpDescUtils::CreateConstOp(tensor); | |||||
| if (const_desc == nullptr) { | |||||
| REPORT_CALL_ERROR("E19999", "Create Const op failed"); | |||||
| GELOGE(OUT_OF_MEMORY, "[Create][ConstOp] failed"); | |||||
| return nullptr; | |||||
| } | |||||
| if (graph == nullptr) { | |||||
| GELOGW("input param graph is null"); | |||||
| return nullptr; | |||||
| } | |||||
| return graph->AddNodeFront(const_desc); | |||||
| } | |||||
| Status MdsUtils::ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src, | |||||
| const InDataAnchorPtr &dst, NodePtr &reduce_node) { | |||||
| std::string node_name_suffix("_" + kPrefix + "_" + std::to_string(data_reduce_count)); | |||||
| reduce_node = AddDynamicInputOutputNode(src_graph, HCOMALLREDUCE, HCOMALLREDUCE + node_name_suffix, 1, 1); | |||||
| MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(src, {dst}, reduce_node), | |||||
| "[DataReduce] failed insert %s between %s and %s", reduce_node->GetName().c_str(), | |||||
| src->GetOwnerNode()->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str()); | |||||
| MDS_REQUIRE_SUCCESS(MdsUtils::SetAttrForHcomNode(reduce_node->GetOpDesc(), kDeployNumber, kDefaultGroup), | |||||
| "[DataReduce][Create] set attr for allreduce node for %s failed", reduce_node->GetName().c_str()); | |||||
| REQUIRE(ge::AttrUtils::SetStr(reduce_node->GetOpDesc(), HCOM_ATTR_REDUCE_TYPE, kDefaultReduction), | |||||
| "Failed to set attr reduction type %s for op:%s(%s)", kDefaultReduction.c_str(), | |||||
| reduce_node->GetName().c_str(), reduce_node->GetType().c_str()); | |||||
| MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(reduce_node, false), "[DataReduce] %s call infershape failed", | |||||
| reduce_node->GetName().c_str()); | |||||
| auto div_node = AddDynamicInputOutputNode(src_graph, REALDIV, REALDIV + node_name_suffix, 2, 1); | |||||
| vector<int64_t> slice_sizes{kDeployNumber}; | |||||
| vector<GeTensorPtr> ge_tensors; | |||||
| GeTensorDesc ge_tensor_desc; | |||||
| ge_tensor_desc.SetDataType(DT_INT64); | |||||
| MDS_REQUIRE_SUCCESS(PassUtils::ConstructTensorDescWithData(ge_tensor_desc, slice_sizes, ge_tensors), | |||||
| "[ConstructReduceNode] failed"); | |||||
| REQUIRE(!ge_tensors.empty(), "[ConstructReduceNode] failed"); | |||||
| auto const_node_div_input = AddConstNodeToGraph(ge_tensors[0], src_graph); | |||||
| MDS_REQUIRE_SUCCESS(GraphUtils::AddEdge(const_node_div_input->GetOutDataAnchor(0), div_node->GetInDataAnchor(1)), | |||||
| "[ConstructSliceNode] add edge failed"); | |||||
| MDS_REQUIRE_SUCCESS(GraphUtils::InsertNodeAfter(reduce_node->GetOutDataAnchor(0), {dst}, div_node), | |||||
| "[DataReduce] failed insert %s between %s and %s", div_node->GetName().c_str(), | |||||
| reduce_node->GetName().c_str(), dst->GetOwnerNode()->GetName().c_str()); | |||||
| MDS_REQUIRE_SUCCESS(ShapeRefiner::InferShapeAndType(div_node, false), "[DataReduce] %s call infershape failed", | |||||
| div_node->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| bool MdsUtils::NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node) { | |||||
| // TODO: recognize that the graph is originally a multi-p model, that is, there is already an allreduce node, | |||||
| // so there is no need to insert i | |||||
| return true; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,130 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_ | |||||
| #define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_ | |||||
| #include "graph/ge_context.h" | |||||
| #include "common/op/ge_op_utils.h" | |||||
| #include "graph/utils/type_utils.h" | |||||
| #include "graph/utils/graph_utils.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "ge/ge_api_types.h" | |||||
| #include "common/ge/ge_util.h" | |||||
| #include "graph/compute_graph.h" | |||||
| #include "graph/shape_refiner.h" | |||||
| #include "graph/debug/ge_op_types.h" | |||||
| #include "framework/common/types.h" | |||||
| #include "graph/utils/op_desc_utils.h" | |||||
| #include "../pass_utils.h" | |||||
| #define REQUIRE(cond, ...) \ | |||||
| do { \ | |||||
| if (!(cond)) { \ | |||||
| REPORT_INNER_ERROR("E19999", __VA_ARGS__); \ | |||||
| GELOGE(FAILED, "[MDS]" __VA_ARGS__); \ | |||||
| return FAILED; \ | |||||
| } \ | |||||
| } while (0) | |||||
| #define MDS_REQUIRE_NOT_NULL(cond, ...) REQUIRE(((cond) != nullptr), __VA_ARGS__) | |||||
| #define MDS_REQUIRE_SUCCESS(cond, ...) REQUIRE(((cond) == SUCCESS), __VA_ARGS__) | |||||
| #define MDS_REQUIRE_GRAPH_SUCCESS(cond, ...) REQUIRE(((cond) == GRAPH_SUCCESS), __VA_ARGS__) | |||||
| namespace ge { | |||||
| namespace { | |||||
| // Invalid location index | |||||
| const int64_t kInvalidIndex = -1; | |||||
| enum NCutIndex { kNLocation0 = 0, kNLocation1, kNLocation2, kNLocation3, kNInvalidLocation = -1 }; | |||||
| enum HCutIndex { kHLocation0 = 0, kHLocation1, kHLocation2, kHLocation3, kHInvalidLocation = -1 }; | |||||
| // NCHW dim N index | |||||
| const int32_t kNchwDimIdxN = 0; | |||||
| // NCHW dim C index | |||||
| const int32_t kNchwDimIdxC = 1; | |||||
| // NCHW dim H index | |||||
| const int32_t kNchwDimIdxH = 2; | |||||
| // NCHW dim W index | |||||
| const int32_t kNchwDimIdxW = 3; | |||||
| // default die number | |||||
| const uint32_t kDeployNumber = 2; | |||||
| enum CutType { kNoCut = 0, kCutN, kCutH, kDynamicCutN, kDynamicCutH, kDynamicCutAll }; | |||||
| enum TensorCutInfo { kNotSupport = 0, kSplitCutSupported, kAnyCutSupported = 3 }; | |||||
| const int64_t kDefaultFissionFactor = 1; | |||||
| const int64_t kDefaultRankSize = 1; | |||||
| const std::string kDefaultGroup = "hccl_world_group"; | |||||
| const std::string kDefaultReduction = "sum"; | |||||
| const char *const kDefaultDeviceType = "DEFAULT_DEVICE_TYPE"; | |||||
| const char *const kDefaultExecUnit = "DEFAULT_DEVICE_TYPE"; | |||||
| // deploy info | |||||
| const char *const kAttrNeedReturnResult = "_need_return_result"; | |||||
| const char *const kAttrDeviceType = "_device_type"; | |||||
| const char *const kDieDeviceTypeValue = "MultiMode"; | |||||
| const char *const kAttrDeviceId = "_device_id"; | |||||
| const char *const kAttrGraphName = "_graph_name"; | |||||
| const char *const kAttrGraphInputs = "_graph_inputs"; | |||||
| using GraphInputs = vector<GeTensorPtr>; | |||||
| using DeviceId = int64_t; | |||||
| using GraphInputNodes = vector<NodePtr>; | |||||
| } // namespace | |||||
| class MdsUtils { | |||||
| public: | |||||
| // Parse the configuration file and determine whether to enable MDS based on the value of device_type. | |||||
| static bool IsMDSNeeded(); | |||||
| static int64_t GetNLocation(Format fmt); | |||||
| static int64_t GetHLocation(Format fmt); | |||||
| static int64_t GetIndexByFormat(const GeTensorDescPtr &ge_tensor_desc, CutType type); | |||||
| static bool IsDistributedDeploySupported(const GeTensorDescPtr &ge_tensor_desc, CutType type); | |||||
| static Status SetAttrForHcomNode(const OpDescPtr &hcom_op, int64_t fission_factor, | |||||
| const std::string &group_name = ""); | |||||
| /// @param [in] index 切分的轴 | |||||
| /// @param [in] deploy_number 切分的份数 | |||||
| static Status DistributedDeploy(const GeTensorDescPtr &ge_tensor_desc, CutType type, | |||||
| int64_t deploy_number = kDeployNumber); | |||||
| // Sets the information, notifies the number of threads to be started during the | |||||
| // loading phase, the device on which each thread should run, and constructs different input data on each device. | |||||
| static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const NodePtr &input_node); | |||||
| static Status SetDeployInfo(const ComputeGraphPtr &compute_graph, const std::multimap<DeviceId, GraphInputs> &deploys, | |||||
| const std::string &device_type = kDieDeviceTypeValue); | |||||
| // Get cut policy for whole graph | |||||
| static CutType TryGetGraphCutType(const ComputeGraphPtr &compute_graph); | |||||
| static GraphInputNodes GetInputNodes() { | |||||
| return input_nodes_; | |||||
| } | |||||
| static void AddInputNode(const NodePtr &input_node) { | |||||
| input_nodes_.push_back(input_node); | |||||
| } | |||||
| static Status DataGather(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); | |||||
| static Status DataReduce(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst); | |||||
| static Status DataSlice(const OutDataAnchorPtr &src, const InDataAnchorPtr &dst, NodePtr &input_node); | |||||
| private: | |||||
| static GraphInputNodes input_nodes_; | |||||
| static NodePtr AddDynamicInputOutputNode(const ComputeGraphPtr &graph, const string &type, const string &node_name, | |||||
| size_t input_num, size_t output_num); | |||||
| static NodePtr AddSingleInputOutputNode(const ComputeGraphPtr &graph, const string &name, const string &type, | |||||
| const GeTensorDesc &tensor = GeTensorDesc()); | |||||
| static Status ConstructReduceNode(const ComputeGraphPtr &src_graph, const OutDataAnchorPtr &src, | |||||
| const InDataAnchorPtr &dst, NodePtr &reduce_node); | |||||
| static Status ConstructSliceNode(const ComputeGraphPtr &src_graph, const GeTensorDesc &tensor, Node *node, | |||||
| NodePtr &slice_node); | |||||
| static bool NeedInsertHcomAllReduce(const NodePtr &src_node, NodePtr &allreduce_node); | |||||
| static NodePtr AddConstNodeToGraph(GeTensorPtr &tensor, const ComputeGraphPtr &graph); | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_MDS_UTILS_H_ | |||||
| @@ -0,0 +1,41 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "variable_mds_kernel.h" | |||||
| #include "mds_kernel_factory.h" | |||||
| namespace ge { | |||||
| Status VariableDeploySchedulerKernel::CutN(const ge::NodePtr& node_ptr) { | |||||
| GE_CHECK_NOTNULL(node_ptr); | |||||
| if (MdsUtils::IsDistributedDeploySupported(node_ptr->GetOpDesc()->MutableOutputDesc(0), kCutN)) { | |||||
| return MdsUtils::DistributedDeploy(node_ptr->GetOpDesc()->MutableOutputDesc(0), kCutN); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status VariableDeploySchedulerKernel::CutH(const ge::NodePtr& node_ptr) { | |||||
| GE_CHECK_NOTNULL(node_ptr); | |||||
| if (MdsUtils::IsDistributedDeploySupported(node_ptr->GetOpDesc()->MutableOutputDesc(0), kCutH)) { | |||||
| return MdsUtils::DistributedDeploy(node_ptr->GetOpDesc()->MutableOutputDesc(0), kCutH); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| REGISTER_MDS_KERNEL(VARIABLE, VariableDeploySchedulerKernel); | |||||
| } | |||||
| @@ -0,0 +1,28 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_VARIABLE_MDS_KERNEL_H_ | |||||
| #define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_VARIABLE_MDS_KERNEL_H_ | |||||
| #include "base_mds_kernel.h" | |||||
| namespace ge { | |||||
| class VariableDeploySchedulerKernel : public DeploySchedulerKernel { | |||||
| public: | |||||
| Status CutN(const ge::NodePtr& node_ptr) override; | |||||
| Status CutH(const ge::NodePtr& node_ptr) override; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif //MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_KERNELS_VARIABLE_MDS_KERNEL_H_ | |||||
| @@ -0,0 +1,177 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #include "./mds_pass.h" | |||||
| namespace ge { | |||||
| Status ModelDeploySchedulerPass::Run(ComputeGraphPtr graph) { | |||||
| GE_CHECK_NOTNULL(graph); | |||||
| compute_graph_ = graph; | |||||
| if (!MdsUtils::IsMDSNeeded()) { | |||||
| return SUCCESS; | |||||
| } | |||||
| GELOGI("[MDS][%s] start to deploy.", GetGraphName()); | |||||
| MDS_REQUIRE_SUCCESS(SMDPProcess(), "[MDS][SMDPProcess] failed, graph_name:[%s]", GetGraphName()); | |||||
| MDS_REQUIRE_SUCCESS(CutProcess(), "[MDS][CutProcess] failed, graph_name:[%s]", GetGraphName()); | |||||
| MDS_REQUIRE_SUCCESS(SMDPProcess(false), "[MDS][SMDPProcess] failed, graph_name:[%s]", GetGraphName()); | |||||
| MDS_REQUIRE_SUCCESS(SwapProcess(), "[MDS][SwapProcess] failed, graph_name:[%s]", GetGraphName()); | |||||
| MDS_REQUIRE_SUCCESS(PiplineProcess(), "[MDS][PiplineProcess] failed, graph_name:[%s]", GetGraphName()); | |||||
| MDS_REQUIRE_SUCCESS(SetDeployInfo(), "[MDS][SetDeployInfo] failed, graph_name:[%s]", GetGraphName()); | |||||
| GELOGI("[MDS][%s] deploy successfully.", graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelDeploySchedulerPass::CutProcess() { | |||||
| GE_CHECK_NOTNULL(compute_graph_); | |||||
| if (!compute_graph_->GetAllSubgraphs().empty() || compute_graph_->GetParentGraph() != nullptr) { | |||||
| GELOGW("[MDS][CutProcess] graph with subgraphs is not supported now. graph_name:[%s]", GetGraphName()); | |||||
| return SUCCESS; | |||||
| } | |||||
| auto type = MdsUtils::TryGetGraphCutType(compute_graph_); | |||||
| switch (type) { | |||||
| case kCutN: | |||||
| MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_), "[MDS][CutNProcessImply] failed, graph_name:[%s]", | |||||
| GetGraphName()); | |||||
| break; | |||||
| case kCutH: | |||||
| MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_), "[MDS][CutHProcessImply] failed, graph_name:[%s]", | |||||
| GetGraphName()); | |||||
| break; | |||||
| case kDynamicCutN: | |||||
| MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph_, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]", | |||||
| GetGraphName()); | |||||
| break; | |||||
| case kDynamicCutH: | |||||
| MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph_, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]", | |||||
| GetGraphName()); | |||||
| break; | |||||
| case kDynamicCutAll: | |||||
| MDS_REQUIRE_SUCCESS(DynamicCutAll(compute_graph_), "[MDS][DynamicCutAll] failed, graph_name:[%s]", | |||||
| GetGraphName()); | |||||
| break; | |||||
| default: | |||||
| GELOGI("[MDS][CutProcess] could not cut, just return. graph_name:[%s]", GetGraphName()); | |||||
| return SUCCESS; | |||||
| } | |||||
| } | |||||
| Status ModelDeploySchedulerPass::CutNProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic) { | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| // step 0: Cut | |||||
| for (const auto &node : compute_graph->GetDirectNode()) { | |||||
| auto op_kernel = mds_cut_pass::GetKernelByType(node); | |||||
| if (op_kernel == nullptr) { | |||||
| op_kernel = DeploySchedulerKernel::Instance(); | |||||
| } | |||||
| if (is_dynamic) { | |||||
| MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutN(node), "[MDS][DYNAMIC_CUTN] failed, node:[%s]", | |||||
| node->GetName().c_str()); | |||||
| } else { | |||||
| MDS_REQUIRE_SUCCESS(op_kernel->CutN(node), "[MDS][CUTN] failed, node:[%s]", node->GetName().c_str()); | |||||
| } | |||||
| bool is_grad_compute_node = false; | |||||
| if (ge::AttrUtils::GetBool(node->GetOpDesc(), ATTR_NAME_GRADIENT_NODE, is_grad_compute_node) && | |||||
| is_grad_compute_node) { | |||||
| grad_compute_nodes_.push_back(node); | |||||
| } | |||||
| } | |||||
| // TODO:for single output multi reference insertion allgather, allreduce nodes, do breadth fusion optimization | |||||
| MDS_REQUIRE_SUCCESS(HcomNodeFusionProcess(), "[MDS][CUTN][HcomNodeFusionProcess] failed, compute graph:[%s]", | |||||
| compute_graph->GetName().c_str()); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelDeploySchedulerPass::CutHProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic) { | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| for (NodePtr &node : compute_graph->GetDirectNode()) { | |||||
| auto op_kernel = mds_cut_pass::GetKernelByType(node); | |||||
| if (op_kernel == nullptr) { | |||||
| op_kernel = DeploySchedulerKernel::Instance(); | |||||
| } | |||||
| if (is_dynamic) { | |||||
| MDS_REQUIRE_SUCCESS(op_kernel->DynamicCutH(node), "[MDS][DYNAMIC_CUTH] failed, node:[%s]", | |||||
| node->GetName().c_str()); | |||||
| } else { | |||||
| MDS_REQUIRE_SUCCESS(op_kernel->CutH(node), "[MDS][CUTH] failed, node:[%s]", node->GetName().c_str()); | |||||
| } | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelDeploySchedulerPass::DynamicCutAll(const ComputeGraphPtr &compute_graph) { | |||||
| std::vector<NodePtr> input_nodes; | |||||
| std::vector<NodePtr> output_nodes; | |||||
| auto compute_graph0 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); | |||||
| auto compute_graph1 = GraphUtils::CloneGraph(compute_graph, "", input_nodes, output_nodes); | |||||
| MDS_REQUIRE_SUCCESS(CutNProcessImply(compute_graph0, true), "[MDS][CutNProcessImply] failed, graph_name:[%s]", | |||||
| compute_graph0->GetName().c_str()); | |||||
| MDS_REQUIRE_SUCCESS(CutHProcessImply(compute_graph1, true), "[MDS][CutHProcessImply] failed, graph_name:[%s]", | |||||
| compute_graph1->GetName().c_str()); | |||||
| // TODO:Create a case node, put the two graphs under the two branches of case | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelDeploySchedulerPass::SMDPProcess(bool before_cut) { | |||||
| if (before_cut) { | |||||
| MDS_REQUIRE_SUCCESS(SMDPModelState(), "[SMDPProcess][SMDPModelState] failed, graph_name:[%s]", GetGraphName()); | |||||
| MDS_REQUIRE_SUCCESS(SMDPWeight(), "[SMDPProcess][SMDPWeight] failed, graph_name:[%s]", GetGraphName()); | |||||
| } else { | |||||
| MDS_REQUIRE_SUCCESS(SMDPGradient(), "[SMDPProcess][SMDPGradient] failed, graph_name:[%s]", GetGraphName()); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelDeploySchedulerPass::SetDeployInfo() { | |||||
| vector<GeAttrValue::NAMED_ATTRS> deployInfo; | |||||
| REQUIRE(!ge::AttrUtils::GetListNamedAttrs(compute_graph_, ATTR_NAME_DEPLOY_INFO, deployInfo), | |||||
| "%s already has deployed before!", GetGraphName()); | |||||
| std::multimap<DeviceId, GraphInputs> deploys; | |||||
| for (int64_t j = 0; j < kDeployNumber; j++) { | |||||
| int64_t device_id = j; | |||||
| GraphInputs graph_inputs; | |||||
| // For now, only one input_node in input_nodes | |||||
| for (const auto &input_node : MdsUtils::GetInputNodes()) { | |||||
| GE_CHECK_NOTNULL(input_node); | |||||
| GeTensorPtr graph_input = MakeShared<GeTensor>(input_node->GetOpDesc()->GetOutputDesc(0)); | |||||
| vector<uint8_t> data{static_cast<uint8_t>(device_id)}; | |||||
| graph_input->SetData(data); | |||||
| graph_inputs.push_back(graph_input); | |||||
| } | |||||
| deploys.emplace(j, graph_inputs); | |||||
| } | |||||
| return MdsUtils::SetDeployInfo(compute_graph_, deploys); | |||||
| } | |||||
| Status ModelDeploySchedulerPass::SwapProcess() { | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelDeploySchedulerPass::PiplineProcess() { | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelDeploySchedulerPass::HcomNodeFusionProcess() { | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelDeploySchedulerPass::SMDPModelState() { | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelDeploySchedulerPass::SMDPWeight() { | |||||
| return SUCCESS; | |||||
| } | |||||
| Status ModelDeploySchedulerPass::SMDPGradient() { | |||||
| return SUCCESS; | |||||
| } | |||||
| } // namespace ge | |||||
| @@ -0,0 +1,71 @@ | |||||
| /** | |||||
| * Copyright 2021 Huawei Technologies Co., Ltd | |||||
| * | |||||
| * Licensed under the Apache License, Version 2.0 (the "License"); | |||||
| * you may not use this file except in compliance with the License. | |||||
| * You may obtain a copy of the License at | |||||
| * | |||||
| * http://www.apache.org/licenses/LICENSE-2.0 | |||||
| * | |||||
| * Unless required by applicable law or agreed to in writing, software | |||||
| * distributed under the License is distributed on an "AS IS" BASIS, | |||||
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||||
| * See the License for the specific language governing permissions and | |||||
| * limitations under the License. | |||||
| */ | |||||
| #ifndef MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_ | |||||
| #define MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_ | |||||
| #include "graph/types.h" | |||||
| #include "ge/ge_api.h" | |||||
| #include "graph/debug/ge_attr_define.h" | |||||
| #include "inc/graph_pass.h" | |||||
| #include "./mds_kernels/base_mds_kernel.h" | |||||
| #include "ge/ge_api_types.h" | |||||
| #include "./mds_kernels/mds_utils.h" | |||||
| namespace ge { | |||||
| class ModelDeploySchedulerPass : public GraphPass { | |||||
| public: | |||||
| Status Run(ge::ComputeGraphPtr graph) override; | |||||
| private: | |||||
| // Part0:Process Func | |||||
| // cut and dynamic cut | |||||
| Status CutProcess(); | |||||
| Status CutNProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic = false); | |||||
| Status CutHProcessImply(const ComputeGraphPtr &compute_graph, bool is_dynamic = false); | |||||
| Status DynamicCutAll(const ComputeGraphPtr &compute_graph); | |||||
| // smdp | |||||
| Status SMDPProcess(bool before_cut = true); | |||||
| Status SMDPModelState(); | |||||
| Status SMDPGradient(); | |||||
| Status SMDPWeight(); | |||||
| // swap | |||||
| Status SwapProcess(); | |||||
| // pipline | |||||
| Status PiplineProcess(); | |||||
| // set delpoyinfo | |||||
| Status SetDeployInfo(); | |||||
| // Part1: Utils Func | |||||
| // std::vector<bool> GetNodeInputsSupportCut(NodePtr node, uint64_t cut_index); | |||||
| // std::vector<bool> GetNodeOutputsSupportCut(NodePtr node, uint64_t cut_index); | |||||
| Status HcomNodeFusionProcess(); | |||||
| Status GetAllModelStateVar(); | |||||
| Status GetAllWeightVar(); | |||||
| std::vector<NodePtr> GetAllGradComputeNodes() { | |||||
| return grad_compute_nodes_; | |||||
| } | |||||
| const char *GetGraphName() const { | |||||
| return compute_graph_->GetName().c_str(); | |||||
| } | |||||
| // members | |||||
| std::vector<NodePtr> model_state_vars_; | |||||
| std::vector<NodePtr> model_weight_vars_; | |||||
| std::vector<NodePtr> grad_compute_nodes_; | |||||
| ComputeGraphPtr compute_graph_ = nullptr; | |||||
| }; | |||||
| } // namespace ge | |||||
| #endif // MAIN_GRAPHENGINE_GE_GRAPH_PASSES_MDS_H_ | |||||
| @@ -28,6 +28,7 @@ | |||||
| namespace ge { | namespace ge { | ||||
| // Option key: graph run mode | // Option key: graph run mode | ||||
| const char *const OPTION_GRAPH_RUN_MODE = "ge.graphRunMode"; | const char *const OPTION_GRAPH_RUN_MODE = "ge.graphRunMode"; | ||||
| const char *const OPTION_DEVICE_TYPE = "ge.deviceType"; | |||||
| // Option key: ome init | // Option key: ome init | ||||
| const char *const OPTION_EXEC_SESSION_ID = "ge.exec.sessionId"; | const char *const OPTION_EXEC_SESSION_ID = "ge.exec.sessionId"; | ||||