From 019e39e4fc0aa03ea714bb2ce2cc4cace8d670fb Mon Sep 17 00:00:00 2001 From: wxl Date: Mon, 8 Feb 2021 10:58:08 +0800 Subject: [PATCH] fix question does not support readvariable op --- .../ge_local/ge_local_node_executor.cc | 2 +- .../node_executor/rts/rts_node_executor.cc | 17 +++++++++++++++++ ge/hybrid/node_executor/rts/rts_node_executor.h | 5 +++++ 3 files changed, 23 insertions(+), 1 deletion(-) diff --git a/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc b/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc index d7d0f547..3d2e3084 100755 --- a/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc +++ b/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc @@ -36,7 +36,7 @@ const std::map> {BROADCASTGRADIENTARGS, {}} }; -const std::set DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE}; +const std::set DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE, NOOP}; Status RefInputTask::UpdateArgs(TaskContext &) { // no need update args diff --git a/ge/hybrid/node_executor/rts/rts_node_executor.cc b/ge/hybrid/node_executor/rts/rts_node_executor.cc index 90b623e0..aa833de0 100644 --- a/ge/hybrid/node_executor/rts/rts_node_executor.cc +++ b/ge/hybrid/node_executor/rts/rts_node_executor.cc @@ -17,6 +17,7 @@ #include "rts_node_executor.h" #include "common/debug/log.h" #include "common/ge/ge_util.h" +#include "common/types.h" #include "graph/utils/tensor_utils.h" #include "hybrid/model/hybrid_model.h" #include "runtime/rt.h" @@ -50,6 +51,20 @@ Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) { return SUCCESS; } +Status ReadVariableOpNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { + GELOGD("[%s] Start to execute.", context.GetNodeName()); + for (int i = 0; i < context.NumInputs(); ++i) { + GE_CHK_STATUS_RET(DoCopyTensor(context, i)); + } + + if (done_callback) { + GE_CHK_STATUS_RET(context.RegisterCallback(done_callback)); + } + + GELOGD("[%s] Done executing successfully.", context.GetNodeName()); + return SUCCESS; +} + Status IdentityNodeTask::ExecuteAsync(TaskContext &context, std::function done_callback) { GELOGD("[%s] Start to execute.", context.GetNodeName()); GE_CHK_STATUS_RET(DoCopyTensor(context, 0)); @@ -111,6 +126,8 @@ Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node, task = MakeShared(); } else if (op_type == IDENTITYN) { task = MakeShared(); + } else if (op_type == READVARIABLEOP) { + task = MakeShared(); } else if (op_type == PROFILINGTRAININGTRACE) { auto *task_defs = model.GetTaskDefs(node); if (task_defs == nullptr || task_defs->empty()) { diff --git a/ge/hybrid/node_executor/rts/rts_node_executor.h b/ge/hybrid/node_executor/rts/rts_node_executor.h index df487d6c..aecf138b 100644 --- a/ge/hybrid/node_executor/rts/rts_node_executor.h +++ b/ge/hybrid/node_executor/rts/rts_node_executor.h @@ -36,6 +36,11 @@ class IdentityNNodeTask : public IdentityNodeTask { Status ExecuteAsync(TaskContext &context, std::function done_callback) override; }; +class ReadVariableOpNodeTask : public IdentityNodeTask { + public: + Status ExecuteAsync(TaskContext &context, std::function done_callback) override; +}; + class ProfilingTraceNodeTask : public NodeTask { public: explicit ProfilingTraceNodeTask(const std::vector &task_defs) : task_defs_(task_defs) {}