@@ -333,6 +333,10 @@ void NodeState::SavePersistTensor(int input_idx, const TensorValue &tensor) { | |||||
return std::any_of(items.begin(), items.end(), is_exist); | return std::any_of(items.begin(), items.end(), is_exist); | ||||
}; | }; | ||||
if (root_tensor_values_.count(input_idx) > 0) { | |||||
return; | |||||
} | |||||
if (is_persist_tensor(node_item_->root_data_, input_idx)) { | if (is_persist_tensor(node_item_->root_data_, input_idx)) { | ||||
GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | GELOGD("[%s] Save Root input tensor: %d", GetName().c_str(), input_idx); | ||||
root_tensor_values_[input_idx] = tensor; | root_tensor_values_[input_idx] = tensor; | ||||
@@ -375,6 +375,7 @@ Status ExecutionEngine::DoExecuteAsync(NodeState &node_state, | |||||
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); | RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] Start"); | ||||
GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", | GE_CHK_STATUS_RET(executor->PrepareTask(*task, task_context), "[Prepare][Task] for [%s] failed.", | ||||
node_state.GetName().c_str()); | node_state.GetName().c_str()); | ||||
node_state.UpdatePersistTensor(); | |||||
RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); | RECORD_EXECUTION_EVENT(&context, task_context.GetNodeName(), "[PrepareTask] End"); | ||||
GELOGD("[%s] Done task preparation successfully.", node_state.GetName().c_str()); | GELOGD("[%s] Done task preparation successfully.", node_state.GetName().c_str()); | ||||
@@ -39,7 +39,6 @@ const char *const kEngineNameHostCpu = "DNN_VM_HOST_CPU_OP_STORE"; | |||||
Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { | Status NodeExecutor::PrepareTask(NodeTask &task, TaskContext &context) const { | ||||
GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); | GE_CHK_STATUS_RET_NOLOG(context.AllocateOutputs()); | ||||
GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces()); | GE_CHK_STATUS_RET_NOLOG(context.AllocateWorkspaces()); | ||||
GE_CHK_STATUS_RET_NOLOG(context.UpdatePersistTensor()); | |||||
GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context)); | GE_CHK_STATUS_RET_NOLOG(task.UpdateArgs(context)); | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
@@ -460,22 +460,12 @@ Status TaskContext::PropagateOutputs() { | |||||
subgraph_context_->all_inputs_[input_offset].SetName( | subgraph_context_->all_inputs_[input_offset].SetName( | ||||
node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); | node_item_->NodeName() + "_in_" + std::to_string(dst_input_idx)); | ||||
} | } | ||||
auto dst_node_state = subgraph_context_->GetOrCreateNodeState(dst_node_item); | |||||
GE_CHECK_NOTNULL(dst_node_state); | |||||
dst_node_state->SavePersistTensor(dst_input_idx, *tensor); | |||||
} | } | ||||
} | } | ||||
(void)guard; | (void)guard; | ||||
return SUCCESS; | return SUCCESS; | ||||
} | } | ||||
Status TaskContext::UpdatePersistTensor() { | |||||
GE_CHECK_NOTNULL(node_state_); | |||||
node_state_->UpdatePersistTensor(); | |||||
return SUCCESS; | |||||
} | |||||
const void *TaskContext::GetVarBaseAddr() { | const void *TaskContext::GetVarBaseAddr() { | ||||
return execution_context_->model->GetVarMemBase(); | return execution_context_->model->GetVarMemBase(); | ||||
} | } | ||||
@@ -501,6 +491,7 @@ void TaskContext::ReleaseInputsAndOutputs() { | |||||
void TaskContext::ReleaseInput(int index) { | void TaskContext::ReleaseInput(int index) { | ||||
auto input_tensor = MutableInput(index); | auto input_tensor = MutableInput(index); | ||||
if (input_tensor != nullptr) { | if (input_tensor != nullptr) { | ||||
node_state_->SavePersistTensor(index, *input_tensor); | |||||
input_tensor->Destroy(); | input_tensor->Destroy(); | ||||
GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); | GELOGD("[%s] Tensor of input[%d] released", GetNodeName(), index); | ||||
} | } | ||||
@@ -78,7 +78,6 @@ class TaskContext { | |||||
Status AllocateOutputs(AllocationAttr *attr = nullptr); | Status AllocateOutputs(AllocationAttr *attr = nullptr); | ||||
Status AllocateWorkspaces(); | Status AllocateWorkspaces(); | ||||
Status AllocateWorkspace(size_t size, void **buffer, void *ori_addr = nullptr); | Status AllocateWorkspace(size_t size, void **buffer, void *ori_addr = nullptr); | ||||
Status UpdatePersistTensor(); | |||||
bool IsTraceEnabled() const; | bool IsTraceEnabled() const; | ||||