| @@ -370,9 +370,6 @@ Status AiCoreOpTask::UpdateTilingInfo(TaskContext &context) { | |||||
| // update op args by tiling info | // update op args by tiling info | ||||
| block_dim_ = tiling_info.GetBlockDim(); | block_dim_ = tiling_info.GetBlockDim(); | ||||
| clear_atomic_ = tiling_info.GetClearAtomic(); | clear_atomic_ = tiling_info.GetClearAtomic(); | ||||
| std::vector<int64_t> workspaces; | |||||
| tiling_info.GetAllWorkspaces(workspaces); | |||||
| op_desc->SetWorkspaceBytes(workspaces); | |||||
| tiling_data_ = tiling_info.GetAllTilingData().str(); | tiling_data_ = tiling_info.GetAllTilingData().str(); | ||||
| tiling_key_ = tiling_info.GetTilingKey(); | tiling_key_ = tiling_info.GetTilingKey(); | ||||
| @@ -415,6 +412,11 @@ Status AiCoreOpTask::CalcTilingInfo(const NodePtr &node, OpRunInfo &tiling_info) | |||||
| GE_CHK_STATUS_RET(optiling::OpParaCalculateV2(*node, tiling_info), | GE_CHK_STATUS_RET(optiling::OpParaCalculateV2(*node, tiling_info), | ||||
| "[Invoke][OpParaCalculate]Failed calc tiling data of node %s.", | "[Invoke][OpParaCalculate]Failed calc tiling data of node %s.", | ||||
| node->GetName().c_str()); | node->GetName().c_str()); | ||||
| // Only non atomic task need update workspace | |||||
| auto op_desc = node->GetOpDesc(); | |||||
| std::vector<int64_t> workspaces; | |||||
| tiling_info.GetAllWorkspaces(workspaces); | |||||
| op_desc->SetWorkspaceBytes(workspaces); | |||||
| GELOGD("[%s] Done invoking OpParaCalculate successfully.", node->GetName().c_str()); | GELOGD("[%s] Done invoking OpParaCalculate successfully.", node->GetName().c_str()); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||