From ff6f8d8e46bb67b643f729a8fe4b4a7096561003 Mon Sep 17 00:00:00 2001 From: chuxing Date: Mon, 9 Nov 2020 10:09:49 +0800 Subject: [PATCH] skip aicore operators whose output tensors are all empty --- ge/hybrid/executor/worker/execution_engine.cc | 12 ++++++----- .../aicore/aicore_node_executor.cc | 20 +++++++++++++++++++ .../aicore/aicore_node_executor.h | 1 + 3 files changed, 28 insertions(+), 5 deletions(-) diff --git a/ge/hybrid/executor/worker/execution_engine.cc b/ge/hybrid/executor/worker/execution_engine.cc index 7dc65433..fad899c6 100755 --- a/ge/hybrid/executor/worker/execution_engine.cc +++ b/ge/hybrid/executor/worker/execution_engine.cc @@ -120,11 +120,13 @@ Status NodeDoneCallback::PrepareConstInputs(const NodeItem &node_item) { node_item.NodeName().c_str(), output_idx, output_tensor->GetSize()); - GE_CHK_RT_RET(rtMemcpy(host_buffer.data(), - tensor_size, - output_tensor->GetData(), - tensor_size, - RT_MEMCPY_DEVICE_TO_HOST)); + if (tensor_size > 0) { + GE_CHK_RT_RET(rtMemcpy(host_buffer.data(), + tensor_size, + output_tensor->GetData(), + tensor_size, + RT_MEMCPY_DEVICE_TO_HOST)); + } tensor.SetData(std::move(host_buffer)); string session_id = std::to_string(context_->GetSessionId()); RuntimeInferenceContext *runtime_infer_ctx = nullptr; diff --git a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc index 09c516fb..4c32f131 100755 --- a/ge/hybrid/node_executor/aicore/aicore_node_executor.cc +++ b/ge/hybrid/node_executor/aicore/aicore_node_executor.cc @@ -156,6 +156,13 @@ Status AiCoreNodeExecutor::CompileTask(const HybridModel &model, Status AiCoreNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeTaskExecuteAsync] Start"); + if (IsNoOp(context)) { + GELOGD("[%s] Skipping execution for op with empty outputs", context.GetNodeName()); + auto ret = context.TryExecuteCallback(done_callback); + RECORD_EXECUTION_EVENT(context.GetExecutionContext(), context.GetNodeName(), "[AiCoreNodeTaskExecuteAsync] End"); + return ret; + } + auto op_desc = context.GetNodeItem().op_desc; GE_CHECK_NOTNULL(op_desc); GELOGI("[%s] ExecuteAsync Start.", op_desc->GetName().c_str()); @@ -219,5 +226,18 @@ bool AiCoreNodeTask::IsSupportDynamicShape() { return true; } + +bool AiCoreNodeTask::IsNoOp(TaskContext &task_context) { + for (int i = 0; i < task_context.NumOutputs(); ++i) { + const auto &tensor_desc = task_context.MutableOutputDesc(i); + GE_CHECK_NOTNULL(tensor_desc); + const auto &shape = tensor_desc->MutableShape(); + if (shape.IsScalar() || shape.GetShapeSize() > 0) { + return false; + } + } + + return true; +} } // namespace hybrid } // namespace ge diff --git a/ge/hybrid/node_executor/aicore/aicore_node_executor.h b/ge/hybrid/node_executor/aicore/aicore_node_executor.h index b4afc34c..374782dc 100755 --- a/ge/hybrid/node_executor/aicore/aicore_node_executor.h +++ b/ge/hybrid/node_executor/aicore/aicore_node_executor.h @@ -52,6 +52,7 @@ class AiCoreNodeTask : public NodeTask { Status UpdateArgs(TaskContext &context) override; Status ExecuteAsync(TaskContext &context, std::function done_callback) override; private: + static bool IsNoOp(TaskContext &task_context); std::vector> tasks_; };