Browse Source

!1655 fix add tbe kernel failed

From: @lichun30
Reviewed-by: @sheng-nan,@xchu42
Signed-off-by: @ji_chen
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
e1f1728f03
2 changed files with 36 additions and 2 deletions
  1. +3
    -2
      ge/hybrid/node_executor/aicore/aicore_op_task.cc
  2. +33
    -0
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 3
- 2
ge/hybrid/node_executor/aicore/aicore_op_task.cc View File

@@ -23,6 +23,7 @@
#include "graph/load/model_manager/tbe_handle_store.h"
#include "graph/types.h"
#include "single_op/task/build_task_utils.h"
#include "single_op/task/tbe_task_builder.h"

using optiling::OpRunInfo;

@@ -131,8 +132,8 @@ Status AiCoreOpTask::RegisterTbeHandle(const OpDesc &op_desc) {
GE_IF_BOOL_EXEC(AttrUtils::GetStr(op_desc_ptr, GetKeyForKernelName(op_desc), kernel_name),
GELOGI("Get original type of kernel_name"));
GELOGI("TBE: binfile_key=%s, kernel_name=%s", stub_name_.c_str(), kernel_name.c_str());
GE_CHK_RT_RET(rtFunctionRegister(bin_handle, stub_name_.c_str(),
stub_name_.c_str(), kernel_name.c_str(), 0));
auto stub_func = KernelBinRegistry::GetInstance().GetUnique(stub_name_);
GE_CHK_RT_RET(rtFunctionRegister(bin_handle, stub_func, stub_name_.c_str(), kernel_name.c_str(), 0));
}
return SUCCESS;
}


+ 33
- 0
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -109,6 +109,39 @@ TEST_F(UtestGeHybrid, aicore_op_task_init_success) {
ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS);
}

TEST_F(UtestGeHybrid, aicore_op_task_init_success2) {
// build aicore task
auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask());
aicore_task->is_single_op_ = true;
domi::TaskDef task_def;
task_def.set_type(RT_MODEL_TASK_KERNEL);
domi::KernelDef *kernel = task_def.mutable_kernel();
kernel->set_block_dim(32);
kernel->set_args_size(64);
string args(64, '1');
kernel->set_args(args.data(), 64);
domi::KernelContext *context = kernel->mutable_context();
context->set_op_index(1);
context->set_kernel_type(2); // ccKernelType::TE
uint16_t args_offset[9] = {0};
context->set_args_offset(args_offset, 9 * sizeof(uint16_t));

OpDescPtr op_desc = CreateOpDesc("Add", "Add");
std::vector<char> kernelBin;
TBEKernelPtr tbe_kernel = std::make_shared<ge::OpKernelBin>("name/Add", std::move(kernelBin));
op_desc->SetExtAttr(ge::OP_EXTATTR_NAME_TBE_KERNEL, tbe_kernel);
std::string kernel_name("kernel/Add");
AttrUtils::SetStr(op_desc, op_desc->GetName() + "_kernelname", kernel_name);
ASSERT_EQ(aicore_task->InitWithTaskDef(*op_desc.get(), task_def), SUCCESS);
rtStream_t stream = nullptr;
rtStreamCreate(&stream, 0);
ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS);
char *handle = "";
aicore_task->handle_ = handle;
aicore_task->tiling_key_ = 1;
ASSERT_EQ(aicore_task->LaunchKernel(stream), SUCCESS);
}

TEST_F(UtestGeHybrid, task_update_tiling_info) {
auto aicore_task = std::unique_ptr<hybrid::AiCoreOpTask>(new(std::nothrow)hybrid::AiCoreOpTask());
auto graph = make_shared<ComputeGraph>("graph");


Loading…
Cancel
Save