From: @zhao_zhixuan Reviewed-by: Signed-off-by:tags/v1.2.0
| @@ -473,10 +473,10 @@ Status SingleOpModel::BuildTaskListForDynamicOp(DynamicSingleOp &single_op) { | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status SingleOpModel::BuildDynamicOp(DynamicSingleOp &single_op) { | |||||
| Status SingleOpModel::BuildDynamicOp(StreamResource &resource, DynamicSingleOp &single_op) { | |||||
| single_op.num_inputs_ = data_ops_.size(); | single_op.num_inputs_ = data_ops_.size(); | ||||
| single_op.num_outputs_ = netoutput_op_->GetAllInputsSize(); | single_op.num_outputs_ = netoutput_op_->GetAllInputsSize(); | ||||
| ParseOpModelParams(model_helper_, model_params_); | |||||
| GE_CHK_STATUS_RET_NOLOG(InitModelMem(resource)); | |||||
| return BuildTaskListForDynamicOp(single_op); | return BuildTaskListForDynamicOp(single_op); | ||||
| } | } | ||||
| } // namespace ge | } // namespace ge | ||||
| @@ -52,7 +52,7 @@ class SingleOpModel { | |||||
| Status Init(); | Status Init(); | ||||
| Status BuildOp(StreamResource &resource, SingleOp &single_op); | Status BuildOp(StreamResource &resource, SingleOp &single_op); | ||||
| Status BuildDynamicOp(DynamicSingleOp &single_op); | |||||
| Status BuildDynamicOp(StreamResource &resource, DynamicSingleOp &single_op); | |||||
| private: | private: | ||||
| Status InitModel(); | Status InitModel(); | ||||
| @@ -155,7 +155,8 @@ Status StreamResource::BuildDynamicOperator(const string &model_name, | |||||
| GE_CHECK_NOTNULL(new_op); | GE_CHECK_NOTNULL(new_op); | ||||
| GELOGI("To build operator: %s", model_name.c_str()); | GELOGI("To build operator: %s", model_name.c_str()); | ||||
| GE_CHK_STATUS_RET(model.BuildDynamicOp(*new_op), "Build op failed. op = %s, ret = %u", model_name.c_str(), ret); | |||||
| GE_CHK_STATUS_RET(model.BuildDynamicOp(*this, *new_op), | |||||
| "Build op failed. op = %s, ret = %u", model_name.c_str(), ret); | |||||
| *single_op = new_op.get(); | *single_op = new_op.get(); | ||||
| dynamic_op_map_[model_data.model_data] = std::move(new_op); | dynamic_op_map_[model_data.model_data] = std::move(new_op); | ||||
| return SUCCESS; | return SUCCESS; | ||||
| @@ -66,6 +66,7 @@ Status AiCpuCCTaskBuilder::BuildTask(AiCpuCCTask &task, uint64_t kernel_id, cons | |||||
| const std::string &kernel_name = kernel_def_.kernel_name(); | const std::string &kernel_name = kernel_def_.kernel_name(); | ||||
| task.SetSoName(so_name); | task.SetSoName(so_name); | ||||
| task.SetkernelName(kernel_name); | task.SetkernelName(kernel_name); | ||||
| GE_CHECK_NOTNULL(op_desc_); | |||||
| task.op_desc_ = op_desc_; | task.op_desc_ = op_desc_; | ||||
| const auto &context = kernel_def_.context(); | const auto &context = kernel_def_.context(); | ||||
| @@ -96,6 +97,7 @@ Status AiCpuCCTaskBuilder::BuildTask(AiCpuCCTask &task, uint64_t kernel_id, cons | |||||
| GELOGE(ret, "Init ext info failed."); | GELOGE(ret, "Init ext info failed."); | ||||
| return ret; | return ret; | ||||
| } | } | ||||
| GE_CHK_STATUS_RET(task.SetInputConst(), "AiCpuCCTask set input_const failed."); | |||||
| if (task.GetUnknownType() == DEPEND_COMPUTE) { | if (task.GetUnknownType() == DEPEND_COMPUTE) { | ||||
| GELOGE(FAILED, "AiCpuCCTask unknown type is depend compute, it's not supported now."); | GELOGE(FAILED, "AiCpuCCTask unknown type is depend compute, it's not supported now."); | ||||
| @@ -88,6 +88,7 @@ namespace ge { | |||||
| return ret; | return ret; | ||||
| } | } | ||||
| GE_CHECK_NOTNULL(op_desc_); | |||||
| task.op_desc_ = op_desc_; | task.op_desc_ = op_desc_; | ||||
| task.num_inputs_ = op_desc_->GetInputsSize(); | task.num_inputs_ = op_desc_->GetInputsSize(); | ||||
| task.num_outputs_ = op_desc_->GetOutputsSize(); | task.num_outputs_ = op_desc_->GetOutputsSize(); | ||||
| @@ -104,6 +105,7 @@ namespace ge { | |||||
| fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = reinterpret_cast<uintptr_t>(task.ext_info_addr_dev_); | fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoAddr = reinterpret_cast<uintptr_t>(task.ext_info_addr_dev_); | ||||
| fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoLen = kernel_ext_info_size; | fwk_op_kernel.fwkKernelBase.fwk_kernel.extInfoLen = kernel_ext_info_size; | ||||
| } | } | ||||
| GE_CHK_STATUS_RET(task.SetInputConst(), "AiCpuTask set input_const failed."); | |||||
| GE_CHK_STATUS_RET(task.InitForSummaryAndCopy(), "AiCpuTask init for summary and copy task failed."); | GE_CHK_STATUS_RET(task.InitForSummaryAndCopy(), "AiCpuTask init for summary and copy task failed."); | ||||
| fwk_op_kernel.fwkKernelBase.fwk_kernel.sessionID = ULLONG_MAX; | fwk_op_kernel.fwkKernelBase.fwk_kernel.sessionID = ULLONG_MAX; | ||||
| @@ -369,6 +369,25 @@ Status AiCpuBaseTask::SetExtInfoAndType(const std::string &kernel_ext_info, uint | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status AiCpuBaseTask::SetInputConst() { | |||||
| input_is_const_.clear(); | |||||
| const vector<bool> v_is_input_const = op_desc_->GetIsInputConst(); | |||||
| for (size_t i = 0; i < op_desc_->GetAllInputsSize(); ++i) { | |||||
| const GeTensorDescPtr tensor_desc = op_desc_->MutableInputDesc(static_cast<uint32_t>(i)); | |||||
| if (tensor_desc == nullptr) { | |||||
| GELOGD("SingleOp: %s, Index: %zu, has no input", op_desc_->GetName().c_str(), i); | |||||
| continue; | |||||
| } | |||||
| if (i < v_is_input_const.size() && v_is_input_const[i]) { | |||||
| GELOGD("SingleOp: %s, Index: %zu, input is const", op_desc_->GetName().c_str(), i); | |||||
| input_is_const_.push_back(true); | |||||
| continue; | |||||
| } | |||||
| input_is_const_.push_back(false); | |||||
| } | |||||
| return SUCCESS; | |||||
| } | |||||
| Status AiCpuBaseTask::UpdateExtInfo(const std::vector<GeTensorDesc> &input_desc, | Status AiCpuBaseTask::UpdateExtInfo(const std::vector<GeTensorDesc> &input_desc, | ||||
| std::vector<GeTensorDesc> &output_desc, | std::vector<GeTensorDesc> &output_desc, | ||||
| rtStream_t stream) { | rtStream_t stream) { | ||||
| @@ -379,9 +398,23 @@ Status AiCpuBaseTask::UpdateExtInfo(const std::vector<GeTensorDesc> &input_desc, | |||||
| } | } | ||||
| GE_CHECK_NOTNULL(aicpu_ext_handle_); | GE_CHECK_NOTNULL(aicpu_ext_handle_); | ||||
| for (size_t i = 0; i < num_inputs_; ++i) { | |||||
| GE_CHK_STATUS_RET(aicpu_ext_handle_->UpdateInputShapeAndType(i, input_desc[i]), | |||||
| "Input[%zu] update input shape failed.", i); | |||||
| size_t non_const_index = 0; | |||||
| for (size_t input_index = 0; input_index < num_inputs_; input_index++) { | |||||
| if (input_index < input_is_const_.size() && input_is_const_[input_index]) { | |||||
| // get input_desc from op_desc if const input, num_inputs_ is op_desc_ input_size | |||||
| auto const_input_desc = op_desc_->MutableInputDesc(static_cast<uint32_t>(input_index)); | |||||
| GE_CHECK_NOTNULL(const_input_desc); | |||||
| GE_CHK_STATUS_RET(aicpu_ext_handle_->UpdateInputShapeAndType(input_index, *const_input_desc), | |||||
| "Input[%zu] update input shape failed.", input_index); | |||||
| continue; | |||||
| } | |||||
| GE_CHK_BOOL_RET_STATUS(non_const_index < input_desc.size(), PARAM_INVALID, | |||||
| "Input_desc size is %zu, but get non_const_index is %zu", | |||||
| input_desc.size(), non_const_index); | |||||
| GE_CHK_STATUS_RET(aicpu_ext_handle_->UpdateInputShapeAndType(input_index, input_desc[non_const_index]), | |||||
| "Input[%zu] update input shape failed.", input_index); | |||||
| non_const_index++; | |||||
| } | } | ||||
| if (unknown_type_ != DEPEND_COMPUTE) { | if (unknown_type_ != DEPEND_COMPUTE) { | ||||
| @@ -460,11 +493,23 @@ Status AiCpuBaseTask::UpdateIoAddr(const vector<DataBuffer> &inputs, const vecto | |||||
| GetIoAddr(arg_base, arg_num); | GetIoAddr(arg_base, arg_num); | ||||
| // input number and output number was check in ValidateParams | // input number and output number was check in ValidateParams | ||||
| for (size_t i = 0; i < inputs.size(); ++i) { | |||||
| auto addr = inputs[i].data; | |||||
| size_t non_const_index = 0; | |||||
| for (size_t input_index = 0; input_index < num_inputs_; input_index++) { | |||||
| if (input_index < input_is_const_.size() && input_is_const_[input_index]) { | |||||
| // const input no need update addr | |||||
| GE_CHECK_NOTNULL(arg_base); | |||||
| GELOGD("AICpuTask input[%zu] addr = %u", input_index, *arg_base); | |||||
| arg_base++; | |||||
| continue; | |||||
| } | |||||
| GE_CHK_BOOL_RET_STATUS(non_const_index < inputs.size(), PARAM_INVALID, | |||||
| "Input size is %zu, but get non_const_index is %zu", | |||||
| inputs.size(), non_const_index); | |||||
| auto addr = inputs[non_const_index].data; | |||||
| GE_CHECK_NOTNULL(addr); | GE_CHECK_NOTNULL(addr); | ||||
| GELOGD("AICpuTask input[%zu] addr = %p", i, addr); | |||||
| GELOGD("AICpuTask input[%zu] addr = %p", input_index, addr); | |||||
| *arg_base++ = reinterpret_cast<uintptr_t>(addr); | *arg_base++ = reinterpret_cast<uintptr_t>(addr); | ||||
| non_const_index++; | |||||
| } | } | ||||
| for (size_t i = 0; i < outputs.size(); ++i) { | for (size_t i = 0; i < outputs.size(); ++i) { | ||||
| @@ -113,6 +113,7 @@ class AiCpuBaseTask : public OpTask { | |||||
| protected: | protected: | ||||
| Status UpdateIoAddr(const std::vector<DataBuffer> &inputs, const std::vector<DataBuffer> &outputs); | Status UpdateIoAddr(const std::vector<DataBuffer> &inputs, const std::vector<DataBuffer> &outputs); | ||||
| Status SetInputConst(); | |||||
| Status SetExtInfoAndType(const std::string &kernel_ext_info, uint64_t kernel_id); | Status SetExtInfoAndType(const std::string &kernel_ext_info, uint64_t kernel_id); | ||||
| Status UpdateExtInfo(const std::vector<GeTensorDesc> &input_desc, | Status UpdateExtInfo(const std::vector<GeTensorDesc> &input_desc, | ||||
| @@ -127,6 +128,7 @@ class AiCpuBaseTask : public OpTask { | |||||
| UnknowShapeOpType unknown_type_ = DEPEND_IN_SHAPE; | UnknowShapeOpType unknown_type_ = DEPEND_IN_SHAPE; | ||||
| std::unique_ptr<ge::hybrid::AicpuExtInfoHandler> aicpu_ext_handle_; | std::unique_ptr<ge::hybrid::AicpuExtInfoHandler> aicpu_ext_handle_; | ||||
| void *ext_info_addr_dev_ = nullptr; | void *ext_info_addr_dev_ = nullptr; | ||||
| vector<bool> input_is_const_; | |||||
| }; | }; | ||||
| class AiCpuTask : public AiCpuBaseTask { | class AiCpuTask : public AiCpuBaseTask { | ||||