From c94e0fbdc6b6560a4d4e67d9b71f7d1e8ccd0b2b Mon Sep 17 00:00:00 2001 From: wxl Date: Tue, 9 Mar 2021 14:57:36 +0800 Subject: [PATCH] add force infershape for some op --- ge/hybrid/executor/worker/shape_inference_engine.cc | 2 +- ge/hybrid/model/hybrid_model_builder.cc | 13 +++++++++++++ ge/hybrid/model/hybrid_model_builder.h | 1 + ge/hybrid/model/node_item.h | 1 + 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/ge/hybrid/executor/worker/shape_inference_engine.cc b/ge/hybrid/executor/worker/shape_inference_engine.cc index bb6281e1..0a7f3985 100755 --- a/ge/hybrid/executor/worker/shape_inference_engine.cc +++ b/ge/hybrid/executor/worker/shape_inference_engine.cc @@ -41,7 +41,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { // Wait for "const input nodes" if node's shape inference function requires any. // Even if output shape is static, there are cases that the const-input will be used in OpTiling and Execution GE_CHK_STATUS_RET_NOLOG(AwaitDependentNodes(node_state)); - if (node_item.is_output_shape_static) { + if (node_item.is_output_shape_static && node_item.is_need_force_infershape) { return SUCCESS; } diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index ac57b2ea..58a7c23f 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -50,6 +50,7 @@ const char *const kProfilingBpNode = "ProfilingBpNode"; const char *const kProfilingEndNode = "ProfilingEndNode"; const char *const kProfilingArNode = "ProfilingAllReduceNode"; const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; +const char *const kForceInfershape = "_force_infershape_when_running"; Status SetOutputNameAttr(ComputeGraph &graph) { vector output_names; @@ -171,6 +172,9 @@ Status HybridModelBuilder::ValidateParams() { Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { auto op_desc = node->GetOpDesc(); + GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item), + "[%s] Failed to parse force_infershape node.", + node_item.NodeName().c_str()); vector dependencies = node->GetOpDesc()->GetOpInferDepends(); GE_CHK_STATUS_RET(ParseDependentInputNodes(node_item, dependencies), "[%s] Failed to parse node dependencies.", @@ -263,6 +267,15 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n return SUCCESS; } +Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + // not care result, if no this attr, stand for the op does not need force infershape + (void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); + GELOGD("node [%s] is need do infershape , flag is %d", node_item.is_need_force_infershape); + return SUCCESS; +} + Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies) { std::set dependent_input_nodes; auto &ge_node = node_item.node; diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 71663a6e..313d5ca6 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -62,6 +62,7 @@ class HybridModelBuilder { Status IdentifySameInputs(NodeItem &node_item); Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); + Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); Status ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies); Status ParseDependentForFusedSubgraph(NodeItem &node_item); Status IndexTaskDefs(); diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index 300744d1..631dbd9e 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -83,6 +83,7 @@ struct NodeItem { bool has_observer = false; bool has_optional_inputs = false; bool is_output_shape_static = true; + bool is_need_force_infershape = false; UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE; std::string node_name; std::string node_type;