|
|
@@ -211,31 +211,26 @@ Status SubgraphExecutor::PrepareNodes() { |
|
|
|
GE_CHECK_NOTNULL(node_state); |
|
|
|
auto p_node_state = node_state.get(); |
|
|
|
|
|
|
|
if (node_item.node_type == NETOUTPUT) { |
|
|
|
// Wait for all inputs become valid |
|
|
|
// after PrepareNodes returned. all output tensors and shapes are valid |
|
|
|
GE_CHK_STATUS_RET_NOLOG(p_node_state->GetShapeInferenceState().AwaitShapesReady(*context_)); |
|
|
|
GE_CHK_STATUS_RET_NOLOG(p_node_state->AwaitInputTensors(*context_)); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
// only do shape inference and compilation for nodes with dynamic shapes. |
|
|
|
if (node_item.is_dynamic) { |
|
|
|
auto prepare_future = pre_run_pool_.commit([this, p_node_state]() -> Status { |
|
|
|
GetContext().SetSessionId(context_->session_id); |
|
|
|
GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state)); |
|
|
|
return PrepareForExecution(context_, *p_node_state); |
|
|
|
}); |
|
|
|
|
|
|
|
p_node_state->SetPrepareFuture(std::move(prepare_future)); |
|
|
|
} else { |
|
|
|
GELOGD("[%s] Skipping shape inference and compilation for node with static shape.", node_item.NodeName().c_str()); |
|
|
|
if (node_item.kernel_task == nullptr) { |
|
|
|
GELOGW("[%s] Node of static shape got no task.", node_item.NodeName().c_str()); |
|
|
|
GE_CHK_STATUS_RET(TaskCompileEngine::Compile(*p_node_state, context_), |
|
|
|
"[%s] Failed to create task.", p_node_state->GetName().c_str()); |
|
|
|
if (node_item.node_type != NETOUTPUT) { |
|
|
|
// only do shape inference and compilation for nodes with dynamic shapes. |
|
|
|
if (node_item.is_dynamic) { |
|
|
|
auto prepare_future = pre_run_pool_.commit([this, p_node_state]() -> Status { |
|
|
|
GetContext().SetSessionId(context_->session_id); |
|
|
|
GE_CHK_STATUS_RET_NOLOG(InferShape(shape_inference_engine_.get(), *p_node_state)); |
|
|
|
return PrepareForExecution(context_, *p_node_state); |
|
|
|
}); |
|
|
|
|
|
|
|
p_node_state->SetPrepareFuture(std::move(prepare_future)); |
|
|
|
} else { |
|
|
|
node_state->SetKernelTask(node_item.kernel_task); |
|
|
|
GELOGD("[%s] Skipping shape inference and compilation for node with static shape.", |
|
|
|
node_item.NodeName().c_str()); |
|
|
|
if (node_item.kernel_task == nullptr) { |
|
|
|
GELOGW("[%s] Node of static shape got no task.", node_item.NodeName().c_str()); |
|
|
|
GE_CHK_STATUS_RET(TaskCompileEngine::Compile(*p_node_state, context_), |
|
|
|
"[%s] Failed to create task.", p_node_state->GetName().c_str()); |
|
|
|
} else { |
|
|
|
node_state->SetKernelTask(node_item.kernel_task); |
|
|
|
} |
|
|
|
} |
|
|
|
} |
|
|
|
|
|
|
@@ -288,6 +283,15 @@ Status SubgraphExecutor::LaunchTasks() { |
|
|
|
return SUCCESS; |
|
|
|
} |
|
|
|
|
|
|
|
if (node_state->GetType() == NETOUTPUT) { |
|
|
|
// Wait for all inputs become valid |
|
|
|
// after PrepareNodes returned. all output tensors and shapes are valid |
|
|
|
GE_CHK_STATUS_RET_NOLOG(node_state->GetShapeInferenceState().AwaitShapesReady(*context_)); |
|
|
|
GE_CHK_STATUS_RET_NOLOG(node_state->AwaitInputTensors(*context_)); |
|
|
|
GELOGD("[%s] Done executing node successfully.", node_state->GetName().c_str()); |
|
|
|
continue; |
|
|
|
} |
|
|
|
|
|
|
|
GE_CHK_STATUS_RET_NOLOG(node_state->WaitForPrepareDone()); |
|
|
|
|
|
|
|
GELOGD("[%s] Start to execute.", node_state->GetName().c_str()); |
|
|
|