|
@@ -17,6 +17,7 @@ |
|
|
#include "rts_node_executor.h" |
|
|
#include "rts_node_executor.h" |
|
|
#include "common/debug/log.h" |
|
|
#include "common/debug/log.h" |
|
|
#include "common/ge/ge_util.h" |
|
|
#include "common/ge/ge_util.h" |
|
|
|
|
|
#include "common/types.h" |
|
|
#include "graph/utils/tensor_utils.h" |
|
|
#include "graph/utils/tensor_utils.h" |
|
|
#include "hybrid/model/hybrid_model.h" |
|
|
#include "hybrid/model/hybrid_model.h" |
|
|
#include "runtime/rt.h" |
|
|
#include "runtime/rt.h" |
|
@@ -50,6 +51,20 @@ Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { |
|
|
return SUCCESS; |
|
|
return SUCCESS; |
|
|
} |
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
Status ReadVariableOpNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { |
|
|
|
|
|
GELOGD("[%s] Start to execute.", context.GetNodeName()); |
|
|
|
|
|
for (int i = 0; i < context.NumInputs(); ++i) { |
|
|
|
|
|
GE_CHK_STATUS_RET(DoCopyTensor(context, i)); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
if (done_callback) { |
|
|
|
|
|
GE_CHK_STATUS_RET(context.RegisterCallback(done_callback)); |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
|
|
|
GELOGD("[%s] Done executing successfully.", context.GetNodeName()); |
|
|
|
|
|
return SUCCESS; |
|
|
|
|
|
} |
|
|
|
|
|
|
|
|
Status IdentityNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { |
|
|
Status IdentityNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { |
|
|
GELOGD("[%s] Start to execute.", context.GetNodeName()); |
|
|
GELOGD("[%s] Start to execute.", context.GetNodeName()); |
|
|
GE_CHK_STATUS_RET(DoCopyTensor(context, 0)); |
|
|
GE_CHK_STATUS_RET(DoCopyTensor(context, 0)); |
|
@@ -111,6 +126,8 @@ Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, |
|
|
task = MakeShared<IdentityNodeTask>(); |
|
|
task = MakeShared<IdentityNodeTask>(); |
|
|
} else if (op_type == IDENTITYN) { |
|
|
} else if (op_type == IDENTITYN) { |
|
|
task = MakeShared<IdentityNNodeTask>(); |
|
|
task = MakeShared<IdentityNNodeTask>(); |
|
|
|
|
|
} else if (op_type == READVARIABLEOP) { |
|
|
|
|
|
task = MakeShared<ReadVariableOpNodeTask>(); |
|
|
} else if (op_type == PROFILINGTRAININGTRACE) { |
|
|
} else if (op_type == PROFILINGTRAININGTRACE) { |
|
|
auto *task_defs = model.GetTaskDefs(node); |
|
|
auto *task_defs = model.GetTaskDefs(node); |
|
|
if (task_defs == nullptr || task_defs->empty()) { |
|
|
if (task_defs == nullptr || task_defs->empty()) { |
|
|