From 381c2b7e0c3c46ec1c1cb858c5f900c26f0ac95e Mon Sep 17 00:00:00 2001 From: wxl Date: Wed, 14 Oct 2020 10:51:50 +0800 Subject: [PATCH] Description: ir build inferface optimize.Shape format dtype modified to get from tensor_desc --- inc/graph/ge_local_context.h | 5 +++++ src/common/graph/option/ge_local_context.cc | 19 +++++++++++++++++++ src/ge/ir_build/ge_ir_build.cc | 17 +++++++++++------ 3 files changed, 35 insertions(+), 6 deletions(-) diff --git a/inc/graph/ge_local_context.h b/inc/graph/ge_local_context.h index b47098fb..58efe37b 100644 --- a/inc/graph/ge_local_context.h +++ b/inc/graph/ge_local_context.h @@ -33,6 +33,11 @@ class GEThreadLocalContext { void SetSessionOption(map options_map); void SetGlobalOption(map options_map); + map GetAllGraphOptions() const; + map GetAllSessionOptions() const; + map GetAllGlobalOptions() const; + map GetAllOptions() const; + private: map graph_options_; map session_options_; diff --git a/src/common/graph/option/ge_local_context.cc b/src/common/graph/option/ge_local_context.cc index 82b1cb01..b2a5da9c 100644 --- a/src/common/graph/option/ge_local_context.cc +++ b/src/common/graph/option/ge_local_context.cc @@ -57,4 +57,23 @@ void GEThreadLocalContext::SetGraphOption(map options_map) graph_options_.clear(); graph_options_ = std::move(options_map); } + +map GEThreadLocalContext::GetAllGraphOptions() const { + return graph_options_; +} + +map GEThreadLocalContext::GetAllSessionOptions() const { + return session_options_; +} + +map GEThreadLocalContext::GetAllGlobalOptions() const { + return global_options_; +} + +map GEThreadLocalContext::GetAllOptions() const { + map options_all; + options_all.insert(global_options_.begin(), global_options_.end()); + options_all.insert(session_options_.begin(), session_options_.end()); + options_all.insert(graph_options_.begin(), graph_options_.end()); +} } // namespace ge diff --git a/src/ge/ir_build/ge_ir_build.cc b/src/ge/ir_build/ge_ir_build.cc index 86b304c1..055cce7e 100644 --- a/src/ge/ir_build/ge_ir_build.cc +++ b/src/ge/ir_build/ge_ir_build.cc @@ -167,7 +167,7 @@ class Impl { graphStatus InitDomiOmgContext(const string &input_shape, const string &input_format, const string &net_format, bool is_dynamic_input); void SetRtSocVersion(); - + void UpdateThreadContext(); public: ge::GeGenerator generator_; std::map options_; @@ -220,8 +220,6 @@ graphStatus Impl::Init(const std::map &options) { return ret; } - GetThreadLocalContext().SetGlobalOption(GetMutableGlobalOptions()); - GetThreadLocalContext().SetGraphOption(options_); std::string build_mode = (options_.find(BUILD_MODE) == options_.end() || options_[BUILD_MODE] == BUILD_MODE_NORMAL) ? "" : options_[BUILD_MODE]; @@ -276,7 +274,7 @@ graphStatus Impl::Init(const std::map &options) { ge::PrintOptionMap(options_, "ge option"); SetRtSocVersion(); - + UpdateThreadContext(); // 3. init generator with options_ ret = generator_.Initialize(options_, omg_context_); if (ret != GRAPH_SUCCESS) { @@ -300,6 +298,11 @@ void Impl::SetRtSocVersion() { } } +void Impl::UpdateThreadContext() { + GetThreadLocalContext().SetGlobalOption(GetMutableGlobalOptions()); + GetThreadLocalContext().SetGraphOption(options_); +} + graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector &inputs) { auto compute_graph = ge::GraphUtils::GetComputeGraph(graph); GE_CHECK_NOTNULL(compute_graph); @@ -323,13 +326,15 @@ graphStatus Impl::CreateInputsForIRBuild(const ge::Graph &graph, vector