Merge pull request !1958 from 赵之轩/my_devtags/v1.5.1
@@ -35,10 +35,11 @@ HybridModelExecutor::HybridModelExecutor(HybridModel *model, uint32_t device_id, | |||||
HybridModelExecutor::~HybridModelExecutor() { | HybridModelExecutor::~HybridModelExecutor() { | ||||
} | } | ||||
Status HybridModelExecutor::Init() { | |||||
Status HybridModelExecutor::Init(ThreadPool *thread_pool) { | |||||
GELOGD("Start to init HybridGraphEngine."); | GELOGD("Start to init HybridGraphEngine."); | ||||
GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | GE_CHK_STATUS_RET_NOLOG(InitExecutionContext()); | ||||
root_graph_executor_.reset(new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_)); | |||||
root_graph_executor_.reset( | |||||
new (std::nothrow) SubgraphExecutor(model_->GetRootGraphItem(), &context_, false, thread_pool)); | |||||
GE_CHECK_NOTNULL(root_graph_executor_); | GE_CHECK_NOTNULL(root_graph_executor_); | ||||
GELOGD("HybridGraphEngine initialized successfully."); | GELOGD("HybridGraphEngine initialized successfully."); | ||||
return SUCCESS; | return SUCCESS; | ||||
@@ -39,7 +39,7 @@ class HybridModelExecutor { | |||||
~HybridModelExecutor(); | ~HybridModelExecutor(); | ||||
Status Init(); | |||||
Status Init(ThreadPool *thread_pool = nullptr); | |||||
const GraphExecutionContext* GetContext() const { | const GraphExecutionContext* GetContext() const { | ||||
return &context_; | return &context_; | ||||
@@ -28,20 +28,30 @@ constexpr int kDefaultQueueSize = 16; | |||||
constexpr int kDataInputIndex = 0; | constexpr int kDataInputIndex = 0; | ||||
} | } | ||||
SubgraphExecutor::SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape) | |||||
SubgraphExecutor::SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape, | |||||
ThreadPool *pre_run_pool) | |||||
: graph_item_(graph_item), | : graph_item_(graph_item), | ||||
context_(context), | context_(context), | ||||
force_infer_shape_(force_infer_shape), | force_infer_shape_(force_infer_shape), | ||||
pre_run_pool_(kDefaultThreadNum), | |||||
pre_run_pool_(pre_run_pool), | |||||
own_thread_pool_(false), | |||||
ready_queue_(kDefaultQueueSize) { | ready_queue_(kDefaultQueueSize) { | ||||
} | } | ||||
SubgraphExecutor::~SubgraphExecutor() { | SubgraphExecutor::~SubgraphExecutor() { | ||||
if (own_thread_pool_ && pre_run_pool_ != nullptr) { | |||||
delete pre_run_pool_; | |||||
} | |||||
GELOGD("[%s] SubgraphExecutor destroyed.", graph_item_->GetName().c_str()); | GELOGD("[%s] SubgraphExecutor destroyed.", graph_item_->GetName().c_str()); | ||||
} | } | ||||
Status SubgraphExecutor::Init(const std::vector<TensorValue> &inputs, | Status SubgraphExecutor::Init(const std::vector<TensorValue> &inputs, | ||||
const std::vector<ConstGeTensorDescPtr> &input_desc) { | const std::vector<ConstGeTensorDescPtr> &input_desc) { | ||||
if (pre_run_pool_ == nullptr) { | |||||
pre_run_pool_ = new (std::nothrow) ThreadPool(kDefaultThreadNum); | |||||
GE_CHECK_NOTNULL(pre_run_pool_); | |||||
own_thread_pool_ = true; | |||||
} | |||||
subgraph_context_.reset(new(std::nothrow)SubgraphContext(graph_item_, context_)); | subgraph_context_.reset(new(std::nothrow)SubgraphContext(graph_item_, context_)); | ||||
GE_CHECK_NOTNULL(subgraph_context_); | GE_CHECK_NOTNULL(subgraph_context_); | ||||
GE_CHK_STATUS_RET(subgraph_context_->Init(), | GE_CHK_STATUS_RET(subgraph_context_->Init(), | ||||
@@ -254,7 +264,8 @@ Status SubgraphExecutor::PrepareNode(const NodeItem &node_item, int group) { | |||||
// only do shape inference and compilation for nodes with dynamic shapes. | // only do shape inference and compilation for nodes with dynamic shapes. | ||||
if (node_item.is_dynamic) { | if (node_item.is_dynamic) { | ||||
auto prepare_future = pre_run_pool_.commit([this, p_node_state]() -> Status { | |||||
GE_CHECK_NOTNULL(pre_run_pool_); | |||||
auto prepare_future = pre_run_pool_->commit([this, p_node_state]() -> Status { | |||||
GetContext().SetSessionId(context_->session_id); | GetContext().SetSessionId(context_->session_id); | ||||
GetContext().SetContextId(context_->context_id); | GetContext().SetContextId(context_->context_id); | ||||
GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state)); | GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state)); | ||||
@@ -349,7 +360,8 @@ Status SubgraphExecutor::NodeScheduled(NodeState *node_state) { | |||||
node_state->GetNodeItem()->data_send_.size(), node_state->GetNodeItem()->ctrl_send_.size(), | node_state->GetNodeItem()->data_send_.size(), node_state->GetNodeItem()->ctrl_send_.size(), | ||||
node_state->GetSwitchIndex(), node_state->GetMergeIndex()); | node_state->GetSwitchIndex(), node_state->GetMergeIndex()); | ||||
auto future = pre_run_pool_.commit([this, node_state]() -> Status { | |||||
GE_CHECK_NOTNULL(pre_run_pool_); | |||||
auto future = pre_run_pool_->commit([this, node_state]() -> Status { | |||||
RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] Start"); | RECORD_CALLBACK_EVENT(context_, node_state->GetName().c_str(), "[NodeScheduled] Start"); | ||||
std::function<void(const NodeItem *)> callback = [&](const NodeItem *node_item) { | std::function<void(const NodeItem *)> callback = [&](const NodeItem *node_item) { | ||||
const auto &node_name = node_item->node_name; | const auto &node_name = node_item->node_name; | ||||
@@ -33,7 +33,8 @@ namespace hybrid { | |||||
// Executor for executing a subgraph | // Executor for executing a subgraph | ||||
class SubgraphExecutor { | class SubgraphExecutor { | ||||
public: | public: | ||||
SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape = false); | |||||
SubgraphExecutor(const GraphItem *graph_item, GraphExecutionContext *context, bool force_infer_shape = false, | |||||
ThreadPool *pre_run_pool = nullptr); | |||||
~SubgraphExecutor(); | ~SubgraphExecutor(); | ||||
Status InitForPartialExecution(const std::vector<TensorValue> &inputs, | Status InitForPartialExecution(const std::vector<TensorValue> &inputs, | ||||
@@ -124,7 +125,8 @@ class SubgraphExecutor { | |||||
GraphExecutionContext *context_; | GraphExecutionContext *context_; | ||||
std::unique_ptr<SubgraphContext> subgraph_context_; | std::unique_ptr<SubgraphContext> subgraph_context_; | ||||
bool force_infer_shape_; | bool force_infer_shape_; | ||||
ThreadPool pre_run_pool_; | |||||
ThreadPool *pre_run_pool_; | |||||
bool own_thread_pool_; | |||||
BlockingQueue<NodeState *> ready_queue_; | BlockingQueue<NodeState *> ready_queue_; | ||||
std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_; | std::unique_ptr<ShapeInferenceEngine> shape_inference_engine_; | ||||
@@ -713,7 +713,9 @@ Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp & | |||||
device_id, | device_id, | ||||
resource.GetStream())); | resource.GetStream())); | ||||
GE_CHECK_NOTNULL(single_op.hybrid_model_executor_); | GE_CHECK_NOTNULL(single_op.hybrid_model_executor_); | ||||
GE_CHK_STATUS_RET(single_op.hybrid_model_executor_->Init(), "[Init][HybridModelExecutor]Failed."); | |||||
ThreadPool *thread_pool = nullptr; | |||||
GE_CHK_STATUS_RET_NOLOG(resource.GetThreadPool(&thread_pool)); | |||||
GE_CHK_STATUS_RET(single_op.hybrid_model_executor_->Init(thread_pool), "[Init][HybridModelExecutor]Failed."); | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
return BuildTaskListForDynamicOp(&resource, single_op); | return BuildTaskListForDynamicOp(&resource, single_op); | ||||
@@ -25,6 +25,7 @@ namespace ge { | |||||
namespace { | namespace { | ||||
// limit available device mem size 1M | // limit available device mem size 1M | ||||
const uint32_t kFuzzDeviceBufferSize = 1 * 1024 * 1024; | const uint32_t kFuzzDeviceBufferSize = 1 * 1024 * 1024; | ||||
constexpr int kDefaultThreadNum = 4; | |||||
} | } | ||||
StreamResource::StreamResource(uintptr_t resource_id) : resource_id_(resource_id) { | StreamResource::StreamResource(uintptr_t resource_id) : resource_id_(resource_id) { | ||||
@@ -219,6 +220,16 @@ Status StreamResource::BuildOperator(const ModelData &model_data, SingleOp **sin | |||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status StreamResource::GetThreadPool(ThreadPool **thread_pool) { | |||||
GE_CHECK_NOTNULL(thread_pool); | |||||
if (thread_pool_ == nullptr) { | |||||
thread_pool_.reset(new (std::nothrow) ThreadPool(kDefaultThreadNum)); | |||||
GE_CHECK_NOTNULL(thread_pool_); | |||||
} | |||||
*thread_pool = thread_pool_.get(); | |||||
return SUCCESS; | |||||
} | |||||
const uint8_t *StreamResource::GetMemoryBase() const { | const uint8_t *StreamResource::GetMemoryBase() const { | ||||
if (memory_list_.empty()) { | if (memory_list_.empty()) { | ||||
return nullptr; | return nullptr; | ||||
@@ -54,6 +54,8 @@ class StreamResource { | |||||
return device_buffer_; | return device_buffer_; | ||||
} | } | ||||
Status GetThreadPool(ThreadPool **thread_pool); | |||||
private: | private: | ||||
uint8_t *DoMallocMemory(const std::string &purpose, | uint8_t *DoMallocMemory(const std::string &purpose, | ||||
size_t size, | size_t size, | ||||
@@ -66,6 +68,7 @@ class StreamResource { | |||||
std::vector<uint8_t *> weight_list_; | std::vector<uint8_t *> weight_list_; | ||||
std::unordered_map<uint64_t, std::unique_ptr<SingleOp>> op_map_; | std::unordered_map<uint64_t, std::unique_ptr<SingleOp>> op_map_; | ||||
std::unordered_map<uint64_t, std::unique_ptr<DynamicSingleOp>> dynamic_op_map_; | std::unordered_map<uint64_t, std::unique_ptr<DynamicSingleOp>> dynamic_op_map_; | ||||
std::unique_ptr<ThreadPool> thread_pool_; | |||||
rtStream_t stream_ = nullptr; | rtStream_t stream_ = nullptr; | ||||
std::mutex mu_; | std::mutex mu_; | ||||
std::mutex stream_mu_; | std::mutex stream_mu_; | ||||
@@ -1 +1 @@ | |||||
Subproject commit 3ace5b6f10e0af784a1c3211fd769d6e8860e864 | |||||
Subproject commit f9a47a45cdd7e6dc507a15291fcb769f96b859b3 |
@@ -1 +1 @@ | |||||
Subproject commit db68a1a4f1a6ae69dbf9a5f338392d50ea3874e3 | |||||
Subproject commit b42a99ea6e1be75156650675fd0aeabca6cb3de9 |
@@ -66,6 +66,9 @@ TEST_F(UtestStreamResource, test_build_op) { | |||||
res.op_map_[0].reset(single_op); | res.op_map_[0].reset(single_op); | ||||
res.dynamic_op_map_[1].reset(dynamic_single_op); | res.dynamic_op_map_[1].reset(dynamic_single_op); | ||||
ThreadPool *thread_pool = nullptr; | |||||
EXPECT_EQ(res.GetThreadPool(&thread_pool), SUCCESS); | |||||
EXPECT_EQ(res.GetOperator(0), nullptr); | EXPECT_EQ(res.GetOperator(0), nullptr); | ||||
EXPECT_EQ(res.GetDynamicOperator(1), nullptr); | EXPECT_EQ(res.GetDynamicOperator(1), nullptr); | ||||
EXPECT_EQ(res.BuildOperator(model_data, &single_op, 0), SUCCESS); | EXPECT_EQ(res.BuildOperator(model_data, &single_op, 0), SUCCESS); | ||||