diff --git a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc index 8c5e413b..09c516fb 100755 --- a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc +++ b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc @@ -159,9 +159,13 @@ Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function auto op_desc = context.GetNodeItem().op_desc; GE_CHECK_NOTNULL(op_desc); GELOGI("[%s] ExecuteAsync Start.", op_desc->GetName().c_str()); - for (auto &task : tasks_) { + for (auto it = tasks_.begin(); it != tasks_.end(); ++it) { + // AtomicAddrClean has 2 tasks + if (tasks_.size() == 2 && it == tasks_.begin() && !(*(tasks_.rbegin()))->GetClearAtomic()) { + continue; + } RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] Start"); - GE_CHK_STATUS_RET_NOLOG(task->LaunchKernel(context.GetStream())); + GE_CHK_STATUS_RET_NOLOG((*it)->LaunchKernel(context.GetStream())); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeLaunchKernel] End"); } @@ -181,8 +185,12 @@ Status AiCoreNodeTask::UpdateArgs(TaskContext &context) { auto op_desc = context.GetNodeItem().op_desc; GE_CHECK_NOTNULL(op_desc); GELOGI("[%s] AiCoreNodeTask UpdateArgs Start.", op_desc->GetName().c_str()); - for (auto &task : tasks_) { - GE_CHK_STATUS_RET_NOLOG(task->UpdateArgs(context)); + for (auto it = tasks_.rbegin(); it != tasks_.rend(); ++it) { + GE_CHK_STATUS_RET_NOLOG((*it)->UpdateArgs(context)); + // AtomicAddrClean has 2 tasks + if (tasks_.size() == 2 && it == tasks_.rbegin() && !(*it)->GetClearAtomic()) { + break; + } } GELOGI("[%s] AiCoreNodeTask UpdateArgs End.", op_desc->GetName().c_str()); return SUCCESS; @@ -190,8 +198,12 @@ Status AiCoreNodeTask::UpdateArgs(TaskContext &context) { Status AiCoreNodeTask::UpdateTilingData(TaskContext &context) { GELOGD("[%s] PrepareWithShape started", context.GetNodeName()); - for (auto &task : tasks_) { - GE_CHK_STATUS_RET_NOLOG(task->PrepareWithShape(context)); + for (auto it = tasks_.rbegin(); it != tasks_.rend(); ++it) { + GE_CHK_STATUS_RET_NOLOG((*it)->PrepareWithShape(context)); + // AtomicAddrClean has 2 tasks + if (tasks_.size() == 2 && it == tasks_.rbegin() && !(*it)->GetClearAtomic()) { + break; + } } GELOGD("[%s] Done PrepareWithShape successfully.", context.GetNodeName()); return SUCCESS; diff --git a/ge/hybrid/node_executor/aicore/aicore_op_task.cc b/ge/hybrid/node_executor/aicore/aicore_op_task.cc index 7f69acd4..fd6387e6 100644 --- a/ge/hybrid/node_executor/aicore/aicore_op_task.cc +++ b/ge/hybrid/node_executor/aicore/aicore_op_task.cc @@ -121,6 +121,7 @@ Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) { GELOGD("[%s] Start to update tiling info for task: [%s]", node->GetName().c_str(), stub_name_.c_str()); OpRunInfo tiling_info; tiling_info.block_dim = -1; // codex: Using uninitialized value + tiling_info.clear_atomic = true; auto execution_context = context.GetExecutionContext(); RECORD_EXECUTION_EVENT(execution_context, context.GetNodeName(), "[CalcTilingInfo] Start"); @@ -130,6 +131,7 @@ Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) { // update op args by tiling info block_dim_ = static_cast(tiling_info.block_dim); op_desc->SetWorkspaceBytes(tiling_info.workspaces); + clear_atomic_ = tiling_info.clear_atomic; tiling_data_ = tiling_info.tiling_data.str(); if (tiling_data_.empty()) { diff --git a/ge/hybrid/node_executor/aicore/aicore_op_task.h b/ge/hybrid/node_executor/aicore/aicore_op_task.h index eaa821e3..0447ade7 100755 --- a/ge/hybrid/node_executor/aicore/aicore_op_task.h +++ b/ge/hybrid/node_executor/aicore/aicore_op_task.h @@ -46,6 +46,8 @@ class AiCoreOpTask { const std::string& GetName() const; + bool GetClearAtomic() const {return clear_atomic_;} + protected: Status UpdateTilingInfo(TaskContext &context); virtual std::string GetKeyForOpParamSize() const; @@ -66,6 +68,7 @@ class AiCoreOpTask { std::unique_ptr args_ = nullptr; uint32_t args_size_ = 0; uint32_t block_dim_ = 1; + bool clear_atomic_ = true; }; class AtomicAddrCleanOpTask : public AiCoreOpTask {