Browse Source

!1224 add process for some op need infershape when running

From: @wan_xuelei
Reviewed-by: @xchu42,@wqtshg
Signed-off-by: @wqtshg
tags/v1.2.0
mindspore-ci-bot Gitee 3 years ago
parent
commit
9f027029d5
7 changed files with 34 additions and 1 deletions
  1. +1
    -0
      ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc
  2. +1
    -1
      ge/hybrid/executor/worker/shape_inference_engine.cc
  3. +15
    -0
      ge/hybrid/model/hybrid_model_builder.cc
  4. +1
    -0
      ge/hybrid/model/hybrid_model_builder.h
  5. +1
    -0
      ge/hybrid/model/node_item.h
  6. +1
    -0
      ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc
  7. +14
    -0
      tests/ut/ge/hybrid/ge_hybrid_unittest.cc

+ 1
- 0
ge/ge_local_engine/ops_kernel_store/op/ge_deleted_op.cc View File

@@ -38,6 +38,7 @@ REGISTER_OP_CREATOR(ExpandDims, GeDeletedOp);
REGISTER_OP_CREATOR(Reshape, GeDeletedOp);
REGISTER_OP_CREATOR(ReFormat, GeDeletedOp);
REGISTER_OP_CREATOR(Squeeze, GeDeletedOp);
REGISTER_OP_CREATOR(Unsqueeze, GeDeletedOp);
REGISTER_OP_CREATOR(Size, GeDeletedOp);
REGISTER_OP_CREATOR(Shape, GeDeletedOp);
REGISTER_OP_CREATOR(ShapeN, GeDeletedOp);


+ 1
- 1
ge/hybrid/executor/worker/shape_inference_engine.cc View File

@@ -41,7 +41,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) {
// Wait for "const input nodes" if node's shape inference function requires any.
// Even if output shape is static, there are cases that the const-input will be used in OpTiling and Execution
GE_CHK_STATUS_RET_NOLOG(AwaitDependentNodes(node_state));
if (node_item.is_output_shape_static) {
if (node_item.is_output_shape_static && !node_item.is_need_force_infershape) {
return SUCCESS;
}



+ 15
- 0
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -50,6 +50,7 @@ const char *const kProfilingBpNode = "ProfilingBpNode";
const char *const kProfilingEndNode = "ProfilingEndNode";
const char *const kProfilingArNode = "ProfilingAllReduceNode";
const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE";
const char *const kForceInfershape = "_force_infershape_when_running";

Status SetOutputNameAttr(ComputeGraph &graph) {
vector<string> output_names;
@@ -171,6 +172,9 @@ Status HybridModelBuilder::ValidateParams() {

Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) {
auto op_desc = node->GetOpDesc();
GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item),
"[%s] Failed to parse force_infershape node.",
node_item.NodeName().c_str());
vector<string> dependencies = node->GetOpDesc()->GetOpInferDepends();
GE_CHK_STATUS_RET(ParseDependentInputNodes(node_item, dependencies),
"[%s] Failed to parse node dependencies.",
@@ -263,6 +267,17 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n
return SUCCESS;
}

Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
// not care result, if no this attr, stand for the op does not need force infershape
(void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape);
GELOGD("node [%s] is need do infershape , flag is %d",
op_desc->GetName().c_str(),
node_item.is_need_force_infershape);
return SUCCESS;
}

Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) {
std::set<NodePtr> dependent_input_nodes;
auto &ge_node = node_item.node;


+ 1
- 0
ge/hybrid/model/hybrid_model_builder.h View File

@@ -62,6 +62,7 @@ class HybridModelBuilder {
Status IdentifySameInputs(NodeItem &node_item);
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item);
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item);
Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item);
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies);
Status ParseDependentForFusedSubgraph(NodeItem &node_item);
Status IndexTaskDefs();


+ 1
- 0
ge/hybrid/model/node_item.h View File

@@ -83,6 +83,7 @@ struct NodeItem {
bool has_observer = false;
bool has_optional_inputs = false;
bool is_output_shape_static = true;
bool is_need_force_infershape = false;
UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE;
std::string node_name;
std::string node_type;


+ 1
- 0
ge/hybrid/node_executor/ge_local/ge_local_node_executor.cc View File

@@ -33,6 +33,7 @@ const std::map<std::string, std::vector<uint32_t>>
{RESHAPE, {}},
{EXPANDDIMS, {}},
{SQUEEZE, {}},
{UNSQUEEZE, {}},
{BROADCASTGRADIENTARGS, {}}
};



+ 14
- 0
tests/ut/ge/hybrid/ge_hybrid_unittest.cc View File

@@ -152,6 +152,20 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) {
ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR);
}

TEST_F(UtestGeHybrid, parse_force_infershape_nodes) {
const char *const kForceInfershape = "_force_infershape_when_running";
auto graph = make_shared<ComputeGraph>("graph");
OpDescPtr op_desc = CreateOpDesc("Conv2D", "Conv2D");
ge::AttrUtils::SetBool(op_desc, kForceInfershape, true);
auto node = graph->AddNode(op_desc);
std::unique_ptr<NodeItem> new_node;
NodeItem::Create(node, new_node);
GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph);
HybridModel hybrid_model(ge_root_model);
HybridModelBuilder hybrid_model_builder(hybrid_model);
ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS);
}

TEST_F(UtestGeHybrid, index_taskdefs_success) {
// build aicore task
domi::ModelTaskDef model_task_def;


Loading…
Cancel
Save