|
@@ -43,7 +43,6 @@ namespace hybrid { |
|
|
REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask); |
|
|
REGISTER_RTS_TASK_CREATOR(STREAMACTIVE, StreamActiveNodeTask); |
|
|
REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask); |
|
|
REGISTER_RTS_TASK_CREATOR(STREAMSWITCH, StreamSwitchNodeTask); |
|
|
REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask); |
|
|
REGISTER_RTS_TASK_CREATOR(STREAMMERGE, StreamMergeNodeTask); |
|
|
REGISTER_RTS_TASK_CREATOR(MEMCPYASYNC, MemcpyAsyncNodeTask); |
|
|
|
|
|
|
|
|
|
|
|
REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask); |
|
|
REGISTER_RTS_TASK_CREATOR(ENTER, PassThroughNodeTask); |
|
|
REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask); |
|
|
REGISTER_RTS_TASK_CREATOR(REFENTER, PassThroughNodeTask); |
|
@@ -168,34 +167,6 @@ Status StreamMergeNodeTask::ExecuteAsync(TaskContext &task_context, std::functio |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
Status MemcpyAsyncNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { |
|
|
|
|
|
GELOGD("[%s] Start to execute.", task_context.GetNodeName()); |
|
|
|
|
|
auto input_desc = task_context.MutableInputDesc(0); |
|
|
|
|
|
GE_CHECK_NOTNULL(input_desc); |
|
|
|
|
|
int64_t copy_size = 0; |
|
|
|
|
|
GE_CHK_GRAPH_STATUS_RET(TensorUtils::GetTensorSizeInBytes(*input_desc, copy_size)); |
|
|
|
|
|
// copy_size would not be negative since GetTensorSizeInBytes returned successfully. |
|
|
|
|
|
if (copy_size > 0) { |
|
|
|
|
|
const auto in_v = task_context.MutableInput(0); |
|
|
|
|
|
const auto out_v = task_context.MutableOutput(0); |
|
|
|
|
|
GE_CHECK_NOTNULL(in_v); |
|
|
|
|
|
GE_CHECK_NOTNULL(out_v); |
|
|
|
|
|
GELOGD("[%s] input size: %zu, output size: %zu, copy size: %ld", task_context.GetNodeName(), |
|
|
|
|
|
in_v->GetSize(), out_v->GetSize(), copy_size); |
|
|
|
|
|
GE_CHK_RT_RET(rtMemcpyAsync(out_v->MutableData(), out_v->GetSize(), in_v->GetData(), copy_size, |
|
|
|
|
|
RT_MEMCPY_DEVICE_TO_DEVICE, task_context.GetStream())); |
|
|
|
|
|
} else { |
|
|
|
|
|
GELOGW("[%s] invalid copy size: %ld", task_context.GetNodeName(), copy_size); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (done_callback) { |
|
|
|
|
|
GE_CHK_STATUS_RET(task_context.RegisterCallback(done_callback)); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
GELOGD("[%s] Done executing successfully.", task_context.GetNodeName()); |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { |
|
|
Status PassThroughNodeTask::ExecuteAsync(TaskContext &task_context, std::function<void()> done_callback) { |
|
|
GELOGD("[%s] Start to execute.", task_context.GetNodeName()); |
|
|
GELOGD("[%s] Start to execute.", task_context.GetNodeName()); |
|
|
const auto in_x = task_context.GetInput(0); // x |
|
|
const auto in_x = task_context.GetInput(0); // x |
|
|