Browse Source

!1720 Add ControlTrigger Support

From: @zhangxiaokun9
Reviewed-by: @xchu42,@wqtshg
Signed-off-by: @wqtshg
tags/v1.3.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
70dfc4d2d7
2 changed files with 20 additions and 2 deletions
  1. +14
    -2
      ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc
  2. +6
    -0
      ge/hybrid/node_executor/ge_local/ge_local_node_executor.h

+ 14
- 2
ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc View File

@@ -39,6 +39,10 @@ const std::map<std::string, std::vector<uint32_t>>

const std::set<std::string> DependInputShapeTask::depend_input_shape_ops_ = {SHAPE, SHAPEN, RANK, SIZE};

const std::set<std::string> ConstantNodeTask::constant_like_task_ops_ = {CONSTANT, CONSTANTOP, VARIABLE};

const std::set<std::string> NoOpNodeTask::control_only_task_ops_ = {NOOP, CONTROLTRIGGER};

Status RefInputTask::UpdateArgs(TaskContext &) {
// no need update args
return SUCCESS;
@@ -244,7 +248,7 @@ Status GeLocalNodeExecutor::LoadTask(const HybridModel &model,
node->GetName().c_str(), node_type.c_str());
return MEMALLOC_FAILED;
}
} else if (node_type == CONSTANT || node_type == CONSTANTOP || node_type == VARIABLE) {
} else if (ConstantNodeTask::IsBelong(node_type)) {
GELOGI("node %s type %s, use ConstantNodeTask.", node->GetName().c_str(), node_type.c_str());
auto tensor = model.GetTensor(node);
if (tensor == nullptr) {
@@ -254,7 +258,7 @@ Status GeLocalNodeExecutor::LoadTask(const HybridModel &model,
}
task = MakeShared<ConstantNodeTask>(tensor);
GE_CHECK_NOTNULL(task);
} else if (node_type == NOOP) {
} else if (NoOpNodeTask::IsBelong(node_type)) {
GELOGI("node %s type %s , use NoOpNodeTask.", node->GetName().c_str(), node_type.c_str());
task = MakeShared<NoOpNodeTask>();
if (task == nullptr) {
@@ -288,6 +292,10 @@ Status ConstantNodeTask::ExecuteAsync(TaskContext &context, std::function<void()
return SUCCESS;
}

bool ConstantNodeTask::IsBelong(const std::string &op_type) {
return constant_like_task_ops_.count(op_type) > 0;
}

Status NoOpNodeTask::UpdateArgs(TaskContext &context) {
// no need to update args
return SUCCESS;
@@ -299,5 +307,9 @@ Status NoOpNodeTask::ExecuteAsync(TaskContext &context, std::function<void()> do
GELOGD("[%s] Done execute successfully.", context.GetNodeName());
return SUCCESS;
}

bool NoOpNodeTask::IsBelong(const std::string &op_type) {
return control_only_task_ops_.count(op_type) > 0;
}
} // namespace hybrid
} // namespace ge

+ 6
- 0
ge/hybrid/node_executor/ge_local/ge_local_node_executor.h View File

@@ -75,8 +75,10 @@ class ConstantNodeTask : public NodeTask {
Status UpdateArgs(TaskContext &context) override;

Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
static bool IsBelong(const std::string &op_type);

private:
static const std::set<std::string> constant_like_task_ops_;
const TensorValue *tensor_;
};

@@ -86,6 +88,10 @@ class NoOpNodeTask : public NodeTask {
~NoOpNodeTask() = default;
Status UpdateArgs(TaskContext &context) override;
Status ExecuteAsync(TaskContext &context, std::function<void()> done_callback) override;
static bool IsBelong(const std::string &op_type);

private:
static const std::set<std::string> control_only_task_ops_;
};

class GeLocalNodeExecutor : public NodeExecutor {


Loading…
Cancel
Save