diff --git a/ge/ge_runtime/task/hccl_task.cc b/ge/ge_runtime/task/hccl_task.cc index 3d5f8504..dfeda94b 100644 --- a/ge/ge_runtime/task/hccl_task.cc +++ b/ge/ge_runtime/task/hccl_task.cc @@ -15,83 +15,56 @@ */ #include "ge_runtime/task/hccl_task.h" +#include #include "ge_runtime/task/task_factory.h" #include "common/opskernel/ops_kernel_info_store.h" #include "common/opskernel/ge_task_info.h" namespace ge { namespace model_runner { +std::map>>> + HcclTask::model_stream_mapping_; +std::mutex HcclTask::model_stream_mapping_mutex_; + HcclTask::HcclTask(const ModelContext &model_context, const std::shared_ptr &task_info) : TaskRepeater(model_context, task_info), task_info_(task_info), stream_(nullptr), + workspace_mem_(nullptr), rt_model_handle_(nullptr), priority_(0), - slave_stream_list_(), - hcom_bind_model_(nullptr), - hcom_unbind_model_(nullptr), - hcom_distribute_task_(nullptr) { + secondary_stream_list_() { if (task_info_ == nullptr) { GELOGW("task_info_ is null!"); } - hcom_bind_model_ = task_info->hcom_bind_model(); - hcom_unbind_model_ = task_info->hcom_unbind_model(); - priority_ = model_context.priority(); rt_model_handle_ = model_context.rt_model_handle(); auto stream_list = model_context.stream_list(); - if (hcom_bind_model_ != nullptr) { - if (rt_model_handle_list_.insert(rt_model_handle_).second) { - for (auto stream : stream_list) { - (void)hcom_bind_model_(rt_model_handle_, stream); - } - } - } - if (stream_list.size() == 1) { stream_ = stream_list[0]; } else if (stream_list.size() > task_info->stream_id()) { stream_ = stream_list[task_info->stream_id()]; } else { - GELOGW("index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size()); + GELOGW("Index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size()); } } HcclTask::~HcclTask() { - for (size_t i = 0; i < slave_stream_list_.size(); ++i) { - rtError_t rt_ret = rtModelUnbindStream(rt_model_handle_, slave_stream_list_[i]); + if (workspace_mem_ != nullptr) { + rtError_t rt_ret = rtFree(workspace_mem_); if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Unbind stream from model failed! Index: %zu", i); - } - } - - for (size_t i = 0; i < slave_stream_list_.size(); ++i) { - rtError_t rt_ret = rtStreamDestroy(slave_stream_list_[i]); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Destroy stream failed! Index: %zu", i); - } - } - - if (hcom_unbind_model_ != nullptr) { - if (rt_model_handle_list_.find(rt_model_handle_) != rt_model_handle_list_.end()) { - (void)hcom_unbind_model_(rt_model_handle_); - (void)rt_model_handle_list_.erase(rt_model_handle_); + GELOGE(RT_FAILED, "rtFree workspace_mem_ failed! ret: 0x%X.", rt_ret); } + workspace_mem_ = nullptr; } } bool HcclTask::Distribute() { - // No ops kernel info store - hcom_distribute_task_ = task_info_->hcom_distribute_task(); - if (hcom_distribute_task_ != nullptr) { - return hcom_distribute_task_(task_info_, stream_); - } - // Ops kernel info store // Get privateDef and opsKernelStorePtr - GELOGI("get custom info in modelTaskDef"); + GELOGI("Get custom info in modelTaskDef"); void *ops_kernel_store = task_info_->ops_kernel_store(); OpsKernelInfoStore *ops_kernel_info_store = reinterpret_cast(ops_kernel_store); if (ops_kernel_store == nullptr) { @@ -101,25 +74,15 @@ bool HcclTask::Distribute() { char *private_def = reinterpret_cast(const_cast(task_info_->private_def().data())); auto private_def_len = static_cast(task_info_->private_def().size()); - GELOGI("the first address of the custom info, privateDef=%p", private_def); - - GELOGI("hcclStreamNum =%ld", task_info_->hccl_stream_num()); - for (int64_t i = 0; i < task_info_->hccl_stream_num(); ++i) { - rtStream_t stream = nullptr; - rtError_t rt_ret = rtStreamCreateWithFlags(&stream, priority_, RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY); - if (rt_ret != RT_ERROR_NONE) { - GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); - return false; - } + GELOGI("The first address of the custom info, privateDef=%p", private_def); + SetSecondaryStream(); - rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM); + if (task_info_->workspace_size() > 0) { + rtError_t rt_ret = rtMalloc(&workspace_mem_, task_info_->workspace_size(), RT_MEMORYINFO_HBM); if (rt_ret != RT_ERROR_NONE) { GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); return false; } - - GELOGI("hccl_stream addr is=%p", stream); - slave_stream_list_.push_back(stream); } GELOGI("HcclTaskInfo Distribute Start. begin to call function LoadTask in hccl."); @@ -128,17 +91,22 @@ bool HcclTask::Distribute() { ge_task.type = static_cast(RT_MODEL_TASK_HCCL); ge_task.stream = stream_; + ge_task.kernelHcclInfo = std::vector(1); ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type(); ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr(); ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr(); - ge_task.kernelHcclInfo[0].workSpaceAddr = task_info_->workspace_addr(); + ge_task.kernelHcclInfo[0].workSpaceAddr = workspace_mem_; ge_task.kernelHcclInfo[0].workSpaceMemSize = task_info_->workspace_size(); ge_task.kernelHcclInfo[0].count = task_info_->count(); ge_task.kernelHcclInfo[0].dataType = static_cast(task_info_->data_type()); ge_task.kernelHcclInfo[0].opType = static_cast(task_info_->op_type()); ge_task.kernelHcclInfo[0].rootId = task_info_->root_id(); - ge_task.kernelHcclInfo[0].hcclStreamList = slave_stream_list_; + std::vector secondary_stream_list; + std::transform(secondary_stream_list_.begin(), secondary_stream_list_.end(), + std::back_inserter(secondary_stream_list), + [](const std::shared_ptr &stream) -> rtStream_t { return stream->GetStream(); }); + ge_task.kernelHcclInfo[0].hcclStreamList = secondary_stream_list; ge_task.privateDef = private_def; ge_task.privateDefLen = private_def_len; @@ -151,10 +119,151 @@ bool HcclTask::Distribute() { return false; } - GELOGI("call function LoadTask end."); + GELOGI("Call function LoadTask end."); return true; } +bool HcclTask::SetSecondaryStream() { + const uint32_t master_stream_id = task_info_->stream_id(); + const int64_t hccl_secondary_stream_num = task_info_->hccl_stream_num(); + Status ret; + std::lock_guard lock(model_stream_mapping_mutex_); + if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) { + GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %ld.", rt_model_handle_, master_stream_id); + ret = CreateStream(hccl_secondary_stream_num, master_stream_id); + if (!ret) { + GELOGE(RT_FAILED, "Create hccl stream failed."); + return false; + } + return true; + } + + std::map>> &master_secondary_stream_map = + model_stream_mapping_.at(rt_model_handle_); + if (auto iter = master_secondary_stream_map.find(master_stream_id); iter != master_secondary_stream_map.end()) { + std::vector> &secondary_stream_vec = iter->second; + auto lock_weak_ptr = [&secondary_stream_vec, this](int64_t index) -> bool { + auto stream = secondary_stream_vec[index].lock(); + if (stream == nullptr) { + rtStream_t new_stream = nullptr; + bool ret = CreateStream(rt_model_handle_, &new_stream); + if (!ret) { + GELOGE(FAILED, "CreateStream failed."); + return false; + } + stream = std::make_shared(rt_model_handle_, new_stream); + if (stream == nullptr) { + GELOGE(FAILED, "MakeShared failed."); + return false; + } + secondary_stream_vec[index] = stream; + } + secondary_stream_list_.push_back(stream); + return true; + }; + + if (static_cast(hccl_secondary_stream_num) <= secondary_stream_vec.size()) { + GELOGI("Number of secondary stream is enough to be reused."); + for (int64_t i = 0; i < hccl_secondary_stream_num; ++i) { + if (!lock_weak_ptr(i)) { + GELOGE(FAILED, "Lock weak ptr failed."); + return false; + } + } + } else { + GELOGI("Need to reuse secondary stream and create new secondary stream."); + size_t created_stream_num = secondary_stream_vec.size(); + for (size_t i = 0; i < secondary_stream_vec.size(); ++i) { + if (!lock_weak_ptr(i)) { + GELOGE(FAILED, "Lock weak ptr failed."); + return false; + } + } + ret = CreateStream(hccl_secondary_stream_num - created_stream_num, master_stream_id); + if (ret != SUCCESS) { + GELOGE(RT_FAILED, "Create hccl stream failed."); + return false; + } + } + GELOGI("Initialize hccl secondary stream success, hccl_secondary_stream_num =%ld", hccl_secondary_stream_num); + } else { + GELOGI("Need to create secondary stream for %s with new mainstream %ld.", task_info_->op_name().c_str(), + master_stream_id); + ret = CreateStream(hccl_secondary_stream_num, master_stream_id); + if (!ret) { + GELOGE(RT_FAILED, "Create hccl stream failed."); + return false; + } + } + return true; +} + +bool HcclTask::CreateStream(int64_t stream_num, int64_t master_stream_id) { + GELOGI("Start to create %ld hccl secondary stream.", stream_num); + for (int64_t i = 0; i < stream_num; ++i) { + rtStream_t stream = nullptr; + bool ret = CreateStream(rt_model_handle_, &stream); + if (!ret) { + GELOGE(FAILED, "CreateStream failed."); + return false; + } + + GELOGD("hccl_stream addr is=%p", stream); + auto shared_stream = std::make_shared(rt_model_handle_, stream); + if (shared_stream == nullptr) { + GELOGE(FAILED, "MakeShared failed."); + return false; + } + SaveHcclSecondaryStream(master_stream_id, shared_stream); + secondary_stream_list_.push_back(shared_stream); + } + GELOGI("CreateStream success."); + return true; +} + +bool HcclTask::CreateStream(rtModel_t model, rtStream_t *stream) const { + if (stream == nullptr) { + GELOGE(FAILED, "Output param stream is null."); + return false; + } + + rtError_t rt_ret = rtStreamCreateWithFlags(stream, priority_, RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + // Create secondary stream, inactive by default, activated by hccl + rt_ret = rtModelBindStream(model, *stream, RT_MODEL_WAIT_ACTIVE_STREAM); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); + return false; + } + return true; +} + +void HcclTask::SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr &stream) { + if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) { + model_stream_mapping_.emplace(rt_model_handle_, std::map>>()); + } + std::map>> &master_secondary_stream_map = + model_stream_mapping_.at(rt_model_handle_); + master_secondary_stream_map[master_stream_id].emplace_back(stream); +} + +HcclTask::StreamGuard::~StreamGuard() { + rtError_t rt_ret = rtModelUnbindStream(model_, stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Unbind stream from model failed!"); + return; + } + + rt_ret = rtStreamDestroy(stream_); + if (rt_ret != RT_ERROR_NONE) { + GELOGE(RT_FAILED, "Destroy stream failed!"); + return; + } +} + REGISTER_TASK(TaskInfoType::HCCL, HcclTask, HcclTaskInfo); } // namespace model_runner } // namespace ge diff --git a/ge/ge_runtime/task/hccl_task.h b/ge/ge_runtime/task/hccl_task.h index 1649a8bd..11d88944 100644 --- a/ge/ge_runtime/task/hccl_task.h +++ b/ge/ge_runtime/task/hccl_task.h @@ -19,7 +19,9 @@ #include #include +#include #include +#include #include "ge_runtime/task/task.h" namespace ge { @@ -33,18 +35,34 @@ class HcclTask : public TaskRepeater { bool Distribute() override; private: + class StreamGuard; + bool SetSecondaryStream(); + bool CreateStream(int64_t stream_num, int64_t master_stream_id); + bool CreateStream(rtModel_t model, rtStream_t *stream) const; + void SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr &stream); + std::shared_ptr task_info_; void *stream_; + void *workspace_mem_; rtModel_t rt_model_handle_; int32_t priority_; - std::vector slave_stream_list_; - std::function hcom_bind_model_; - std::function hcom_unbind_model_; - std::function, void *)> hcom_distribute_task_; - static std::set rt_model_handle_list_; + std::vector> secondary_stream_list_; + + // map>> + static std::map>>> model_stream_mapping_; + static std::mutex model_stream_mapping_mutex_; }; -std::set HcclTask::rt_model_handle_list_{}; +class HcclTask::StreamGuard { + public: + StreamGuard(rtModel_t model, rtStream_t stream) : model_(model), stream_(stream) {} + ~StreamGuard(); + rtStream_t GetStream() const { return stream_; } + + private: + rtModel_t model_; + rtStream_t stream_; +}; } // namespace model_runner } // namespace ge diff --git a/inc/framework/ge_runtime/task_info.h b/inc/framework/ge_runtime/task_info.h index e36c4333..f59c6454 100644 --- a/inc/framework/ge_runtime/task_info.h +++ b/inc/framework/ge_runtime/task_info.h @@ -18,7 +18,6 @@ #define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_ #include -#include #include #include #include @@ -219,9 +218,9 @@ class LabelSwitchTaskInfo : public TaskInfo { label_list_(label_list), cond_(cond) {} ~LabelSwitchTaskInfo() override {} - uint32_t label_size() { return label_size_; }; - const std::vector &label_list() { return label_list_; }; - void *cond() { return cond_; }; + uint32_t label_size() const { return label_size_; } + const std::vector &label_list() const { return label_list_; } + void *cond() const { return cond_; } private: uint32_t label_size_; @@ -236,7 +235,7 @@ class EventTaskInfo : public TaskInfo { protected: EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id) : TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {} - virtual ~EventTaskInfo() override {} + ~EventTaskInfo() override {} uint32_t event_id_; }; @@ -272,16 +271,13 @@ class FusionEndTaskInfo : public TaskInfo { class HcclTaskInfo : public TaskInfo { public: HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr, - void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num, + void *output_data_addr, int64_t workspace_size, int64_t hccl_stream_num, const std::vector &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, - int64_t op_type, int64_t data_type, const std::string &group, - std::function hcom_bind_model, std::function hcom_unbind_model, - std::function, void *)> hcom_distribute_task, bool dump_flag) + int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag) : TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag), hccl_type_(hccl_type), input_data_addr_(input_data_addr), output_data_addr_(output_data_addr), - workspace_addr_(workspace_addr), workspace_size_(workspace_size), hccl_stream_num_(hccl_stream_num), private_def_(private_def), @@ -290,16 +286,12 @@ class HcclTaskInfo : public TaskInfo { root_id_(root_id), op_type_(op_type), data_type_(data_type), - group_(group), - hcom_bind_model_(hcom_bind_model), - hcom_unbind_model_(hcom_unbind_model), - hcom_distribute_task_(hcom_distribute_task) {} + group_(group) {} ~HcclTaskInfo() override {} const std::string &hccl_type() const { return hccl_type_; } void *input_data_addr() const { return input_data_addr_; } void *output_data_addr() const { return output_data_addr_; } - void *workspace_addr() const { return workspace_addr_; } int64_t workspace_size() const { return workspace_size_; } int64_t hccl_stream_num() const { return hccl_stream_num_; } const std::vector &private_def() const { return private_def_; } @@ -309,17 +301,11 @@ class HcclTaskInfo : public TaskInfo { int64_t op_type() const { return op_type_; } int64_t data_type() const { return data_type_; } const std::string &group() const { return group_; } - std::function hcom_bind_model() const { return hcom_bind_model_; } - std::function hcom_unbind_model() const { return hcom_unbind_model_; } - std::function, void *)> hcom_distribute_task() const { - return hcom_distribute_task_; - } private: std::string hccl_type_; void *input_data_addr_; void *output_data_addr_; - void *workspace_addr_; int64_t workspace_size_; int64_t hccl_stream_num_; std::vector private_def_; @@ -329,9 +315,6 @@ class HcclTaskInfo : public TaskInfo { int64_t op_type_; int64_t data_type_; std::string group_; - std::function hcom_bind_model_; - std::function hcom_unbind_model_; - std::function, void *)> hcom_distribute_task_; }; class ProfilerTraceTaskInfo : public TaskInfo {