| @@ -48,7 +48,7 @@ const uint32_t kInputIndexOfData = 0; | |||||
| const uint32_t kOutputIndexOfData = 0; | const uint32_t kOutputIndexOfData = 0; | ||||
| constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; | constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; | ||||
| Status CheckHostMem(const std::vector<string> &dependencies, const NodePtr &node, bool &flag) { | |||||
| Status CheckHostMem(const std::vector<string> &dependencies, const NodePtr &node, bool &is_host_mem) { | |||||
| for (const auto &input_name : dependencies) { | for (const auto &input_name : dependencies) { | ||||
| auto op_desc = node->GetOpDesc(); | auto op_desc = node->GetOpDesc(); | ||||
| int input_index = op_desc->GetInputIndexByName(input_name); | int input_index = op_desc->GetInputIndexByName(input_name); | ||||
| @@ -75,14 +75,14 @@ Status CheckHostMem(const std::vector<string> &dependencies, const NodePtr &node | |||||
| continue; | continue; | ||||
| } | } | ||||
| } | } | ||||
| flag = false; | |||||
| is_host_mem = false; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| flag = true; | |||||
| is_host_mem = true; | |||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { | |||||
| Status CheckInferDepend(GeModelPtr &ge_model, bool &is_infer_depend, bool &is_host_mem) { | |||||
| auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | ||||
| GE_CHECK_NOTNULL(comp_graph); | GE_CHECK_NOTNULL(comp_graph); | ||||
| for (const auto &node : comp_graph->GetAllNodes()) { | for (const auto &node : comp_graph->GetAllNodes()) { | ||||
| @@ -93,16 +93,18 @@ Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { | |||||
| bool support_dynamic_shape = false; | bool support_dynamic_shape = false; | ||||
| (void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic_shape); | (void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic_shape); | ||||
| if (!depends.empty() && support_dynamic_shape) { | if (!depends.empty() && support_dynamic_shape) { | ||||
| CheckHostMem(depends, node, flag); | |||||
| return SUCCESS; | |||||
| is_infer_depend = true; | |||||
| return CheckHostMem(depends, node, is_host_mem); | |||||
| } | } | ||||
| } | } | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { | Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { | ||||
| bool infer_depend_flag = false; | |||||
| GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); | |||||
| bool is_infer_depend = false; | |||||
| bool is_host_mem = false; | |||||
| GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed."); | |||||
| bool need_d2h_cpy = is_infer_depend && !is_host_mem; | |||||
| auto tasks = ge_model->GetModelTaskDefPtr()->task(); | auto tasks = ge_model->GetModelTaskDefPtr()->task(); | ||||
| int32_t kernel_task_num = 0; | int32_t kernel_task_num = 0; | ||||
| for (int i = 0; i < tasks.size(); ++i) { | for (int i = 0; i < tasks.size(); ++i) { | ||||
| @@ -112,7 +114,7 @@ Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { | |||||
| tasks[i].kernel_with_handle().context(); | tasks[i].kernel_with_handle().context(); | ||||
| auto kernel_type = static_cast<ccKernelType>(context.kernel_type()); | auto kernel_type = static_cast<ccKernelType>(context.kernel_type()); | ||||
| if (kernel_type == ccKernelType::TE) { | if (kernel_type == ccKernelType::TE) { | ||||
| if (infer_depend_flag) { | |||||
| if (need_d2h_cpy) { | |||||
| flag = true; | flag = true; | ||||
| return SUCCESS; | return SUCCESS; | ||||
| } | } | ||||
| @@ -553,7 +555,8 @@ Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { | |||||
| auto ge_model = model_helper_.GetGeModel(); | auto ge_model = model_helper_.GetGeModel(); | ||||
| GE_CHECK_NOTNULL(ge_model); | GE_CHECK_NOTNULL(ge_model); | ||||
| bool infer_depend_flag = false; | bool infer_depend_flag = false; | ||||
| GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); | |||||
| bool is_host_mem = false; | |||||
| GE_CHK_STATUS_RET(CheckInferDepend(ge_model, infer_depend_flag, is_host_mem)), "[Check][InferDepend] failed."); | |||||
| if (infer_depend_flag) { | if (infer_depend_flag) { | ||||
| // construct single_op, do single op with HybridModelExecutor | // construct single_op, do single op with HybridModelExecutor | ||||
| GELOGD("Init hybrid model params of single op, and will do execute with hybrid model executor."); | GELOGD("Init hybrid model params of single op, and will do execute with hybrid model executor."); | ||||