Browse Source

Delete deprecated hccl calling methods

tags/v1.1.0
yanghaoran 3 years ago
parent
commit
fa21658965
3 changed files with 197 additions and 87 deletions
  1. +166
    -57
      ge/ge_runtime/task/hccl_task.cc
  2. +24
    -6
      ge/ge_runtime/task/hccl_task.h
  3. +7
    -24
      inc/framework/ge_runtime/task_info.h

+ 166
- 57
ge/ge_runtime/task/hccl_task.cc View File

@@ -15,83 +15,56 @@
*/

#include "ge_runtime/task/hccl_task.h"
#include <algorithm>
#include "ge_runtime/task/task_factory.h"
#include "common/opskernel/ops_kernel_info_store.h"
#include "common/opskernel/ge_task_info.h"

namespace ge {
namespace model_runner {
std::map<rtModel_t, std::map<uint32_t, std::vector<std::weak_ptr<HcclTask::StreamGuard>>>>
HcclTask::model_stream_mapping_;
std::mutex HcclTask::model_stream_mapping_mutex_;

HcclTask::HcclTask(const ModelContext &model_context, const std::shared_ptr<HcclTaskInfo> &task_info)
: TaskRepeater<HcclTaskInfo>(model_context, task_info),
task_info_(task_info),
stream_(nullptr),
workspace_mem_(nullptr),
rt_model_handle_(nullptr),
priority_(0),
slave_stream_list_(),
hcom_bind_model_(nullptr),
hcom_unbind_model_(nullptr),
hcom_distribute_task_(nullptr) {
secondary_stream_list_() {
if (task_info_ == nullptr) {
GELOGW("task_info_ is null!");
}

hcom_bind_model_ = task_info->hcom_bind_model();
hcom_unbind_model_ = task_info->hcom_unbind_model();

priority_ = model_context.priority();
rt_model_handle_ = model_context.rt_model_handle();
auto stream_list = model_context.stream_list();

if (hcom_bind_model_ != nullptr) {
if (rt_model_handle_list_.insert(rt_model_handle_).second) {
for (auto stream : stream_list) {
(void)hcom_bind_model_(rt_model_handle_, stream);
}
}
}

if (stream_list.size() == 1) {
stream_ = stream_list[0];
} else if (stream_list.size() > task_info->stream_id()) {
stream_ = stream_list[task_info->stream_id()];
} else {
GELOGW("index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size());
GELOGW("Index: %u >= stream_list.size(): %zu.", task_info->stream_id(), stream_list.size());
}
}

HcclTask::~HcclTask() {
for (size_t i = 0; i < slave_stream_list_.size(); ++i) {
rtError_t rt_ret = rtModelUnbindStream(rt_model_handle_, slave_stream_list_[i]);
if (workspace_mem_ != nullptr) {
rtError_t rt_ret = rtFree(workspace_mem_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Unbind stream from model failed! Index: %zu", i);
}
}

for (size_t i = 0; i < slave_stream_list_.size(); ++i) {
rtError_t rt_ret = rtStreamDestroy(slave_stream_list_[i]);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Destroy stream failed! Index: %zu", i);
}
}

if (hcom_unbind_model_ != nullptr) {
if (rt_model_handle_list_.find(rt_model_handle_) != rt_model_handle_list_.end()) {
(void)hcom_unbind_model_(rt_model_handle_);
(void)rt_model_handle_list_.erase(rt_model_handle_);
GELOGE(RT_FAILED, "rtFree workspace_mem_ failed! ret: 0x%X.", rt_ret);
}
workspace_mem_ = nullptr;
}
}

bool HcclTask::Distribute() {
// No ops kernel info store
hcom_distribute_task_ = task_info_->hcom_distribute_task();
if (hcom_distribute_task_ != nullptr) {
return hcom_distribute_task_(task_info_, stream_);
}

// Ops kernel info store
// Get privateDef and opsKernelStorePtr
GELOGI("get custom info in modelTaskDef");
GELOGI("Get custom info in modelTaskDef");
void *ops_kernel_store = task_info_->ops_kernel_store();
OpsKernelInfoStore *ops_kernel_info_store = reinterpret_cast<OpsKernelInfoStore *>(ops_kernel_store);
if (ops_kernel_store == nullptr) {
@@ -101,25 +74,15 @@ bool HcclTask::Distribute() {

char *private_def = reinterpret_cast<char *>(const_cast<char unsigned *>(task_info_->private_def().data()));
auto private_def_len = static_cast<uint32_t>(task_info_->private_def().size());
GELOGI("the first address of the custom info, privateDef=%p", private_def);

GELOGI("hcclStreamNum =%ld", task_info_->hccl_stream_num());
for (int64_t i = 0; i < task_info_->hccl_stream_num(); ++i) {
rtStream_t stream = nullptr;
rtError_t rt_ret = rtStreamCreateWithFlags(&stream, priority_, RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
GELOGI("The first address of the custom info, privateDef=%p", private_def);
SetSecondaryStream();

rt_ret = rtModelBindStream(rt_model_handle_, stream, RT_HEAD_STREAM);
if (task_info_->workspace_size() > 0) {
rtError_t rt_ret = rtMalloc(&workspace_mem_, task_info_->workspace_size(), RT_MEMORYINFO_HBM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}

GELOGI("hccl_stream addr is=%p", stream);
slave_stream_list_.push_back(stream);
}

GELOGI("HcclTaskInfo Distribute Start. begin to call function LoadTask in hccl.");
@@ -128,17 +91,22 @@ bool HcclTask::Distribute() {
ge_task.type = static_cast<uint16_t>(RT_MODEL_TASK_HCCL);
ge_task.stream = stream_;

ge_task.kernelHcclInfo = std::vector<GETaskKernelHcclInfo>(1);
ge_task.kernelHcclInfo[0].hccl_type = task_info_->hccl_type();
ge_task.kernelHcclInfo[0].inputDataAddr = task_info_->input_data_addr();
ge_task.kernelHcclInfo[0].outputDataAddr = task_info_->output_data_addr();
ge_task.kernelHcclInfo[0].workSpaceAddr = task_info_->workspace_addr();
ge_task.kernelHcclInfo[0].workSpaceAddr = workspace_mem_;
ge_task.kernelHcclInfo[0].workSpaceMemSize = task_info_->workspace_size();
ge_task.kernelHcclInfo[0].count = task_info_->count();
ge_task.kernelHcclInfo[0].dataType = static_cast<int32_t>(task_info_->data_type());
ge_task.kernelHcclInfo[0].opType = static_cast<int32_t>(task_info_->op_type());
ge_task.kernelHcclInfo[0].rootId = task_info_->root_id();

ge_task.kernelHcclInfo[0].hcclStreamList = slave_stream_list_;
std::vector<rtStream_t> secondary_stream_list;
std::transform(secondary_stream_list_.begin(), secondary_stream_list_.end(),
std::back_inserter(secondary_stream_list),
[](const std::shared_ptr<StreamGuard> &stream) -> rtStream_t { return stream->GetStream(); });
ge_task.kernelHcclInfo[0].hcclStreamList = secondary_stream_list;

ge_task.privateDef = private_def;
ge_task.privateDefLen = private_def_len;
@@ -151,10 +119,151 @@ bool HcclTask::Distribute() {
return false;
}

GELOGI("call function LoadTask end.");
GELOGI("Call function LoadTask end.");
return true;
}

bool HcclTask::SetSecondaryStream() {
const uint32_t master_stream_id = task_info_->stream_id();
const int64_t hccl_secondary_stream_num = task_info_->hccl_stream_num();
Status ret;
std::lock_guard<std::mutex> lock(model_stream_mapping_mutex_);
if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) {
GELOGI("Need to create map for rt_model_handle_:%p with new mainstream %ld.", rt_model_handle_, master_stream_id);
ret = CreateStream(hccl_secondary_stream_num, master_stream_id);
if (!ret) {
GELOGE(RT_FAILED, "Create hccl stream failed.");
return false;
}
return true;
}

std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>> &master_secondary_stream_map =
model_stream_mapping_.at(rt_model_handle_);
if (auto iter = master_secondary_stream_map.find(master_stream_id); iter != master_secondary_stream_map.end()) {
std::vector<std::weak_ptr<StreamGuard>> &secondary_stream_vec = iter->second;
auto lock_weak_ptr = [&secondary_stream_vec, this](int64_t index) -> bool {
auto stream = secondary_stream_vec[index].lock();
if (stream == nullptr) {
rtStream_t new_stream = nullptr;
bool ret = CreateStream(rt_model_handle_, &new_stream);
if (!ret) {
GELOGE(FAILED, "CreateStream failed.");
return false;
}
stream = std::make_shared<HcclTask::StreamGuard>(rt_model_handle_, new_stream);
if (stream == nullptr) {
GELOGE(FAILED, "MakeShared failed.");
return false;
}
secondary_stream_vec[index] = stream;
}
secondary_stream_list_.push_back(stream);
return true;
};

if (static_cast<size_t>(hccl_secondary_stream_num) <= secondary_stream_vec.size()) {
GELOGI("Number of secondary stream is enough to be reused.");
for (int64_t i = 0; i < hccl_secondary_stream_num; ++i) {
if (!lock_weak_ptr(i)) {
GELOGE(FAILED, "Lock weak ptr failed.");
return false;
}
}
} else {
GELOGI("Need to reuse secondary stream and create new secondary stream.");
size_t created_stream_num = secondary_stream_vec.size();
for (size_t i = 0; i < secondary_stream_vec.size(); ++i) {
if (!lock_weak_ptr(i)) {
GELOGE(FAILED, "Lock weak ptr failed.");
return false;
}
}
ret = CreateStream(hccl_secondary_stream_num - created_stream_num, master_stream_id);
if (ret != SUCCESS) {
GELOGE(RT_FAILED, "Create hccl stream failed.");
return false;
}
}
GELOGI("Initialize hccl secondary stream success, hccl_secondary_stream_num =%ld", hccl_secondary_stream_num);
} else {
GELOGI("Need to create secondary stream for %s with new mainstream %ld.", task_info_->op_name().c_str(),
master_stream_id);
ret = CreateStream(hccl_secondary_stream_num, master_stream_id);
if (!ret) {
GELOGE(RT_FAILED, "Create hccl stream failed.");
return false;
}
}
return true;
}

bool HcclTask::CreateStream(int64_t stream_num, int64_t master_stream_id) {
GELOGI("Start to create %ld hccl secondary stream.", stream_num);
for (int64_t i = 0; i < stream_num; ++i) {
rtStream_t stream = nullptr;
bool ret = CreateStream(rt_model_handle_, &stream);
if (!ret) {
GELOGE(FAILED, "CreateStream failed.");
return false;
}

GELOGD("hccl_stream addr is=%p", stream);
auto shared_stream = std::make_shared<StreamGuard>(rt_model_handle_, stream);
if (shared_stream == nullptr) {
GELOGE(FAILED, "MakeShared failed.");
return false;
}
SaveHcclSecondaryStream(master_stream_id, shared_stream);
secondary_stream_list_.push_back(shared_stream);
}
GELOGI("CreateStream success.");
return true;
}

bool HcclTask::CreateStream(rtModel_t model, rtStream_t *stream) const {
if (stream == nullptr) {
GELOGE(FAILED, "Output param stream is null.");
return false;
}

rtError_t rt_ret = rtStreamCreateWithFlags(stream, priority_, RT_STREAM_PERSISTENT | RT_STREAM_FORCE_COPY);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
// Create secondary stream, inactive by default, activated by hccl
rt_ret = rtModelBindStream(model, *stream, RT_MODEL_WAIT_ACTIVE_STREAM);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false;
}
return true;
}

void HcclTask::SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr<StreamGuard> &stream) {
if (model_stream_mapping_.find(rt_model_handle_) == model_stream_mapping_.end()) {
model_stream_mapping_.emplace(rt_model_handle_, std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>>());
}
std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>> &master_secondary_stream_map =
model_stream_mapping_.at(rt_model_handle_);
master_secondary_stream_map[master_stream_id].emplace_back(stream);
}

HcclTask::StreamGuard::~StreamGuard() {
rtError_t rt_ret = rtModelUnbindStream(model_, stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Unbind stream from model failed!");
return;
}

rt_ret = rtStreamDestroy(stream_);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Destroy stream failed!");
return;
}
}

REGISTER_TASK(TaskInfoType::HCCL, HcclTask, HcclTaskInfo);
} // namespace model_runner
} // namespace ge

+ 24
- 6
ge/ge_runtime/task/hccl_task.h View File

@@ -19,7 +19,9 @@

#include <memory>
#include <set>
#include <map>
#include <vector>
#include <mutex>
#include "ge_runtime/task/task.h"

namespace ge {
@@ -33,18 +35,34 @@ class HcclTask : public TaskRepeater<HcclTaskInfo> {
bool Distribute() override;

private:
class StreamGuard;
bool SetSecondaryStream();
bool CreateStream(int64_t stream_num, int64_t master_stream_id);
bool CreateStream(rtModel_t model, rtStream_t *stream) const;
void SaveHcclSecondaryStream(int64_t master_stream_id, const std::shared_ptr<StreamGuard> &stream);

std::shared_ptr<HcclTaskInfo> task_info_;
void *stream_;
void *workspace_mem_;
rtModel_t rt_model_handle_;
int32_t priority_;
std::vector<void *> slave_stream_list_;
std::function<bool(void *, void *)> hcom_bind_model_;
std::function<bool(void *)> hcom_unbind_model_;
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_;
static std::set<rtModel_t> rt_model_handle_list_;
std::vector<std::shared_ptr<StreamGuard>> secondary_stream_list_;
// map<key: model pointer, value: map<key: primary stream id, value: vector<secondary stream pointer>>>
static std::map<rtModel_t, std::map<uint32_t, std::vector<std::weak_ptr<StreamGuard>>>> model_stream_mapping_;
static std::mutex model_stream_mapping_mutex_;
};

std::set<rtModel_t> HcclTask::rt_model_handle_list_{};
class HcclTask::StreamGuard {
public:
StreamGuard(rtModel_t model, rtStream_t stream) : model_(model), stream_(stream) {}
~StreamGuard();
rtStream_t GetStream() const { return stream_; }

private:
rtModel_t model_;
rtStream_t stream_;
};
} // namespace model_runner
} // namespace ge



+ 7
- 24
inc/framework/ge_runtime/task_info.h View File

@@ -18,7 +18,6 @@
#define INC_FRAMEWORK_GE_RUNTIME_TASK_INFO_H_

#include <stdint.h>
#include <functional>
#include <memory>
#include <string>
#include <utility>
@@ -219,9 +218,9 @@ class LabelSwitchTaskInfo : public TaskInfo {
label_list_(label_list),
cond_(cond) {}
~LabelSwitchTaskInfo() override {}
uint32_t label_size() { return label_size_; };
const std::vector<uint32_t> &label_list() { return label_list_; };
void *cond() { return cond_; };
uint32_t label_size() const { return label_size_; }
const std::vector<uint32_t> &label_list() const { return label_list_; }
void *cond() const { return cond_; }

private:
uint32_t label_size_;
@@ -236,7 +235,7 @@ class EventTaskInfo : public TaskInfo {
protected:
EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
: TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
virtual ~EventTaskInfo() override {}
~EventTaskInfo() override {}

uint32_t event_id_;
};
@@ -272,16 +271,13 @@ class FusionEndTaskInfo : public TaskInfo {
class HcclTaskInfo : public TaskInfo {
public:
HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
void *output_data_addr, int64_t workspace_size, int64_t hccl_stream_num,
const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
int64_t op_type, int64_t data_type, const std::string &group,
std::function<bool(void *, void *)> hcom_bind_model, std::function<bool(void *)> hcom_unbind_model,
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task, bool dump_flag)
int64_t op_type, int64_t data_type, const std::string &group, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
hccl_type_(hccl_type),
input_data_addr_(input_data_addr),
output_data_addr_(output_data_addr),
workspace_addr_(workspace_addr),
workspace_size_(workspace_size),
hccl_stream_num_(hccl_stream_num),
private_def_(private_def),
@@ -290,16 +286,12 @@ class HcclTaskInfo : public TaskInfo {
root_id_(root_id),
op_type_(op_type),
data_type_(data_type),
group_(group),
hcom_bind_model_(hcom_bind_model),
hcom_unbind_model_(hcom_unbind_model),
hcom_distribute_task_(hcom_distribute_task) {}
group_(group) {}
~HcclTaskInfo() override {}

const std::string &hccl_type() const { return hccl_type_; }
void *input_data_addr() const { return input_data_addr_; }
void *output_data_addr() const { return output_data_addr_; }
void *workspace_addr() const { return workspace_addr_; }
int64_t workspace_size() const { return workspace_size_; }
int64_t hccl_stream_num() const { return hccl_stream_num_; }
const std::vector<uint8_t> &private_def() const { return private_def_; }
@@ -309,17 +301,11 @@ class HcclTaskInfo : public TaskInfo {
int64_t op_type() const { return op_type_; }
int64_t data_type() const { return data_type_; }
const std::string &group() const { return group_; }
std::function<bool(void *, void *)> hcom_bind_model() const { return hcom_bind_model_; }
std::function<bool(void *)> hcom_unbind_model() const { return hcom_unbind_model_; }
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task() const {
return hcom_distribute_task_;
}

private:
std::string hccl_type_;
void *input_data_addr_;
void *output_data_addr_;
void *workspace_addr_;
int64_t workspace_size_;
int64_t hccl_stream_num_;
std::vector<uint8_t> private_def_;
@@ -329,9 +315,6 @@ class HcclTaskInfo : public TaskInfo {
int64_t op_type_;
int64_t data_type_;
std::string group_;
std::function<bool(void *, void *)> hcom_bind_model_;
std::function<bool(void *)> hcom_unbind_model_;
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task_;
};

class ProfilerTraceTaskInfo : public TaskInfo {


Loading…
Cancel
Save