Browse Source

add force infershape for some op

tags/v1.2.0
wxl 3 years ago
parent
commit
c94e0fbdc6
4 changed files with 16 additions and 1 deletions
  1. +1
    -1
      ge/hybrid/executor/worker/shape_inference_engine.cc
  2. +13
    -0
      ge/hybrid/model/hybrid_model_builder.cc
  3. +1
    -0
      ge/hybrid/model/hybrid_model_builder.h
  4. +1
    -0
      ge/hybrid/model/node_item.h

+ 1
- 1
ge/hybrid/executor/worker/shape_inference_engine.cc View File

@@ -41,7 +41,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) {
// Wait for "const input nodes" if node's shape inference function requires any.
// Even if output shape is static, there are cases that the const-input will be used in OpTiling and Execution
GE_CHK_STATUS_RET_NOLOG(AwaitDependentNodes(node_state));
if (node_item.is_output_shape_static) {
if (node_item.is_output_shape_static && node_item.is_need_force_infershape) {
return SUCCESS;
}



+ 13
- 0
ge/hybrid/model/hybrid_model_builder.cc View File

@@ -50,6 +50,7 @@ const char *const kProfilingBpNode = "ProfilingBpNode";
const char *const kProfilingEndNode = "ProfilingEndNode";
const char *const kProfilingArNode = "ProfilingAllReduceNode";
const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE";
const char *const kForceInfershape = "_force_infershape_when_running";

Status SetOutputNameAttr(ComputeGraph &graph) {
vector<string> output_names;
@@ -171,6 +172,9 @@ Status HybridModelBuilder::ValidateParams() {

Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) {
auto op_desc = node->GetOpDesc();
GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item),
"[%s] Failed to parse force_infershape node.",
node_item.NodeName().c_str());
vector<string> dependencies = node->GetOpDesc()->GetOpInferDepends();
GE_CHK_STATUS_RET(ParseDependentInputNodes(node_item, dependencies),
"[%s] Failed to parse node dependencies.",
@@ -263,6 +267,15 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n
return SUCCESS;
}

Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item) {
auto op_desc = node->GetOpDesc();
GE_CHECK_NOTNULL(op_desc);
// not care result, if no this attr, stand for the op does not need force infershape
(void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape);
GELOGD("node [%s] is need do infershape , flag is %d", node_item.is_need_force_infershape);
return SUCCESS;
}

Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) {
std::set<NodePtr> dependent_input_nodes;
auto &ge_node = node_item.node;


+ 1
- 0
ge/hybrid/model/hybrid_model_builder.h View File

@@ -62,6 +62,7 @@ class HybridModelBuilder {
Status IdentifySameInputs(NodeItem &node_item);
Status BuildNodeItem(const NodePtr &node, NodeItem &node_item);
Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item);
Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item);
Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies);
Status ParseDependentForFusedSubgraph(NodeItem &node_item);
Status IndexTaskDefs();


+ 1
- 0
ge/hybrid/model/node_item.h View File

@@ -83,6 +83,7 @@ struct NodeItem {
bool has_observer = false;
bool has_optional_inputs = false;
bool is_output_shape_static = true;
bool is_need_force_infershape = false;
UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE;
std::string node_name;
std::string node_type;


Loading…
Cancel
Save