| @@ -35,6 +35,7 @@ | |||
| #include "ir_build/atc_ir_common.h" | |||
| #include "model/ge_model.h" | |||
| #include "graph/shape_refiner.h" | |||
| #include "graph/opsproto_manager.h" | |||
| using std::string; | |||
| using namespace std; | |||
| @@ -109,6 +110,27 @@ static graphStatus CheckGlobalOptions(std::map<std::string, std::string> &global | |||
| return GRAPH_SUCCESS; | |||
| } | |||
| static void GetOpsProtoPath(string &opsproto_path) { | |||
| GELOGI("Start to get ops proto path schedule."); | |||
| const char *path_env = std::getenv("ASCEND_OPP_PATH"); | |||
| if (path_env != nullptr) { | |||
| string path = path_env; | |||
| string file_path = RealPath(path.c_str()); | |||
| if (file_path.empty()) { | |||
| GELOGE(FAILED, "File path %s is invalid.", path.c_str()); | |||
| return; | |||
| } | |||
| opsproto_path = (path + "/op_proto/custom/" + ":") + (path + "/op_proto/built-in/"); | |||
| GELOGI("Get opsproto so path from env : %s", path.c_str()); | |||
| return; | |||
| } | |||
| string path_base = PluginManager::GetPath(); | |||
| GELOGI("path_base is %s", path_base.c_str()); | |||
| path_base = path_base.substr(0, path_base.rfind('/')); | |||
| path_base = path_base.substr(0, path_base.rfind('/') + 1); | |||
| opsproto_path = (path_base + "ops/op_proto/custom/" + ":") + (path_base + "ops/op_proto/built-in/"); | |||
| } | |||
| graphStatus aclgrphBuildInitialize(std::map<std::string, std::string> global_options) { | |||
| GELOGD("Enter aclgrphInitialize start!"); | |||
| // check global options | |||
| @@ -172,6 +194,7 @@ class Impl { | |||
| bool is_dynamic_input); | |||
| void SetRtSocVersion(); | |||
| void UpdateThreadContext(); | |||
| void LoadOpsProto();s | |||
| public: | |||
| ge::GeGenerator generator_; | |||
| std::map<std::string, std::string> options_; | |||
| @@ -313,6 +336,16 @@ void Impl::UpdateThreadContext() { | |||
| GetThreadLocalContext().SetGraphOption(options_); | |||
| } | |||
| void Impl::LoadOpsProto() { | |||
| string opsproto_path; | |||
| GetOpsProtoPath(opsproto_path); | |||
| GELOGI("Get opsproto path is %s", opsproto_path.c_str()); | |||
| OpsProtoManager *manager = OpsProtoManager::Instance(); | |||
| map<string, string> option_tmp; | |||
| option_tmp.emplace(std::pair<string, string>(string("ge.opsProtoLibPath"), opsproto_path)); | |||
| (void)manager->Initialize(option_tmp); | |||
| } | |||
| graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector<ge::GeTensor> &inputs) { | |||
| auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| @@ -440,12 +473,17 @@ graphStatus aclgrphGetIRVersion(int *major_version, int *minor_version, int *pat | |||
| graphStatus aclgrphInferShapeAndType(ge::Graph &graph) { | |||
| Impl builder; | |||
| std::map<std::string, std::string> options = {}; | |||
| builder.Init(options); | |||
| builder.LoadOpsProto(); | |||
| auto compute_graph = GraphUtils::GetComputeGraph(graph); | |||
| GE_CHECK_NOTNULL(compute_graph); | |||
| auto root_graph = compute_graph->GetParentGraph(); | |||
| if (root_graph != nullptr) { | |||
| GELOGE(GRAPH_PARAM_INVALID, "Input param should not be subgraph"); | |||
| return GRAPH_PARAM_INVALID; | |||
| } | |||
| auto ret = compute_graph->InferOriginFormat(); | |||
| if (ret != GRAPH_SUCCESS) { | |||
| GELOGE(ret, "Acl InferOriginFormat failed."); | |||