Browse Source

!41 Support data dump

Merge pull request !41 from caifubi/data-dump-2
tags/v0.6.0-beta
mindspore-ci-bot Gitee 4 years ago
parent
commit
89ad354691
10 changed files with 135 additions and 57 deletions
  1. +4
    -1
      inc/framework/ge_runtime/model_runner.h
  2. +55
    -40
      inc/framework/ge_runtime/task_info.h
  3. +20
    -0
      src/ge/ge_runtime/model_runner.cc
  4. +24
    -6
      src/ge/ge_runtime/runtime_model.cc
  5. +4
    -1
      src/ge/ge_runtime/runtime_model.h
  6. +9
    -5
      src/ge/ge_runtime/task/aicpu_task.cc
  7. +6
    -0
      src/ge/ge_runtime/task/aicpu_task.h
  8. +6
    -0
      src/ge/ge_runtime/task/task.h
  9. +3
    -4
      src/ge/ge_runtime/task/tbe_task.cc
  10. +4
    -0
      src/ge/ge_runtime/task/tbe_task.h

+ 4
- 1
inc/framework/ge_runtime/model_runner.h View File

@@ -28,18 +28,21 @@
namespace ge { namespace ge {
namespace model_runner { namespace model_runner {
class RuntimeModel; class RuntimeModel;
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>;
class ModelRunner { class ModelRunner {
public: public:
static ModelRunner &Instance(); static ModelRunner &Instance();


bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id, bool LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint32_t model_id,
std::shared_ptr<DavinciModel> davinci_model, std::shared_ptr<ModelListener> listener); std::shared_ptr<DavinciModel> davinci_model, std::shared_ptr<ModelListener> listener);
bool LoadModelComplete(uint32_t model_id);


const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const; const std::vector<uint32_t> &GetTaskIdList(uint32_t model_id) const;


const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const; const std::vector<uint32_t> &GetStreamIdList(uint32_t model_id) const;


const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap(uint32_t model_id) const;

bool UnloadModel(uint32_t model_id); bool UnloadModel(uint32_t model_id);


bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data); bool RunModel(uint32_t model_id, const InputData &input_data, OutputData *output_data);


+ 55
- 40
inc/framework/ge_runtime/task_info.h View File

@@ -21,6 +21,7 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <string> #include <string>
#include <utility>
#include <vector> #include <vector>


#include "cce/taskdown_api.h" #include "cce/taskdown_api.h"
@@ -52,21 +53,27 @@ class TaskInfo {
virtual ~TaskInfo() {} virtual ~TaskInfo() {}
uint32_t stream_id() const { return stream_id_; } uint32_t stream_id() const { return stream_id_; }
TaskInfoType type() const { return type_; } TaskInfoType type() const { return type_; }
std::string op_name() const { return op_name_; }
bool dump_flag() const { return dump_flag_; }


protected: protected:
TaskInfo(uint32_t stream_id, TaskInfoType type) : stream_id_(stream_id), type_(type) {}
TaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, bool dump_flag)
: op_name_(op_name), stream_id_(stream_id), type_(type), dump_flag_(dump_flag) {}


private: private:
std::string op_name_;
uint32_t stream_id_; uint32_t stream_id_;
TaskInfoType type_; TaskInfoType type_;
bool dump_flag_;
}; };


class CceTaskInfo : public TaskInfo { class CceTaskInfo : public TaskInfo {
public: public:
CceTaskInfo(uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func, uint32_t block_dim,
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc,
const std::vector<uint8_t> &flow_table, const std::vector<uint8_t> &args_offset, bool is_flowtable)
: TaskInfo(stream_id, TaskInfoType::CCE),
CceTaskInfo(const std::string &op_name, uint32_t stream_id, const cce::ccOpContext &ctx, const std::string &stub_func,
uint32_t block_dim, const std::vector<uint8_t> &args, uint32_t args_size,
const std::vector<uint8_t> &sm_desc, const std::vector<uint8_t> &flow_table,
const std::vector<uint8_t> &args_offset, bool is_flowtable)
: TaskInfo(op_name, stream_id, TaskInfoType::CCE, false),
ctx_(ctx), ctx_(ctx),
stub_func_(stub_func), stub_func_(stub_func),
block_dim_(block_dim), block_dim_(block_dim),
@@ -102,11 +109,11 @@ class CceTaskInfo : public TaskInfo {


class TbeTaskInfo : public TaskInfo { class TbeTaskInfo : public TaskInfo {
public: public:
TbeTaskInfo(uint32_t stream_id, const std::string &stub_func, uint32_t block_dim, const std::vector<uint8_t> &args,
uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary, uint32_t binary_size,
const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs)
: TaskInfo(stream_id, TaskInfoType::TBE),
TbeTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string &stub_func, uint32_t block_dim,
const std::vector<uint8_t> &args, uint32_t args_size, const std::vector<uint8_t> &sm_desc, void *binary,
uint32_t binary_size, const std::vector<uint8_t> &meta_data, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, const std::vector<void *> &workspace_addrs, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::TBE, dump_flag),
stub_func_(stub_func), stub_func_(stub_func),
block_dim_(block_dim), block_dim_(block_dim),
args_(args), args_(args),
@@ -153,9 +160,10 @@ class TbeTaskInfo : public TaskInfo {


class AicpuTaskInfo : public TaskInfo { class AicpuTaskInfo : public TaskInfo {
public: public:
AicpuTaskInfo(uint32_t stream_id, const string &so_name, const std::string &kernel_name, const std::string &node_def,
const std::vector<void *> &input_data_addrs, const std::vector<void *> &output_data_addrs)
: TaskInfo(stream_id, TaskInfoType::AICPU),
AicpuTaskInfo(const std::string &op_name, uint32_t stream_id, const string &so_name, const std::string &kernel_name,
const std::string &node_def, const std::vector<void *> &input_data_addrs,
const std::vector<void *> &output_data_addrs, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::AICPU, dump_flag),
so_name_(so_name), so_name_(so_name),
kernel_name_(kernel_name), kernel_name_(kernel_name),
node_def_(node_def), node_def_(node_def),
@@ -179,8 +187,8 @@ class AicpuTaskInfo : public TaskInfo {


class LabelSetTaskInfo : public TaskInfo { class LabelSetTaskInfo : public TaskInfo {
public: public:
LabelSetTaskInfo(uint32_t stream_id, uint32_t label_id)
: TaskInfo(stream_id, TaskInfoType::LABEL_SET), label_id_(label_id) {}
LabelSetTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SET, false), label_id_(label_id) {}
~LabelSetTaskInfo() override {} ~LabelSetTaskInfo() override {}
uint32_t label_id() const { return label_id_; } uint32_t label_id() const { return label_id_; }


@@ -190,8 +198,8 @@ class LabelSetTaskInfo : public TaskInfo {


class LabelGotoTaskInfo : public TaskInfo { class LabelGotoTaskInfo : public TaskInfo {
public: public:
LabelGotoTaskInfo(uint32_t stream_id, uint32_t label_id)
: TaskInfo(stream_id, TaskInfoType::LABEL_GOTO), label_id_(label_id) {}
LabelGotoTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_id)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_GOTO, false), label_id_(label_id) {}
~LabelGotoTaskInfo() override {} ~LabelGotoTaskInfo() override {}
uint32_t label_id() const { return label_id_; } uint32_t label_id() const { return label_id_; }


@@ -201,8 +209,9 @@ class LabelGotoTaskInfo : public TaskInfo {


class LabelSwitchTaskInfo : public TaskInfo { class LabelSwitchTaskInfo : public TaskInfo {
public: public:
LabelSwitchTaskInfo(uint32_t stream_id, uint32_t label_size, const std::vector<uint32_t> &label_list, void *cond)
: TaskInfo(stream_id, TaskInfoType::LABEL_SWITCH),
LabelSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t label_size,
const std::vector<uint32_t> &label_list, void *cond)
: TaskInfo(op_name, stream_id, TaskInfoType::LABEL_SWITCH, false),
label_size_(label_size), label_size_(label_size),
label_list_(label_list), label_list_(label_list),
cond_(cond) {} cond_(cond) {}
@@ -222,8 +231,8 @@ class EventTaskInfo : public TaskInfo {
uint32_t event_id() const { return event_id_; } uint32_t event_id() const { return event_id_; }


protected: protected:
EventTaskInfo(uint32_t stream_id, TaskInfoType type, uint32_t event_id)
: TaskInfo(stream_id, type), event_id_(event_id) {}
EventTaskInfo(const std::string &op_name, uint32_t stream_id, TaskInfoType type, uint32_t event_id)
: TaskInfo(op_name, stream_id, type, false), event_id_(event_id) {}
virtual ~EventTaskInfo() override {} virtual ~EventTaskInfo() override {}


uint32_t event_id_; uint32_t event_id_;
@@ -231,39 +240,41 @@ class EventTaskInfo : public TaskInfo {


class EventRecordTaskInfo : public EventTaskInfo { class EventRecordTaskInfo : public EventTaskInfo {
public: public:
EventRecordTaskInfo(uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
EventRecordTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_RECORD, event_id) {}
~EventRecordTaskInfo() override {} ~EventRecordTaskInfo() override {}
}; };


class EventWaitTaskInfo : public EventTaskInfo { class EventWaitTaskInfo : public EventTaskInfo {
public: public:
EventWaitTaskInfo(uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
EventWaitTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t event_id)
: EventTaskInfo(op_name, stream_id, TaskInfoType::EVENT_WAIT, event_id) {}
~EventWaitTaskInfo() override {} ~EventWaitTaskInfo() override {}
}; };


class FusionStartTaskInfo : public TaskInfo { class FusionStartTaskInfo : public TaskInfo {
public: public:
explicit FusionStartTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_START) {}
explicit FusionStartTaskInfo(const std::string &op_name, uint32_t stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_START, false) {}
~FusionStartTaskInfo() override {} ~FusionStartTaskInfo() override {}
}; };


class FusionEndTaskInfo : public TaskInfo { class FusionEndTaskInfo : public TaskInfo {
public: public:
explicit FusionEndTaskInfo(uint32_t stream_id) : TaskInfo(stream_id, TaskInfoType::FUSION_END) {}
explicit FusionEndTaskInfo(const std::string &op_name, uint32_t stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::FUSION_END, false) {}
~FusionEndTaskInfo() override {} ~FusionEndTaskInfo() override {}
}; };


class HcclTaskInfo : public TaskInfo { class HcclTaskInfo : public TaskInfo {
public: public:
HcclTaskInfo(uint32_t stream_id, const std::string hccl_type, void *input_data_addr, void *output_data_addr,
void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
HcclTaskInfo(const std::string &op_name, uint32_t stream_id, const std::string hccl_type, void *input_data_addr,
void *output_data_addr, void *workspace_addr, int64_t workspace_size, int64_t hccl_stream_num,
const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id, const std::vector<uint8_t> &private_def, void *ops_kernel_store, int32_t count, int64_t root_id,
int64_t op_type, int64_t data_type, const std::string &group, int64_t op_type, int64_t data_type, const std::string &group,
std::function<bool(void *, void *)> hcom_bind_model, std::function<bool(void *)> hcom_unbind_model, std::function<bool(void *, void *)> hcom_bind_model, std::function<bool(void *)> hcom_unbind_model,
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task)
: TaskInfo(stream_id, TaskInfoType::HCCL),
std::function<bool(std::shared_ptr<HcclTaskInfo>, void *)> hcom_distribute_task, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::HCCL, dump_flag),
hccl_type_(hccl_type), hccl_type_(hccl_type),
input_data_addr_(input_data_addr), input_data_addr_(input_data_addr),
output_data_addr_(output_data_addr), output_data_addr_(output_data_addr),
@@ -322,8 +333,11 @@ class HcclTaskInfo : public TaskInfo {


class ProfilerTraceTaskInfo : public TaskInfo { class ProfilerTraceTaskInfo : public TaskInfo {
public: public:
ProfilerTraceTaskInfo(uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
: TaskInfo(stream_id, TaskInfoType::PROFILER_TRACE), log_id_(log_id), notify_(notify), flat_(flat) {}
ProfilerTraceTaskInfo(const std::string &op_name, uint32_t stream_id, uint64_t log_id, bool notify, uint32_t flat)
: TaskInfo(op_name, stream_id, TaskInfoType::PROFILER_TRACE, false),
log_id_(log_id),
notify_(notify),
flat_(flat) {}
~ProfilerTraceTaskInfo() override {} ~ProfilerTraceTaskInfo() override {}


uint64_t log_id() const { return log_id_; } uint64_t log_id() const { return log_id_; }
@@ -338,8 +352,9 @@ class ProfilerTraceTaskInfo : public TaskInfo {


class MemcpyAsyncTaskInfo : public TaskInfo { class MemcpyAsyncTaskInfo : public TaskInfo {
public: public:
MemcpyAsyncTaskInfo(uint32_t stream_id, void *dst, uint64_t dst_max, void *src, uint64_t count, uint32_t kind)
: TaskInfo(stream_id, TaskInfoType::MEMCPY_ASYNC),
MemcpyAsyncTaskInfo(const std::string &op_name, uint32_t stream_id, void *dst, uint64_t dst_max, void *src,
uint64_t count, uint32_t kind, bool dump_flag)
: TaskInfo(op_name, stream_id, TaskInfoType::MEMCPY_ASYNC, dump_flag),
dst_(dst), dst_(dst),
dst_max_(dst_max), dst_max_(dst_max),
src_(src), src_(src),
@@ -363,9 +378,9 @@ class MemcpyAsyncTaskInfo : public TaskInfo {


class StreamSwitchTaskInfo : public TaskInfo { class StreamSwitchTaskInfo : public TaskInfo {
public: public:
StreamSwitchTaskInfo(uint32_t stream_id, int64_t true_stream_id, void *input_addr, void *value_addr, int64_t cond,
int64_t data_type)
: TaskInfo(stream_id, TaskInfoType::STREAM_SWITCH),
StreamSwitchTaskInfo(const std::string &op_name, uint32_t stream_id, int64_t true_stream_id, void *input_addr,
void *value_addr, int64_t cond, int64_t data_type)
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_SWITCH, false),
true_stream_id_(true_stream_id), true_stream_id_(true_stream_id),
input_addr_(input_addr), input_addr_(input_addr),
value_addr_(value_addr), value_addr_(value_addr),
@@ -389,8 +404,8 @@ class StreamSwitchTaskInfo : public TaskInfo {


class StreamActiveTaskInfo : public TaskInfo { class StreamActiveTaskInfo : public TaskInfo {
public: public:
StreamActiveTaskInfo(uint32_t stream_id, uint32_t active_stream_id)
: TaskInfo(stream_id, TaskInfoType::STREAM_ACTIVE), active_stream_id_(active_stream_id) {}
StreamActiveTaskInfo(const std::string &op_name, uint32_t stream_id, uint32_t active_stream_id)
: TaskInfo(op_name, stream_id, TaskInfoType::STREAM_ACTIVE, false), active_stream_id_(active_stream_id) {}
~StreamActiveTaskInfo() override {} ~StreamActiveTaskInfo() override {}


uint32_t active_stream_id() const { return active_stream_id_; } uint32_t active_stream_id() const { return active_stream_id_; }


+ 20
- 0
src/ge/ge_runtime/model_runner.cc View File

@@ -49,6 +49,15 @@ bool ModelRunner::LoadDavinciModel(uint32_t device_id, uint64_t session_id, uint
return true; return true;
} }


bool ModelRunner::LoadModelComplete(uint32_t model_id) {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGE(PARAM_INVALID, "Model id %u not found.", model_id);
return false;
}
return model_iter->second->LoadComplete();
}

const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const { const std::vector<uint32_t> &ModelRunner::GetTaskIdList(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id); auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) { if (model_iter == runtime_models_.end()) {
@@ -71,6 +80,17 @@ const std::vector<uint32_t> &ModelRunner::GetStreamIdList(uint32_t model_id) con
return model_iter->second->GetStreamIdList(); return model_iter->second->GetStreamIdList();
} }


const std::map<std::string, std::shared_ptr<RuntimeInfo>> &ModelRunner::GetRuntimeInfoMap(uint32_t model_id) const {
auto model_iter = runtime_models_.find(model_id);
if (model_iter == runtime_models_.end()) {
GELOGW("Model id %u not found.", model_id);
static const std::map<std::string, std::shared_ptr<RuntimeInfo>> empty_ret;
return empty_ret;
}

return model_iter->second->GetRuntimeInfoMap();
}

bool ModelRunner::UnloadModel(uint32_t model_id) { bool ModelRunner::UnloadModel(uint32_t model_id) {
auto iter = runtime_models_.find(model_id); auto iter = runtime_models_.find(model_id);
if (iter != runtime_models_.end()) { if (iter != runtime_models_.end()) {


+ 24
- 6
src/ge/ge_runtime/runtime_model.cc View File

@@ -28,7 +28,6 @@


namespace ge { namespace ge {
namespace model_runner { namespace model_runner {

RuntimeModel::~RuntimeModel() { RuntimeModel::~RuntimeModel() {
GELOGI("RuntimeModel destructor start"); GELOGI("RuntimeModel destructor start");


@@ -221,21 +220,40 @@ bool RuntimeModel::LoadTask() {
} }
task_id_list_.push_back(task_id); task_id_list_.push_back(task_id);
stream_id_list_.push_back(stream_id); stream_id_list_.push_back(stream_id);
if (task->Args() != nullptr) {
std::shared_ptr<RuntimeInfo> runtime_tuple = nullptr;
GE_MAKE_SHARED(runtime_tuple = std::make_shared<RuntimeInfo>(task_id, stream_id, task->Args()), return false);
auto emplace_ret = runtime_info_map_.emplace(task->task_name(), runtime_tuple);
if (!emplace_ret.second) {
GELOGW("Task name exist:%s", task->task_name().c_str());
}
}
} }
if (task_list_.empty()) { if (task_list_.empty()) {
GELOGE(FAILED, "Task list is empty"); GELOGE(FAILED, "Task list is empty");
return false; return false;
} }
GELOGI("Distribute task succ.");


auto rt_ret = rtModelLoadComplete(rt_model_handle_);
GELOGI("LoadTask succ.");
return true;
}

bool RuntimeModel::LoadComplete() {
uint32_t task_id = 0;
uint32_t stream_id = 0;
auto rt_ret = rtModelGetTaskId(rt_model_handle_, &task_id, &stream_id);
if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rtModelGetTaskId failed, ret:0x%X", rt_ret);
return RT_FAILED;
}
task_id_list_.push_back(task_id);
stream_id_list_.push_back(stream_id);

rt_ret = rtModelLoadComplete(rt_model_handle_);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtModelLoadComplete failed, ret: 0x%X.", rt_ret); GELOGE(RT_FAILED, "Call rt api rtModelLoadComplete failed, ret: 0x%X.", rt_ret);
return false; return false;
} }

GELOGI("LoadTask succ.");
return true;
} }


bool RuntimeModel::Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model) { bool RuntimeModel::Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model) {


+ 4
- 1
src/ge/ge_runtime/runtime_model.h View File

@@ -27,7 +27,7 @@


namespace ge { namespace ge {
namespace model_runner { namespace model_runner {
using RuntimeInfo = std::tuple<uint32_t, uint32_t, void *>;
class Task; class Task;
class RuntimeModel { class RuntimeModel {
public: public:
@@ -35,8 +35,10 @@ class RuntimeModel {
~RuntimeModel(); ~RuntimeModel();


bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model); bool Load(uint32_t device_id, uint64_t session_id, std::shared_ptr<DavinciModel> &davinci_model);
bool LoadComplete();
const std::vector<uint32_t> &GetTaskIdList() const; const std::vector<uint32_t> &GetTaskIdList() const;
const std::vector<uint32_t> &GetStreamIdList() const; const std::vector<uint32_t> &GetStreamIdList() const;
const std::map<std::string, std::shared_ptr<RuntimeInfo>> &GetRuntimeInfoMap() const { return runtime_info_map_; }
bool Run(); bool Run();
bool CopyInputData(const InputData &input_data); bool CopyInputData(const InputData &input_data);
bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc, bool GetInputOutputDescInfo(bool zero_copy, std::vector<InputOutputDescInfo> *input_desc,
@@ -79,6 +81,7 @@ class RuntimeModel {


std::vector<uint32_t> task_id_list_{}; std::vector<uint32_t> task_id_list_{};
std::vector<uint32_t> stream_id_list_{}; std::vector<uint32_t> stream_id_list_{};
std::map<std::string, std::shared_ptr<RuntimeInfo>> runtime_info_map_;
}; };


} // namespace model_runner } // namespace model_runner


+ 9
- 5
src/ge/ge_runtime/task/aicpu_task.cc View File

@@ -85,11 +85,15 @@ bool AicpuTask::Distribute() {
return false; return false;
} }


GELOGI("Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s.", args_size,
io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data());
rt_ret = rtCpuKernelLaunch(reinterpret_cast<const void *>(task_info_->so_name().data()),
reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_, args_size,
nullptr, stream_);
input_output_addr_ = reinterpret_cast<void *>(reinterpret_cast<uint8_t *>(args_) + io_addr_offset);

auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
GELOGI(
"Distribute AicpuTask start, args_size = %u, io_addrs_num = %u, so_name = %s, kernel_name = %s, dump_flag = %d.",
args_size, io_addrs_num, task_info_->so_name().data(), task_info_->kernel_name().data(), dump_flag);
rt_ret = rtCpuKernelLaunchWithFlag(reinterpret_cast<const void *>(task_info_->so_name().data()),
reinterpret_cast<const void *>(task_info_->kernel_name().data()), 1, args_,
args_size, nullptr, stream_, dump_flag);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret); GELOGE(RT_FAILED, "Call rt api failed, ret: 0x%X", rt_ret);
return false; return false;


+ 6
- 0
src/ge/ge_runtime/task/aicpu_task.h View File

@@ -18,6 +18,7 @@
#define GE_GE_RUNTIME_TASK_AICPU_TASK_H_ #define GE_GE_RUNTIME_TASK_AICPU_TASK_H_


#include <memory> #include <memory>
#include <string>
#include "ge_runtime/task/task.h" #include "ge_runtime/task/task.h"


namespace ge { namespace ge {
@@ -30,12 +31,17 @@ class AicpuTask : public TaskRepeater<AicpuTaskInfo> {


bool Distribute() override; bool Distribute() override;


void *Args() override { return input_output_addr_; }

std::string task_name() const override { return task_info_->op_name(); }

private: private:
static void ReleaseRtMem(void **ptr) noexcept; static void ReleaseRtMem(void **ptr) noexcept;


std::shared_ptr<AicpuTaskInfo> task_info_; std::shared_ptr<AicpuTaskInfo> task_info_;
void *stream_; void *stream_;
void *args_; void *args_;
void *input_output_addr_;
}; };
} // namespace model_runner } // namespace model_runner
} // namespace ge } // namespace ge


+ 6
- 0
src/ge/ge_runtime/task/task.h View File

@@ -18,7 +18,9 @@
#define GE_GE_RUNTIME_TASK_TASK_H_ #define GE_GE_RUNTIME_TASK_TASK_H_


#include <memory> #include <memory>
#include <utility>
#include <vector> #include <vector>
#include <string>
#include "runtime/rt_model.h" #include "runtime/rt_model.h"
#include "ge_runtime/model_context.h" #include "ge_runtime/model_context.h"
#include "ge_runtime/task_info.h" #include "ge_runtime/task_info.h"
@@ -32,6 +34,10 @@ class Task {
virtual ~Task() {} virtual ~Task() {}


virtual bool Distribute() = 0; virtual bool Distribute() = 0;

virtual void *Args() { return nullptr; }

virtual std::string task_name() const { return ""; }
}; };


template <class T> template <class T>


+ 3
- 4
src/ge/ge_runtime/task/tbe_task.cc View File

@@ -95,15 +95,14 @@ bool TbeTask::Distribute() {
return false; return false;
} }


GELOGI("InitTbeTask end.");
GELOGI("DistributeTbeTask start."); GELOGI("DistributeTbeTask start.");
rt_ret = rtKernelLaunch(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_);
auto dump_flag = task_info_->dump_flag() ? RT_KERNEL_DUMPFLAG : RT_KERNEL_DEFAULT;
rt_ret = rtKernelLaunchWithFlag(stub_func_, task_info_->block_dim(), args_, args_size, nullptr, stream_, dump_flag);
if (rt_ret != RT_ERROR_NONE) { if (rt_ret != RT_ERROR_NONE) {
GELOGE(RT_FAILED, "Call rt api rtKernelLaunch failed, ret: 0x%X", rt_ret); GELOGE(RT_FAILED, "Call rt api rtKernelLaunch failed, ret: 0x%X", rt_ret);
return false; return false;
} }

GELOGI("DistributeTbeTask end.");
GELOGI("[DataDump] task name:%s, dump_flag:%d", task_info_->op_name().c_str(), dump_flag);
return true; return true;
} }




+ 4
- 0
src/ge/ge_runtime/task/tbe_task.h View File

@@ -30,6 +30,10 @@ class TbeTask : public TaskRepeater<TbeTaskInfo> {


bool Distribute() override; bool Distribute() override;


void *Args() override { return args_; }

std::string task_name() const override { return task_info_->op_name(); }

private: private:
std::shared_ptr<TbeTaskInfo> task_info_; std::shared_ptr<TbeTaskInfo> task_info_;
void *stream_; void *stream_;


Loading…
Cancel
Save