Browse Source

!278 Support StatelessIf and StatelessWhile

From: @xchu42
Reviewed-by: @ji_chen,@wqtshg
Signed-off-by: @ji_chen
tags/v1.1.0
mindspore-ci-bot Gitee 4 years ago
parent
commit
f01cfe5167
5 changed files with 16 additions and 7 deletions
  1. +2
    -2
      ge/hybrid/model/hybrid_model_builder.cc
  2. +9
    -2
      ge/hybrid/model/node_item.cc
  3. +2
    -0
      ge/hybrid/model/node_item.h
  4. +2
    -2
      ge/hybrid/node_executor/controlop/control_op_executor.cc
  5. +1
    -1
      ge/hybrid/node_executor/node_executor.cc

+ 2
- 2
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -257,7 +257,7 @@ Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const s
}

// cond or branch need to be prepared before the execution of IF or CASE
if (node_item.node_type == IF || node_item.node_type == CASE) {
if (node_item.node_type == IF || node_item.node_type == STATELESSIF || node_item.node_type == CASE) {
const auto &in_anchor = ge_node->GetInDataAnchor(0);
GE_CHECK_NOTNULL(in_anchor);
const auto &peer_anchor = in_anchor->GetPeerOutAnchor();
@@ -920,7 +920,7 @@ Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr
auto parent_node = sub_graph.GetParentNode();
GE_CHECK_NOTNULL(parent_node);
auto op_type = parent_node->GetType();
if (op_type == IF || op_type == CASE || op_type == WHILE) {
if (IsControlOp(op_type)) {
GELOGD("Set ge_model for control op subgraph: [%s], task_size = %d",
sub_graph.GetName().c_str(),
ge_model->GetModelTaskDefPtr()->task_size());


+ 9
- 2
ge/hybrid/model/node_item.cc View File

@@ -28,6 +28,9 @@ namespace hybrid {
namespace {
const char * const kAttrNameOriginalFusionGraph = "_original_fusion_graph";
const char * const kNodeTypeRetVal = "_RetVal";
std::set<std::string> kControlOpTypes {
IF, STATELESSIF, CASE, WHILE, STATELESSWHILE
};

Status ParseInputMapping(Node &node, OpDesc &op_desc, FusedSubgraph &fused_subgraph) {
uint32_t parent_index = 0;
@@ -102,6 +105,11 @@ Status ParseFusedSubgraph(NodeItem &node_item) {
return SUCCESS;
}
} // namespace

bool IsControlOp(const std::string &op_type) {
return kControlOpTypes.count(op_type) > 0;
}

NodeItem::NodeItem(NodePtr node): node(std::move(node)) {
this->op_desc = this->node->GetOpDesc().get();
this->node_id = this->op_desc->GetId();
@@ -153,8 +161,7 @@ Status NodeItem::Init() {
}

bool NodeItem::IsControlOp() const {
auto op_type = op_desc->GetType();
return op_type == IF || op_type == CASE || op_type == WHILE || op_type == FOR;
return ge::hybrid::IsControlOp(op_desc->GetType());
}

std::string NodeItem::DebugString() const {


+ 2
- 0
ge/hybrid/model/node_item.h View File

@@ -36,6 +36,8 @@ struct FusedSubgraph {
ComputeGraphPtr graph;
};

bool IsControlOp(const std::string &op_type);

// for caching static information across execution
struct NodeItem {
explicit NodeItem(NodePtr node);


+ 2
- 2
ge/hybrid/node_executor/controlop/control_op_executor.cc View File

@@ -404,11 +404,11 @@ Status ControlOpNodeExecutor::LoadTask(const HybridModel &model,

unique_ptr<ControlOpNodeTask> node_task;
auto node_type = node->GetType();
if (node_type == IF) {
if (node_type == IF || node_type == STATELESSIF) {
node_task.reset(new(std::nothrow) IfOpNodeTask());
} else if (node_type == CASE) {
node_task.reset(new(std::nothrow) CaseOpNodeTask());
} else if (node_type == WHILE) {
} else if (node_type == WHILE || node_type == STATELESSWHILE) {
node_task.reset(new(std::nothrow) WhileOpNodeTask());
} else {
GELOGE(PARAM_INVALID, "[%s] Unsupported type: %s", node->GetName().c_str(), node_type.c_str());


+ 1
- 1
ge/hybrid/node_executor/node_executor.cc View File

@@ -97,7 +97,7 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node
return ExecutorType::GE_LOCAL;
}

if (op_type == IF || op_type == CASE || op_type == WHILE) {
if (IsControlOp(op_type)) {
return ExecutorType::CONTROL_OP;
}



Loading…
Cancel
Save