diff --git a/ge/graph/load/new_model_manager/model_utils.cc b/ge/graph/load/new_model_manager/model_utils.cc index efcd2e48..0884ba8b 100755 --- a/ge/graph/load/new_model_manager/model_utils.cc +++ b/ge/graph/load/new_model_manager/model_utils.cc @@ -479,13 +479,15 @@ vector ModelUtils::GetWorkspaceDataAddrs(const RuntimeParam &model_param ge::AttrUtils::GetListInt(op_desc, ATTR_NAME_WORKSPACE_TYPE_LIST, workspace_memory_type); for (size_t i = 0; i < v_workspace_bytes.size(); ++i) { // Temporary solution, the aicpu workspace of multiple images cannot be shared. - if (has_workspace_reuse && i < workspace_reuse_flag.size() && !workspace_reuse_flag[i]) { + if (has_workspace_reuse && i < workspace_reuse_flag.size() + && !workspace_reuse_flag[i] && !model_param.is_single_op) { void *mem_addr = model_param.aicpu_mem_mall->Acquire(v_workspace_offset[i], v_workspace_bytes[i]); v_workspace_data_addr.push_back(mem_addr); GELOGI( "[IMAS]GetWorkspaceDataAddrs graph_%u type[F] name[%s] aicpu workspace[%zu] offset[%ld] bytes[%ld] " "memaddr[%p]", model_param.graph_id, op_desc->GetName().c_str(), i, v_workspace_offset[i], v_workspace_bytes[i], mem_addr); + continue; } else if (has_mem_type_workspace && workspace_memory_type[i] == RT_MEMORY_P2P_DDR) { int64_t p2p_workspace_offset = v_workspace_offset[i]; int64_t p2p_workspace_bytes = v_workspace_bytes[i]; diff --git a/ge/graph/load/new_model_manager/task_info/task_info.h b/ge/graph/load/new_model_manager/task_info/task_info.h index a50b0360..d296d29e 100644 --- a/ge/graph/load/new_model_manager/task_info/task_info.h +++ b/ge/graph/load/new_model_manager/task_info/task_info.h @@ -56,6 +56,7 @@ struct RuntimeParam { uint32_t label_num = 0; uint64_t session_id = 0; uint32_t graph_id = 0; + bool is_single_op = false; std::unique_ptr ts_mem_mall; std::unique_ptr aicpu_mem_mall; diff --git a/ge/single_op/task/build_task_utils.cc b/ge/single_op/task/build_task_utils.cc index 28177dc7..29f1657b 100644 --- a/ge/single_op/task/build_task_utils.cc +++ b/ge/single_op/task/build_task_utils.cc @@ -45,6 +45,7 @@ std::vector> BuildTaskUtils::GetAddresses(const OpDescPtr &o runtime_para.logic_var_base = kLogicVarBase; runtime_para.var_base = kVarBase; runtime_para.session_id = kSessionId; + runtime_para.is_single_op = true; ret.emplace_back(ModelUtils::GetInputDataAddrs(runtime_para, op_desc)); ret.emplace_back(ModelUtils::GetOutputDataAddrs(runtime_para, op_desc));