diff --git a/ge/hybrid/executor/hybrid_model_executor.cc b/ge/hybrid/executor/hybrid_model_executor.cc index 2bb683c7..dd8aace6 100755 --- a/ge/hybrid/executor/hybrid_model_executor.cc +++ b/ge/hybrid/executor/hybrid_model_executor.cc @@ -35,10 +35,11 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id, HybridModelExecutor::~HybridModelExecutor() { } -Status HybridModelExecutor::Init() { +Status HybridModelExecutor::Init(ThreadPool *thread_pool) { GELOGD("Start to init HybridGraphEngine."); GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); - root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_)); + root_graph_executor_.reset( + new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_, false, thread_pool)); GE_CHECK_NOTNULL(root_graph_executor_); GELOGD("HybridGraphEngine initialized successfully."); return SUCCESS; diff --git a/ge/hybrid/executor/hybrid_model_executor.h b/ge/hybrid/executor/hybrid_model_executor.h index 102e4f8b..dbec7adf 100644 --- a/ge/hybrid/executor/hybrid_model_executor.h +++ b/ge/hybrid/executor/hybrid_model_executor.h @@ -39,7 +39,7 @@ class HybridModelExecutor { ~HybridModelExecutor(); - Status Init(); + Status Init(ThreadPool *thread_pool = nullptr); const GraphExecutionContext* GetContext() const { return &context_; diff --git a/ge/hybrid/executor/subgraph_executor.cc b/ge/hybrid/executor/subgraph_executor.cc index 33a2846c..7fcdec5d 100644 --- a/ge/hybrid/executor/subgraph_executor.cc +++ b/ge/hybrid/executor/subgraph_executor.cc @@ -28,20 +28,30 @@ constexpr int kDefaultQueueSize = 16; constexpr int kDataInputIndex = 0; } -SubgraphExecutor::SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape) +SubgraphExecutor::SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape, + ThreadPool *pre_run_pool) : graph_item_(graph_item), context_(context), force_infer_shape_(force_infer_shape), - pre_run_pool_(kDefaultThreadNum), + pre_run_pool_(pre_run_pool), + own_thread_pool_(false), ready_queue_(kDefaultQueueSize) { } SubgraphExecutor::~SubgraphExecutor() { + if (own_thread_pool_ && pre_run_pool_ != nullptr) { + delete pre_run_pool_; + } GELOGD("[%s] SubgraphExecutor destroyed.", graph_item_->GetName().c_str()); } Status SubgraphExecutor::Init(const std::vector &inputs, const std::vector &input_desc) { + if (pre_run_pool_ == nullptr) { + pre_run_pool_ = new (std::nothrow) ThreadPool(kDefaultThreadNum); + GE_CHECK_NOTNULL(pre_run_pool_); + own_thread_pool_ = true; + } subgraph_context_.reset(new(std::nothrow)SubgraphContext(graph_item_, context_)); GE_CHECK_NOTNULL(subgraph_context_); GE_CHK_STATUS_RET(subgraph_context_->Init(), @@ -254,7 +264,8 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) { // only do shape inference and compilation for nodes with dynamic shapes. if (node_item.is_dynamic) { - auto prepare_future = pre_run_pool_.commit([this, p_node_state]() -> Status { + GE_CHECK_NOTNULL(pre_run_pool_); + auto prepare_future = pre_run_pool_->commit([this, p_node_state]() -> Status { GetContext().SetSessionId(context_->session_id); GetContext().SetContextId(context_->context_id); GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state)); @@ -349,7 +360,8 @@ Status SubgraphExecutor::NodeScheduled(NodeState *node_state) { node_state->GetNodeItem()->data_send_.size(), node_state->GetNodeItem()->ctrl_send_.size(), node_state->GetSwitchIndex(), node_state->GetMergeIndex()); - auto future = pre_run_pool_.commit([this, node_state]() -> Status { + GE_CHECK_NOTNULL(pre_run_pool_); + auto future = pre_run_pool_->commit([this, node_state]() -> Status { RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] Start"); std::function callback = [&](const NodeItem *node_item) { const auto &node_name = node_item->node_name; diff --git a/ge/hybrid/executor/subgraph_executor.h b/ge/hybrid/executor/subgraph_executor.h index 76732c37..be11ff59 100644 --- a/ge/hybrid/executor/subgraph_executor.h +++ b/ge/hybrid/executor/subgraph_executor.h @@ -33,7 +33,8 @@ namespace hybrid { // Executor for executing a subgraph class SubgraphExecutor { public: - SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape = false); + SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape = false, + ThreadPool *pre_run_pool = nullptr); ~SubgraphExecutor(); Status InitForPartialExecution(const std::vector &inputs, @@ -124,7 +125,8 @@ class SubgraphExecutor { GraphExecutionContext *context_; std::unique_ptr subgraph_context_; bool force_infer_shape_; - ThreadPool pre_run_pool_; + ThreadPool *pre_run_pool_; + bool own_thread_pool_; BlockingQueue ready_queue_; std::unique_ptr shape_inference_engine_; diff --git a/ge/single_op/single_op_model.cc b/ge/single_op/single_op_model.cc index 426d3233..ca07d2ae 100755 --- a/ge/single_op/single_op_model.cc +++ b/ge/single_op/single_op_model.cc @@ -713,7 +713,9 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp & device_id, resource.GetStream())); GE_CHECK_NOTNULL(single_op.hybrid_model_executor_); - GE_CHK_STATUS_RET(single_op.hybrid_model_executor_->Init(), "[Init][HybridModelExecutor]Failed."); + ThreadPool *thread_pool = nullptr; + GE_CHK_STATUS_RET_NOLOG(resource.GetThreadPool(&thread_pool)); + GE_CHK_STATUS_RET(single_op.hybrid_model_executor_->Init(thread_pool), "[Init][HybridModelExecutor]Failed."); return SUCCESS; } return BuildTaskListForDynamicOp(&resource, single_op); diff --git a/ge/single_op/stream_resource.cc b/ge/single_op/stream_resource.cc index 9fe8f26a..10a8f72b 100755 --- a/ge/single_op/stream_resource.cc +++ b/ge/single_op/stream_resource.cc @@ -25,6 +25,7 @@ namespace ge { namespace { // limit available device mem size 1M const uint32_t kFuzzDeviceBufferSize = 1 * 1024 * 1024; +constexpr int kDefaultThreadNum = 4; } StreamResource::StreamResource(uintptr_t resource_id) : resource_id_(resource_id) { @@ -219,6 +220,16 @@ Status StreamResource::BuildOperator(const ModelData &model_data, SingleOp **sin return SUCCESS; } +Status StreamResource::GetThreadPool(ThreadPool **thread_pool) { + GE_CHECK_NOTNULL(thread_pool); + if (thread_pool_ == nullptr) { + thread_pool_.reset(new (std::nothrow) ThreadPool(kDefaultThreadNum)); + GE_CHECK_NOTNULL(thread_pool_); + } + *thread_pool = thread_pool_.get(); + return SUCCESS; +} + const uint8_t *StreamResource::GetMemoryBase() const { if (memory_list_.empty()) { return nullptr; diff --git a/ge/single_op/stream_resource.h b/ge/single_op/stream_resource.h index 8986634b..f1e1bebb 100755 --- a/ge/single_op/stream_resource.h +++ b/ge/single_op/stream_resource.h @@ -54,6 +54,8 @@ class StreamResource { return device_buffer_; } + Status GetThreadPool(ThreadPool **thread_pool); + private: uint8_t *DoMallocMemory(const std::string &purpose, size_t size, @@ -66,6 +68,7 @@ class StreamResource { std::vector weight_list_; std::unordered_map> op_map_; std::unordered_map> dynamic_op_map_; + std::unique_ptr thread_pool_; rtStream_t stream_ = nullptr; std::mutex mu_; std::mutex stream_mu_; diff --git a/metadef b/metadef index 3ace5b6f..f9a47a45 160000 --- a/metadef +++ b/metadef @@ -1 +1 @@ -Subproject commit 3ace5b6f10e0af784a1c3211fd769d6e8860e864 +Subproject commit f9a47a45cdd7e6dc507a15291fcb769f96b859b3 diff --git a/parser b/parser index db68a1a4..b42a99ea 160000 --- a/parser +++ b/parser @@ -1 +1 @@ -Subproject commit db68a1a4f1a6ae69dbf9a5f338392d50ea3874e3 +Subproject commit b42a99ea6e1be75156650675fd0aeabca6cb3de9 diff --git a/tests/ut/ge/single_op/stream_resource_unittest.cc b/tests/ut/ge/single_op/stream_resource_unittest.cc index e07fc39d..e4ab469e 100644 --- a/tests/ut/ge/single_op/stream_resource_unittest.cc +++ b/tests/ut/ge/single_op/stream_resource_unittest.cc @@ -66,6 +66,9 @@ TEST_F(UtestStreamResource, test_build_op) { res.op_map_[0].reset(single_op); res.dynamic_op_map_[1].reset(dynamic_single_op); + ThreadPool *thread_pool = nullptr; + EXPECT_EQ(res.GetThreadPool(&thread_pool), SUCCESS); + EXPECT_EQ(res.GetOperator(0), nullptr); EXPECT_EQ(res.GetDynamicOperator(1), nullptr); EXPECT_EQ(res.BuildOperator(model_data, &single_op, 0), SUCCESS);