From ec7bb516652e0d3f631bbf48586f6f0e6168a507 Mon Sep 17 00:00:00 2001 From: zhaozhixuan Date: Thu, 6 May 2021 20:19:07 +0800 Subject: [PATCH] MemcpyAsync in aicore executor. --- .../node_executor/aicore/aicore_op_task.cc | 13 ++++++------- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 18 +++++++++++++++--- 2 files changed, 21 insertions(+), 10 deletions(-) diff --git a/ge/hybrid/node_executor/aicore/aicore_op_task.cc b/ge/hybrid/node_executor/aicore/aicore_op_task.cc index 8bb871fb..36f65bbe 100644 --- a/ge/hybrid/node_executor/aicore/aicore_op_task.cc +++ b/ge/hybrid/node_executor/aicore/aicore_op_task.cc @@ -354,8 +354,6 @@ Status AiCoreOpTask::PrepareWithShape(TaskContext &context) { Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) { auto node = context.GetNodeItem().node; GE_CHECK_NOTNULL(node); - auto op_desc = node->GetOpDesc(); - GE_CHECK_NOTNULL(op_desc); GELOGD("[%s] Start to update tiling info for task: [%s]", node->GetName().c_str(), stub_name_.c_str()); OpRunInfo tiling_info; @@ -370,12 +368,14 @@ Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) { // update op args by tiling info block_dim_ = static_cast(tiling_info.block_dim); - op_desc->SetWorkspaceBytes(tiling_info.workspaces); clear_atomic_ = tiling_info.clear_atomic; - tiling_data_ = tiling_info.tiling_data.str(); tiling_key_ = tiling_info.tiling_key; GELOGD("Successfully getting [tiling_key] : %u", tiling_key_); + + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + op_desc->SetWorkspaceBytes(tiling_info.workspaces); if (tiling_data_.empty()) { GELOGD("[%s] Tiling data is empty.", op_desc->GetName().c_str()); return SUCCESS; @@ -401,9 +401,8 @@ Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) { } RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CopyTilingInfo] Start"); - GE_CHK_RT_RET(rtMemcpy(tiling_buffer_->GetData(), tiling_buffer_->GetSize(), - tiling_data_.c_str(), tiling_data_.size(), - RT_MEMCPY_HOST_TO_DEVICE)); + GE_CHK_RT_RET(rtMemcpyAsync(tiling_buffer_->GetData(), tiling_buffer_->GetSize(), tiling_data_.c_str(), + tiling_data_.size(), RT_MEMCPY_HOST_TO_DEVICE_EX, context.GetStream())); RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CopyTilingInfo] End"); GELOGD("[%s] Done updating tiling info for task: [%s]", node->GetName().c_str(), stub_name_.c_str()); diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index b5aac527..4eae475d 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -111,14 +111,26 @@ TEST_F(UtestGeHybrid, aicore_op_task_init_success) { TEST_F(UtestGeHybrid, task_update_tiling_info) { auto aicore_task = std::unique_ptr(new(std::nothrow)hybrid::AiCoreOpTask()); - aicore_task->is_single_op_ = true; auto graph = make_shared("graph"); OpDescPtr op_desc = CreateOpDesc("Add", "Add"); ge::AttrUtils::SetStr(op_desc, "compile_info_key", "key"); ge::AttrUtils::SetStr(op_desc, "compile_info_json", "json"); + ge::AttrUtils::SetBool(op_desc, "support_dynamicshape", true); + ge::AttrUtils::SetInt(op_desc, "op_para_size", 1); auto node = graph->AddNode(op_desc); - optiling::OpRunInfo tiling_info; - ASSERT_EQ(aicore_task->CalcTilingInfo(node, tiling_info), SUCCESS); + + std::unique_ptr node_item; + NodeItem::Create(node, node_item); + node_item->input_start = 0; + node_item->output_start = 0; + + GraphExecutionContext execution_context; + SubgraphContext subgraph_context(nullptr, &execution_context); + NodeState node_state(*node_item, &subgraph_context); + auto task_context = TaskContext::Create(&node_state, &execution_context, &subgraph_context); + ASSERT_TRUE(task_context != nullptr); + ASSERT_EQ(aicore_task->InitTilingInfo(*op_desc), SUCCESS); + ASSERT_EQ(aicore_task->UpdateTilingInfo(*task_context), SUCCESS); } TEST_F(UtestGeHybrid, index_taskdefs_failed) {