| @@ -36,6 +36,9 @@ | |||||
| #include "model/ge_model.h" | #include "model/ge_model.h" | ||||
| #include "graph/shape_refiner.h" | #include "graph/shape_refiner.h" | ||||
| #include "graph/opsproto_manager.h" | #include "graph/opsproto_manager.h" | ||||
| #include "inc/pass_manager.h" | |||||
| #include "graph/passes/net_output_pass.h" | |||||
| #include "graph/passes/data_pass.h" | |||||
| using std::string; | using std::string; | ||||
| using namespace std; | using namespace std; | ||||
| @@ -233,6 +236,7 @@ class Impl { | |||||
| ModelBufferData &ge_models); | ModelBufferData &ge_models); | ||||
| graphStatus InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, | graphStatus InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, | ||||
| bool is_dynamic_input); | bool is_dynamic_input); | ||||
| static graphStatus InferShapePrepare(const ComputeGraphPtr &compute_graph); | |||||
| void SetRtSocVersion(); | void SetRtSocVersion(); | ||||
| void UpdateThreadContext(); | void UpdateThreadContext(); | ||||
| void LoadOpsProto(); | void LoadOpsProto(); | ||||
| @@ -243,6 +247,22 @@ class Impl { | |||||
| OmgContext omg_context_; | OmgContext omg_context_; | ||||
| }; | }; | ||||
| static graphStatus InferShapePrepare(const ComputeGraphPtr &compute_graph) { | |||||
| GE_CHECK_NOTNULL(compute_graph); | |||||
| PassManager prepare_infershape; | |||||
| prepare_infershape.AddPass("PrepareNetoutput", new(std::nothrow) NetOutputPass); | |||||
| prepare_infershape.AddPass("PrepareSubGraphReflection", new (std::nothrow) DataPass); | |||||
| auto ret = prepare_infershape.Run(compute_graph); | |||||
| if ((ret != SUCCESS) && (ret != NOT_CHANGED)) { | |||||
| GELOGE(ret, "Prepair for infershape failed, ret:%d", ret); | |||||
| return ret; | |||||
| } | |||||
| GELOGD("Prepair for infershape success!"); | |||||
| return GRAPH_SUCCESS; | |||||
| } | |||||
| graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { | graphStatus Impl::UpdateDataOpAttr(const Graph &graph) { | ||||
| GELOGD("Enter Update Data Attr Process!"); | GELOGD("Enter Update Data Attr Process!"); | ||||
| if (options_.find(kInputShape) == options_.end()) { | if (options_.find(kInputShape) == options_.end()) { | ||||
| @@ -591,7 +611,12 @@ graphStatus aclgrphInferShapeAndType(ge::Graph &graph) { | |||||
| return GRAPH_PARAM_INVALID; | return GRAPH_PARAM_INVALID; | ||||
| } | } | ||||
| auto ret = compute_graph->TopologicalSorting(); | |||||
| auto ret = Impl::InferShapePrepare(root_graph); | |||||
| if (ret != GRAPH_SUCCESS) { | |||||
| return ret; | |||||
| } | |||||
| ret = compute_graph->TopologicalSorting(); | |||||
| if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
| GELOGE(ret, "Acl topo logical sort failed."); | GELOGE(ret, "Acl topo logical sort failed."); | ||||
| return ret; | return ret; | ||||