diff --git a/ge/hybrid/node_executor/rts/rts_node_executor.cc b/ge/hybrid/node_executor/rts/rts_node_executor.cc index d52f56b9..e3058ee3 100644 --- a/ge/hybrid/node_executor/rts/rts_node_executor.cc +++ b/ge/hybrid/node_executor/rts/rts_node_executor.cc @@ -29,6 +29,7 @@ REGISTER_RTS_TASK_CREATOR(IDENTITY, IdentityNodeTask); REGISTER_RTS_TASK_CREATOR(IDENTITYN, IdentityNNodeTask); REGISTER_RTS_TASK_CREATOR(READVARIABLEOP, ReadVariableOpNodeTask); REGISTER_RTS_TASK_CREATOR(PROFILINGTRAININGTRACE, ProfilingTraceNodeTask); +REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, IdentityNodeTask); Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { auto input_desc = context.MutableInputDesc(index); diff --git a/ge/hybrid/node_executor/rts/rts_node_task.cc b/ge/hybrid/node_executor/rts/rts_node_task.cc index 9af54815..7b95f98a 100644 --- a/ge/hybrid/node_executor/rts/rts_node_task.cc +++ b/ge/hybrid/node_executor/rts/rts_node_task.cc @@ -43,7 +43,6 @@ namespace hybrid { REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask); REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask); REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask); -REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, MemcpyAsyncNodeTask); REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask); REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask); @@ -168,34 +167,6 @@ Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::functio return SUCCESS; } -Status MemcpyAsyncNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { - GELOGD("[%s] Start to execute.", task_context.GetNodeName()); - auto input_desc = task_context.MutableInputDesc(0); - GE_CHECK_NOTNULL(input_desc); - int64_t copy_size = 0; - GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*input_desc, copy_size)); - // copy_size would not be negative since GetTensorSizeInBytes returned successfully. - if (copy_size > 0) { - const auto in_v = task_context.MutableInput(0); - const auto out_v = task_context.MutableOutput(0); - GE_CHECK_NOTNULL(in_v); - GE_CHECK_NOTNULL(out_v); - GELOGD("[%s] input size: %zu, output size: %zu, copy size: %ld", task_context.GetNodeName(), - in_v->GetSize(), out_v->GetSize(), copy_size); - GE_CHK_RT_RET(rtMemcpyAsync(out_v->MutableData(), out_v->GetSize(), in_v->GetData(), copy_size, - RT_MEMCPY_DEVICE_TO_DEVICE, task_context.GetStream())); - } else { - GELOGW("[%s] invalid copy size: %ld", task_context.GetNodeName(), copy_size); - } - - if (done_callback) { - GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); - } - - GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); - return SUCCESS; -} - Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function done_callback) { GELOGD("[%s] Start to execute.", task_context.GetNodeName()); const auto in_x = task_context.GetInput(0); // x diff --git a/ge/hybrid/node_executor/rts/rts_node_task.h b/ge/hybrid/node_executor/rts/rts_node_task.h index d7d63eb5..e18f9a8f 100644 --- a/ge/hybrid/node_executor/rts/rts_node_task.h +++ b/ge/hybrid/node_executor/rts/rts_node_task.h @@ -60,11 +60,6 @@ class StreamMergeNodeTask : public RtsNodeTask { Status ExecuteAsync(TaskContext &task_context, std::function done_callback) override; }; -class MemcpyAsyncNodeTask : public RtsNodeTask { - public: - Status ExecuteAsync(TaskContext &task_context, std::function done_callback) override; -}; - class PassThroughNodeTask : public RtsNodeTask { public: Status ExecuteAsync(TaskContext &task_context, std::function done_callback) override;