Browse Source

!1720 Add ControlTrigger Support

From: @zhangxiaokun9
Reviewed-by: @xchu42,@wqtshg
Signed-off-by: @wqtshg
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
70dfc4d2d7
2 changed files with 20 additions and 2 deletions
  1. +14
    -2
      ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc
  2. +6
    -0
      ge/hybrid/node_executor/ge_local/ge_local_node_executor.h

+ 14
- 2
ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc View File

@@ -39,6 +39,10 @@ const std::map<std::string, std::vector<uint32_t>>


const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE}; const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE};


const std::set<std::string> ConstantNodeTask::constant_like_task_ops_ = {CONSTANT, CONSTANTOP, VARIABLE};

const std::set<std::string> NoOpNodeTask::control_only_task_ops_ = {NOOP, CONTROLTRIGGER};

Status RefInputTask::UpdateArgs(TaskContext &) { Status RefInputTask::UpdateArgs(TaskContext &) {
// no need update args // no need update args
return SUCCESS; return SUCCESS;
@@ -244,7 +248,7 @@ Status GeLocalNodeExecutor::LoadTask(const HybridModel &model,
node->GetName().c_str(), node_type.c_str()); node->GetName().c_str(), node_type.c_str());
return MEMALLOC_FAILED; return MEMALLOC_FAILED;
} }
} else if (node_type == CONSTANT || node_type == CONSTANTOP || node_type == VARIABLE) {
} else if (ConstantNodeTask::IsBelong(node_type)) {
GELOGI("node %s type %s, use ConstantNodeTask.", node->GetName().c_str(), node_type.c_str()); GELOGI("node %s type %s, use ConstantNodeTask.", node->GetName().c_str(), node_type.c_str());
auto tensor = model.GetTensor(node); auto tensor = model.GetTensor(node);
if (tensor == nullptr) { if (tensor == nullptr) {
@@ -254,7 +258,7 @@ Status GeLocalNodeExecutor::LoadTask(const HybridModel &model,
} }
task = MakeShared<ConstantNodeTask>(tensor); task = MakeShared<ConstantNodeTask>(tensor);
GE_CHECK_NOTNULL(task); GE_CHECK_NOTNULL(task);
} else if (node_type == NOOP) {
} else if (NoOpNodeTask::IsBelong(node_type)) {
GELOGI("node %s type %s , use NoOpNodeTask.", node->GetName().c_str(), node_type.c_str()); GELOGI("node %s type %s , use NoOpNodeTask.", node->GetName().c_str(), node_type.c_str());
task = MakeShared<NoOpNodeTask>(); task = MakeShared<NoOpNodeTask>();
if (task == nullptr) { if (task == nullptr) {
@@ -288,6 +292,10 @@ Status ConstantNodeTask::ExecuteAsync(TaskContext &context, std::function<void()
return SUCCESS; return SUCCESS;
} }


bool ConstantNodeTask::IsBelong(const std::string &op_type) {
return constant_like_task_ops_.count(op_type) > 0;
}

Status NoOpNodeTask::UpdateArgs(TaskContext &context) { Status NoOpNodeTask::UpdateArgs(TaskContext &context) {
// no need to update args // no need to update args
return SUCCESS; return SUCCESS;
@@ -299,5 +307,9 @@ Status NoOpNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do
GELOGD("[%s] Done execute successfully.", context.GetNodeName()); GELOGD("[%s] Done execute successfully.", context.GetNodeName());
return SUCCESS; return SUCCESS;
} }

bool NoOpNodeTask::IsBelong(const std::string &op_type) {
return control_only_task_ops_.count(op_type) > 0;
}
} // namespace hybrid } // namespace hybrid
} // namespace ge } // namespace ge

+ 6
- 0
ge/hybrid/node_executor/ge_local/ge_local_node_executor.h View File

@@ -75,8 +75,10 @@ class ConstantNodeTask : public NodeTask {
Status UpdateArgs(TaskContext &context) override; Status UpdateArgs(TaskContext &context) override;


Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
static bool IsBelong(const std::string &op_type);


private: private:
static const std::set<std::string> constant_like_task_ops_;
const TensorValue *tensor_; const TensorValue *tensor_;
}; };


@@ -86,6 +88,10 @@ class NoOpNodeTask : public NodeTask {
~NoOpNodeTask() = default; ~NoOpNodeTask() = default;
Status UpdateArgs(TaskContext &context) override; Status UpdateArgs(TaskContext &context) override;
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override; Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
static bool IsBelong(const std::string &op_type);

private:
static const std::set<std::string> control_only_task_ops_;
}; };


class GeLocalNodeExecutor : public NodeExecutor { class GeLocalNodeExecutor : public NodeExecutor {


Loading…
Cancel
Save