| @@ -48,7 +48,7 @@ const uint32_t kInputIndexOfData = 0; | |||
| const uint32_t kOutputIndexOfData = 0; | |||
| constexpr char const *kAttrSupportDynamicShape = "support_dynamicshape"; | |||
| Status CheckHostMem(const std::vector<string> &dependencies, const NodePtr &node, bool &flag) { | |||
| Status CheckHostMem(const std::vector<string> &dependencies, const NodePtr &node, bool &is_host_mem) { | |||
| for (const auto &input_name : dependencies) { | |||
| auto op_desc = node->GetOpDesc(); | |||
| int input_index = op_desc->GetInputIndexByName(input_name); | |||
| @@ -75,14 +75,14 @@ Status CheckHostMem(const std::vector<string> &dependencies, const NodePtr &node | |||
| continue; | |||
| } | |||
| } | |||
| flag = false; | |||
| is_host_mem = false; | |||
| return SUCCESS; | |||
| } | |||
| flag = true; | |||
| is_host_mem = true; | |||
| return SUCCESS; | |||
| } | |||
| Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { | |||
| Status CheckInferDepend(GeModelPtr &ge_model, bool &is_infer_depend, bool &is_host_mem) { | |||
| auto comp_graph = GraphUtils::GetComputeGraph(ge_model->GetGraph()); | |||
| GE_CHECK_NOTNULL(comp_graph); | |||
| for (const auto &node : comp_graph->GetAllNodes()) { | |||
| @@ -93,16 +93,18 @@ Status IfInferDepend(GeModelPtr &ge_model, bool &flag) { | |||
| bool support_dynamic_shape = false; | |||
| (void)AttrUtils::GetBool(op_desc, kAttrSupportDynamicShape, support_dynamic_shape); | |||
| if (!depends.empty() && support_dynamic_shape) { | |||
| CheckHostMem(depends, node, flag); | |||
| return SUCCESS; | |||
| is_infer_depend = true; | |||
| return CheckHostMem(depends, node, is_host_mem); | |||
| } | |||
| } | |||
| return SUCCESS; | |||
| } | |||
| Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { | |||
| bool infer_depend_flag = false; | |||
| GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); | |||
| bool is_infer_depend = false; | |||
| bool is_host_mem = false; | |||
| GE_CHK_STATUS_RET(CheckInferDepend(ge_model, is_infer_depend, is_host_mem), "[Check][InferDepend] failed."); | |||
| bool need_d2h_cpy = is_infer_depend && !is_host_mem; | |||
| auto tasks = ge_model->GetModelTaskDefPtr()->task(); | |||
| int32_t kernel_task_num = 0; | |||
| for (int i = 0; i < tasks.size(); ++i) { | |||
| @@ -112,7 +114,7 @@ Status NeedHybridModel(GeModelPtr &ge_model, bool &flag) { | |||
| tasks[i].kernel_with_handle().context(); | |||
| auto kernel_type = static_cast<ccKernelType>(context.kernel_type()); | |||
| if (kernel_type == ccKernelType::TE) { | |||
| if (infer_depend_flag) { | |||
| if (need_d2h_cpy) { | |||
| flag = true; | |||
| return SUCCESS; | |||
| } | |||
| @@ -553,7 +555,8 @@ Status SingleOpModel::BuildOp(StreamResource &resource, SingleOp &single_op) { | |||
| auto ge_model = model_helper_.GetGeModel(); | |||
| GE_CHECK_NOTNULL(ge_model); | |||
| bool infer_depend_flag = false; | |||
| GE_CHK_STATUS_RET(IfInferDepend(ge_model, infer_depend_flag), "[Check][InferDepend] failed."); | |||
| bool is_host_mem = false; | |||
| GE_CHK_STATUS_RET(CheckInferDepend(ge_model, infer_depend_flag, is_host_mem)), "[Check][InferDepend] failed."); | |||
| if (infer_depend_flag) { | |||
| // construct single_op, do single op with HybridModelExecutor | |||
| GELOGD("Init hybrid model params of single op, and will do execute with hybrid model executor."); | |||