| @@ -35,6 +35,7 @@ | |||||
| #include "ir_build/atc_ir_common.h" | #include "ir_build/atc_ir_common.h" | ||||
| #include "model/ge_model.h" | #include "model/ge_model.h" | ||||
| #include "graph/shape_refiner.h" | #include "graph/shape_refiner.h" | ||||
| #include "graph/opsproto_manager.h" | |||||
| using std::string; | using std::string; | ||||
| using namespace std; | using namespace std; | ||||
| @@ -109,6 +110,27 @@ static graphStatus CheckGlobalOptions(std::map<std::string, std::string> &global | |||||
| return GRAPH_SUCCESS; | return GRAPH_SUCCESS; | ||||
| } | } | ||||
| static void GetOpsProtoPath(string &opsproto_path) { | |||||
| GELOGI("Start to get ops proto path schedule."); | |||||
| const char *path_env = std::getenv("ASCEND_OPP_PATH"); | |||||
| if (path_env != nullptr) { | |||||
| string path = path_env; | |||||
| string file_path = RealPath(path.c_str()); | |||||
| if (file_path.empty()) { | |||||
| GELOGE(FAILED, "File path %s is invalid.", path.c_str()); | |||||
| return; | |||||
| } | |||||
| opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/"); | |||||
| GELOGI("Get opsproto so path from env : %s", path.c_str()); | |||||
| return; | |||||
| } | |||||
| string path_base = PluginManager::GetPath(); | |||||
| GELOGI("path_base is %s", path_base.c_str()); | |||||
| path_base = path_base.substr(0, path_base.rfind('/')); | |||||
| path_base = path_base.substr(0, path_base.rfind('/') + 1); | |||||
| opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); | |||||
| } | |||||
| graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options) { | graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options) { | ||||
| GELOGD("Enter aclgrphInitialize start!"); | GELOGD("Enter aclgrphInitialize start!"); | ||||
| // check global options | // check global options | ||||
| @@ -172,6 +194,7 @@ class Impl { | |||||
| bool is_dynamic_input); | bool is_dynamic_input); | ||||
| void SetRtSocVersion(); | void SetRtSocVersion(); | ||||
| void UpdateThreadContext(); | void UpdateThreadContext(); | ||||
| void LoadOpsProto();s | |||||
| public: | public: | ||||
| ge::GeGenerator generator_; | ge::GeGenerator generator_; | ||||
| std::map<std::string, std::string> options_; | std::map<std::string, std::string> options_; | ||||
| @@ -313,6 +336,16 @@ void Impl::UpdateThreadContext() { | |||||
| GetThreadLocalContext().SetGraphOption(options_); | GetThreadLocalContext().SetGraphOption(options_); | ||||
| } | } | ||||
| void Impl::LoadOpsProto() { | |||||
| string opsproto_path; | |||||
| GetOpsProtoPath(opsproto_path); | |||||
| GELOGI("Get opsproto path is %s", opsproto_path.c_str()); | |||||
| OpsProtoManager *manager = OpsProtoManager::Instance(); | |||||
| map<string, string> option_tmp; | |||||
| option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | |||||
| (void)manager->Initialize(option_tmp); | |||||
| } | |||||
| graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTensor> &inputs) { | graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTensor> &inputs) { | ||||
| auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); | auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); | ||||
| GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
| @@ -440,12 +473,17 @@ graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *pat | |||||
| graphStatus aclgrphInferShapeAndType(ge::Graph &graph) { | graphStatus aclgrphInferShapeAndType(ge::Graph &graph) { | ||||
| Impl builder; | Impl builder; | ||||
| std::map<std::string, std::string> options = {}; | |||||
| builder.Init(options); | |||||
| builder.LoadOpsProto(); | |||||
| auto compute_graph = GraphUtils::GetComputeGraph(graph); | auto compute_graph = GraphUtils::GetComputeGraph(graph); | ||||
| GE_CHECK_NOTNULL(compute_graph); | GE_CHECK_NOTNULL(compute_graph); | ||||
| auto root_graph = compute_graph->GetParentGraph(); | |||||
| if (root_graph != nullptr) { | |||||
| GELOGE(GRAPH_PARAM_INVALID, "Input param should not be subgraph"); | |||||
| return GRAPH_PARAM_INVALID; | |||||
| } | |||||
| auto ret = compute_graph->InferOriginFormat(); | auto ret = compute_graph->InferOriginFormat(); | ||||
| if (ret != GRAPH_SUCCESS) { | if (ret != GRAPH_SUCCESS) { | ||||
| GELOGE(ret, "Acl InferOriginFormat failed."); | GELOGE(ret, "Acl InferOriginFormat failed."); | ||||