From: @xchu42 Reviewed-by: @ji_chen Signed-off-by: @ji_chentags/v1.5.1
@@ -1227,6 +1227,28 @@ Status HybridModelBuilder::LoadGeModel(ComputeGraph &sub_graph, const GeModelPtr | |||
hybrid_model_.known_shape_sub_models_.emplace(parent_node, ge_model); | |||
} | |||
GE_CHK_STATUS_RET_NOLOG(InitHcclExecutorOnDemand(ge_model)); | |||
return SUCCESS; | |||
} | |||
Status HybridModelBuilder::InitHcclExecutorOnDemand(const GeModelPtr &ge_model) { | |||
if (NodeExecutorManager::GetInstance().IsExecutorInitialized(NodeExecutorManager::ExecutorType::HCCL)) { | |||
return SUCCESS; | |||
} | |||
// HCCL tasks in known-shaped subgraph which resides in a dynamic root graph | |||
// still depends on the initialization of the HcclExecutor | |||
auto tasks = ge_model->GetModelTaskDefPtr()->task(); | |||
for (int i = 0; i < tasks.size(); ++i) { | |||
const domi::TaskDef &task_def = tasks[i]; | |||
auto task_type = static_cast<rtModelTaskType_t>(task_def.type()); | |||
if (task_type == RT_MODEL_TASK_HCCL) { | |||
const NodeExecutor *unused = nullptr; | |||
GE_CHK_STATUS_RET_NOLOG(NodeExecutorManager::GetInstance() | |||
.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::HCCL, &unused)); | |||
return SUCCESS; | |||
} | |||
} | |||
return SUCCESS; | |||
} | |||
@@ -57,6 +57,7 @@ class HybridModelBuilder { | |||
Status ValidateParams(); | |||
Status LoadGraph(); | |||
Status LoadGeModel(ComputeGraph &graph, const GeModelPtr &ge_model); | |||
static Status InitHcclExecutorOnDemand(const GeModelPtr &ge_model); | |||
Status LoadTask(NodeItem &node_item); | |||
Status LoadTasks(); | |||
Status IdentifyVariableOutputs(NodeItem &node_item, const ComputeGraphPtr &subgraph); | |||
@@ -58,8 +58,8 @@ Status NodeExecutor::CompileTask(const HybridModel &model, const NodePtr &node, | |||
} | |||
Status NodeExecutorManager::EnsureInitialized() { | |||
GE_CHK_STATUS_RET(InitializeExecutors()); | |||
std::lock_guard<std::mutex> lk(mu_); | |||
++ref_count_; | |||
if (initialized_) { | |||
return SUCCESS; | |||
} | |||
@@ -115,17 +115,14 @@ NodeExecutorManager::ExecutorType NodeExecutorManager::ResolveExecutorType(Node | |||
return it->second; | |||
} | |||
Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executor) const { | |||
Status NodeExecutorManager::GetExecutor(Node &node, const NodeExecutor **executor) { | |||
auto executor_type = ResolveExecutorType(node); | |||
GELOGD("[%s] Set node executor by type: %d.", node.GetName().c_str(), static_cast<int>(executor_type)); | |||
const auto it = executors_.find(executor_type); | |||
if (it == executors_.end()) { | |||
REPORT_INNER_ERROR("E19999", "Failed to get executor by type: %d.", static_cast<int>(executor_type)); | |||
GELOGE(INTERNAL_ERROR, "[Check][ExecutorType]Failed to get executor by type: %d.", | |||
static_cast<int>(executor_type)); | |||
return INTERNAL_ERROR; | |||
return GetOrCreateExecutor(executor_type, executor); | |||
} | |||
GELOGD("[%s] Set node executor by type: %d.", node.GetName().c_str(), static_cast<int>(executor_type)); | |||
*executor = it->second.get(); | |||
return SUCCESS; | |||
} | |||
@@ -178,51 +175,55 @@ Status NodeExecutorManager::CalcOpRunningParam(Node &node) const { | |||
return OpsKernelBuilderManager::Instance().CalcOpRunningParam(node); | |||
} | |||
Status NodeExecutorManager::InitializeExecutors() { | |||
bool NodeExecutorManager::IsExecutorInitialized(NodeExecutorManager::ExecutorType executor_type) { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
return executors_.find(executor_type) != executors_.end(); | |||
} | |||
Status NodeExecutorManager::GetOrCreateExecutor(ExecutorType executor_type, const NodeExecutor **out_executor) { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
if (executor_initialized_) { | |||
++ref_count_; | |||
GELOGI("Executor is already initialized. add ref count to [%d]", ref_count_); | |||
const auto executor_it = executors_.find(executor_type); | |||
if (executor_it != executors_.end()) { | |||
*out_executor = executor_it->second.get(); | |||
return SUCCESS; | |||
} | |||
GELOGI("Start to Initialize NodeExecutors"); | |||
for (auto &it : builders_) { | |||
auto engine_type = it.first; | |||
auto build_fn = it.second; | |||
GE_CHECK_NOTNULL(build_fn); | |||
auto executor = std::unique_ptr<NodeExecutor>(build_fn()); | |||
if (executor == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for engine type = %d", | |||
static_cast<int>(engine_type)); | |||
GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for engine type = %d", static_cast<int>(engine_type)); | |||
return INTERNAL_ERROR; | |||
} | |||
GELOGI("Start to Initialize NodeExecutor, type = %d", static_cast<int>(executor_type)); | |||
auto it = builders_.find(executor_type); | |||
if (it == builders_.end()) { | |||
REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for executor type = %d", | |||
static_cast<int>(executor_type)); | |||
GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for executor type = %d", static_cast<int>(executor_type)); | |||
return INTERNAL_ERROR; | |||
} | |||
GELOGD("Executor of engine type = %d was created successfully", static_cast<int>(engine_type)); | |||
auto ret = executor->Initialize(); | |||
if (ret != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Initialize NodeExecutor failed for type = %d", static_cast<int>(engine_type)); | |||
GELOGE(ret, "[Initialize][NodeExecutor] failed for type = %d", static_cast<int>(engine_type)); | |||
for (auto &executor_it : executors_) { | |||
executor_it.second->Finalize(); | |||
} | |||
executors_.clear(); | |||
return ret; | |||
} | |||
auto build_fn = it->second; | |||
GE_CHECK_NOTNULL(build_fn); | |||
auto executor = std::unique_ptr<NodeExecutor>(build_fn()); | |||
if (executor == nullptr) { | |||
REPORT_CALL_ERROR("E19999", "Create NodeExecutor failed for executor type = %d", | |||
static_cast<int>(executor_type)); | |||
GELOGE(INTERNAL_ERROR, "[Create][NodeExecutor] failed for engine type = %d", static_cast<int>(executor_type)); | |||
return INTERNAL_ERROR; | |||
} | |||
executors_.emplace(engine_type, std::move(executor)); | |||
GELOGD("Executor of engine type = %d was created successfully", static_cast<int>(executor_type)); | |||
auto ret = executor->Initialize(); | |||
if (ret != SUCCESS) { | |||
REPORT_CALL_ERROR("E19999", "Initialize NodeExecutor failed for type = %d", static_cast<int>(executor_type)); | |||
GELOGE(ret, "[Initialize][NodeExecutor] failed for type = %d", static_cast<int>(executor_type)); | |||
return ret; | |||
} | |||
++ref_count_; | |||
executor_initialized_ = true; | |||
GELOGI("Initializing NodeExecutors successfully."); | |||
*out_executor = executor.get(); | |||
executors_.emplace(executor_type, std::move(executor)); | |||
GELOGI("Initializing NodeExecutor successfully, type = %d", static_cast<int>(executor_type)); | |||
return SUCCESS; | |||
} | |||
void NodeExecutorManager::FinalizeExecutors() { | |||
std::lock_guard<std::mutex> lk(mu_); | |||
if (!executor_initialized_) { | |||
if (ref_count_ <= 0) { | |||
GELOGD("No need for finalizing for not initialized."); | |||
return; | |||
} | |||
@@ -237,7 +238,6 @@ void NodeExecutorManager::FinalizeExecutors() { | |||
it.second->Finalize(); | |||
} | |||
executors_.clear(); | |||
executor_initialized_ = false; | |||
GELOGD("Done invoking Finalize successfully."); | |||
} | |||
@@ -179,8 +179,6 @@ class NodeExecutorManager { | |||
*/ | |||
Status EnsureInitialized(); | |||
Status InitializeExecutors(); | |||
void FinalizeExecutors(); | |||
/** | |||
@@ -196,7 +194,7 @@ class NodeExecutorManager { | |||
* @param executor executor | |||
* @return SUCCESS on success, error code otherwise | |||
*/ | |||
Status GetExecutor(Node &node, const NodeExecutor **executor) const; | |||
Status GetExecutor(Node &node, const NodeExecutor **executor); | |||
/** | |||
* Resolve executor type by node | |||
@@ -205,13 +203,16 @@ class NodeExecutorManager { | |||
*/ | |||
ExecutorType ResolveExecutorType(Node &node) const; | |||
Status GetOrCreateExecutor(ExecutorType executor_type, const NodeExecutor **executor); | |||
bool IsExecutorInitialized(ExecutorType executor_type); | |||
private: | |||
std::map<ExecutorType, std::unique_ptr<NodeExecutor>> executors_; | |||
std::map<ExecutorType, std::function<NodeExecutor *()>> builders_; | |||
std::map<std::string, NodeExecutorManager::ExecutorType> engine_mapping_; | |||
std::mutex mu_; | |||
bool initialized_ = false; | |||
bool executor_initialized_ = false; | |||
int ref_count_ = 0; | |||
}; | |||
@@ -851,6 +851,7 @@ set(HYBRID_TEST_FILES | |||
"hybrid/executor/subgraph_executor_unittest.cc" | |||
"hybrid/executor/worker/execution_engine_unittest.cc" | |||
"hybrid/model/hybrid_model_builder_unittest.cc" | |||
"hybrid/node_executor/node_executor_unittest.cc" | |||
"hybrid/node_executor/rts/rts_node_task_unittest.cc" | |||
"hybrid/node_executor/host_cpu/host_cpu_node_task_unittest.cc" | |||
"hybrid/node_executor/ge_local/ge_local_node_executor_unittest.cc" | |||
@@ -858,6 +859,7 @@ set(HYBRID_TEST_FILES | |||
"hybrid/executor/hybrid_model_async_executor_unittest.cc" | |||
"hybrid/executor/hybrid_model_pipeline_executor_unittest.cc" | |||
"hybrid/node_executor/aicore/aicore_task_compiler_unittest.cc" | |||
) | |||
set(OTHERS_TEST_FILES | |||
@@ -346,4 +346,31 @@ EXPECT_EQ(hybrid_model_builder.InitVariableTensors(), SUCCESS); | |||
EXPECT_EQ(hybrid_model_builder.hybrid_model_.variable_tensors_.size(), 1); | |||
HostMemManager::Instance().var_memory_base_map_.clear(); | |||
} | |||
TEST_F(UtestHybridModelBuilder, TestInitHcclExecutorOnDemand) { | |||
NodeExecutorManager::GetInstance().builders_.erase(NodeExecutorManager::ExecutorType::HCCL); | |||
// build aicore task | |||
domi::ModelTaskDef model_task_def; | |||
std::shared_ptr<domi::ModelTaskDef> model_task_def_ptr = make_shared<domi::ModelTaskDef>(model_task_def); | |||
GeModelPtr ge_model = make_shared<GeModel>(); | |||
ge_model->SetModelTaskDef(model_task_def_ptr); | |||
// No hccl task | |||
domi::TaskDef *task_def = model_task_def_ptr->add_task(); | |||
task_def->set_type(RT_MODEL_TASK_MEMCPY_ASYNC); | |||
ASSERT_EQ(HybridModelBuilder::InitHcclExecutorOnDemand(ge_model), SUCCESS); | |||
// get executor failed due to no builder | |||
task_def = model_task_def_ptr->add_task(); | |||
task_def->set_type(RT_MODEL_TASK_HCCL); | |||
ASSERT_EQ(HybridModelBuilder::InitHcclExecutorOnDemand(ge_model), INTERNAL_ERROR); | |||
// get executor success | |||
REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::HCCL, NodeExecutor); | |||
ASSERT_EQ(HybridModelBuilder::InitHcclExecutorOnDemand(ge_model), SUCCESS); | |||
// repeat get, do not access builder | |||
NodeExecutorManager::GetInstance().builders_.erase(NodeExecutorManager::ExecutorType::HCCL); | |||
ASSERT_EQ(HybridModelBuilder::InitHcclExecutorOnDemand(ge_model), SUCCESS); | |||
} | |||
} // namespace ge |
@@ -0,0 +1,103 @@ | |||
/** | |||
* Copyright 2021 Huawei Technologies Co., Ltd | |||
* | |||
* Licensed under the Apache License, Version 2.0 (the "License"); | |||
* you may not use this file except in compliance with the License. | |||
* You may obtain a copy of the License at | |||
* | |||
* http://www.apache.org/licenses/LICENSE-2.0 | |||
* | |||
* Unless required by applicable law or agreed to in writing, software | |||
* distributed under the License is distributed on an "AS IS" BASIS, | |||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |||
* See the License for the specific language governing permissions and | |||
* limitations under the License. | |||
*/ | |||
#include <gtest/gtest.h> | |||
#include <gmock/gmock.h> | |||
#include <vector> | |||
#define private public | |||
#define protected public | |||
#include "hybrid/node_executor/node_executor.h" | |||
#undef protected | |||
#undef private | |||
using namespace std; | |||
using namespace testing; | |||
namespace ge { | |||
using namespace hybrid; | |||
namespace { | |||
bool finalized = false; | |||
} | |||
class NodeExecutorTest : public testing::Test { | |||
protected: | |||
void SetUp() {} | |||
void TearDown() { } | |||
}; | |||
class FailureNodeExecutor : public NodeExecutor { | |||
public: | |||
Status Initialize() override { | |||
return INTERNAL_ERROR; | |||
} | |||
}; | |||
class SuccessNodeExecutor : public NodeExecutor { | |||
public: | |||
Status Initialize() override { | |||
initialized = true; | |||
finalized = false; | |||
return SUCCESS; | |||
} | |||
Status Finalize() override { | |||
finalized = true; | |||
} | |||
bool initialized = false; | |||
}; | |||
REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICORE, FailureNodeExecutor); | |||
REGISTER_NODE_EXECUTOR_BUILDER(NodeExecutorManager::ExecutorType::AICPU_TF, SuccessNodeExecutor); | |||
TEST_F(NodeExecutorTest, TestGetOrCreateExecutor) { | |||
auto &manager = NodeExecutorManager::GetInstance(); | |||
const NodeExecutor *executor = nullptr; | |||
Status ret = SUCCESS; | |||
// no builder | |||
ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::RESERVED, &executor); | |||
ASSERT_EQ(ret, INTERNAL_ERROR); | |||
// initialize failure | |||
ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICORE, &executor); | |||
ASSERT_EQ(ret, INTERNAL_ERROR); | |||
ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICPU_TF, &executor); | |||
ASSERT_EQ(ret, SUCCESS); | |||
ASSERT_TRUE(executor != nullptr); | |||
ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICPU_TF, &executor); | |||
ASSERT_EQ(ret, SUCCESS); | |||
ASSERT_TRUE(executor != nullptr); | |||
ASSERT_TRUE(((SuccessNodeExecutor*)executor)->initialized); | |||
} | |||
TEST_F(NodeExecutorTest, TestInitAndFinalize) { | |||
auto &manager = NodeExecutorManager::GetInstance(); | |||
manager.FinalizeExecutors(); | |||
manager.EnsureInitialized(); | |||
manager.EnsureInitialized(); | |||
const NodeExecutor *executor = nullptr; | |||
auto ret = manager.GetOrCreateExecutor(NodeExecutorManager::ExecutorType::AICPU_TF, &executor); | |||
ASSERT_EQ(ret, SUCCESS); | |||
ASSERT_TRUE(executor != nullptr); | |||
ASSERT_TRUE(((SuccessNodeExecutor*)executor)->initialized); | |||
manager.FinalizeExecutors(); | |||
ASSERT_FALSE(manager.executors_.empty()); | |||
manager.FinalizeExecutors(); | |||
ASSERT_TRUE(manager.executors_.empty()); | |||
ASSERT_TRUE(finalized); | |||
} | |||
} // namespace ge |