|
|
@@ -227,6 +227,7 @@ Status SubgraphExecutor::PrepareNodes(int group) { |
|
|
|
if (node_item.is_dynamic) { |
|
|
|
auto prepare_future = pre_run_pool_.commit([this, p_node_state]() -> Status { |
|
|
|
GetContext().SetSessionId(context_->session_id); |
|
|
|
GetContext().SetContextId(context_->context_id); |
|
|
|
GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state)); |
|
|
|
return PrepareForExecution(context_, *p_node_state); |
|
|
|
}); |
|
|
@@ -273,10 +274,8 @@ Status SubgraphExecutor::PrepareNodes(int group) { |
|
|
|
} |
|
|
|
|
|
|
|
Status SubgraphExecutor::InferShape(ShapeInferenceEngine *shape_inference_engine, NodeState &node_state) const { |
|
|
|
GetContext().SetSessionId(context_->context_id); |
|
|
|
HYBRID_CHK_STATUS_RET(shape_inference_engine->InferShape(node_state), |
|
|
|
"[%s] Failed to InferShape.", node_state.GetName().c_str()); |
|
|
|
GetContext().SetSessionId(context_->session_id); |
|
|
|
HYBRID_CHK_STATUS_RET(shape_inference_engine->PropagateOutputShapes(node_state), |
|
|
|
"[%s] Failed to PropagateOutputShapes.", node_state.GetName().c_str()); |
|
|
|
return SUCCESS; |
|
|
@@ -345,6 +344,7 @@ Status SubgraphExecutor::ScheduleTasks(int group) { |
|
|
|
GELOGD("[%s] Start to schedule prepare workers.", graph_item_->GetName().c_str()); |
|
|
|
auto prepare_future = std::async(std::launch::async, [&]() -> Status { |
|
|
|
GetContext().SetSessionId(context_->session_id); |
|
|
|
GetContext().SetContextId(context_->context_id); |
|
|
|
auto ret = PrepareNodes(group); |
|
|
|
ready_queue_.Push(nullptr); |
|
|
|
return ret; |
|
|
|