diff --git a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc index b178b906..71a60f2f 100644 --- a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc +++ b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.cc @@ -67,6 +67,9 @@ Status AicpuExtInfoHandler::Parse(const std::string &ext_info) { case aicpu::FWKAdapter::FWK_ADPT_EXT_BITMAP: GE_CHK_STATUS_RET(ParseExtBitMap(aicpu_ext_info), "Parse ext bit map failed."); break; + case aicpu::FWKAdapter::FWK_ADPT_EXT_UPDATE_ADDR: + GE_CHK_STATUS_RET(ParseExtUpdateAddr(aicpu_ext_info), "Parse ext update_addr failed."); + break; default: GELOGD("Node[%s] ignore infoType=%d, infoLen=%u.", node_name_.c_str(), aicpu_ext_info->infoType, aicpu_ext_info->infoLen); @@ -153,6 +156,16 @@ Status AicpuExtInfoHandler::ParseExtBitMap(AicpuExtInfo *aicpu_ext_info) { return SUCCESS; } +Status AicpuExtInfoHandler::ParseExtUpdateAddr(AicpuExtInfo *aicpu_ext_info) { + GE_CHK_BOOL_RET_STATUS(aicpu_ext_info->infoLen == sizeof(uint32_t), PARAM_INVALID, + "Node[%s] parse update_addr info failed as infoLen must be %zu but %u.", + node_name_.c_str(), sizeof(uint32_t), aicpu_ext_info->infoLen); + + update_addr_ = reinterpret_cast(aicpu_ext_info->infoMsg); + GELOGI("Node[%s] update_addr info success infoLen=%u.", node_name_.c_str(), aicpu_ext_info->infoLen); + return SUCCESS; +} + Status AicpuExtInfoHandler::UpdateExecuteMode(bool flag) { if (bit_map_ == nullptr) { GELOGD("There is no bit_map in ext_info, no need update."); @@ -233,6 +246,10 @@ Status AicpuExtInfoHandler::GetOutputShapeAndType(uint32_t output_index, GeShape return SUCCESS; } +bool AicpuExtInfoHandler::IsNeedRefreshIOAddr() { + return update_addr_ != nullptr && *update_addr_ != static_cast(aicpu::FWKAdapter::FWK_ADPT_UPDATE_NULL); +} + Status AicpuExtInfoHandler::UpdateShapeAndType(const GeShape &shape, DataType data_type, AicpuShapeAndType *shape_and_type) { auto dim_num = shape.GetDimNum(); diff --git a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h index e5b94452..01092204 100644 --- a/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h +++ b/ge/hybrid/node_executor/aicpu/aicpu_ext_info.h @@ -61,6 +61,8 @@ class AicpuExtInfoHandler { Status GetOutputShapeAndType(uint32_t output_index, GeShape &shape, DataType &data_type); + bool IsNeedRefreshIOAddr(); + private: Status ParseExtShapeType(AicpuExtInfo *aicpu_ext_info); @@ -68,6 +70,7 @@ class AicpuExtInfoHandler { Status ParseExtOutputShape(AicpuExtInfo *aicpu_ext_info); Status ParseExtSessionInfo(AicpuExtInfo *aicpu_ext_info); Status ParseExtBitMap(AicpuExtInfo *aicpu_ext_info); + Status ParseExtUpdateAddr(AicpuExtInfo *aicpu_ext_info); static Status UpdateShapeAndType(const GeShape &shape, DataType data_type, @@ -84,6 +87,7 @@ class AicpuExtInfoHandler { UnknowShapeOpType unknown_type_; AicpuSessionInfo *session_info_ = nullptr; uint64_t *bit_map_ = nullptr; + uint32_t *update_addr_ = nullptr; std::unique_ptr ext_info_; size_t ext_info_len_ = 0; diff --git a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc index 1f77bab8..55b41120 100755 --- a/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc +++ b/ge/hybrid/node_executor/aicpu/aicpu_node_executor.cc @@ -61,7 +61,9 @@ Status AicpuNodeTaskBase::InitExtInfo(const std::string &kernel_ext_info, int64_ GELOGD("To update aicpu_task ext_info session_info session_id to %lu", session_id); GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateSessionInfoSessionId(session_id), "UpdateSessionInfoSessionId failed."); - GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateExecuteMode(!node_item_->is_dynamic), "UpdateExecuteMode failed."); + + bool execute_mode = !aicpu_ext_handle_.IsNeedRefreshIOAddr() && !node_item_->is_dynamic; + GE_CHK_STATUS_RET(aicpu_ext_handle_.UpdateExecuteMode(execute_mode), "UpdateExecuteMode failed."); // copy task args buf GE_CHK_STATUS_RET(AllocTensorBuffer(aicpu_ext_handle_.GetExtInfoLen(), ext_info_addr_dev_), diff --git a/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc b/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc index 68d029a8..44d4d042 100644 --- a/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc +++ b/tests/ut/ge/graph/load/kernel_ex_task_info_unittest.cc @@ -146,4 +146,12 @@ TEST_F(UtestKernelExTaskInfo, kernel_ex_task_ext_info) { EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); } +TEST_F(UtestKernelExTaskInfo, parse_update_addr) { + const string ext_info = {3,0,0,0,4,0,0,0,4,0,0,0}; + const OpDescPtr op_desc = CreateOpDesc("FrameworkOp", "FrameworkOp"); + AttrUtils::SetBool(op_desc, "_AllShape", true); + + KernelExTaskInfo kernel_ex_task_info; + EXPECT_EQ(kernel_ex_task_info.InitTaskExtInfo(ext_info, op_desc), SUCCESS); +} } // namespace ge