Browse Source

Replace MemcpyAsyncNodeTask

tags/v1.5.1
zhangxiaokun 3 years ago
parent
commit
4fd937eb05
3 changed files with 1 additions and 34 deletions
  1. +1
    -0
      ge/hybrid/node_executor/rts/rts_node_executor.cc
  2. +0
    -29
      ge/hybrid/node_executor/rts/rts_node_task.cc
  3. +0
    -5
      ge/hybrid/node_executor/rts/rts_node_task.h

+ 1
- 0
ge/hybrid/node_executor/rts/rts_node_executor.cc View File

@@ -29,6 +29,7 @@ REGISTER_RTS_TASK_CREATOR(IDENTITY, IdentityNodeTask);
REGISTER_RTS_TASK_CREATOR(IDENTITYN, IdentityNNodeTask);
REGISTER_RTS_TASK_CREATOR(READVARIABLEOP, ReadVariableOpNodeTask);
REGISTER_RTS_TASK_CREATOR(PROFILINGTRAININGTRACE, ProfilingTraceNodeTask);
REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, IdentityNodeTask);

Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) {
auto input_desc = context.MutableInputDesc(index);


+ 0
- 29
ge/hybrid/node_executor/rts/rts_node_task.cc View File

@@ -43,7 +43,6 @@ namespace hybrid {
REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask);
REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask);
REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask);
REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, MemcpyAsyncNodeTask);

REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask);
REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask);
@@ -168,34 +167,6 @@ Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::functio
return SUCCESS;
}

Status MemcpyAsyncNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", task_context.GetNodeName());
auto input_desc = task_context.MutableInputDesc(0);
GE_CHECK_NOTNULL(input_desc);
int64_t copy_size = 0;
GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*input_desc, copy_size));
// copy_size would not be negative since GetTensorSizeInBytes returned successfully.
if (copy_size > 0) {
const auto in_v = task_context.MutableInput(0);
const auto out_v = task_context.MutableOutput(0);
GE_CHECK_NOTNULL(in_v);
GE_CHECK_NOTNULL(out_v);
GELOGD("[%s] input size: %zu, output size: %zu, copy size: %ld", task_context.GetNodeName(),
in_v->GetSize(), out_v->GetSize(), copy_size);
GE_CHK_RT_RET(rtMemcpyAsync(out_v->MutableData(), out_v->GetSize(), in_v->GetData(), copy_size,
RT_MEMCPY_DEVICE_TO_DEVICE, task_context.GetStream()));
} else {
GELOGW("[%s] invalid copy size: %ld", task_context.GetNodeName(), copy_size);
}

if (done_callback) {
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback));
}

GELOGD("[%s] Done executing successfully.", task_context.GetNodeName());
return SUCCESS;
}

Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", task_context.GetNodeName());
const auto in_x = task_context.GetInput(0); // x


+ 0
- 5
ge/hybrid/node_executor/rts/rts_node_task.h View File

@@ -60,11 +60,6 @@ class StreamMergeNodeTask : public RtsNodeTask {
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override;
};

class MemcpyAsyncNodeTask : public RtsNodeTask {
public:
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override;
};

class PassThroughNodeTask : public RtsNodeTask {
public:
Status ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) override;


Loading…
Cancel
Save