Browse Source

fix question does not support readvariable op

tags/v1.2.0
wxl 3 years ago
parent
commit
019e39e4fc
3 changed files with 23 additions and 1 deletions
  1. +1
    -1
      ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc
  2. +17
    -0
      ge/hybrid/node_executor/rts/rts_node_executor.cc
  3. +5
    -0
      ge/hybrid/node_executor/rts/rts_node_executor.h

+ 1
- 1
ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc View File

@@ -36,7 +36,7 @@ const std::map<std::string, std::vector<uint32_t>>
{BROADCASTGRADIENTARGS, {}} {BROADCASTGRADIENTARGS, {}}
}; };


const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE};
const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE, NOOP};


Status RefInputTask::UpdateArgs(TaskContext &) { Status RefInputTask::UpdateArgs(TaskContext &) {
// no need update args // no need update args


+ 17
- 0
ge/hybrid/node_executor/rts/rts_node_executor.cc View File

@@ -17,6 +17,7 @@
#include "rts_node_executor.h" #include "rts_node_executor.h"
#include "common/debug/log.h" #include "common/debug/log.h"
#include "common/ge/ge_util.h" #include "common/ge/ge_util.h"
#include "common/types.h"
#include "graph/utils/tensor_utils.h" #include "graph/utils/tensor_utils.h"
#include "hybrid/model/hybrid_model.h" #include "hybrid/model/hybrid_model.h"
#include "runtime/rt.h" #include "runtime/rt.h"
@@ -50,6 +51,20 @@ Status IdentityNodeTask::DoCopyTensor(TaskContext &context, int index) {
return SUCCESS; return SUCCESS;
} }


Status ReadVariableOpNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", context.GetNodeName());
for (int i = 0; i < context.NumInputs(); ++i) {
GE_CHK_STATUS_RET(DoCopyTensor(context, i));
}

if (done_callback) {
GE_CHK_STATUS_RET(context.RegisterCallback(done_callback));
}

GELOGD("[%s] Done executing successfully.", context.GetNodeName());
return SUCCESS;
}

Status IdentityNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) { Status IdentityNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> done_callback) {
GELOGD("[%s] Start to execute.", context.GetNodeName()); GELOGD("[%s] Start to execute.", context.GetNodeName());
GE_CHK_STATUS_RET(DoCopyTensor(context, 0)); GE_CHK_STATUS_RET(DoCopyTensor(context, 0));
@@ -111,6 +126,8 @@ Status RtsNodeExecutor::LoadTask(const HybridModel &model, const NodePtr &node,
task = MakeShared<IdentityNodeTask>(); task = MakeShared<IdentityNodeTask>();
} else if (op_type == IDENTITYN) { } else if (op_type == IDENTITYN) {
task = MakeShared<IdentityNNodeTask>(); task = MakeShared<IdentityNNodeTask>();
} else if (op_type == READVARIABLEOP) {
task = MakeShared<ReadVariableOpNodeTask>();
} else if (op_type == PROFILINGTRAININGTRACE) { } else if (op_type == PROFILINGTRAININGTRACE) {
auto *task_defs = model.GetTaskDefs(node); auto *task_defs = model.GetTaskDefs(node);
if (task_defs == nullptr || task_defs->empty()) { if (task_defs == nullptr || task_defs->empty()) {


+ 5
- 0
ge/hybrid/node_executor/rts/rts_node_executor.h View File

@@ -36,6 +36,11 @@ class IdentityNNodeTask : public IdentityNodeTask {
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
}; };


class ReadVariableOpNodeTask : public IdentityNodeTask {
public:
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
};

class ProfilingTraceNodeTask : public NodeTask { class ProfilingTraceNodeTask : public NodeTask {
public: public:
explicit ProfilingTraceNodeTask(const std::vector<domi::TaskDef> &task_defs) : task_defs_(task_defs) {} explicit ProfilingTraceNodeTask(const std::vector<domi::TaskDef> &task_defs) : task_defs_(task_defs) {}


Loading…
Cancel
Save