diff --git a/ge/hybrid/node_executor/node_executor.cc b/ge/hybrid/node_executor/node_executor.cc index 04225557..5f3d6e45 100755 --- a/ge/hybrid/node_executor/node_executor.cc +++ b/ge/hybrid/node_executor/node_executor.cc @@ -58,8 +58,8 @@ Status NodeExecutor::CompileTask(const HybridModel &model, const NodePtr &node, } Status NodeExecutorManager::EnsureInitialized() { + GE_CHK_STATUS_RET(InitializeExecutors()); std::lock_guard lk(mu_); - ++ref_count_; if (initialized_) { return SUCCESS; } @@ -115,14 +115,17 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node return it->second; } -Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executor) { +Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executor) const { auto executor_type = ResolveExecutorType(node); - GELOGD("[%s] Set node executor by type: %d.", node.GetName().c_str(), static_cast(executor_type)); const auto it = executors_.find(executor_type); if (it == executors_.end()) { - return GetOrCreateExecutor(executor_type, executor); + REPORT_INNER_ERROR("E19999", "Failed to get executor by type: %d.", static_cast(executor_type)); + GELOGE(INTERNAL_ERROR, "[Check][ExecutorType]Failed to get executor by type: %d.", + static_cast(executor_type)); + return INTERNAL_ERROR; } + GELOGD("[%s] Set node executor by type: %d.", node.GetName().c_str(), static_cast(executor_type)); *executor = it->second.get(); return SUCCESS; } @@ -175,50 +178,51 @@ Status NodeExecutorManager::CalcOpRunningParam(Node &node) const { return OpsKernelBuilderManager::Instance().CalcOpRunningParam(node); } -Status NodeExecutorManager::GetOrCreateExecutor(ExecutorType executor_type, const NodeExecutor **out_executor) { +Status NodeExecutorManager::InitializeExecutors() { std::lock_guard lk(mu_); - const auto executor_it = executors_.find(executor_type); - if (executor_it != executors_.end()) { - *out_executor = executor_it->second.get(); + if (executor_initialized_) { + ++ref_count_; + GELOGI("Executor is already initialized. add ref count to [%d]", ref_count_); return SUCCESS; } - GELOGI("Start to Initialize NodeExecutor, type = %d", static_cast(executor_type)); - auto it = builders_.find(executor_type); - if (it == builders_.end()) { - REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for executor type = %d", - static_cast(executor_type)); - GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for executor type = %d", static_cast(executor_type)); - return INTERNAL_ERROR; - } + GELOGI("Start to Initialize NodeExecutors"); + for (auto &it : builders_) { + auto engine_type = it.first; + auto build_fn = it.second; + GE_CHECK_NOTNULL(build_fn); + auto executor = std::unique_ptr(build_fn()); + if (executor == nullptr) { + REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for engine type = %d", + static_cast(engine_type)); + GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for engine type = %d", static_cast(engine_type)); + return INTERNAL_ERROR; + } - auto build_fn = it->second; - GE_CHECK_NOTNULL(build_fn); - auto executor = std::unique_ptr(build_fn()); - if (executor == nullptr) { - REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for executor type = %d", - static_cast(executor_type)); - GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for engine type = %d", static_cast(executor_type)); - return INTERNAL_ERROR; - } + GELOGD("Executor of engine type = %d was created successfully", static_cast(engine_type)); + auto ret = executor->Initialize(); + if (ret != SUCCESS) { + REPORT_CALL_ERROR("E19999", "Initialize NodeExecutor failed for type = %d", static_cast(engine_type)); + GELOGE(ret, "[Initialize][NodeExecutor] failed for type = %d", static_cast(engine_type)); + for (auto &executor_it : executors_) { + executor_it.second->Finalize(); + } + executors_.clear(); + return ret; + } - GELOGD("Executor of engine type = %d was created successfully", static_cast(executor_type)); - auto ret = executor->Initialize(); - if (ret != SUCCESS) { - REPORT_CALL_ERROR("E19999", "Initialize NodeExecutor failed for type = %d", static_cast(executor_type)); - GELOGE(ret, "[Initialize][NodeExecutor] failed for type = %d", static_cast(executor_type)); - return ret; + executors_.emplace(engine_type, std::move(executor)); } - *out_executor = executor.get(); - executors_.emplace(executor_type, std::move(executor)); - GELOGI("Initializing NodeExecutor successfully, type = %d", static_cast(executor_type)); + ++ref_count_; + executor_initialized_ = true; + GELOGI("Initializing NodeExecutors successfully."); return SUCCESS; } void NodeExecutorManager::FinalizeExecutors() { std::lock_guard lk(mu_); - if (ref_count_ <= 0) { + if (!executor_initialized_) { GELOGD("No need for finalizing for not initialized."); return; } @@ -233,6 +237,7 @@ void NodeExecutorManager::FinalizeExecutors() { it.second->Finalize(); } executors_.clear(); + executor_initialized_ = false; GELOGD("Done invoking Finalize successfully."); } diff --git a/ge/hybrid/node_executor/node_executor.h b/ge/hybrid/node_executor/node_executor.h index 97c9cee9..fffd4e7d 100644 --- a/ge/hybrid/node_executor/node_executor.h +++ b/ge/hybrid/node_executor/node_executor.h @@ -179,6 +179,8 @@ class NodeExecutorManager { */ Status EnsureInitialized(); + Status InitializeExecutors(); + void FinalizeExecutors(); /** @@ -194,7 +196,7 @@ class NodeExecutorManager { * @param executor executor * @return SUCCESS on success, error code otherwise */ - Status GetExecutor(Node &node, const NodeExecutor **executor); + Status GetExecutor(Node &node, const NodeExecutor **executor) const; /** * Resolve executor type by node @@ -204,13 +206,12 @@ class NodeExecutorManager { ExecutorType ResolveExecutorType(Node &node) const; private: - Status GetOrCreateExecutor(ExecutorType executor_type, const NodeExecutor **executor); - std::map> executors_; std::map> builders_; std::map engine_mapping_; std::mutex mu_; bool initialized_ = false; + bool executor_initialized_ = false; int ref_count_ = 0; }; diff --git a/tests/ut/ge/CMakeLists.txt b/tests/ut/ge/CMakeLists.txt index 631e18f8..8b024820 100755 --- a/tests/ut/ge/CMakeLists.txt +++ b/tests/ut/ge/CMakeLists.txt @@ -839,7 +839,6 @@ set(HYBRID_TEST_FILES "hybrid/executor/subgraph_executor_unittest.cc" "hybrid/executor/worker/execution_engine_unittest.cc" "hybrid/model/hybrid_model_builder_unittest.cc" - "hybrid/node_executor/node_executor_unittest.cc" "hybrid/node_executor/rts/rts_node_task_unittest.cc" "hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc" "hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" @@ -847,7 +846,6 @@ set(HYBRID_TEST_FILES "hybrid/executor/hybrid_model_async_executor_unittest.cc" "hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" "hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" - ) set(OTHERS_TEST_FILES diff --git a/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc b/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc deleted file mode 100644 index 8a1240d3..00000000 --- a/tests/ut/ge/hybrid/node_executor/node_executor_unittest.cc +++ /dev/null @@ -1,103 +0,0 @@ -/** - * Copyright 2021 Huawei Technologies Co., Ltd - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -#include -#include -#include - -#define private public -#define protected public -#include "hybrid/node_executor/node_executor.h" -#undef protected -#undef private - -using namespace std; -using namespace testing; - -namespace ge { -using namespace hybrid; - -namespace { - bool finalized = false; -} - -class NodeExecutorTest : public testing::Test { - protected: - void SetUp() {} - void TearDown() { } -}; - -class FailureNodeExecutor : public NodeExecutor { - public: - Status Initialize() override { - return INTERNAL_ERROR; - } -}; - -class SuccessNodeExecutor : public NodeExecutor { - public: - Status Initialize() override { - initialized = true; - finalized = false; - return SUCCESS; - } - - Status Finalize() override { - finalized = true; - } - - bool initialized = false; -}; - -REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICORE, FailureNodeExecutor); -REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICPU_TF, SuccessNodeExecutor); - -TEST_F(NodeExecutorTest, TestGetOrCreateExecutor) { - auto &manager = NodeExecutorManager::GetInstance(); - const NodeExecutor *executor = nullptr; - Status ret = SUCCESS; - // no builder - ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::RESERVED, &executor); - ASSERT_EQ(ret, INTERNAL_ERROR); - // initialize failure - ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICORE, &executor); - ASSERT_EQ(ret, INTERNAL_ERROR); - ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICPU_TF, &executor); - ASSERT_EQ(ret, SUCCESS); - ASSERT_TRUE(executor != nullptr); - ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICPU_TF, &executor); - ASSERT_EQ(ret, SUCCESS); - ASSERT_TRUE(executor != nullptr); - ASSERT_TRUE(((SuccessNodeExecutor*)executor)->initialized); -} - -TEST_F(NodeExecutorTest, TestInitAndFinalize) { - auto &manager = NodeExecutorManager::GetInstance(); - manager.FinalizeExecutors(); - manager.EnsureInitialized(); - manager.EnsureInitialized(); - const NodeExecutor *executor = nullptr; - auto ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICPU_TF, &executor); - ASSERT_EQ(ret, SUCCESS); - ASSERT_TRUE(executor != nullptr); - ASSERT_TRUE(((SuccessNodeExecutor*)executor)->initialized); - manager.FinalizeExecutors(); - ASSERT_FALSE(manager.executors_.empty()); - manager.FinalizeExecutors(); - ASSERT_TRUE(manager.executors_.empty()); - ASSERT_TRUE(finalized); -} -} // namespace ge