From: @wan_xuelei Reviewed-by: @xchu42,@wqtshg Signed-off-by: @wqtshgtags/v1.2.0
| @@ -38,6 +38,7 @@ REGISTER_OP_CREATOR(ExpandDims, GeDeletedOp); | |||||
| REGISTER_OP_CREATOR(Reshape, GeDeletedOp); | REGISTER_OP_CREATOR(Reshape, GeDeletedOp); | ||||
| REGISTER_OP_CREATOR(ReFormat, GeDeletedOp); | REGISTER_OP_CREATOR(ReFormat, GeDeletedOp); | ||||
| REGISTER_OP_CREATOR(Squeeze, GeDeletedOp); | REGISTER_OP_CREATOR(Squeeze, GeDeletedOp); | ||||
| REGISTER_OP_CREATOR(Unsqueeze, GeDeletedOp); | |||||
| REGISTER_OP_CREATOR(Size, GeDeletedOp); | REGISTER_OP_CREATOR(Size, GeDeletedOp); | ||||
| REGISTER_OP_CREATOR(Shape, GeDeletedOp); | REGISTER_OP_CREATOR(Shape, GeDeletedOp); | ||||
| REGISTER_OP_CREATOR(ShapeN, GeDeletedOp); | REGISTER_OP_CREATOR(ShapeN, GeDeletedOp); | ||||
| @@ -41,7 +41,7 @@ Status ShapeInferenceEngine::InferShape(NodeState &node_state) { | |||||
| // Wait for "const input nodes" if node's shape inference function requires any. | // Wait for "const input nodes" if node's shape inference function requires any. | ||||
| // Even if output shape is static, there are cases that the const-input will be used in OpTiling and Execution | // Even if output shape is static, there are cases that the const-input will be used in OpTiling and Execution | ||||
| GE_CHK_STATUS_RET_NOLOG(AwaitDependentNodes(node_state)); | GE_CHK_STATUS_RET_NOLOG(AwaitDependentNodes(node_state)); | ||||
| if (node_item.is_output_shape_static) { | |||||
| if (node_item.is_output_shape_static && !node_item.is_need_force_infershape) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -50,6 +50,7 @@ const char *const kProfilingBpNode = "ProfilingBpNode"; | |||||
| const char *const kProfilingEndNode = "ProfilingEndNode"; | const char *const kProfilingEndNode = "ProfilingEndNode"; | ||||
| const char *const kProfilingArNode = "ProfilingAllReduceNode"; | const char *const kProfilingArNode = "ProfilingAllReduceNode"; | ||||
| const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; | const char *const kEngineNameRts = "DNN_VM_RTS_OP_STORE"; | ||||
| const char *const kForceInfershape = "_force_infershape_when_running"; | |||||
| Status SetOutputNameAttr(ComputeGraph &graph) { | Status SetOutputNameAttr(ComputeGraph &graph) { | ||||
| vector<string> output_names; | vector<string> output_names; | ||||
| @@ -171,6 +172,9 @@ Status HybridModelBuilder::ValidateParams() { | |||||
| Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { | Status HybridModelBuilder::BuildNodeItem(const NodePtr &node, NodeItem &node_item) { | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| GE_CHK_STATUS_RET(ParseForceInfershapeNodes(node, node_item), | |||||
| "[%s] Failed to parse force_infershape node.", | |||||
| node_item.NodeName().c_str()); | |||||
| vector<string> dependencies = node->GetOpDesc()->GetOpInferDepends(); | vector<string> dependencies = node->GetOpDesc()->GetOpInferDepends(); | ||||
| GE_CHK_STATUS_RET(ParseDependentInputNodes(node_item, dependencies), | GE_CHK_STATUS_RET(ParseDependentInputNodes(node_item, dependencies), | ||||
| "[%s] Failed to parse node dependencies.", | "[%s] Failed to parse node dependencies.", | ||||
| @@ -263,6 +267,17 @@ Status HybridModelBuilder::GetOrCreateNodeItem(const NodePtr &node, NodeItem **n | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status HybridModelBuilder::ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item) { | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| GE_CHECK_NOTNULL(op_desc); | |||||
| // not care result, if no this attr, stand for the op does not need force infershape | |||||
| (void)AttrUtils::GetBool(op_desc, kForceInfershape, node_item.is_need_force_infershape); | |||||
| GELOGD("node [%s] is need do infershape , flag is %d", | |||||
| op_desc->GetName().c_str(), | |||||
| node_item.is_need_force_infershape); | |||||
| return SUCCESS; | |||||
| } | |||||
| Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) { | Status HybridModelBuilder::ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies) { | ||||
| std::set<NodePtr> dependent_input_nodes; | std::set<NodePtr> dependent_input_nodes; | ||||
| auto &ge_node = node_item.node; | auto &ge_node = node_item.node; | ||||
| @@ -62,6 +62,7 @@ class HybridModelBuilder { | |||||
| Status IdentifySameInputs(NodeItem &node_item); | Status IdentifySameInputs(NodeItem &node_item); | ||||
| Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | Status BuildNodeItem(const NodePtr &node, NodeItem &node_item); | ||||
| Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | Status GetOrCreateNodeItem(const NodePtr &node, NodeItem **node_item); | ||||
| Status ParseForceInfershapeNodes(const NodePtr &node, NodeItem &node_item); | |||||
| Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | Status ParseDependentInputNodes(NodeItem &node_item, const std::vector<string> &dependencies); | ||||
| Status ParseDependentForFusedSubgraph(NodeItem &node_item); | Status ParseDependentForFusedSubgraph(NodeItem &node_item); | ||||
| Status IndexTaskDefs(); | Status IndexTaskDefs(); | ||||
| @@ -83,6 +83,7 @@ struct NodeItem { | |||||
| bool has_observer = false; | bool has_observer = false; | ||||
| bool has_optional_inputs = false; | bool has_optional_inputs = false; | ||||
| bool is_output_shape_static = true; | bool is_output_shape_static = true; | ||||
| bool is_need_force_infershape = false; | |||||
| UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE; | UnknowShapeOpType shape_inference_type = DEPEND_IN_SHAPE; | ||||
| std::string node_name; | std::string node_name; | ||||
| std::string node_type; | std::string node_type; | ||||
| @@ -33,6 +33,7 @@ const std::map<std::string, std::vector<uint32_t>> | |||||
| {RESHAPE, {}}, | {RESHAPE, {}}, | ||||
| {EXPANDDIMS, {}}, | {EXPANDDIMS, {}}, | ||||
| {SQUEEZE, {}}, | {SQUEEZE, {}}, | ||||
| {UNSQUEEZE, {}}, | |||||
| {BROADCASTGRADIENTARGS, {}} | {BROADCASTGRADIENTARGS, {}} | ||||
| }; | }; | ||||
| @@ -152,6 +152,20 @@ TEST_F(UtestGeHybrid, index_taskdefs_failed) { | |||||
| ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); | ASSERT_EQ(hybrid_model_builder.IndexTaskDefs(graph, ge_model), INTERNAL_ERROR); | ||||
| } | } | ||||
| TEST_F(UtestGeHybrid, parse_force_infershape_nodes) { | |||||
| const char *const kForceInfershape = "_force_infershape_when_running"; | |||||
| auto graph = make_shared<ComputeGraph>("graph"); | |||||
| OpDescPtr op_desc = CreateOpDesc("Conv2D", "Conv2D"); | |||||
| ge::AttrUtils::SetBool(op_desc, kForceInfershape, true); | |||||
| auto node = graph->AddNode(op_desc); | |||||
| std::unique_ptr<NodeItem> new_node; | |||||
| NodeItem::Create(node, new_node); | |||||
| GeRootModelPtr ge_root_model = make_shared<GeRootModel>(graph); | |||||
| HybridModel hybrid_model(ge_root_model); | |||||
| HybridModelBuilder hybrid_model_builder(hybrid_model); | |||||
| ASSERT_EQ(hybrid_model_builder.ParseForceInfershapeNodes(node, *new_node), SUCCESS); | |||||
| } | |||||
| TEST_F(UtestGeHybrid, index_taskdefs_success) { | TEST_F(UtestGeHybrid, index_taskdefs_success) { | ||||
| // build aicore task | // build aicore task | ||||
| domi::ModelTaskDef model_task_def; | domi::ModelTaskDef model_task_def; | ||||