From c94e0fbdc6b6560a4d4e67d9b71f7d1e8ccd0b2b Mon Sep 17 00:00:00 2001 From: wxl Date: Tue, 9 Mar 2021 14:57:36 +0800 Subject: [PATCH 1/4] add force infershape for some op --- ge/hybrid/executor/worker/shape_inference_engine.cc | 2 +- ge/hybrid/model/hybrid_model_builder.cc | 13 +++++++++++++ ge/hybrid/model/hybrid_model_builder.h | 1 + ge/hybrid/model/node_item.h | 1 + 4 files changed, 16 insertions(+), 1 deletion(-) diff --git a/ge/hybrid/executor/worker/shape_inference_engine.cc b/ge/hybrid/executor/worker/shape_inference_engine.cc index bb6281e1..0a7f3985 100755 --- a/ge/hybrid/executor/worker/shape_inference_engine.cc +++ b/ge/hybrid/executor/worker/shape_inference_engine.cc @@ -41,7 +41,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { // Wait for "const input nodes" if node's shape inference function requires any. // Even if output shape is static, there are cases that the const-input will be used in OpTiling and Execution GE_CHK_STATUS_RET_NOLOG(AwaitDependentNodes(node_state)); - if (node_item.is_output_shape_static) { + if (node_item.is_output_shape_static && node_item.is_need_force_infershape) { return SUCCESS; } diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index ac57b2ea..58a7c23f 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -50,6 +50,7 @@ const char *const kProfilingBpNode = "ProfilingBpNode"; const char *const kProfilingEndNode = "ProfilingEndNode"; const char *const kProfilingArNode = "ProfilingAllReduceNode"; const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; +const char *const kForceInfershape = "_force_infershape_when_running"; Status SetOutputNameAttr(ComputeGraph &graph) { vector output_names; @@ -171,6 +172,9 @@ Status HybridModelBuilder::ValidateParams() { Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { auto op_desc = node->GetOpDesc(); + GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item), + "[%s] Failed to parse force_infershape node.", + node_item.NodeName().c_str()); vector dependencies = node->GetOpDesc()->GetOpInferDepends(); GE_CHK_STATUS_RET(ParseDependentInputNodes(node_item, dependencies), "[%s] Failed to parse node dependencies.", @@ -263,6 +267,15 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n return SUCCESS; } +Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item) { + auto op_desc = node->GetOpDesc(); + GE_CHECK_NOTNULL(op_desc); + // not care result, if no this attr, stand for the op does not need force infershape + (void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); + GELOGD("node [%s] is need do infershape , flag is %d", node_item.is_need_force_infershape); + return SUCCESS; +} + Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies) { std::set dependent_input_nodes; auto &ge_node = node_item.node; diff --git a/ge/hybrid/model/hybrid_model_builder.h b/ge/hybrid/model/hybrid_model_builder.h index 71663a6e..313d5ca6 100644 --- a/ge/hybrid/model/hybrid_model_builder.h +++ b/ge/hybrid/model/hybrid_model_builder.h @@ -62,6 +62,7 @@ class HybridModelBuilder { Status IdentifySameInputs(NodeItem &node_item); Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); + Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); Status ParseDependentInputNodes(NodeItem &node_item, const std::vector &dependencies); Status ParseDependentForFusedSubgraph(NodeItem &node_item); Status IndexTaskDefs(); diff --git a/ge/hybrid/model/node_item.h b/ge/hybrid/model/node_item.h index 300744d1..631dbd9e 100644 --- a/ge/hybrid/model/node_item.h +++ b/ge/hybrid/model/node_item.h @@ -83,6 +83,7 @@ struct NodeItem { bool has_observer = false; bool has_optional_inputs = false; bool is_output_shape_static = true; + bool is_need_force_infershape = false; UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE; std::string node_name; std::string node_type; From 365401b52fe53306f7b3ef87e4a2b17ac8090911 Mon Sep 17 00:00:00 2001 From: wxl Date: Tue, 9 Mar 2021 19:57:27 +0800 Subject: [PATCH 2/4] add force infershape for some op --- ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc | 1 + ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc | 1 + 2 files changed, 2 insertions(+) diff --git a/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc b/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc index b2f3d095..90d95217 100755 --- a/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc +++ b/ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc @@ -38,6 +38,7 @@ REGISTER_OP_CREATOR(ExpandDims, GeDeletedOp); REGISTER_OP_CREATOR(Reshape, GeDeletedOp); REGISTER_OP_CREATOR(ReFormat, GeDeletedOp); REGISTER_OP_CREATOR(Squeeze, GeDeletedOp); +REGISTER_OP_CREATOR(Unsqueeze, GeDeletedOp); REGISTER_OP_CREATOR(Size, GeDeletedOp); REGISTER_OP_CREATOR(Shape, GeDeletedOp); REGISTER_OP_CREATOR(ShapeN, GeDeletedOp); diff --git a/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc b/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc index 3d2e3084..9d92420e 100755 --- a/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc +++ b/ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc @@ -33,6 +33,7 @@ const std::map> {RESHAPE, {}}, {EXPANDDIMS, {}}, {SQUEEZE, {}}, + {UNSQUEEZE, {}}, {BROADCASTGRADIENTARGS, {}} }; From 5ae267433be2f99134d5fe26f6b6adbcb37f71ba Mon Sep 17 00:00:00 2001 From: wxl Date: Tue, 9 Mar 2021 22:36:32 +0800 Subject: [PATCH 3/4] add force infershape for some op --- ge/hybrid/model/hybrid_model_builder.cc | 4 +++- tests/ut/ge/hybrid/ge_hybrid_unittest.cc | 14 ++++++++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/ge/hybrid/model/hybrid_model_builder.cc b/ge/hybrid/model/hybrid_model_builder.cc index 58a7c23f..a349210d 100755 --- a/ge/hybrid/model/hybrid_model_builder.cc +++ b/ge/hybrid/model/hybrid_model_builder.cc @@ -272,7 +272,9 @@ Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeIt GE_CHECK_NOTNULL(op_desc); // not care result, if no this attr, stand for the op does not need force infershape (void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); - GELOGD("node [%s] is need do infershape , flag is %d", node_item.is_need_force_infershape); + GELOGD("node [%s] is need do infershape , flag is %d", + op_desc->GetName().c_str(), + node_item.is_need_force_infershape); return SUCCESS; } diff --git a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc index 0b6ca271..286186de 100644 --- a/tests/ut/ge/hybrid/ge_hybrid_unittest.cc +++ b/tests/ut/ge/hybrid/ge_hybrid_unittest.cc @@ -152,6 +152,20 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) { ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); } +TEST_F(UtestGeHybrid, parse_force_infershape_nodes) { + const char *const kForceInfershape = "_force_infershape_when_running"; + auto graph = make_shared("graph"); + OpDescPtr op_desc = CreateOpDesc("Conv2D", "Conv2D"); + ge::AttrUtils::SetBool(op_desc, kForceInfershape, true); + auto node = graph->AddNode(op_desc); + std::unique_ptr new_node; + NodeItem::Create(node, new_node); + GeRootModelPtr ge_root_model = make_shared(graph); + HybridModel hybrid_model(ge_root_model); + HybridModelBuilder hybrid_model_builder(hybrid_model); + ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS); +} + TEST_F(UtestGeHybrid, index_taskdefs_success) { // build aicore task domi::ModelTaskDef model_task_def; From 1227e0339ffd7ef7855c9d6b791a4926ce32d8b5 Mon Sep 17 00:00:00 2001 From: wxl Date: Thu, 11 Mar 2021 10:35:53 +0800 Subject: [PATCH 4/4] add force infershape for some op --- ge/hybrid/executor/worker/shape_inference_engine.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ge/hybrid/executor/worker/shape_inference_engine.cc b/ge/hybrid/executor/worker/shape_inference_engine.cc index 0a7f3985..27919589 100755 --- a/ge/hybrid/executor/worker/shape_inference_engine.cc +++ b/ge/hybrid/executor/worker/shape_inference_engine.cc @@ -41,7 +41,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { // Wait for "const input nodes" if node's shape inference function requires any. // Even if output shape is static, there are cases that the const-input will be used in OpTiling and Execution GE_CHK_STATUS_RET_NOLOG(AwaitDependentNodes(node_state)); - if (node_item.is_output_shape_static && node_item.is_need_force_infershape) { + if (node_item.is_output_shape_static && !node_item.is_need_force_infershape) { return SUCCESS; }